Update to MiniCPM-Llama3-V 2.5

This commit is contained in:
yiranyyu
2024-05-20 12:44:33 +08:00
parent c0e39dbfe2
commit 2c75097411
27 changed files with 1944 additions and 985 deletions

View File

@@ -1,90 +1,115 @@
import os
import math
import json
import copy
import json
import logging
import math
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from typing import Dict, Optional, List
from PIL import Image
from dataclasses import dataclass, field
from transformers import AutoTokenizer, AutoProcessor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, transform, tokenizer, slice_config):
def __init__(
self,
raw_data,
transform,
tokenizer,
slice_config,
llm_type="minicpm",
patch_size=14,
query_nums=64,
batch_vision=False,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
self.tokenizer = tokenizer
self.transform = transform
self.slice_config = slice_config
self.llm_type = llm_type
self.patch_size = patch_size
self.query_nums=query_nums
self.batch_vision = batch_vision
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = Image.open(self.raw_data[i]["image"]).convert("RGB")
ret = preprocess(image, self.raw_data[i]["conversations"], self.tokenizer, self.transform, slice_config=self.slice_config)
ret = preprocess(
image,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
)
ret = dict(
input_ids=ret["input_ids"],
labels=ret["target"],
attention_mask=ret["input_ids"].ne(self.tokenizer.pad_token_id),
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
return ret
def data_collator(examples, padding_value=0):
input_ids = pad_sequence([example["input_ids"] for example in examples], batch_first=True, padding_value=padding_value)
targets = pad_sequence([example["labels"] for example in examples], batch_first=True, padding_value=padding_value)
attention_mask = pad_sequence([example["attention_mask"] for example in examples], batch_first=True, padding_value=padding_value)
input_ids = pad_sequence(
[example["input_ids"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
targets = pad_sequence(
[example["labels"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
attention_mask = pad_sequence(
[example["attention_mask"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
pixel_values = [example["pixel_values"] for example in examples]
image_bound = [example["image_bound"] for example in examples]
return {"input_ids": input_ids, "labels":targets, "attention_mask": attention_mask, "image_bound": image_bound, "pixel_values": pixel_values}
tgt_sizes = [example["tgt_sizes"] for example in examples]
return {
"input_ids": input_ids,
"labels": targets,
"attention_mask": attention_mask,
"image_bound": image_bound,
"tgt_sizes": tgt_sizes,
"pixel_values": pixel_values,
}
def conversation_to_ids(conversation, tokenizer):
def conversation_to_ids(conversation, tokenizer, llm_type=None):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
{'role': 'assistant', 'content': 'This is a cat.'}]
"""
raw_msg = ''
input_ids = []
context = []
for idx, msg in enumerate(conversation):
role = msg['role']
message = msg['content']
assert role in ['user', 'assistant']
if role == 'user':
prefix = '<用户>'
else:
prefix = '<AI>'
# append eos
if idx == len(conversation) - 1:
message = message + tokenizer.eos_token
prefix_ids = tokenizer.encode(prefix)[1:] # remove bos
message_ids = tokenizer.encode(message)[1:]
if llm_type == "llama3":
input_ids, context, raw_msg = conversation_to_ids_llama3(
conversation, tokenizer
)
else:
input_ids, context, raw_msg = conversation_to_ids_minicpm(
conversation, tokenizer
)
input_ids.append(prefix_ids)
input_ids.append(message_ids)
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
if role == 'assistant':
context.append(np.zeros((len(message_ids),), dtype=np.int8))
else:
context.append(np.ones((len(message_ids),), dtype=np.int8))
raw_msg += (prefix + message)
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
@@ -94,45 +119,137 @@ def conversation_to_ids(conversation, tokenizer):
if context[i] == 0:
target[i - 1] = ids[i]
if context[i] == 1 and context[i - 1] == 0:
target[i - 1] = tokenizer.eos_id
if hasattr(tokenizer, "eot_id"):
target[i - 1] = tokenizer.eot_id
else:
target[i - 1] = tokenizer.eos_id
# build image bound
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print('image start token != image end tokens')
if len(image_start_tokens)>0:
image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
print("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
)
else:
image_bound = []
return {
'input_ids': ids,
'target': target,
'image_bound': image_bound,
'raw_msg': raw_msg,
"input_ids": ids,
"target": target,
"image_bound": image_bound,
"raw_msg": raw_msg,
}
def preprocess(image, conversation, tokenizer, transform, query_nums=64, slice_config=None):
def conversation_to_ids_minicpm(conversation, tokenizer):
raw_msg = ""
input_ids = []
context = []
for idx, msg in enumerate(conversation):
role = msg["role"]
message = msg["content"]
assert role in ["user", "assistant"]
if role == "user":
prefix = "<用户>"
else:
prefix = "<AI>"
# append eos
if idx == len(conversation) - 1:
message = message + tokenizer.eos_token
prefix_ids = tokenizer.encode(prefix)[1:] # remove bos
message_ids = tokenizer.encode(message)[1:]
input_ids.append(prefix_ids)
input_ids.append(message_ids)
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
if role == "assistant":
context.append(np.zeros((len(message_ids),), dtype=np.int8))
else:
context.append(np.ones((len(message_ids),), dtype=np.int8))
raw_msg += prefix + message
return input_ids, context, raw_msg
def conversation_to_ids_llama3(conversation, tokenizer):
raw_msg = ""
input_ids = []
context = []
raw_msg = tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=False
)
input_ids = tokenizer.apply_chat_template(
conversation, tokenize=True, add_generation_prompt=False
)
input_ids = np.array(input_ids)
start_header_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|start_header_id|>")
)[0]
assistant_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("assistant")
)[0]
end_header_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|end_header_id|>")
)[0]
eot_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|eot_id|>"))[0]
context = np.ones_like(input_ids, dtype=np.int8)
for assistant_idx in assistant_idxs:
if assistant_idx in set((start_header_idxs + end_header_idxs) / 2):
st = assistant_idx + 3 # assistant<|end_header_id|>\n\n
for eot_idx in eot_idxs:
if eot_idx > st:
context[st: eot_idx + 1] = 0
break
input_ids = np.hstack(input_ids)
context = np.hstack(context)
return input_ids, context, raw_msg
def preprocess(
image,
conversation,
tokenizer,
transform,
query_nums=64,
slice_config=None,
llm_type=None,
patch_size=14,
batch_vision=False,
):
"""
single image preprocess, the image will be placed at the top of the conversation
"""
conversation = copy.deepcopy(conversation)
assert len(conversation) > 1, "conversation length must large than 2"
assert conversation[0]['role'] == 'user', "the first role must be user"
assert conversation[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
assert 'patch_size' in slice_config
assert 'max_slice_nums' in slice_config
assert 'scale_resolution' in slice_config
default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
assert "patch_size" in slice_config
assert "max_slice_nums" in slice_config
assert "scale_resolution" in slice_config
default_image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
)
if slice_config:
images = []
source_image, patches, best_grid = slice_image(
image, slice_config['max_slice_nums'], slice_config['scale_resolution'], slice_config['patch_size']
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
@@ -142,30 +259,51 @@ def preprocess(image, conversation, tokenizer, transform, query_nums=64, slice_c
images.append(patches[i][j])
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums
)
tokenizer, best_grid, query_nums)
images = [transform(i) for i in images]
else:
images = [transform(image)]
image_placeholder = default_image_placeholder
if '<image>' in conversation[0]['content']:
conversation[0]['content'] = conversation[0]['content'].replace('<image>', image_placeholder)
if "<image>" in conversation[0]["content"]:
conversation[0]["content"] = conversation[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversation[0]['content'] = image_placeholder + '\n' + conversation[0]['content']
conversation[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
if batch_vision:
tgt_sizes = []
reshape_images = []
for image in images:
H, W = image.shape[1:]
reshape_image = reshape_by_patch(image, patch_size)
reshape_images.append(reshape_image)
tgt_sizes.append([H // patch_size, W // patch_size])
if tgt_sizes:
tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32)
input_dict["pixel_values"] = reshape_images
input_dict["tgt_sizes"] = tgt_sizes
else:
input_dict["pixel_values"] = images
input_dict["tgt_sizes"] = []
input_dict = conversation_to_ids(conversation, tokenizer)
input_dict['pixel_values'] = images
return input_dict
def slice_image(
image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
):
original_size = image.size
original_width, original_height = original_size
log_ratio = math.log(original_width / original_height)
ratio = original_width * original_height / (scale_resolution * scale_resolution)
ratio = original_width * original_height / \
(scale_resolution * scale_resolution)
multiple = min(math.ceil(ratio), max_slice_nums)
source_image = None
@@ -186,7 +324,8 @@ def slice_image(
candidate_split_grids_nums.append(i)
# source image, down-sampling and ensure divided by patch_size
best_resize = find_best_resize(original_size, scale_resolution, patch_size)
best_resize = find_best_resize(
original_size, scale_resolution, patch_size)
source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
candidate_grids = []
@@ -285,6 +424,22 @@ def get_grid_placeholder(tokenizer, grid, query_num):
for j in range(cols):
lines.append(image_placeholder)
slices.append("".join(lines))
slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end
slice_placeholder = tokenizer.slice_start + \
"\n".join(slices) + tokenizer.slice_end
return slice_placeholder
def reshape_by_patch(image_tensor, patch_size):
"""
:param image_tensor: shape [3, H, W]
:param patch_size:
:return: [3, patch_size, HW/patch_size]
"""
patches = torch.nn.functional.unfold(
image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size)
)
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
patches = patches.permute(0, 1, 3, 2).reshape(
image_tensor.size(0), patch_size, -1)
return patches

View File

@@ -1,22 +1,22 @@
import os
import glob
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Dict, Optional, List
from typing import Dict, List, Optional
import torch
from torch.utils.data import Dataset
import transformers
from trainer import CPMTrainer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed import zero
from dataset import data_collator, SupervisedDataset
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from accelerate.utils import DistributedType
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from PIL import Image
from torch.utils.data import Dataset
from transformers import AutoModel, AutoTokenizer
from dataset import SupervisedDataset, data_collator
from trainer import CPMTrainer
@dataclass
class ModelArguments:
@@ -44,6 +44,8 @@ class TrainingArguments(transformers.TrainingArguments):
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
tune_vision: Optional[bool] = field(default=True)
tune_llm: Optional[bool] = field(default=True)
def rank0_print(*args):
@@ -52,7 +54,15 @@ def rank0_print(*args):
def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer, data_args, transform, data_collator=None, slice_config=None,
tokenizer: transformers.PreTrainedTokenizer,
data_args,
transform,
data_collator=None,
llm_type="minicpm",
slice_config=None,
patch_size=14,
query_nums=64,
batch_vision=False,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = SupervisedDataset
@@ -60,19 +70,57 @@ def make_supervised_data_module(
rank0_print("Loading data...")
train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(train_json, transform, tokenizer, slice_config=slice_config)
train_dataset = dataset_cls(
train_json,
transform,
tokenizer,
slice_config=slice_config,
llm_type=llm_type,
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
)
if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(eval_json, transform, tokenizer, slice_config=slice_config)
eval_dataset = dataset_cls(
eval_json,
transform,
tokenizer,
slice_config=slice_config,
llm_type=llm_type,
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
)
else:
eval_dataset = None
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)
return dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
def get_parameter_number(model):
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return {'Total': all_param, 'Trainable': trainable_params}
local_rank = 0
def train():
global local_rank
@@ -85,8 +133,8 @@ def train():
data_args,
training_args,
) = parser.parse_args_into_dataclasses()
if getattr(training_args, 'deepspeed', None):
if getattr(training_args, "deepspeed", None):
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
compute_dtype = (
@@ -99,14 +147,50 @@ def train():
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, torch_dtype=compute_dtype, device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
#Load data
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
torch_dtype=compute_dtype,
device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
)
if not training_args.tune_vision:
model.vpm.requires_grad_(False)
if not training_args.tune_llm:
model.llm.requires_grad_(False)
rank0_print(get_parameter_number(model))
llm_type = "minicpm"
if "llama3" in model.name_or_path.lower():
tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
llm_type = "llama3"
# Load data
if hasattr(model.config, "slice_config"):
slice_config = model.config.slice_config.to_dict()
else:
slice_config = model.config.to_dict()
if hasattr(model.config, "batch_vision_input"):
batch_vision = model.config.batch_vision_input
else:
batch_vision = False
data_module = make_supervised_data_module(
tokenizer=tokenizer, data_args=data_args, transform=model.transform, data_collator=data_collator, slice_config=model.config.__dict__,
tokenizer=tokenizer,
data_args=data_args,
transform=model.transform,
data_collator=data_collator,
slice_config=slice_config,
llm_type=llm_type,
patch_size=model.config.patch_size,
query_nums=model.config.query_num,
batch_vision=batch_vision,
)
trainer = CPMTrainer(
@@ -115,11 +199,10 @@ def train():
args=training_args,
**data_module,
)
trainer.train()
trainer.save_state()
if __name__ == "__main__":
train()

View File

@@ -1,18 +1,13 @@
# Minicpm-V2 Finetuning
# MiniCPM-V Finetuning
<div align="center">
[English](README.md)
</div>
We offer the official scripts for easy finetuning of the pretrained minicpm-v2 model on downstream tasks. Our finetune scripts use DeepSpeed by default.
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
### Data preparation
To prepare your finetuning data, you should (1) formulate each sample as a dictionary consisting of an id, an image path list with an image (optional, not required for pure-text example), and a list of conversations, and (2) save data samples in JSON files.
To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files.
For the vision-language example with image, you are required to define placeholder(s) <ImageHere> to define the position to insert the image embeddings.
For the vision-language example with image, you are required to provide **\<image\>** to define the position to insert the image embeddings. If you don't provide \<image\>, the image will be placed at the front of the conversation.
<details>
<summary>
@@ -57,10 +52,19 @@ For the vision-language example with image, you are required to define placehold
### Full-parameter finetuning
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. To launch your training, run the following script:
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path and DATA path in the shell scripts.
```shell
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
DATA="path/to/trainging_data" # json file
EVAL_DATA="path/to/test_data" # json file
```
To launch your training, run the following script:
```
sh finetune_ds.sh
```
#### Customizing Hyperparameters
To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset.

View File

@@ -1,23 +1,22 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from typing import Tuple, Union, Optional, List, Dict, Any
from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled
class CPMTrainer(Trainer):
def compute_loss(
self,
model,
inputs,
return_outputs=False
):
def compute_loss(self, model, inputs, return_outputs=False):
if "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
vllm_embedding, vision_hidden_states = self.model.get_vllm_embedding(inputs)
vllm_embedding, vision_hidden_states = self.model.get_vllm_embedding(
inputs)
outputs = self.model.llm(
inputs_embeds=vllm_embedding,
use_cache=False,
@@ -26,7 +25,8 @@ class CPMTrainer(Trainer):
if labels is not None:
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = outputs.logits.view(-1, self.model.config.vocab_size).contiguous()
logits = outputs.logits.view(-1,
self.model.config.vocab_size).contiguous()
labels = labels.view(-1).long().contiguous()
# Enable model parallelism
labels = labels.to(logits.device)
@@ -35,19 +35,20 @@ class CPMTrainer(Trainer):
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
f"{','.join(outputs.keys())}. For reference, the inputs it received are {
','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def prediction_step(
self,
model: nn.Module,
inputs:Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
@@ -72,25 +73,34 @@ class CPMTrainer(Trainer):
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
has_labels = (
False
if len(self.label_names) == 0
else all(inputs.get(k) is not None for k in self.label_names)
)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
loss_without_labels = (
True if len(self.label_names) == 0 and return_loss else False
)
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
ignore_keys = getattr(
self.model.config, "keys_to_ignore_at_inference", []
)
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
labels = nested_detach(tuple(inputs.get(name)
for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
@@ -102,7 +112,11 @@ class CPMTrainer(Trainer):
if has_labels or loss_without_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
logits_mb = tuple(
v
for k, v in raw_outputs.items()
if k not in ignore_keys + ["loss"]
)
else:
loss_mb = raw_outputs[0]
logits_mb = raw_outputs[1:]
@@ -112,18 +126,26 @@ class CPMTrainer(Trainer):
else:
loss = None
if isinstance(raw_outputs, dict):
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
logits_mb = tuple(
v for k, v in raw_outputs.items() if k not in ignore_keys
)
else:
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss, outputs = self.compute_loss(
model, inputs, return_outputs=True
)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
logits = tuple(
v
for k, v in outputs.items()
if k not in ignore_keys + ["loss"]
)
else:
logits = outputs[1:]
else:
@@ -131,7 +153,9 @@ class CPMTrainer(Trainer):
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
logits = tuple(
v for k, v in outputs.items() if k not in ignore_keys
)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
@@ -146,5 +170,3 @@ class CPMTrainer(Trainer):
logits = logits[0]
return (loss, logits, labels)