From e002c0e6ec2418e25e4019fb6ca911a2dd18adc5 Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:32:17 +0800 Subject: [PATCH] =?UTF-8?q?zero3=E6=94=AF=E6=8C=81=20(#273)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit update for lora's modules_to_save --- finetune/finetune.py | 8 +++----- finetune/readme.md | 28 +++++++++++++--------------- finetune/trainer.py | 41 ++++++++++------------------------------- 3 files changed, 26 insertions(+), 51 deletions(-) diff --git a/finetune/finetune.py b/finetune/finetune.py index 3333454..671d6c7 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -250,6 +250,9 @@ def train(): rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.") for name, param in model.llm.named_parameters(): param.requires_grad = False + modules_to_save = ['embed_tokens','resampler'] + if training_args.tune_vision: + modules_to_save.append('vpm') lora_config = LoraConfig( r=lora_args.lora_r, lora_alpha=lora_args.lora_alpha, @@ -257,7 +260,6 @@ def train(): lora_dropout=lora_args.lora_dropout, bias=lora_args.lora_bias, layers_to_transform=lora_args.lora_layers_to_transform, - task_type="CAUSAL_LM", ) if not hasattr(model, 'get_input_embeddings'): def get_input_embeddings(self): @@ -268,10 +270,6 @@ def train(): model, use_gradient_checkpointing=training_args.gradient_checkpointing ) model = get_peft_model(model, lora_config) - model.base_model.resampler.requires_grad_(True) - model.base_model.llm.model.embed_tokens.weight.requires_grad_(True) - if training_args.tune_vision: - model.base_model.vpm.requires_grad_(True) if training_args.gradient_checkpointing: model.enable_input_require_grads() diff --git a/finetune/readme.md b/finetune/readme.md index 2f4e095..c656c0d 100644 --- a/finetune/readme.md +++ b/finetune/readme.md @@ -80,20 +80,16 @@ sh finetune_lora.sh After training, you could load the model with the path to the adapter. We advise you to use absolute path for your pretrained model. This is because LoRA only saves the adapter and the absolute path in the adapter configuration json file is used for finding out the pretrained model to load. ``` -from peft import AutoPeftModelForCausalLM +from peft import AutoPeftModel -path_to_adapter="path_to_adapter" +path_to_adapter="path_to_your_fine_tuned_checkpoint" -model = AutoPeftModelForCausalLM.from_pretrained( +model = AutoPeftModel.from_pretrained( # path to the output directory path_to_adapter, device_map="auto", trust_remote_code=True -).eval() - -vpm_resampler_embedtokens_weight = torch.load(f"{path_to_adapter}/vpm_resampler_embedtokens.pt") - -msg = model.load_state_dict(vpm_resampler_embedtokens_weight, strict=False) +).eval().cuda() ``` @@ -173,14 +169,16 @@ A: The error as described in [issues 168](https://github.com/OpenBMB/MiniCPM-V/i 1.**Reload the Fine-Tuned Model:** Make sure you correctly load the checkpoint that has been fine-tuned using lora techniques. Use the following code example to guide you: ```python - from peft import AutoPeftModelForCausalLM + from peft import AutoPeftModel - model = AutoPeftModelForCausalLM.from_pretrained( - 'path_to_your_fine_tuned_checkpoint', # Path to your fine-tuned checkpoint directory - output='output/minicpmv2_lora', - device_map='auto', - trust_remote_code=True - ).eval() +path_to_adapter="path_to_your_fine_tuned_checkpoint" + +model = AutoPeftModel.from_pretrained( + # path to the output directory + path_to_adapter, + device_map="auto", + trust_remote_code=True +).eval().cuda() ``` 2.**Update the `model_minicpmv.py` File:** - **Verification:** Make sure you verify and update your `model_minicpmv.py` file to ensure it is the latest version. diff --git a/finetune/trainer.py b/finetune/trainer.py index fa57bd0..cc45c97 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -15,19 +15,12 @@ class CPMTrainer(Trainer): else: labels = None self.model.resampler.pos_embed = self.model.resampler.pos_embed.to(self.model.device) - if is_deepspeed_zero3_enabled(): - with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0): - if not self.args.use_lora: - outputs = self.model(data = inputs, use_cache=False) - else: - with self.model._enable_peft_forward_hooks(**inputs): - outputs = self.model.base_model(data = inputs, use_cache=False) + + if not self.args.use_lora: + outputs = self.model(data = inputs, use_cache=False) else: - if not self.args.use_lora: - outputs = self.model(data = inputs, use_cache=False) - else: - with self.model._enable_peft_forward_hooks(**inputs): - outputs = self.model.base_model(data = inputs, use_cache=False) + with self.model._enable_peft_forward_hooks(**inputs): + outputs = self.model.base_model(data = inputs, use_cache=False) if labels is not None: # Flatten the tokens @@ -215,11 +208,7 @@ class CPMTrainer(Trainer): with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - if is_deepspeed_zero3_enabled(): - with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0): - self.accelerator.backward(loss) - else: - self.accelerator.backward(loss) + self.accelerator.backward(loss) return loss.detach() / self.args.gradient_accumulation_steps @@ -249,20 +238,10 @@ class CPMTrainer(Trainer): else: torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - if self.args.use_lora: - from collections import OrderedDict - state_dict_vision = OrderedDict() - for key, values in state_dict.items(): - if 'vpm' in key or 'resampler' in key or 'embed_tokens' in key: - state_dict_vision[key] = values - self.model.save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors - ) - torch.save(state_dict_vision, f"{output_dir}/vpm_resampler_embedtokens.pt", ) - else: - self.model.save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors - ) + + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir)