mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Update to MiniCPM-Llama3-V 2.5
This commit is contained in:
@@ -1,22 +1,22 @@
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import transformers
|
||||
from trainer import CPMTrainer
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
from deepspeed import zero
|
||||
|
||||
from dataset import data_collator, SupervisedDataset
|
||||
|
||||
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from accelerate.utils import DistributedType
|
||||
from deepspeed import zero
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from dataset import SupervisedDataset, data_collator
|
||||
from trainer import CPMTrainer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@@ -44,6 +44,8 @@ class TrainingArguments(transformers.TrainingArguments):
|
||||
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
||||
},
|
||||
)
|
||||
tune_vision: Optional[bool] = field(default=True)
|
||||
tune_llm: Optional[bool] = field(default=True)
|
||||
|
||||
|
||||
def rank0_print(*args):
|
||||
@@ -52,7 +54,15 @@ def rank0_print(*args):
|
||||
|
||||
|
||||
def make_supervised_data_module(
|
||||
tokenizer: transformers.PreTrainedTokenizer, data_args, transform, data_collator=None, slice_config=None,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
data_args,
|
||||
transform,
|
||||
data_collator=None,
|
||||
llm_type="minicpm",
|
||||
slice_config=None,
|
||||
patch_size=14,
|
||||
query_nums=64,
|
||||
batch_vision=False,
|
||||
) -> Dict:
|
||||
"""Make dataset and collator for supervised fine-tuning."""
|
||||
dataset_cls = SupervisedDataset
|
||||
@@ -60,19 +70,57 @@ def make_supervised_data_module(
|
||||
rank0_print("Loading data...")
|
||||
|
||||
train_json = json.load(open(data_args.data_path, "r"))
|
||||
train_dataset = dataset_cls(train_json, transform, tokenizer, slice_config=slice_config)
|
||||
train_dataset = dataset_cls(
|
||||
train_json,
|
||||
transform,
|
||||
tokenizer,
|
||||
slice_config=slice_config,
|
||||
llm_type=llm_type,
|
||||
patch_size=patch_size,
|
||||
query_nums=query_nums,
|
||||
batch_vision=batch_vision,
|
||||
)
|
||||
|
||||
if data_args.eval_data_path:
|
||||
eval_json = json.load(open(data_args.eval_data_path, "r"))
|
||||
eval_dataset = dataset_cls(eval_json, transform, tokenizer, slice_config=slice_config)
|
||||
eval_dataset = dataset_cls(
|
||||
eval_json,
|
||||
transform,
|
||||
tokenizer,
|
||||
slice_config=slice_config,
|
||||
llm_type=llm_type,
|
||||
patch_size=patch_size,
|
||||
query_nums=query_nums,
|
||||
batch_vision=batch_vision,
|
||||
)
|
||||
else:
|
||||
eval_dataset = None
|
||||
|
||||
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)
|
||||
return dict(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
|
||||
def get_parameter_number(model):
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
# if using DS Zero 3 and the weights are initialized empty
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
trainable_params += num_params
|
||||
|
||||
return {'Total': all_param, 'Trainable': trainable_params}
|
||||
|
||||
|
||||
local_rank = 0
|
||||
|
||||
|
||||
def train():
|
||||
global local_rank
|
||||
|
||||
@@ -85,8 +133,8 @@ def train():
|
||||
data_args,
|
||||
training_args,
|
||||
) = parser.parse_args_into_dataclasses()
|
||||
|
||||
if getattr(training_args, 'deepspeed', None):
|
||||
|
||||
if getattr(training_args, "deepspeed", None):
|
||||
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
|
||||
|
||||
compute_dtype = (
|
||||
@@ -99,14 +147,50 @@ def train():
|
||||
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
ddp = world_size != 1
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
||||
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, torch_dtype=compute_dtype, device_map=device_map)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||
|
||||
#Load data
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map=device_map,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
|
||||
if not training_args.tune_vision:
|
||||
model.vpm.requires_grad_(False)
|
||||
if not training_args.tune_llm:
|
||||
model.llm.requires_grad_(False)
|
||||
rank0_print(get_parameter_number(model))
|
||||
|
||||
llm_type = "minicpm"
|
||||
if "llama3" in model.name_or_path.lower():
|
||||
tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
||||
llm_type = "llama3"
|
||||
|
||||
# Load data
|
||||
if hasattr(model.config, "slice_config"):
|
||||
slice_config = model.config.slice_config.to_dict()
|
||||
else:
|
||||
slice_config = model.config.to_dict()
|
||||
if hasattr(model.config, "batch_vision_input"):
|
||||
batch_vision = model.config.batch_vision_input
|
||||
else:
|
||||
batch_vision = False
|
||||
|
||||
data_module = make_supervised_data_module(
|
||||
tokenizer=tokenizer, data_args=data_args, transform=model.transform, data_collator=data_collator, slice_config=model.config.__dict__,
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
transform=model.transform,
|
||||
data_collator=data_collator,
|
||||
slice_config=slice_config,
|
||||
llm_type=llm_type,
|
||||
patch_size=model.config.patch_size,
|
||||
query_nums=model.config.query_num,
|
||||
batch_vision=batch_vision,
|
||||
)
|
||||
|
||||
trainer = CPMTrainer(
|
||||
@@ -115,11 +199,10 @@ def train():
|
||||
args=training_args,
|
||||
**data_module,
|
||||
)
|
||||
|
||||
|
||||
trainer.train()
|
||||
trainer.save_state()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user