update finetuning code

This commit is contained in:
qianyu chen
2024-05-08 09:51:34 +08:00
committed by GitHub
parent 9f345c4020
commit f6cbd4fb25
9 changed files with 0 additions and 37 deletions

View File

@@ -1,8 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2024 AI, ZHIHU Inc. (zhihu.com)
#
# @author: chenqianyu <cqy1195@zhihu.com@zhihu.com>
# @date: 2024/5/02
#

View File

@@ -1,300 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2024 AI, ZHIHU Inc. (zhihu.com)
#
# @author: wangchongyi <wangchongyi@zhihu.com>
# @author: chenqianyu <cqy1195@zhihu.com>
# @date: 2024/5/06
#
import os
import math
import json
import copy
import logging
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from typing import Dict, Optional, List
from PIL import Image
from dataclasses import dataclass, field
from transformers import AutoTokenizer, AutoProcessor
from torch.utils.data import Dataset
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, transform, tokenizer, slice_config):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
self.tokenizer = tokenizer
self.transform = transform
self.slice_config = slice_config
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = Image.open(self.raw_data[i]["image"]).convert("RGB")
ret = preprocess(image, self.raw_data[i]["conversations"], self.tokenizer, self.transform, slice_config=self.slice_config)
ret = dict(
input_ids=ret["input_ids"],
labels=ret["target"],
attention_mask=ret["input_ids"].ne(self.tokenizer.pad_token_id),
pixel_values=ret["pixel_values"],
image_bound=ret["image_bound"],
)
return ret
def data_collator(examples, padding_value=0):
input_ids = pad_sequence([example["input_ids"] for example in examples], batch_first=True, padding_value=padding_value)
targets = pad_sequence([example["labels"] for example in examples], batch_first=True, padding_value=padding_value)
attention_mask = pad_sequence([example["attention_mask"] for example in examples], batch_first=True, padding_value=padding_value)
pixel_values = [example["pixel_values"] for example in examples]
image_bound = [example["image_bound"] for example in examples]
return {"input_ids": input_ids, "labels":targets, "attention_mask": attention_mask, "image_bound": image_bound, "pixel_values": pixel_values}
def conversation_to_ids(conversation, tokenizer):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
{'role': 'assistant', 'content': 'This is a cat.'}]
"""
raw_msg = ''
input_ids = []
context = []
for idx, msg in enumerate(conversation):
role = msg['role']
message = msg['content']
assert role in ['user', 'assistant']
if role == 'user':
prefix = '<用户>'
else:
prefix = '<AI>'
# append eos
if idx == len(conversation) - 1:
message = message + tokenizer.eos_token
prefix_ids = tokenizer.encode(prefix)[1:] # remove bos
message_ids = tokenizer.encode(message)[1:]
input_ids.append(prefix_ids)
input_ids.append(message_ids)
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
if role == 'assistant':
context.append(np.zeros((len(message_ids),), dtype=np.int8))
else:
context.append(np.ones((len(message_ids),), dtype=np.int8))
raw_msg += (prefix + message)
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
for i in range(1, len(ids)):
if context[i] == 0:
target[i - 1] = ids[i]
if context[i] == 1 and context[i - 1] == 0:
target[i - 1] = tokenizer.eos_id
# build image bound
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print('image start token != image end tokens')
if len(image_start_tokens)>0:
image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
else:
image_bound = []
return {
'input_ids': ids,
'target': target,
'image_bound': image_bound,
'raw_msg': raw_msg,
}
def preprocess(image, conversation, tokenizer, transform, query_nums=64, slice_config=None):
"""
single image preprocess, the image will be placed at the top of the conversation
"""
conversation = copy.deepcopy(conversation)
assert len(conversation) > 1, "conversation length must large than 2"
assert conversation[0]['role'] == 'user', "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
assert 'patch_size' in slice_config
assert 'max_slice_nums' in slice_config
assert 'scale_resolution' in slice_config
default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
if slice_config:
images = []
source_image, patches, best_grid = slice_image(
image, slice_config['max_slice_nums'], slice_config['scale_resolution'], slice_config['patch_size']
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums
)
images = [transform(i) for i in images]
else:
images = [transform(image)]
image_placeholder = default_image_placeholder
if '<image>' in conversation[0]['content']:
conversation[0]['content'] = conversation[0]['content'].replace('<image>', image_placeholder)
else:
conversation[0]['content'] = image_placeholder + '\n' + conversation[0]['content']
input_dict = conversation_to_ids(conversation, tokenizer)
input_dict['pixel_values'] = images
return input_dict
def slice_image(
image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
):
original_size = image.size
original_width, original_height = original_size
log_ratio = math.log(original_width / original_height)
ratio = original_width * original_height / (scale_resolution * scale_resolution)
multiple = min(math.ceil(ratio), max_slice_nums)
source_image = None
best_grid = None
patches = []
if multiple <= 1 or never_split:
# dont need to slice, upsample
best_size = find_best_resize(
original_size, scale_resolution, patch_size, allow_upscale=True
)
source_image = image.resize(best_size, Image.Resampling.BICUBIC)
else:
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i == 1 or i > max_slice_nums:
continue
candidate_split_grids_nums.append(i)
# source image, down-sampling and ensure divided by patch_size
best_resize = find_best_resize(original_size, scale_resolution, patch_size)
source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
candidate_grids = []
# find best grid
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error
refine_size = get_refine_size(
original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
)
refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
patches = split_to_patches(refine_image, best_grid)
return source_image, patches, best_grid
def ensure_divide(length, patch_size):
return max(round(length / patch_size) * patch_size, patch_size)
def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
r = width / height
height = int(scale_resolution / math.sqrt(r))
width = int(height * r)
best_width = ensure_divide(width, patch_size)
best_height = ensure_divide(height, patch_size)
return (best_width, best_height)
def get_refine_size(
original_size, grid, scale_resolution, patch_size, allow_upscale=False
):
width, height = original_size
grid_x, grid_y = grid
refine_width = ensure_divide(width, grid_x)
refine_height = ensure_divide(height, grid_y)
grid_width = refine_width / grid_x
grid_height = refine_height / grid_y
best_grid_size = find_best_resize(
(grid_width, grid_height),
scale_resolution,
patch_size,
allow_upscale=allow_upscale,
)
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
return refine_size
def split_to_patches(image, grid):
patches = []
width, height = image.size
grid_x = int(width / grid[0])
grid_y = int(height / grid[1])
for i in range(0, height, grid_y):
images = []
for j in range(0, width, grid_x):
box = (j, i, j + grid_x, i + grid_y)
patch = image.crop(box)
images.append(patch)
patches.append(images)
return patches
def get_grid_placeholder(tokenizer, grid, query_num):
image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
)
cols = grid[0]
rows = grid[1]
slices = []
for i in range(rows):
lines = []
for j in range(cols):
lines.append(image_placeholder)
slices.append("".join(lines))
slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end
return slice_placeholder

View File

@@ -1,54 +0,0 @@
{
"fp16": {
"enabled": false,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -1,61 +0,0 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -1,133 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2024 AI, ZHIHU Inc. (zhihu.com)
#
# @author: chenqianyu <cqy1195@zhihu.com@zhihu.com>
# @date: 2024/5/03
#
import os
import glob
import json
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, List
import torch
from torch.utils.data import Dataset
import transformers
from trainer import CPMTrainer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed import zero
from dataset import data_collator, SupervisedDataset
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from accelerate.utils import DistributedType
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-V-2")
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
eval_data_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."}
)
lazy_preprocess: bool = False
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=2048,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
def rank0_print(*args):
if local_rank == 0:
print(*args)
def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer, data_args, transform, data_collator=None, slice_config=None,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = SupervisedDataset
rank0_print("Loading data...")
train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(train_json, transform, tokenizer, slice_config=slice_config)
if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(eval_json, transform, tokenizer, slice_config=slice_config)
else:
eval_dataset = None
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)
local_rank = 0
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
(
model_args,
data_args,
training_args,
) = parser.parse_args_into_dataclasses()
if getattr(training_args, 'deepspeed', None):
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
local_rank = training_args.local_rank
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, torch_dtype=compute_dtype, device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
#Load data
data_module = make_supervised_data_module(
tokenizer=tokenizer, data_args=data_args, transform=model.transform, data_collator=data_collator, slice_config=model.config.__dict__,
)
trainer = CPMTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module,
)
trainer.train()
trainer.save_state()
if __name__ == "__main__":
train()

View File

@@ -1,53 +0,0 @@
#!/bin/bash
GPUS_PER_NODE=8
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001
MODEL="path/to/minicpmv2"
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path/to/trainging_data"
EVAL_DATA="path/to/test_data"
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
torchrun $DISTRIBUTED_ARGS finetune.py \
--model_name_or_path $MODEL \
--data_path $DATA \
--eval_data_path $EVAL_DATA \
--remove_unused_columns false \
--label_names "labels" \
--prediction_loss_only false \
--bf16 true \
--bf16_full_eval true \
--do_train \
--do_eval \
--max_steps 80000 \
--eval_steps 200 \
--output_dir output/output_minicpmv2 \
--logging_dir output/output_minicpmv2 \
--logging_strategy "steps" \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 5e-7 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--gradient_checkpointing True \
--deepspeed ds_config_zero2.json \
--report_to "tensorboard" # wandb

View File

@@ -1,69 +0,0 @@
# Minicpm-V2 Finetuning
<div align="center">
[English](README.md)
</div>
We offer the official scripts for easy finetuning of the pretrained minicpm-v2 model on downstream tasks. Our finetune scripts use DeepSpeed by default.
### Data preparation
To prepare your finetuning data, you should (1) formulate each sample as a dictionary consisting of an id, an image path list with an image (optional, not required for pure-text example), and a list of conversations, and (2) save data samples in JSON files.
For the vision-language example with image, you are required to define placeholder(s) <ImageHere> to define the position to insert the image embeddings.
<details>
<summary>
<b>vision-language example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"image": 'path/to/image_0.jpg',
"conversations": [
{
'role': 'user',
'content': '<image>\nHow many desserts are on the white plate?'
},
{
'role': 'assistant',
'content': 'There are three desserts on the white plate.'
},
{
'role': 'user',
'content': 'What type of desserts are they?'
},
{
'role': 'assistant',
'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.'
},
{
'role': 'user',
'content': 'What is the setting of the image?'},
{
'role': 'assistant',
'content': 'The image is set on a table top with a plate containing the three desserts.'
},
]
},
]
```
</details>
### Full-parameter finetuning
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. To launch your training, run the following script:
```
sh finetune_ds.sh
```
#### Customizing Hyperparameters
To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset.
### LoRA finetuning
**This part is still unfinished, and we will complete it as soon as possible.**

View File

@@ -1,158 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2024 AI, ZHIHU Inc. (zhihu.com)
#
# @author: chenqianyu <cqy1195@zhihu.com@zhihu.com>
# @date: 2024/5/03
#
import torch
import torch.nn as nn
from typing import Tuple, Union, Optional, List, Dict, Any
from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled
class CPMTrainer(Trainer):
def compute_loss(
self,
model,
inputs,
return_outputs=False
):
if "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
vllm_embedding, vision_hidden_states = self.model.get_vllm_embedding(inputs)
outputs = self.model.llm(
inputs_embeds=vllm_embedding,
use_cache=False,
)
if labels is not None:
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = outputs.logits.view(-1, self.model.config.vocab_size).contiguous()
labels = labels.view(-1).long().contiguous()
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def prediction_step(
self,
model: nn.Module,
inputs:Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
with torch.no_grad():
if is_sagemaker_mp_enabled():
raw_outputs = smp_forward_only(model, inputs)
if has_labels or loss_without_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
else:
loss_mb = raw_outputs[0]
logits_mb = raw_outputs[1:]
loss = loss_mb.reduce_mean().detach().cpu()
logits = smp_nested_concat(logits_mb)
else:
loss = None
if isinstance(raw_outputs, dict):
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
else:
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
logits = outputs[1:]
else:
loss = None
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
return (loss, logits, labels)