Update to MiniCPM-Llama3-V 2.5

This commit is contained in:
yiranyyu
2024-05-20 12:44:33 +08:00
parent c0e39dbfe2
commit 2c75097411
27 changed files with 1944 additions and 985 deletions

View File

@@ -1,22 +1,22 @@
import os
import glob
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Dict, Optional, List
from typing import Dict, List, Optional
import torch
from torch.utils.data import Dataset
import transformers
from trainer import CPMTrainer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed import zero
from dataset import data_collator, SupervisedDataset
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from accelerate.utils import DistributedType
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from PIL import Image
from torch.utils.data import Dataset
from transformers import AutoModel, AutoTokenizer
from dataset import SupervisedDataset, data_collator
from trainer import CPMTrainer
@dataclass
class ModelArguments:
@@ -44,6 +44,8 @@ class TrainingArguments(transformers.TrainingArguments):
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
tune_vision: Optional[bool] = field(default=True)
tune_llm: Optional[bool] = field(default=True)
def rank0_print(*args):
@@ -52,7 +54,15 @@ def rank0_print(*args):
def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer, data_args, transform, data_collator=None, slice_config=None,
tokenizer: transformers.PreTrainedTokenizer,
data_args,
transform,
data_collator=None,
llm_type="minicpm",
slice_config=None,
patch_size=14,
query_nums=64,
batch_vision=False,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = SupervisedDataset
@@ -60,19 +70,57 @@ def make_supervised_data_module(
rank0_print("Loading data...")
train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(train_json, transform, tokenizer, slice_config=slice_config)
train_dataset = dataset_cls(
train_json,
transform,
tokenizer,
slice_config=slice_config,
llm_type=llm_type,
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
)
if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(eval_json, transform, tokenizer, slice_config=slice_config)
eval_dataset = dataset_cls(
eval_json,
transform,
tokenizer,
slice_config=slice_config,
llm_type=llm_type,
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
)
else:
eval_dataset = None
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)
return dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
def get_parameter_number(model):
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return {'Total': all_param, 'Trainable': trainable_params}
local_rank = 0
def train():
global local_rank
@@ -85,8 +133,8 @@ def train():
data_args,
training_args,
) = parser.parse_args_into_dataclasses()
if getattr(training_args, 'deepspeed', None):
if getattr(training_args, "deepspeed", None):
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
compute_dtype = (
@@ -99,14 +147,50 @@ def train():
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, torch_dtype=compute_dtype, device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
#Load data
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
torch_dtype=compute_dtype,
device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
)
if not training_args.tune_vision:
model.vpm.requires_grad_(False)
if not training_args.tune_llm:
model.llm.requires_grad_(False)
rank0_print(get_parameter_number(model))
llm_type = "minicpm"
if "llama3" in model.name_or_path.lower():
tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
llm_type = "llama3"
# Load data
if hasattr(model.config, "slice_config"):
slice_config = model.config.slice_config.to_dict()
else:
slice_config = model.config.to_dict()
if hasattr(model.config, "batch_vision_input"):
batch_vision = model.config.batch_vision_input
else:
batch_vision = False
data_module = make_supervised_data_module(
tokenizer=tokenizer, data_args=data_args, transform=model.transform, data_collator=data_collator, slice_config=model.config.__dict__,
tokenizer=tokenizer,
data_args=data_args,
transform=model.transform,
data_collator=data_collator,
slice_config=slice_config,
llm_type=llm_type,
patch_size=model.config.patch_size,
query_nums=model.config.query_num,
batch_vision=batch_vision,
)
trainer = CPMTrainer(
@@ -115,11 +199,10 @@ def train():
args=training_args,
**data_module,
)
trainer.train()
trainer.save_state()
if __name__ == "__main__":
train()