mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
upload fintune script
This commit is contained in:
158
finetune/trainer.py
Normal file
158
finetune/trainer.py
Normal file
@@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: chenqianyu <cqy1195@zhihu.com@zhihu.com>
|
||||
# @date: 2024/5/03
|
||||
#
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Union, Optional, List, Dict, Any
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_pt_utils import nested_detach
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
class CPMTrainer(Trainer):
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False
|
||||
):
|
||||
if "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
|
||||
vllm_embedding, vision_hidden_states = self.model.get_vllm_embedding(inputs)
|
||||
|
||||
outputs = self.model.llm(
|
||||
inputs_embeds=vllm_embedding,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if labels is not None:
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
logits = outputs.logits.view(-1, self.model.config.vocab_size).contiguous()
|
||||
labels = labels.view(-1).long().contiguous()
|
||||
# Enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits, labels)
|
||||
else:
|
||||
if isinstance(outputs, dict) and "loss" not in outputs:
|
||||
raise ValueError(
|
||||
"The model did not return a loss from the inputs, only the following keys: "
|
||||
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
||||
)
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs:Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on `model` using `inputs`.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument `labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (`bool`):
|
||||
Whether or not to return the loss only.
|
||||
ignore_keys (`List[str]`, *optional*):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
|
||||
logits and labels (each being optional).
|
||||
"""
|
||||
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
|
||||
# For CLIP-like models capable of returning loss values.
|
||||
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
|
||||
# is `True` in `model.forward`.
|
||||
return_loss = inputs.get("return_loss", None)
|
||||
if return_loss is None:
|
||||
return_loss = self.can_return_loss
|
||||
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
|
||||
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
if ignore_keys is None:
|
||||
if hasattr(self.model, "config"):
|
||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
||||
if has_labels or loss_without_labels:
|
||||
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||
if len(labels) == 1:
|
||||
labels = labels[0]
|
||||
else:
|
||||
labels = None
|
||||
|
||||
with torch.no_grad():
|
||||
if is_sagemaker_mp_enabled():
|
||||
raw_outputs = smp_forward_only(model, inputs)
|
||||
if has_labels or loss_without_labels:
|
||||
if isinstance(raw_outputs, dict):
|
||||
loss_mb = raw_outputs["loss"]
|
||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
|
||||
else:
|
||||
loss_mb = raw_outputs[0]
|
||||
logits_mb = raw_outputs[1:]
|
||||
|
||||
loss = loss_mb.reduce_mean().detach().cpu()
|
||||
logits = smp_nested_concat(logits_mb)
|
||||
else:
|
||||
loss = None
|
||||
if isinstance(raw_outputs, dict):
|
||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
|
||||
else:
|
||||
logits_mb = raw_outputs
|
||||
logits = smp_nested_concat(logits_mb)
|
||||
else:
|
||||
if has_labels or loss_without_labels:
|
||||
with self.compute_loss_context_manager():
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
loss = loss.mean().detach()
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
||||
else:
|
||||
logits = outputs[1:]
|
||||
else:
|
||||
loss = None
|
||||
with self.compute_loss_context_manager():
|
||||
outputs = model(**inputs)
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
||||
else:
|
||||
logits = outputs
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index - 1]
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
logits = nested_detach(logits)
|
||||
if len(logits) == 1:
|
||||
logits = logits[0]
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user