mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
update lora finetune inference bug (#224)
This commit is contained in:
@@ -20,12 +20,14 @@ class CPMTrainer(Trainer):
|
||||
if not self.args.use_lora:
|
||||
outputs = self.model(data = inputs, use_cache=False)
|
||||
else:
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
with self.model._enable_peft_forward_hooks(**inputs):
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
else:
|
||||
if not self.args.use_lora:
|
||||
outputs = self.model(data = inputs, use_cache=False)
|
||||
else:
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
with self.model._enable_peft_forward_hooks(**inputs):
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
|
||||
if labels is not None:
|
||||
# Flatten the tokens
|
||||
@@ -174,6 +176,7 @@ class CPMTrainer(Trainer):
|
||||
logits = logits[0]
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
@@ -219,5 +222,50 @@ class CPMTrainer(Trainer):
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return loss.detach() / self.args.gradient_accumulation_steps
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, supported_classes):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if isinstance(unwrap_model(self.model), supported_classes):
|
||||
unwrap_model(self.model).save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
if self.args.save_safetensors:
|
||||
safetensors.torch.save_file(
|
||||
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
|
||||
)
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
if self.args.use_lora:
|
||||
from collections import OrderedDict
|
||||
state_dict_vision = OrderedDict()
|
||||
for key, values in state_dict.items():
|
||||
if 'vpm' in key or 'resampler' in key or 'embed_tokens' in key:
|
||||
state_dict_vision[key] = values
|
||||
self.model.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
torch.save(state_dict_vision, f"{output_dir}/vpm_resampler_embedtokens.pt", )
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
Reference in New Issue
Block a user