mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
Update model_minicpmv.py for latest compatibility (#174)
This commit is contained in:
@@ -47,7 +47,7 @@ class TrainingArguments(transformers.TrainingArguments):
|
||||
},
|
||||
)
|
||||
tune_vision: Optional[bool] = field(default=True)
|
||||
tune_llm: Optional[bool] = field(default=False)
|
||||
tune_llm: Optional[bool] = field(default=True)
|
||||
llm_type: str = field(default="minicpm")
|
||||
use_lora: Optional[bool] = field(default=False)
|
||||
|
||||
@@ -252,12 +252,15 @@ def train():
|
||||
layers_to_transform=lora_args.lora_layers_to_transform,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
if training_args.gradient_checkpointing:
|
||||
if not hasattr(model, 'get_input_embeddings'):
|
||||
def get_input_embeddings(self):
|
||||
return self.llm.get_input_embeddings()
|
||||
model.get_input_embeddings = MethodType(get_input_embeddings, model)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
|
||||
if training_args.tune_vision:
|
||||
model.base_model.vpm.requires_grad_(True)
|
||||
model.base_model.resampler.requires_grad_(True)
|
||||
if training_args.gradient_checkpointing:
|
||||
model.enable_input_require_grads()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user