mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Update to MiniCPM-o 2.6
This commit is contained in:
@@ -14,7 +14,7 @@ from accelerate.utils import DistributedType
|
||||
from deepspeed import zero
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers.integrations import deepspeed
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
@@ -53,8 +53,6 @@ class TrainingArguments(transformers.TrainingArguments):
|
||||
llm_type: str = field(default="minicpm")
|
||||
use_lora: Optional[bool] = field(default=False)
|
||||
max_slice_nums: Optional[int] = field(default=9)
|
||||
video_max_slice_nums: Optional[int] = field(default=2)
|
||||
max_num_frames: Optional[int] = field(default=1)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -94,8 +92,6 @@ def make_supervised_data_module(
|
||||
query_nums=64,
|
||||
batch_vision=False,
|
||||
max_length=2048,
|
||||
video_max_slice_nums=2,
|
||||
max_num_frames=1,
|
||||
) -> Dict:
|
||||
"""Make dataset and collator for supervised fine-tuning."""
|
||||
dataset_cls = SupervisedDataset
|
||||
@@ -113,8 +109,6 @@ def make_supervised_data_module(
|
||||
query_nums=query_nums,
|
||||
batch_vision=batch_vision,
|
||||
max_length=max_length,
|
||||
video_max_slice_nums=video_max_slice_nums,
|
||||
max_num_frames=max_num_frames,
|
||||
)
|
||||
|
||||
if data_args.eval_data_path:
|
||||
@@ -129,8 +123,6 @@ def make_supervised_data_module(
|
||||
query_nums=query_nums,
|
||||
batch_vision=batch_vision,
|
||||
max_length=max_length,
|
||||
video_max_slice_nums=video_max_slice_nums,
|
||||
max_num_frames=max_num_frames,
|
||||
)
|
||||
else:
|
||||
eval_dataset = None
|
||||
@@ -210,10 +202,10 @@ def train():
|
||||
trust_remote_code=True,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map=device_map,
|
||||
init_vision=True,
|
||||
init_audio=False,
|
||||
init_tts=False,
|
||||
)
|
||||
model.__class__.register_for_auto_class()
|
||||
|
||||
model.processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=True
|
||||
@@ -287,8 +279,6 @@ def train():
|
||||
query_nums=model.config.query_num,
|
||||
batch_vision=batch_vision,
|
||||
max_length=training_args.model_max_length,
|
||||
video_max_slice_nums=training_args.video_max_slice_nums,
|
||||
max_num_frames=training_args.max_num_frames,
|
||||
)
|
||||
|
||||
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
|
||||
|
||||
Reference in New Issue
Block a user