mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 02:09:20 +08:00
Update LoRA finetuning code (#154)
* update lora tuning * updata lora fine-tuning code * update finetuning lora code * lora code * lora finetuning code * updating lora finetuning code * update lora finetuning code * Update Lora finetuning code * Update LoRA finetuning code * Update LoRA finetuning code
This commit is contained in:
@@ -56,6 +56,7 @@ class SupervisedDataset(Dataset):
|
||||
)
|
||||
ret = dict(
|
||||
input_ids=ret["input_ids"],
|
||||
position_ids=ret["position_ids"],
|
||||
labels=ret["target"],
|
||||
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
|
||||
pixel_values=ret["pixel_values"],
|
||||
@@ -72,6 +73,11 @@ def data_collator(examples, padding_value=0):
|
||||
batch_first=True,
|
||||
padding_value=padding_value,
|
||||
)
|
||||
position_ids = pad_sequence(
|
||||
[example["position_ids"] for example in examples],
|
||||
batch_first=True,
|
||||
padding_value=padding_value,
|
||||
)
|
||||
targets = pad_sequence(
|
||||
[example["labels"] for example in examples],
|
||||
batch_first=True,
|
||||
@@ -87,6 +93,7 @@ def data_collator(examples, padding_value=0):
|
||||
tgt_sizes = [example["tgt_sizes"] for example in examples]
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"labels": targets,
|
||||
"attention_mask": attention_mask,
|
||||
"image_bound": image_bound,
|
||||
@@ -130,6 +137,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
||||
if len(image_start_tokens) != len(image_end_tokens):
|
||||
print("image start token != image end tokens")
|
||||
|
||||
if len(image_start_tokens) > 0:
|
||||
image_bound = torch.hstack(
|
||||
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
|
||||
@@ -137,11 +145,13 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
else:
|
||||
image_bound = []
|
||||
|
||||
position_ids = torch.where(ids != 0, torch.arange(ids.size(0)), torch.tensor(0)).long()
|
||||
return {
|
||||
"input_ids": ids,
|
||||
"target": target,
|
||||
"image_bound": image_bound,
|
||||
"raw_msg": raw_msg,
|
||||
"position_ids": position_ids
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": false,
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
@@ -9,7 +9,7 @@
|
||||
},
|
||||
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
"enabled": "auto"
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
|
||||
@@ -3,20 +3,22 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from typing import Dict, List, Optional, Union, Literal, Tuple
|
||||
from types import MethodType
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate.utils import DistributedType
|
||||
from deepspeed import zero
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers.integrations import deepspeed
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from dataset import SupervisedDataset, data_collator
|
||||
from trainer import CPMTrainer
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@@ -31,7 +33,6 @@ class DataArguments:
|
||||
eval_data_path: str = field(
|
||||
default=None, metadata={"help": "Path to the evaluation data."}
|
||||
)
|
||||
lazy_preprocess: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -45,15 +46,83 @@ class TrainingArguments(transformers.TrainingArguments):
|
||||
},
|
||||
)
|
||||
tune_vision: Optional[bool] = field(default=True)
|
||||
tune_llm: Optional[bool] = field(default=True)
|
||||
tune_llm: Optional[bool] = field(default=False)
|
||||
llm_type: str = field(default="minicpm")
|
||||
use_lora: Optional[bool] = field(default=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraArguments:
|
||||
lora_r: int = 64
|
||||
lora_alpha: int = 64
|
||||
lora_dropout: float = 0.05
|
||||
lora_target_modules: str = r"llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj)"
|
||||
lora_weight_path: str = ""
|
||||
lora_bias: str = "none"
|
||||
q_lora: bool = False
|
||||
lora_modules_to_save: str = ""
|
||||
lora_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||
lora_layers_to_transform: Optional[List[int]] = None
|
||||
lora_layers_pattern: Optional[str] = None
|
||||
|
||||
def maybe_zero_3(param):
|
||||
if hasattr(param, "ds_id"):
|
||||
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
|
||||
with zero.GatheredParameters([param]):
|
||||
param = param.data.detach().cpu().clone()
|
||||
else:
|
||||
param = param.detach().cpu().clone()
|
||||
return param
|
||||
|
||||
|
||||
# Borrowed from peft.utils.get_peft_model_state_dict
|
||||
def get_peft_state_maybe_zero_3(named_params, bias):
|
||||
if bias == "none":
|
||||
to_return = {k: t for k, t in named_params if "lora_" in k}
|
||||
elif bias == "all":
|
||||
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
||||
elif bias == "lora_only":
|
||||
to_return = {}
|
||||
maybe_lora_bias = {}
|
||||
lora_bias_names = set()
|
||||
for k, t in named_params:
|
||||
if "lora_" in k:
|
||||
to_return[k] = t
|
||||
bias_name = k.split("lora_")[0] + "bias"
|
||||
lora_bias_names.add(bias_name)
|
||||
elif "bias" in k:
|
||||
maybe_lora_bias[k] = t
|
||||
for k, t in maybe_lora_bias:
|
||||
if bias_name in lora_bias_names:
|
||||
to_return[bias_name] = t
|
||||
else:
|
||||
raise NotImplementedError
|
||||
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
|
||||
return to_return
|
||||
|
||||
|
||||
local_rank = None
|
||||
def rank0_print(*args):
|
||||
if local_rank == 0:
|
||||
print(*args)
|
||||
|
||||
|
||||
def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
|
||||
"""Collects the state dict and dump to disk."""
|
||||
# check if zero3 mode enabled
|
||||
if deepspeed.is_deepspeed_zero3_enabled():
|
||||
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
|
||||
else:
|
||||
if trainer.args.use_lora:
|
||||
state_dict = get_peft_state_maybe_zero_3(
|
||||
trainer.model.named_parameters(), bias
|
||||
)
|
||||
else:
|
||||
state_dict = trainer.model.state_dict()
|
||||
if trainer.args.should_save and trainer.args.local_rank == 0:
|
||||
trainer._save(output_dir, state_dict=state_dict)
|
||||
|
||||
|
||||
def make_supervised_data_module(
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
data_args,
|
||||
@@ -124,18 +193,18 @@ local_rank = 0
|
||||
|
||||
def train():
|
||||
global local_rank
|
||||
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
|
||||
)
|
||||
|
||||
(
|
||||
model_args,
|
||||
data_args,
|
||||
training_args,
|
||||
lora_args,
|
||||
) = parser.parse_args_into_dataclasses()
|
||||
|
||||
if getattr(training_args, "deepspeed", None):
|
||||
if getattr(training_args, "deepspeed", None) :
|
||||
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
|
||||
|
||||
compute_dtype = (
|
||||
@@ -145,18 +214,17 @@ def train():
|
||||
)
|
||||
|
||||
local_rank = training_args.local_rank
|
||||
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
ddp = world_size != 1
|
||||
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
||||
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map=device_map,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
@@ -165,6 +233,32 @@ def train():
|
||||
model.vpm.requires_grad_(False)
|
||||
if not training_args.tune_llm:
|
||||
model.llm.requires_grad_(False)
|
||||
|
||||
if training_args.use_lora:
|
||||
if training_args.use_lora and training_args.tune_llm:
|
||||
raise ValueError("The model cannot simultaneously adjust LLM parameters and apply LoRA.")
|
||||
|
||||
rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
|
||||
for name, param in model.llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
lora_config = LoraConfig(
|
||||
r=lora_args.lora_r,
|
||||
lora_alpha=lora_args.lora_alpha,
|
||||
target_modules=lora_args.lora_target_modules,
|
||||
lora_dropout=lora_args.lora_dropout,
|
||||
bias=lora_args.lora_bias,
|
||||
layers_to_transform=lora_args.lora_layers_to_transform,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
if training_args.gradient_checkpointing:
|
||||
def get_input_embeddings(self):
|
||||
return self.llm.get_input_embeddings()
|
||||
model.get_input_embeddings = MethodType(get_input_embeddings, model)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
|
||||
if training_args.gradient_checkpointing:
|
||||
model.enable_input_require_grads()
|
||||
|
||||
rank0_print(get_parameter_number(model))
|
||||
|
||||
llm_type = training_args.llm_type
|
||||
@@ -194,7 +288,7 @@ def train():
|
||||
query_nums=model.config.query_num,
|
||||
batch_vision=batch_vision,
|
||||
)
|
||||
|
||||
|
||||
trainer = CPMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -205,6 +299,11 @@ def train():
|
||||
trainer.train()
|
||||
trainer.save_state()
|
||||
|
||||
safe_save_model_for_hf_trainer(
|
||||
trainer=trainer,
|
||||
output_dir=training_args.output_dir,
|
||||
bias=lora_args.lora_bias)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
||||
@@ -20,37 +20,41 @@ DISTRIBUTED_ARGS="
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT
|
||||
"
|
||||
torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--model_name_or_path $MODEL \
|
||||
--llm_type $LLM_TYPE \
|
||||
--data_path $DATA \
|
||||
--eval_data_path $EVAL_DATA \
|
||||
--remove_unused_columns false \
|
||||
--label_names "labels" \
|
||||
--prediction_loss_only false \
|
||||
--prediction_loss_only false \
|
||||
--bf16 true \
|
||||
--bf16_full_eval true \
|
||||
--fp16 false \
|
||||
--fp16_full_eval false \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--tune_vision false \
|
||||
--tune_llm false \
|
||||
--model_max_length 2048 \
|
||||
--max_steps 80000 \
|
||||
--eval_steps 200 \
|
||||
--max_steps 10000 \
|
||||
--eval_steps 1000 \
|
||||
--output_dir output/output_minicpmv2 \
|
||||
--logging_dir output/output_minicpmv2 \
|
||||
--logging_strategy "steps" \
|
||||
--per_device_train_batch_size 8 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--evaluation_strategy "steps" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 1000 \
|
||||
--save_total_limit 10 \
|
||||
--learning_rate 5e-7 \
|
||||
--learning_rate 1e-6 \
|
||||
--weight_decay 0.1 \
|
||||
--adam_beta2 0.95 \
|
||||
--warmup_ratio 0.01 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--gradient_checkpointing True \
|
||||
--gradient_checkpointing true \
|
||||
--deepspeed ds_config_zero2.json \
|
||||
--report_to "tensorboard" # wandb
|
||||
--report_to "tensorboard"
|
||||
|
||||
61
finetune/finetune_lora.sh
Normal file
61
finetune/finetune_lora.sh
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
|
||||
GPUS_PER_NODE=8
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
|
||||
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
|
||||
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
||||
# See the section for finetuning in README for more information.
|
||||
DATA="path/to/trainging_data"
|
||||
EVAL_DATA="path/to/test_data"
|
||||
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $GPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT
|
||||
"
|
||||
torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--model_name_or_path $MODEL \
|
||||
--llm_type $LLM_TYPE \
|
||||
--data_path $DATA \
|
||||
--eval_data_path $EVAL_DATA \
|
||||
--remove_unused_columns false \
|
||||
--label_names "labels" \
|
||||
--prediction_loss_only false \
|
||||
--bf16 true \
|
||||
--bf16_full_eval true \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--tune_vision false \
|
||||
--tune_llm false \
|
||||
--use_lora true \
|
||||
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj)" \
|
||||
--model_max_length 2048 \
|
||||
--max_steps 10000 \
|
||||
--eval_steps 1000 \
|
||||
--output_dir output/output_minicpmv2_lora \
|
||||
--logging_dir output/output_minicpmv2_lora \
|
||||
--logging_strategy "steps" \
|
||||
--per_device_train_batch_size w \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--evaluation_strategy "steps" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 1000 \
|
||||
--save_total_limit 10 \
|
||||
--learning_rate 1e-6 \
|
||||
--weight_decay 0.1 \
|
||||
--adam_beta2 0.95 \
|
||||
--warmup_ratio 0.01 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--gradient_checkpointing true \
|
||||
--deepspeed ds_config_zero2.json \
|
||||
--report_to "tensorboard" \ # wandb
|
||||
|
||||
@@ -69,6 +69,65 @@ sh finetune_ds.sh
|
||||
|
||||
Specially, Llama3 has a different chat_template for training and inference, we modified the chat_template for training, so please take care to restore the chat_template when inference on the training ckpt.
|
||||
|
||||
### LoRA finetuning
|
||||
|
||||
The LoRA allows light-weight model tuning with only a small subset of parameters updated. We provide the LoRA implementation based on `peft`. To launch your training, run the following script:
|
||||
|
||||
```
|
||||
sh finetune_ds_lora.sh
|
||||
```
|
||||
|
||||
After training, you could load the model with the path to the adapter. We advise you to use absolute path for your pretrained model. This is because LoRA only saves the adapter and the absolute path in the adapter configuration json file is used for finding out the pretrained model to load.
|
||||
|
||||
```
|
||||
from peft import AutoPeftModelForCausalLM
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
# path to the output directory
|
||||
path_to_adapter,
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
).eval()
|
||||
```
|
||||
|
||||
### Finetuning FAQs
|
||||
<details>
|
||||
<summary>Q: How do I use the `flash_attention_2` implementation when loading a pretrained model?</summary>
|
||||
|
||||
A: If your environment supports `flash_attn2`, you can add an argument `_attn_implementation="flash_attention_2"` when using the `AutoModel.from_pretrained` method to load a model. For example:
|
||||
|
||||
```python
|
||||
model = AutoModel.from_pretrained('model_name', _attn_implementation="flash_attention_2")
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Q: What if our data is resized to 512? Can we use the original image size instead?</summary>
|
||||
|
||||
A: Our model supports up to 1344x1344 lossless encoding. If you are currently resizing your images to 512, you might want to try using the original image sizes instead. Our system automatically includes a high-definition image encoding scheme by default.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Q: What should we do if we encounter out-of-memory (OOM) errors?</summary>
|
||||
|
||||
A: If you experience OOM issues, consider reducing the batch size (`bs`). To maintain an equivalent total batch size, you can adjust the `gradient_accumulation_steps` setting. This approach allows you to manage memory usage effectively while still processing the desired amount of data per training step.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Q: How can we determine the maximum length for our training data, and what if we do not want to train the vision encoder?</summary>
|
||||
|
||||
A: I recommend using this function [here](https://github.com/OpenBMB/MiniCPM-V/blob/main/finetune/dataset.py#L220) to sample the length of your training data. Note that the `input_ids` length includes the image portion. Once you determine the maximum length, you can specify it in the startup command using `--model_max_length xxx`.
|
||||
|
||||
Additionally, if you prefer not to train the vision encoder, you can add `--tune_vision false` to your command.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Q: How can we adjust training hyperparameters when using LoRA to train our model?</summary>
|
||||
|
||||
A: You can refer to the [LoRA documentation](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig) for guidance on adjusting your training hyperparameters when using LoRA. This documentation provides detailed information on configuring various parameters specific to the LoRA adaptation technique.
|
||||
</details>
|
||||
|
||||
#### Customizing Hyperparameters
|
||||
To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset.
|
||||
To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) and [Lora documentation](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset.
|
||||
|
||||
@@ -13,14 +13,10 @@ class CPMTrainer(Trainer):
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
|
||||
vllm_embedding, vision_hidden_states = self.model.get_vllm_embedding(
|
||||
inputs)
|
||||
|
||||
outputs = self.model.llm(
|
||||
inputs_embeds=vllm_embedding,
|
||||
use_cache=False,
|
||||
)
|
||||
if not self. args.use_lora:
|
||||
outputs = self.model(data = inputs, use_cache=False)
|
||||
else:
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
|
||||
if labels is not None:
|
||||
# Flatten the tokens
|
||||
|
||||
Reference in New Issue
Block a user