update lora finetune inference bug (#224)

This commit is contained in:
qianyu chen
2024-06-07 18:00:22 +08:00
committed by GitHub
parent 31eaa26ee1
commit 9bd93a281c
6 changed files with 61 additions and 17 deletions

View File

@@ -13,6 +13,7 @@ from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
@@ -194,10 +195,10 @@ def conversation_to_ids_llama3(conversation, tokenizer):
input_ids = []
context = []
raw_msg = tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=False
conversation, tokenize=False, add_generation_prompt=False, chat_template=llama3_chat_template,
)
input_ids = tokenizer.apply_chat_template(
conversation, tokenize=True, add_generation_prompt=False
conversation, tokenize=True, add_generation_prompt=False, chat_template=llama3_chat_template,
)
input_ids = np.array(input_ids)

View File

@@ -51,7 +51,6 @@ class TrainingArguments(transformers.TrainingArguments):
llm_type: str = field(default="minicpm")
use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9)
scale_resolution: Optional[int] = field(default=448)
@dataclass
@@ -270,17 +269,15 @@ def train():
)
model = get_peft_model(model, lora_config)
model.base_model.resampler.requires_grad_(True)
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
if training_args.tune_vision:
model.base_model.vpm.requires_grad_(True)
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
rank0_print(get_parameter_number(model))
llm_type = training_args.llm_type
if llm_type == "llama3":
tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
rank0_print(f'llm_type={llm_type}')
@@ -288,11 +285,9 @@ def train():
# Load data
if hasattr(model.config, "slice_config"):
model.config.slice_config.max_slice_nums = training_args.max_slice_nums
model.config.slice_config.scale_resolution = training_args.scale_resolution
slice_config = model.config.slice_config.to_dict()
else:
model.config.max_slice_nums = training_args.max_slice_nums
model.config.scale_resolution = training_args.scale_resolution
slice_config = model.config.to_dict()
if hasattr(model.config, "batch_vision_input"):

View File

@@ -38,7 +38,6 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--tune_llm true \
--model_max_length 2048 \
--max_slice_nums 9 \
--scale_resolution 448 \
--max_steps 10000 \
--eval_steps 1000 \
--output_dir output/output_minicpmv2 \

View File

@@ -40,7 +40,6 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj)" \
--model_max_length 2048 \
--max_slice_nums 9 \
--scale_resolution 448 \
--max_steps 10000 \
--eval_steps 1000 \
--output_dir output/output_minicpmv2_lora \

View File

@@ -74,7 +74,7 @@ Specially, Llama3 has a different chat_template for training and inference, we m
The LoRA allows light-weight model tuning with only a small subset of parameters updated. We provide the LoRA implementation based on `peft`. To launch your training, run the following script:
```
sh finetune_ds_lora.sh
sh finetune_lora.sh
```
After training, you could load the model with the path to the adapter. We advise you to use absolute path for your pretrained model. This is because LoRA only saves the adapter and the absolute path in the adapter configuration json file is used for finding out the pretrained model to load.
@@ -82,12 +82,18 @@ After training, you could load the model with the path to the adapter. We advise
```
from peft import AutoPeftModelForCausalLM
path_to_adapter="path_to_adapter"
model = AutoPeftModelForCausalLM.from_pretrained(
# path to the output directory
path_to_adapter,
device_map="auto",
trust_remote_code=True
).eval()
vpm_resampler_embedtokens_weight = torch.load(f"{path_to_adapter}/vpm_resampler_embedtokens.pt")
msg = model.load_state_dict(vpm_resampler_embedtokens_weight, strict=False)
```
@@ -122,10 +128,6 @@ AWhen you face Out of Memory (OOM) issues during training large models, the f
```
--batch_size 1
```
- **Lower image resolution**: If your model processes image data, reducing the input resolution of images can effectively decrease memory usage.
```
--scale_resolution 448
```
- **Reduce the number of slices (`slice`)**: When handling large datasets such as large images files, reducing the number of slices processed each time can lower memory requirements.
```
--max_slice_nums 9

View File

@@ -20,12 +20,14 @@ class CPMTrainer(Trainer):
if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False)
else:
outputs = self.model.base_model(data = inputs, use_cache=False)
with self.model._enable_peft_forward_hooks(**inputs):
outputs = self.model.base_model(data = inputs, use_cache=False)
else:
if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False)
else:
outputs = self.model.base_model(data = inputs, use_cache=False)
with self.model._enable_peft_forward_hooks(**inputs):
outputs = self.model.base_model(data = inputs, use_cache=False)
if labels is not None:
# Flatten the tokens
@@ -174,6 +176,7 @@ class CPMTrainer(Trainer):
logits = logits[0]
return (loss, logits, labels)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
@@ -219,5 +222,50 @@ class CPMTrainer(Trainer):
self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
if self.args.use_lora:
from collections import OrderedDict
state_dict_vision = OrderedDict()
for key, values in state_dict.items():
if 'vpm' in key or 'resampler' in key or 'embed_tokens' in key:
state_dict_vision[key] = values
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
torch.save(state_dict_vision, f"{output_dir}/vpm_resampler_embedtokens.pt", )
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))