diff --git a/finetune/__init__.py b/finetune/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/finetune/dataset.py b/finetune/dataset.py new file mode 100644 index 0000000..b93035a --- /dev/null +++ b/finetune/dataset.py @@ -0,0 +1,290 @@ +import os +import math +import json +import copy +import logging + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence +from typing import Dict, Optional, List +from PIL import Image + + +from dataclasses import dataclass, field +from transformers import AutoTokenizer, AutoProcessor +from torch.utils.data import Dataset + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + def __init__(self, raw_data, transform, tokenizer, slice_config): + super(SupervisedDataset, self).__init__() + self.raw_data = raw_data + self.tokenizer = tokenizer + self.transform = transform + self.slice_config = slice_config + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + image = Image.open(self.raw_data[i]["image"]).convert("RGB") + ret = preprocess(image, self.raw_data[i]["conversations"], self.tokenizer, self.transform, slice_config=self.slice_config) + ret = dict( + input_ids=ret["input_ids"], + labels=ret["target"], + attention_mask=ret["input_ids"].ne(self.tokenizer.pad_token_id), + pixel_values=ret["pixel_values"], + image_bound=ret["image_bound"], + ) + + return ret + + +def data_collator(examples, padding_value=0): + input_ids = pad_sequence([example["input_ids"] for example in examples], batch_first=True, padding_value=padding_value) + targets = pad_sequence([example["labels"] for example in examples], batch_first=True, padding_value=padding_value) + attention_mask = pad_sequence([example["attention_mask"] for example in examples], batch_first=True, padding_value=padding_value) + pixel_values = [example["pixel_values"] for example in examples] + image_bound = [example["image_bound"] for example in examples] + return {"input_ids": input_ids, "labels":targets, "attention_mask": attention_mask, "image_bound": image_bound, "pixel_values": pixel_values} + + +def conversation_to_ids(conversation, tokenizer): + """ + for single image multi-turn conversation + conversation: [{'role': 'user', 'content': 'Describe this image'}, + {'role': 'assistant', 'content': 'This is a cat.'}] + """ + raw_msg = '' + input_ids = [] + context = [] + for idx, msg in enumerate(conversation): + role = msg['role'] + message = msg['content'] + assert role in ['user', 'assistant'] + if role == 'user': + prefix = '<用户>' + else: + prefix = '' + # append eos + if idx == len(conversation) - 1: + message = message + tokenizer.eos_token + prefix_ids = tokenizer.encode(prefix)[1:] # remove bos + message_ids = tokenizer.encode(message)[1:] + + input_ids.append(prefix_ids) + input_ids.append(message_ids) + + context.append(np.ones((len(prefix_ids),), dtype=np.int8)) + if role == 'assistant': + context.append(np.zeros((len(message_ids),), dtype=np.int8)) + else: + context.append(np.ones((len(message_ids),), dtype=np.int8)) + + raw_msg += (prefix + message) + + ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32)) + context = torch.from_numpy(np.hstack(context, dtype=np.int8)) + + # build target + target = torch.full_like(ids, -100, dtype=torch.int32) + for i in range(1, len(ids)): + if context[i] == 0: + target[i - 1] = ids[i] + if context[i] == 1 and context[i - 1] == 0: + target[i - 1] = tokenizer.eos_id + + # build image bound + image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0] + if len(image_start_tokens) != len(image_end_tokens): + print('image start token != image end tokens') + if len(image_start_tokens)>0: + image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]) + else: + image_bound = [] + + return { + 'input_ids': ids, + 'target': target, + 'image_bound': image_bound, + 'raw_msg': raw_msg, + } + + +def preprocess(image, conversation, tokenizer, transform, query_nums=64, slice_config=None): + """ + single image preprocess, the image will be placed at the top of the conversation + """ + conversation = copy.deepcopy(conversation) + assert len(conversation) > 1, "conversation length must large than 2" + assert conversation[0]['role'] == 'user', "the first role must be user" + + if slice_config is not None: + assert isinstance(slice_config, Dict) + assert 'patch_size' in slice_config + assert 'max_slice_nums' in slice_config + assert 'scale_resolution' in slice_config + default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + if slice_config: + images = [] + source_image, patches, best_grid = slice_image( + image, slice_config['max_slice_nums'], slice_config['scale_resolution'], slice_config['patch_size'] + ) + images.append(source_image) + image_placeholder = default_image_placeholder + if len(patches) > 0: + for i in range(len(patches)): + for j in range(len(patches[0])): + images.append(patches[i][j]) + + image_placeholder += get_grid_placeholder( + tokenizer, best_grid, query_nums + ) + images = [transform(i) for i in images] + else: + images = [transform(image)] + image_placeholder = default_image_placeholder + if '' in conversation[0]['content']: + conversation[0]['content'] = conversation[0]['content'].replace('', image_placeholder) + else: + conversation[0]['content'] = image_placeholder + '\n' + conversation[0]['content'] + + input_dict = conversation_to_ids(conversation, tokenizer) + input_dict['pixel_values'] = images + return input_dict + + + +def slice_image( + image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False +): + original_size = image.size + original_width, original_height = original_size + log_ratio = math.log(original_width / original_height) + ratio = original_width * original_height / (scale_resolution * scale_resolution) + multiple = min(math.ceil(ratio), max_slice_nums) + + source_image = None + best_grid = None + patches = [] + + if multiple <= 1 or never_split: + # dont need to slice, upsample + best_size = find_best_resize( + original_size, scale_resolution, patch_size, allow_upscale=True + ) + source_image = image.resize(best_size, Image.Resampling.BICUBIC) + else: + candidate_split_grids_nums = [] + for i in [multiple - 1, multiple, multiple + 1]: + if i == 1 or i > max_slice_nums: + continue + candidate_split_grids_nums.append(i) + + # source image, down-sampling and ensure divided by patch_size + best_resize = find_best_resize(original_size, scale_resolution, patch_size) + source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) + candidate_grids = [] + + # find best grid + for split_grids_nums in candidate_split_grids_nums: + m = 1 + while m <= split_grids_nums: + if split_grids_nums % m == 0: + candidate_grids.append([m, split_grids_nums // m]) + m += 1 + + best_grid = [1, 1] + min_error = float("inf") + for grid in candidate_grids: + error = abs(log_ratio - math.log(grid[0] / grid[1])) + if error < min_error: + best_grid = grid + min_error = error + + refine_size = get_refine_size( + original_size, best_grid, scale_resolution, patch_size, allow_upscale=True + ) + + refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) + patches = split_to_patches(refine_image, best_grid) + + return source_image, patches, best_grid + + +def ensure_divide(length, patch_size): + return max(round(length / patch_size) * patch_size, patch_size) + + +def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False): + width, height = original_size + if (width * height > scale_resolution * scale_resolution) or allow_upscale: + r = width / height + height = int(scale_resolution / math.sqrt(r)) + width = int(height * r) + best_width = ensure_divide(width, patch_size) + best_height = ensure_divide(height, patch_size) + return (best_width, best_height) + + +def get_refine_size( + original_size, grid, scale_resolution, patch_size, allow_upscale=False +): + width, height = original_size + grid_x, grid_y = grid + + refine_width = ensure_divide(width, grid_x) + refine_height = ensure_divide(height, grid_y) + + grid_width = refine_width / grid_x + grid_height = refine_height / grid_y + + best_grid_size = find_best_resize( + (grid_width, grid_height), + scale_resolution, + patch_size, + allow_upscale=allow_upscale, + ) + + refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) + + return refine_size + + +def split_to_patches(image, grid): + patches = [] + width, height = image.size + grid_x = int(width / grid[0]) + grid_y = int(height / grid[1]) + + for i in range(0, height, grid_y): + images = [] + for j in range(0, width, grid_x): + box = (j, i, j + grid_x, i + grid_y) + patch = image.crop(box) + images.append(patch) + patches.append(images) + + return patches + + +def get_grid_placeholder(tokenizer, grid, query_num): + image_placeholder = ( + tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end + ) + + cols = grid[0] + rows = grid[1] + slices = [] + for i in range(rows): + lines = [] + for j in range(cols): + lines.append(image_placeholder) + slices.append("".join(lines)) + slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end + return slice_placeholder + diff --git a/finetune/ds_config_zero2.json b/finetune/ds_config_zero2.json new file mode 100644 index 0000000..d0cbb97 --- /dev/null +++ b/finetune/ds_config_zero2.json @@ -0,0 +1,54 @@ +{ + "fp16": { + "enabled": false, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + "bf16": { + "enabled": true + }, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/finetune/ds_config_zero3.json b/finetune/ds_config_zero3.json new file mode 100644 index 0000000..5c3cd9c --- /dev/null +++ b/finetune/ds_config_zero3.json @@ -0,0 +1,61 @@ + +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "offload_param": { + "device": "none", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} + diff --git a/finetune/finetune.py b/finetune/finetune.py new file mode 100644 index 0000000..6a751e3 --- /dev/null +++ b/finetune/finetune.py @@ -0,0 +1,125 @@ +import os +import glob +import json +import logging +from dataclasses import dataclass, field +from typing import Dict, Optional, List +import torch +from torch.utils.data import Dataset +import transformers +from trainer import CPMTrainer +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from deepspeed import zero + +from dataset import data_collator, SupervisedDataset + + +from PIL import Image +from transformers import AutoModel, AutoTokenizer +from accelerate.utils import DistributedType + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-V-2") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + eval_data_path: str = field( + default=None, metadata={"help": "Path to the evaluation data."} + ) + lazy_preprocess: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=2048, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args, transform, data_collator=None, slice_config=None, +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + dataset_cls = SupervisedDataset + + rank0_print("Loading data...") + + train_json = json.load(open(data_args.data_path, "r")) + train_dataset = dataset_cls(train_json, transform, tokenizer, slice_config=slice_config) + + if data_args.eval_data_path: + eval_json = json.load(open(data_args.eval_data_path, "r")) + eval_dataset = dataset_cls(eval_json, transform, tokenizer, slice_config=slice_config) + else: + eval_dataset = None + + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator) + + +local_rank = 0 + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + + ( + model_args, + data_args, + training_args, + ) = parser.parse_args_into_dataclasses() + + if getattr(training_args, 'deepspeed', None): + training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED + + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + + local_rank = training_args.local_rank + + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + + model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, torch_dtype=compute_dtype, device_map=device_map) + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + + #Load data + data_module = make_supervised_data_module( + tokenizer=tokenizer, data_args=data_args, transform=model.transform, data_collator=data_collator, slice_config=model.config.__dict__, + ) + + trainer = CPMTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, + **data_module, + ) + + trainer.train() + trainer.save_state() + + +if __name__ == "__main__": + train() + diff --git a/finetune/finetune_ds.sh b/finetune/finetune_ds.sh new file mode 100644 index 0000000..3654db1 --- /dev/null +++ b/finetune/finetune_ds.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +GPUS_PER_NODE=8 +NNODES=1 +NODE_RANK=0 +MASTER_ADDR=localhost +MASTER_PORT=6001 + +MODEL="path/to/minicpmv2" +# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. +# See the section for finetuning in README for more information. +DATA="path/to/trainging_data" +EVAL_DATA="path/to/test_data" + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" +torchrun $DISTRIBUTED_ARGS finetune.py \ + --model_name_or_path $MODEL \ + --data_path $DATA \ + --eval_data_path $EVAL_DATA \ + --remove_unused_columns false \ + --label_names "labels" \ + --prediction_loss_only false \ + --bf16 true \ + --bf16_full_eval true \ + --do_train \ + --do_eval \ + --max_steps 80000 \ + --eval_steps 200 \ + --output_dir output/output_minicpmv2 \ + --logging_dir output/output_minicpmv2 \ + --logging_strategy "steps" \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "steps" \ + --save_strategy "steps" \ + --save_steps 1000 \ + --save_total_limit 10 \ + --learning_rate 5e-7 \ + --weight_decay 0.1 \ + --adam_beta2 0.95 \ + --warmup_ratio 0.01 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --gradient_checkpointing True \ + --deepspeed ds_config_zero2.json \ + --report_to "tensorboard" # wandb diff --git a/finetune/readme.md b/finetune/readme.md new file mode 100644 index 0000000..26c3c47 --- /dev/null +++ b/finetune/readme.md @@ -0,0 +1,66 @@ +# Minicpm-V2 Finetuning + +
+ +[English](README.md) + +
+ +We offer the official scripts for easy finetuning of the pretrained minicpm-v2 model on downstream tasks. Our finetune scripts use DeepSpeed by default. + +### Data preparation + +To prepare your finetuning data, you should (1) formulate each sample as a dictionary consisting of an id, an image path list with an image (optional, not required for pure-text example), and a list of conversations, and (2) save data samples in JSON files. + +For the vision-language example with image, you are required to define placeholder(s) to define the position to insert the image embeddings. + +
+ + vision-language example (vl_finetune_data.json) with 1 samples. + + +``` + [ + { + "id": "0", + "image": 'path/to/image_0.jpg', + "conversations": [ + { + 'role': 'user', + 'content': '\nHow many desserts are on the white plate?' + }, + { + 'role': 'assistant', + 'content': 'There are three desserts on the white plate.' + }, + { + 'role': 'user', + 'content': 'What type of desserts are they?' + }, + { + 'role': 'assistant', + 'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.' + }, + { + 'role': 'user', + 'content': 'What is the setting of the image?'}, + { + 'role': 'assistant', + 'content': 'The image is set on a table top with a plate containing the three desserts.' + }, + ] + }, + ] +``` + +
+ +### Full-parameter finetuning + +Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. To launch your training, run the following script: + +``` +sh finetune_ds.sh +``` +#### Customizing Hyperparameters +To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset. diff --git a/finetune/trainer.py b/finetune/trainer.py new file mode 100644 index 0000000..53ffaa6 --- /dev/null +++ b/finetune/trainer.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +from typing import Tuple, Union, Optional, List, Dict, Any +from transformers import Trainer +from transformers.trainer_pt_utils import nested_detach +from transformers.utils import is_sagemaker_mp_enabled +class CPMTrainer(Trainer): + def compute_loss( + self, + model, + inputs, + return_outputs=False + ): + if "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + + vllm_embedding, vision_hidden_states = self.model.get_vllm_embedding(inputs) + + outputs = self.model.llm( + inputs_embeds=vllm_embedding, + use_cache=False, + ) + + if labels is not None: + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = outputs.logits.view(-1, self.model.config.vocab_size).contiguous() + labels = labels.view(-1).long().contiguous() + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return (loss, outputs) if return_outputs else loss + + def prediction_step( + self, + model: nn.Module, + inputs:Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) + +