support video sft and auto save and load all files

This commit is contained in:
fzc8578
2025-01-11 13:50:36 +08:00
parent 8464c94a7b
commit c5e82b1bc7
4 changed files with 170 additions and 22 deletions

View File

@@ -14,7 +14,7 @@ from accelerate.utils import DistributedType
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, AutoTokenizer, AutoProcessor
from transformers.integrations import deepspeed
from transformers import AutoModel, AutoTokenizer
@@ -53,6 +53,8 @@ class TrainingArguments(transformers.TrainingArguments):
llm_type: str = field(default="minicpm")
use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9)
video_max_slice_nums: Optional[int] = field(default=2)
max_num_frames: Optional[int] = field(default=1)
@dataclass
@@ -92,6 +94,8 @@ def make_supervised_data_module(
query_nums=64,
batch_vision=False,
max_length=2048,
video_max_slice_nums=2,
max_num_frames=1,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = SupervisedDataset
@@ -109,6 +113,8 @@ def make_supervised_data_module(
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
video_max_slice_nums=video_max_slice_nums,
max_num_frames=max_num_frames,
)
if data_args.eval_data_path:
@@ -123,6 +129,8 @@ def make_supervised_data_module(
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
video_max_slice_nums=video_max_slice_nums,
max_num_frames=max_num_frames,
)
else:
eval_dataset = None
@@ -203,6 +211,9 @@ def train():
torch_dtype=compute_dtype,
device_map=device_map,
)
model.__class__.register_for_auto_class()
model.processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
@@ -276,6 +287,8 @@ def train():
query_nums=model.config.query_num,
batch_vision=batch_vision,
max_length=training_args.model_max_length,
video_max_slice_nums=training_args.video_max_slice_nums,
max_num_frames=training_args.max_num_frames,
)
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}