zero3支持 (#273)

update for lora's modules_to_save
This commit is contained in:
qianyu chen
2024-07-15 10:32:17 +08:00
committed by GitHub
parent ef7cfa81ec
commit e002c0e6ec
3 changed files with 26 additions and 51 deletions

View File

@@ -250,6 +250,9 @@ def train():
rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
for name, param in model.llm.named_parameters():
param.requires_grad = False
modules_to_save = ['embed_tokens','resampler']
if training_args.tune_vision:
modules_to_save.append('vpm')
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
@@ -257,7 +260,6 @@ def train():
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
layers_to_transform=lora_args.lora_layers_to_transform,
task_type="CAUSAL_LM",
)
if not hasattr(model, 'get_input_embeddings'):
def get_input_embeddings(self):
@@ -268,10 +270,6 @@ def train():
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
model = get_peft_model(model, lora_config)
model.base_model.resampler.requires_grad_(True)
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
if training_args.tune_vision:
model.base_model.vpm.requires_grad_(True)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()