mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 09:49:20 +08:00
250 lines
10 KiB
Python
250 lines
10 KiB
Python
|
|
import torch
|
|
import torch.nn as nn
|
|
import deepspeed
|
|
from transformers import Trainer
|
|
from transformers.trainer_pt_utils import nested_detach
|
|
from transformers.utils import is_sagemaker_mp_enabled
|
|
from transformers.trainer import *
|
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
class CPMTrainer(Trainer):
|
|
def compute_loss(self, model, inputs, return_outputs=False):
|
|
if "labels" in inputs:
|
|
labels = inputs.pop("labels")
|
|
else:
|
|
labels = None
|
|
|
|
if not self.args.use_lora:
|
|
outputs = self.model(data = inputs, use_cache=False)
|
|
else:
|
|
with self.model._enable_peft_forward_hooks(**inputs):
|
|
outputs = self.model.base_model(data = inputs, use_cache=False)
|
|
|
|
if labels is not None:
|
|
# Flatten the tokens
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
logits = outputs.logits.view(-1,
|
|
self.model.config.vocab_size).contiguous()
|
|
labels = labels.view(-1).long().contiguous()
|
|
# Enable model parallelism
|
|
labels = labels.to(logits.device)
|
|
loss = loss_fct(logits, labels)
|
|
else:
|
|
if isinstance(outputs, dict) and "loss" not in outputs:
|
|
raise ValueError(
|
|
"The model did not return a loss from the inputs, only the following keys: "
|
|
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
|
)
|
|
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
|
|
|
return (loss, outputs) if return_outputs else loss
|
|
|
|
def prediction_step(
|
|
self,
|
|
model: nn.Module,
|
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
|
prediction_loss_only: bool,
|
|
ignore_keys: Optional[List[str]] = None,
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
"""
|
|
Perform an evaluation step on `model` using `inputs`.
|
|
|
|
Subclass and override to inject custom behavior.
|
|
|
|
Args:
|
|
model (`nn.Module`):
|
|
The model to evaluate.
|
|
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
|
The inputs and targets of the model.
|
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
|
argument `labels`. Check your model's documentation for all accepted arguments.
|
|
prediction_loss_only (`bool`):
|
|
Whether or not to return the loss only.
|
|
ignore_keys (`List[str]`, *optional*):
|
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
|
gathering predictions.
|
|
|
|
Return:
|
|
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
|
|
logits and labels (each being optional).
|
|
"""
|
|
has_labels = (
|
|
False
|
|
if len(self.label_names) == 0
|
|
else all(inputs.get(k) is not None for k in self.label_names)
|
|
)
|
|
# For CLIP-like models capable of returning loss values.
|
|
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
|
|
# is `True` in `model.forward`.
|
|
return_loss = inputs.get("return_loss", None)
|
|
if return_loss is None:
|
|
return_loss = self.can_return_loss
|
|
loss_without_labels = (
|
|
True if len(self.label_names) == 0 and return_loss else False
|
|
)
|
|
|
|
inputs = self._prepare_inputs(inputs)
|
|
if ignore_keys is None:
|
|
if hasattr(self.model, "config"):
|
|
ignore_keys = getattr(
|
|
self.model.config, "keys_to_ignore_at_inference", []
|
|
)
|
|
else:
|
|
ignore_keys = []
|
|
|
|
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
|
if has_labels or loss_without_labels:
|
|
labels = nested_detach(tuple(inputs.get(name)
|
|
for name in self.label_names))
|
|
if len(labels) == 1:
|
|
labels = labels[0]
|
|
else:
|
|
labels = None
|
|
|
|
with torch.no_grad():
|
|
if is_sagemaker_mp_enabled():
|
|
raw_outputs = smp_forward_only(model, inputs)
|
|
if has_labels or loss_without_labels:
|
|
if isinstance(raw_outputs, dict):
|
|
loss_mb = raw_outputs["loss"]
|
|
logits_mb = tuple(
|
|
v
|
|
for k, v in raw_outputs.items()
|
|
if k not in ignore_keys + ["loss"]
|
|
)
|
|
else:
|
|
loss_mb = raw_outputs[0]
|
|
logits_mb = raw_outputs[1:]
|
|
|
|
loss = loss_mb.reduce_mean().detach().cpu()
|
|
logits = smp_nested_concat(logits_mb)
|
|
else:
|
|
loss = None
|
|
if isinstance(raw_outputs, dict):
|
|
logits_mb = tuple(
|
|
v for k, v in raw_outputs.items() if k not in ignore_keys
|
|
)
|
|
else:
|
|
logits_mb = raw_outputs
|
|
logits = smp_nested_concat(logits_mb)
|
|
else:
|
|
if has_labels or loss_without_labels:
|
|
with self.compute_loss_context_manager():
|
|
loss, outputs = self.compute_loss(
|
|
model, inputs, return_outputs=True
|
|
)
|
|
loss = loss.mean().detach()
|
|
|
|
if isinstance(outputs, dict):
|
|
logits = tuple(
|
|
v
|
|
for k, v in outputs.items()
|
|
if k not in ignore_keys + ["loss"]
|
|
)
|
|
else:
|
|
logits = outputs[1:]
|
|
else:
|
|
loss = None
|
|
with self.compute_loss_context_manager():
|
|
outputs = model(**inputs)
|
|
if isinstance(outputs, dict):
|
|
logits = tuple(
|
|
v for k, v in outputs.items() if k not in ignore_keys
|
|
)
|
|
else:
|
|
logits = outputs
|
|
# TODO: this needs to be fixed and made cleaner later.
|
|
if self.args.past_index >= 0:
|
|
self._past = outputs[self.args.past_index - 1]
|
|
|
|
if prediction_loss_only:
|
|
return (loss, None, None)
|
|
|
|
logits = nested_detach(logits)
|
|
if len(logits) == 1:
|
|
logits = logits[0]
|
|
|
|
return (loss, logits, labels)
|
|
|
|
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None) -> torch.Tensor:
|
|
"""
|
|
Perform a training step on a batch of inputs.
|
|
|
|
Subclass and override to inject custom behavior.
|
|
|
|
Args:
|
|
model (`nn.Module`):
|
|
The model to train.
|
|
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
|
The inputs and targets of the model.
|
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
|
argument `labels`. Check your model's documentation for all accepted arguments.
|
|
|
|
Return:
|
|
`torch.Tensor`: The tensor with training loss on this batch.
|
|
"""
|
|
model.train()
|
|
inputs = self._prepare_inputs(inputs)
|
|
if is_sagemaker_mp_enabled():
|
|
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
|
return loss_mb.reduce_mean().detach().to(self.args.device)
|
|
|
|
with self.compute_loss_context_manager():
|
|
loss = self.compute_loss(model, inputs)
|
|
|
|
del inputs
|
|
torch.cuda.empty_cache()
|
|
|
|
if self.args.n_gpu > 1:
|
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
|
|
|
if self.use_apex:
|
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
self.accelerator.backward(loss)
|
|
|
|
return loss.detach() / self.args.gradient_accumulation_steps
|
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
|
# If we are executing this function, we are the process zero, so we don't check for that.
|
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
logger.info(f"Saving model checkpoint to {output_dir}")
|
|
|
|
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
|
# Save a trained model and configuration using `save_pretrained()`.
|
|
# They can then be reloaded using `from_pretrained()`
|
|
if not isinstance(self.model, supported_classes):
|
|
if state_dict is None:
|
|
state_dict = self.model.state_dict()
|
|
|
|
if isinstance(unwrap_model(self.model), supported_classes):
|
|
unwrap_model(self.model).save_pretrained(
|
|
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
|
)
|
|
else:
|
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
|
if self.args.save_safetensors:
|
|
safetensors.torch.save_file(
|
|
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
|
|
)
|
|
else:
|
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
|
else:
|
|
|
|
self.model.save_pretrained(
|
|
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
|
)
|
|
|
|
if self.tokenizer is not None:
|
|
self.tokenizer.save_pretrained(output_dir)
|
|
|
|
# Good practice: save your training arguments together with the trained model
|
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|