From 9bd93a281cf0da7e23dd11d597fe044bc92c9dbd Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Fri, 7 Jun 2024 18:00:22 +0800 Subject: [PATCH] update lora finetune inference bug (#224) --- finetune/dataset.py | 5 ++-- finetune/finetune.py | 7 +----- finetune/finetune_ds.sh | 1 - finetune/finetune_lora.sh | 1 - finetune/readme.md | 12 +++++---- finetune/trainer.py | 52 +++++++++++++++++++++++++++++++++++++-- 6 files changed, 61 insertions(+), 17 deletions(-) diff --git a/finetune/dataset.py b/finetune/dataset.py index c2dbfda..92807c3 100644 --- a/finetune/dataset.py +++ b/finetune/dataset.py @@ -13,6 +13,7 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from transformers import AutoProcessor, AutoTokenizer +llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" @@ -194,10 +195,10 @@ def conversation_to_ids_llama3(conversation, tokenizer): input_ids = [] context = [] raw_msg = tokenizer.apply_chat_template( - conversation, tokenize=False, add_generation_prompt=False + conversation, tokenize=False, add_generation_prompt=False, chat_template=llama3_chat_template, ) input_ids = tokenizer.apply_chat_template( - conversation, tokenize=True, add_generation_prompt=False + conversation, tokenize=True, add_generation_prompt=False, chat_template=llama3_chat_template, ) input_ids = np.array(input_ids) diff --git a/finetune/finetune.py b/finetune/finetune.py index bf94c3d..3333454 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -51,7 +51,6 @@ class TrainingArguments(transformers.TrainingArguments): llm_type: str = field(default="minicpm") use_lora: Optional[bool] = field(default=False) max_slice_nums: Optional[int] = field(default=9) - scale_resolution: Optional[int] = field(default=448) @dataclass @@ -270,17 +269,15 @@ def train(): ) model = get_peft_model(model, lora_config) model.base_model.resampler.requires_grad_(True) + model.base_model.llm.model.embed_tokens.weight.requires_grad_(True) if training_args.tune_vision: model.base_model.vpm.requires_grad_(True) - model.base_model.llm.model.embed_tokens.weight.requires_grad_(True) if training_args.gradient_checkpointing: model.enable_input_require_grads() rank0_print(get_parameter_number(model)) llm_type = training_args.llm_type - if llm_type == "llama3": - tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" rank0_print(f'llm_type={llm_type}') @@ -288,11 +285,9 @@ def train(): # Load data if hasattr(model.config, "slice_config"): model.config.slice_config.max_slice_nums = training_args.max_slice_nums - model.config.slice_config.scale_resolution = training_args.scale_resolution slice_config = model.config.slice_config.to_dict() else: model.config.max_slice_nums = training_args.max_slice_nums - model.config.scale_resolution = training_args.scale_resolution slice_config = model.config.to_dict() if hasattr(model.config, "batch_vision_input"): diff --git a/finetune/finetune_ds.sh b/finetune/finetune_ds.sh index 2a86e6b..5dc3a3e 100644 --- a/finetune/finetune_ds.sh +++ b/finetune/finetune_ds.sh @@ -38,7 +38,6 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --tune_llm true \ --model_max_length 2048 \ --max_slice_nums 9 \ - --scale_resolution 448 \ --max_steps 10000 \ --eval_steps 1000 \ --output_dir output/output_minicpmv2 \ diff --git a/finetune/finetune_lora.sh b/finetune/finetune_lora.sh index 25d61c7..49c375b 100644 --- a/finetune/finetune_lora.sh +++ b/finetune/finetune_lora.sh @@ -40,7 +40,6 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj)" \ --model_max_length 2048 \ --max_slice_nums 9 \ - --scale_resolution 448 \ --max_steps 10000 \ --eval_steps 1000 \ --output_dir output/output_minicpmv2_lora \ diff --git a/finetune/readme.md b/finetune/readme.md index 4945404..2f4e095 100644 --- a/finetune/readme.md +++ b/finetune/readme.md @@ -74,7 +74,7 @@ Specially, Llama3 has a different chat_template for training and inference, we m The LoRA allows light-weight model tuning with only a small subset of parameters updated. We provide the LoRA implementation based on `peft`. To launch your training, run the following script: ``` -sh finetune_ds_lora.sh +sh finetune_lora.sh ``` After training, you could load the model with the path to the adapter. We advise you to use absolute path for your pretrained model. This is because LoRA only saves the adapter and the absolute path in the adapter configuration json file is used for finding out the pretrained model to load. @@ -82,12 +82,18 @@ After training, you could load the model with the path to the adapter. We advise ``` from peft import AutoPeftModelForCausalLM +path_to_adapter="path_to_adapter" + model = AutoPeftModelForCausalLM.from_pretrained( # path to the output directory path_to_adapter, device_map="auto", trust_remote_code=True ).eval() + +vpm_resampler_embedtokens_weight = torch.load(f"{path_to_adapter}/vpm_resampler_embedtokens.pt") + +msg = model.load_state_dict(vpm_resampler_embedtokens_weight, strict=False) ``` @@ -122,10 +128,6 @@ A:When you face Out of Memory (OOM) issues during training large models, the f ``` --batch_size 1 ``` -- **Lower image resolution**: If your model processes image data, reducing the input resolution of images can effectively decrease memory usage. -``` ---scale_resolution 448 -``` - **Reduce the number of slices (`slice`)**: When handling large datasets such as large images files, reducing the number of slices processed each time can lower memory requirements. ``` --max_slice_nums 9 diff --git a/finetune/trainer.py b/finetune/trainer.py index fbb8c89..bea2eff 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -20,12 +20,14 @@ class CPMTrainer(Trainer): if not self.args.use_lora: outputs = self.model(data = inputs, use_cache=False) else: - outputs = self.model.base_model(data = inputs, use_cache=False) + with self.model._enable_peft_forward_hooks(**inputs): + outputs = self.model.base_model(data = inputs, use_cache=False) else: if not self.args.use_lora: outputs = self.model(data = inputs, use_cache=False) else: - outputs = self.model.base_model(data = inputs, use_cache=False) + with self.model._enable_peft_forward_hooks(**inputs): + outputs = self.model.base_model(data = inputs, use_cache=False) if labels is not None: # Flatten the tokens @@ -174,6 +176,7 @@ class CPMTrainer(Trainer): logits = logits[0] return (loss, logits, labels) + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: """ Perform a training step on a batch of inputs. @@ -219,5 +222,50 @@ class CPMTrainer(Trainer): self.accelerator.backward(loss) return loss.detach() / self.args.gradient_accumulation_steps + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, supported_classes): + if state_dict is None: + state_dict = self.model.state_dict() + if isinstance(unwrap_model(self.model), supported_classes): + unwrap_model(self.model).save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if self.args.save_safetensors: + safetensors.torch.save_file( + state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} + ) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + if self.args.use_lora: + from collections import OrderedDict + state_dict_vision = OrderedDict() + for key, values in state_dict.items(): + if 'vpm' in key or 'resampler' in key or 'embed_tokens' in key: + state_dict_vision[key] = values + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + torch.save(state_dict_vision, f"{output_dir}/vpm_resampler_embedtokens.pt", ) + else: + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))