mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
@@ -250,6 +250,9 @@ def train():
|
||||
rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
|
||||
for name, param in model.llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
modules_to_save = ['embed_tokens','resampler']
|
||||
if training_args.tune_vision:
|
||||
modules_to_save.append('vpm')
|
||||
lora_config = LoraConfig(
|
||||
r=lora_args.lora_r,
|
||||
lora_alpha=lora_args.lora_alpha,
|
||||
@@ -257,7 +260,6 @@ def train():
|
||||
lora_dropout=lora_args.lora_dropout,
|
||||
bias=lora_args.lora_bias,
|
||||
layers_to_transform=lora_args.lora_layers_to_transform,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
if not hasattr(model, 'get_input_embeddings'):
|
||||
def get_input_embeddings(self):
|
||||
@@ -268,10 +270,6 @@ def train():
|
||||
model, use_gradient_checkpointing=training_args.gradient_checkpointing
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.base_model.resampler.requires_grad_(True)
|
||||
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
|
||||
if training_args.tune_vision:
|
||||
model.base_model.vpm.requires_grad_(True)
|
||||
if training_args.gradient_checkpointing:
|
||||
model.enable_input_require_grads()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user