mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 09:49:20 +08:00
209 lines
5.7 KiB
Python
209 lines
5.7 KiB
Python
import glob
|
|
import json
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional
|
|
|
|
import torch
|
|
import transformers
|
|
from accelerate.utils import DistributedType
|
|
from deepspeed import zero
|
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
from dataset import SupervisedDataset, data_collator
|
|
from trainer import CPMTrainer
|
|
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-V-2")
|
|
|
|
|
|
@dataclass
|
|
class DataArguments:
|
|
data_path: str = field(
|
|
default=None, metadata={"help": "Path to the training data."}
|
|
)
|
|
eval_data_path: str = field(
|
|
default=None, metadata={"help": "Path to the evaluation data."}
|
|
)
|
|
lazy_preprocess: bool = False
|
|
|
|
|
|
@dataclass
|
|
class TrainingArguments(transformers.TrainingArguments):
|
|
cache_dir: Optional[str] = field(default=None)
|
|
optim: str = field(default="adamw_torch")
|
|
model_max_length: int = field(
|
|
default=2048,
|
|
metadata={
|
|
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
|
},
|
|
)
|
|
tune_vision: Optional[bool] = field(default=True)
|
|
tune_llm: Optional[bool] = field(default=True)
|
|
|
|
|
|
def rank0_print(*args):
|
|
if local_rank == 0:
|
|
print(*args)
|
|
|
|
|
|
def make_supervised_data_module(
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
data_args,
|
|
transform,
|
|
data_collator=None,
|
|
llm_type="minicpm",
|
|
slice_config=None,
|
|
patch_size=14,
|
|
query_nums=64,
|
|
batch_vision=False,
|
|
) -> Dict:
|
|
"""Make dataset and collator for supervised fine-tuning."""
|
|
dataset_cls = SupervisedDataset
|
|
|
|
rank0_print("Loading data...")
|
|
|
|
train_json = json.load(open(data_args.data_path, "r"))
|
|
train_dataset = dataset_cls(
|
|
train_json,
|
|
transform,
|
|
tokenizer,
|
|
slice_config=slice_config,
|
|
llm_type=llm_type,
|
|
patch_size=patch_size,
|
|
query_nums=query_nums,
|
|
batch_vision=batch_vision,
|
|
)
|
|
|
|
if data_args.eval_data_path:
|
|
eval_json = json.load(open(data_args.eval_data_path, "r"))
|
|
eval_dataset = dataset_cls(
|
|
eval_json,
|
|
transform,
|
|
tokenizer,
|
|
slice_config=slice_config,
|
|
llm_type=llm_type,
|
|
patch_size=patch_size,
|
|
query_nums=query_nums,
|
|
batch_vision=batch_vision,
|
|
)
|
|
else:
|
|
eval_dataset = None
|
|
|
|
return dict(
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
data_collator=data_collator,
|
|
)
|
|
|
|
|
|
def get_parameter_number(model):
|
|
trainable_params, all_param = 0, 0
|
|
for param in model.parameters():
|
|
num_params = param.numel()
|
|
# if using DS Zero 3 and the weights are initialized empty
|
|
if num_params == 0 and hasattr(param, "ds_numel"):
|
|
num_params = param.ds_numel
|
|
|
|
all_param += num_params
|
|
if param.requires_grad:
|
|
trainable_params += num_params
|
|
|
|
return {'Total': all_param, 'Trainable': trainable_params}
|
|
|
|
|
|
local_rank = 0
|
|
|
|
|
|
def train():
|
|
global local_rank
|
|
|
|
parser = transformers.HfArgumentParser(
|
|
(ModelArguments, DataArguments, TrainingArguments)
|
|
)
|
|
|
|
(
|
|
model_args,
|
|
data_args,
|
|
training_args,
|
|
) = parser.parse_args_into_dataclasses()
|
|
|
|
if getattr(training_args, "deepspeed", None):
|
|
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
|
|
|
|
compute_dtype = (
|
|
torch.float16
|
|
if training_args.fp16
|
|
else (torch.bfloat16 if training_args.bf16 else torch.float32)
|
|
)
|
|
|
|
local_rank = training_args.local_rank
|
|
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
ddp = world_size != 1
|
|
|
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
|
|
|
model = AutoModel.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=True,
|
|
torch_dtype=compute_dtype,
|
|
device_map=device_map,
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_args.model_name_or_path, trust_remote_code=True
|
|
)
|
|
|
|
if not training_args.tune_vision:
|
|
model.vpm.requires_grad_(False)
|
|
if not training_args.tune_llm:
|
|
model.llm.requires_grad_(False)
|
|
rank0_print(get_parameter_number(model))
|
|
|
|
llm_type = "minicpm"
|
|
if "llama3" in model.name_or_path.lower():
|
|
tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
|
llm_type = "llama3"
|
|
|
|
# Load data
|
|
if hasattr(model.config, "slice_config"):
|
|
slice_config = model.config.slice_config.to_dict()
|
|
else:
|
|
slice_config = model.config.to_dict()
|
|
if hasattr(model.config, "batch_vision_input"):
|
|
batch_vision = model.config.batch_vision_input
|
|
else:
|
|
batch_vision = False
|
|
|
|
data_module = make_supervised_data_module(
|
|
tokenizer=tokenizer,
|
|
data_args=data_args,
|
|
transform=model.transform,
|
|
data_collator=data_collator,
|
|
slice_config=slice_config,
|
|
llm_type=llm_type,
|
|
patch_size=model.config.patch_size,
|
|
query_nums=model.config.query_num,
|
|
batch_vision=batch_vision,
|
|
)
|
|
|
|
trainer = CPMTrainer(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
args=training_args,
|
|
**data_module,
|
|
)
|
|
|
|
trainer.train()
|
|
trainer.save_state()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train()
|