zero3支持 (#273)

update for lora's modules_to_save
This commit is contained in:
qianyu chen
2024-07-15 10:32:17 +08:00
committed by GitHub
parent ef7cfa81ec
commit e002c0e6ec
3 changed files with 26 additions and 51 deletions

View File

@@ -250,6 +250,9 @@ def train():
rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.") rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
for name, param in model.llm.named_parameters(): for name, param in model.llm.named_parameters():
param.requires_grad = False param.requires_grad = False
modules_to_save = ['embed_tokens','resampler']
if training_args.tune_vision:
modules_to_save.append('vpm')
lora_config = LoraConfig( lora_config = LoraConfig(
r=lora_args.lora_r, r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha, lora_alpha=lora_args.lora_alpha,
@@ -257,7 +260,6 @@ def train():
lora_dropout=lora_args.lora_dropout, lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias, bias=lora_args.lora_bias,
layers_to_transform=lora_args.lora_layers_to_transform, layers_to_transform=lora_args.lora_layers_to_transform,
task_type="CAUSAL_LM",
) )
if not hasattr(model, 'get_input_embeddings'): if not hasattr(model, 'get_input_embeddings'):
def get_input_embeddings(self): def get_input_embeddings(self):
@@ -268,10 +270,6 @@ def train():
model, use_gradient_checkpointing=training_args.gradient_checkpointing model, use_gradient_checkpointing=training_args.gradient_checkpointing
) )
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
model.base_model.resampler.requires_grad_(True)
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
if training_args.tune_vision:
model.base_model.vpm.requires_grad_(True)
if training_args.gradient_checkpointing: if training_args.gradient_checkpointing:
model.enable_input_require_grads() model.enable_input_require_grads()

View File

@@ -80,20 +80,16 @@ sh finetune_lora.sh
After training, you could load the model with the path to the adapter. We advise you to use absolute path for your pretrained model. This is because LoRA only saves the adapter and the absolute path in the adapter configuration json file is used for finding out the pretrained model to load. After training, you could load the model with the path to the adapter. We advise you to use absolute path for your pretrained model. This is because LoRA only saves the adapter and the absolute path in the adapter configuration json file is used for finding out the pretrained model to load.
``` ```
from peft import AutoPeftModelForCausalLM from peft import AutoPeftModel
path_to_adapter="path_to_adapter" path_to_adapter="path_to_your_fine_tuned_checkpoint"
model = AutoPeftModelForCausalLM.from_pretrained( model = AutoPeftModel.from_pretrained(
# path to the output directory # path to the output directory
path_to_adapter, path_to_adapter,
device_map="auto", device_map="auto",
trust_remote_code=True trust_remote_code=True
).eval() ).eval().cuda()
vpm_resampler_embedtokens_weight = torch.load(f"{path_to_adapter}/vpm_resampler_embedtokens.pt")
msg = model.load_state_dict(vpm_resampler_embedtokens_weight, strict=False)
``` ```
@@ -173,14 +169,16 @@ A: The error as described in [issues 168](https://github.com/OpenBMB/MiniCPM-V/i
1.**Reload the Fine-Tuned Model:** Make sure you correctly load the checkpoint that has been fine-tuned using lora techniques. Use the following code example to guide you: 1.**Reload the Fine-Tuned Model:** Make sure you correctly load the checkpoint that has been fine-tuned using lora techniques. Use the following code example to guide you:
```python ```python
from peft import AutoPeftModelForCausalLM from peft import AutoPeftModel
model = AutoPeftModelForCausalLM.from_pretrained( path_to_adapter="path_to_your_fine_tuned_checkpoint"
'path_to_your_fine_tuned_checkpoint', # Path to your fine-tuned checkpoint directory
output='output/minicpmv2_lora', model = AutoPeftModel.from_pretrained(
device_map='auto', # path to the output directory
path_to_adapter,
device_map="auto",
trust_remote_code=True trust_remote_code=True
).eval() ).eval().cuda()
``` ```
2.**Update the `model_minicpmv.py` File:** 2.**Update the `model_minicpmv.py` File:**
- **Verification:** Make sure you verify and update your `model_minicpmv.py` file to ensure it is the latest version. - **Verification:** Make sure you verify and update your `model_minicpmv.py` file to ensure it is the latest version.

View File

@@ -15,14 +15,7 @@ class CPMTrainer(Trainer):
else: else:
labels = None labels = None
self.model.resampler.pos_embed = self.model.resampler.pos_embed.to(self.model.device) self.model.resampler.pos_embed = self.model.resampler.pos_embed.to(self.model.device)
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False)
else:
with self.model._enable_peft_forward_hooks(**inputs):
outputs = self.model.base_model(data = inputs, use_cache=False)
else:
if not self.args.use_lora: if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False) outputs = self.model(data = inputs, use_cache=False)
else: else:
@@ -214,10 +207,6 @@ class CPMTrainer(Trainer):
if self.use_apex: if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss: with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else:
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
self.accelerator.backward(loss)
else: else:
self.accelerator.backward(loss) self.accelerator.backward(loss)
@@ -249,17 +238,7 @@ class CPMTrainer(Trainer):
else: else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
if self.args.use_lora:
from collections import OrderedDict
state_dict_vision = OrderedDict()
for key, values in state_dict.items():
if 'vpm' in key or 'resampler' in key or 'embed_tokens' in key:
state_dict_vision[key] = values
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
torch.save(state_dict_vision, f"{output_dir}/vpm_resampler_embedtokens.pt", )
else:
self.model.save_pretrained( self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
) )