Update to MiniCPM-Llama3-V 2.5

This commit is contained in:
yiranyyu
2024-05-20 12:44:33 +08:00
parent c0e39dbfe2
commit 2c75097411
27 changed files with 1944 additions and 985 deletions

View File

@@ -1,23 +1,22 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from typing import Tuple, Union, Optional, List, Dict, Any
from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled
class CPMTrainer(Trainer):
def compute_loss(
self,
model,
inputs,
return_outputs=False
):
def compute_loss(self, model, inputs, return_outputs=False):
if "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
vllm_embedding, vision_hidden_states = self.model.get_vllm_embedding(inputs)
vllm_embedding, vision_hidden_states = self.model.get_vllm_embedding(
inputs)
outputs = self.model.llm(
inputs_embeds=vllm_embedding,
use_cache=False,
@@ -26,7 +25,8 @@ class CPMTrainer(Trainer):
if labels is not None:
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = outputs.logits.view(-1, self.model.config.vocab_size).contiguous()
logits = outputs.logits.view(-1,
self.model.config.vocab_size).contiguous()
labels = labels.view(-1).long().contiguous()
# Enable model parallelism
labels = labels.to(logits.device)
@@ -35,19 +35,20 @@ class CPMTrainer(Trainer):
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
f"{','.join(outputs.keys())}. For reference, the inputs it received are {
','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def prediction_step(
self,
model: nn.Module,
inputs:Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
@@ -72,25 +73,34 @@ class CPMTrainer(Trainer):
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
has_labels = (
False
if len(self.label_names) == 0
else all(inputs.get(k) is not None for k in self.label_names)
)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
loss_without_labels = (
True if len(self.label_names) == 0 and return_loss else False
)
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
ignore_keys = getattr(
self.model.config, "keys_to_ignore_at_inference", []
)
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
labels = nested_detach(tuple(inputs.get(name)
for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
@@ -102,7 +112,11 @@ class CPMTrainer(Trainer):
if has_labels or loss_without_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
logits_mb = tuple(
v
for k, v in raw_outputs.items()
if k not in ignore_keys + ["loss"]
)
else:
loss_mb = raw_outputs[0]
logits_mb = raw_outputs[1:]
@@ -112,18 +126,26 @@ class CPMTrainer(Trainer):
else:
loss = None
if isinstance(raw_outputs, dict):
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
logits_mb = tuple(
v for k, v in raw_outputs.items() if k not in ignore_keys
)
else:
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss, outputs = self.compute_loss(
model, inputs, return_outputs=True
)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
logits = tuple(
v
for k, v in outputs.items()
if k not in ignore_keys + ["loss"]
)
else:
logits = outputs[1:]
else:
@@ -131,7 +153,9 @@ class CPMTrainer(Trainer):
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
logits = tuple(
v for k, v in outputs.items() if k not in ignore_keys
)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
@@ -146,5 +170,3 @@ class CPMTrainer(Trainer):
logits = logits[0]
return (loss, logits, labels)