From b8104397050ba80726a84a33180ec7194bf6a64a Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Tue, 7 May 2024 22:13:43 +0800 Subject: [PATCH 1/4] Create finetune --- omnilmm/finetune | 1 + 1 file changed, 1 insertion(+) create mode 100644 omnilmm/finetune diff --git a/omnilmm/finetune b/omnilmm/finetune new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/omnilmm/finetune @@ -0,0 +1 @@ + From fcacda295b356ffa4c935db1036643db00ad84b4 Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Tue, 7 May 2024 22:28:42 +0800 Subject: [PATCH 2/4] upload fintune script --- finetune/__init__.py | 8 + finetune/dataset.py | 300 ++++++++++++++++++++++++++++++++++ finetune/ds_config_zero2.json | 54 ++++++ finetune/ds_config_zero3.json | 61 +++++++ finetune/finetune.py | 133 +++++++++++++++ finetune/finetune_ds.sh | 53 ++++++ finetune/readme.md | 69 ++++++++ finetune/trainer.py | 158 ++++++++++++++++++ omnilmm/finetune | 1 - 9 files changed, 836 insertions(+), 1 deletion(-) create mode 100644 finetune/__init__.py create mode 100644 finetune/dataset.py create mode 100644 finetune/ds_config_zero2.json create mode 100644 finetune/ds_config_zero3.json create mode 100644 finetune/finetune.py create mode 100644 finetune/finetune_ds.sh create mode 100644 finetune/readme.md create mode 100644 finetune/trainer.py delete mode 100644 omnilmm/finetune diff --git a/finetune/__init__.py b/finetune/__init__.py new file mode 100644 index 0000000..9b79fbc --- /dev/null +++ b/finetune/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright @2024 AI, ZHIHU Inc. (zhihu.com) +# +# @author: chenqianyu +# @date: 2024/5/02 +# diff --git a/finetune/dataset.py b/finetune/dataset.py new file mode 100644 index 0000000..19592eb --- /dev/null +++ b/finetune/dataset.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright @2024 AI, ZHIHU Inc. (zhihu.com) +# +# @author: wangchongyi +# @author: chenqianyu +# @date: 2024/5/06 +# + +import os +import math +import json +import copy +import logging + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence +from typing import Dict, Optional, List +from PIL import Image + + +from dataclasses import dataclass, field +from transformers import AutoTokenizer, AutoProcessor +from torch.utils.data import Dataset + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + def __init__(self, raw_data, transform, tokenizer, slice_config): + super(SupervisedDataset, self).__init__() + self.raw_data = raw_data + self.tokenizer = tokenizer + self.transform = transform + self.slice_config = slice_config + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + image = Image.open(self.raw_data[i]["image"]).convert("RGB") + ret = preprocess(image, self.raw_data[i]["conversations"], self.tokenizer, self.transform, slice_config=self.slice_config) + ret = dict( + input_ids=ret["input_ids"], + labels=ret["target"], + attention_mask=ret["input_ids"].ne(self.tokenizer.pad_token_id), + pixel_values=ret["pixel_values"], + image_bound=ret["image_bound"], + ) + + return ret + + +def data_collator(examples, padding_value=0): + input_ids = pad_sequence([example["input_ids"] for example in examples], batch_first=True, padding_value=padding_value) + targets = pad_sequence([example["labels"] for example in examples], batch_first=True, padding_value=padding_value) + attention_mask = pad_sequence([example["attention_mask"] for example in examples], batch_first=True, padding_value=padding_value) + pixel_values = [example["pixel_values"] for example in examples] + image_bound = [example["image_bound"] for example in examples] + return {"input_ids": input_ids, "labels":targets, "attention_mask": attention_mask, "image_bound": image_bound, "pixel_values": pixel_values} + + +def conversation_to_ids(conversation, tokenizer): + """ + for single image multi-turn conversation + conversation: [{'role': 'user', 'content': 'Describe this image'}, + {'role': 'assistant', 'content': 'This is a cat.'}] + """ + raw_msg = '' + input_ids = [] + context = [] + for idx, msg in enumerate(conversation): + role = msg['role'] + message = msg['content'] + assert role in ['user', 'assistant'] + if role == 'user': + prefix = '<用户>' + else: + prefix = '' + # append eos + if idx == len(conversation) - 1: + message = message + tokenizer.eos_token + prefix_ids = tokenizer.encode(prefix)[1:] # remove bos + message_ids = tokenizer.encode(message)[1:] + + input_ids.append(prefix_ids) + input_ids.append(message_ids) + + context.append(np.ones((len(prefix_ids),), dtype=np.int8)) + if role == 'assistant': + context.append(np.zeros((len(message_ids),), dtype=np.int8)) + else: + context.append(np.ones((len(message_ids),), dtype=np.int8)) + + raw_msg += (prefix + message) + + ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32)) + context = torch.from_numpy(np.hstack(context, dtype=np.int8)) + + # build target + target = torch.full_like(ids, -100, dtype=torch.int32) + for i in range(1, len(ids)): + if context[i] == 0: + target[i - 1] = ids[i] + if context[i] == 1 and context[i - 1] == 0: + target[i - 1] = tokenizer.eos_id + + # build image bound + image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0] + if len(image_start_tokens) != len(image_end_tokens): + print('image start token != image end tokens') + if len(image_start_tokens)>0: + image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]) + else: + image_bound = [] + + return { + 'input_ids': ids, + 'target': target, + 'image_bound': image_bound, + 'raw_msg': raw_msg, + } + + +def preprocess(image, conversation, tokenizer, transform, query_nums=64, slice_config=None): + """ + single image preprocess, the image will be placed at the top of the conversation + """ + conversation = copy.deepcopy(conversation) + assert len(conversation) > 1, "conversation length must large than 2" + assert conversation[0]['role'] == 'user', "the first role must be user" + + if slice_config is not None: + assert isinstance(slice_config, Dict) + assert 'patch_size' in slice_config + assert 'max_slice_nums' in slice_config + assert 'scale_resolution' in slice_config + default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + if slice_config: + images = [] + source_image, patches, best_grid = slice_image( + image, slice_config['max_slice_nums'], slice_config['scale_resolution'], slice_config['patch_size'] + ) + images.append(source_image) + image_placeholder = default_image_placeholder + if len(patches) > 0: + for i in range(len(patches)): + for j in range(len(patches[0])): + images.append(patches[i][j]) + + image_placeholder += get_grid_placeholder( + tokenizer, best_grid, query_nums + ) + images = [transform(i) for i in images] + else: + images = [transform(image)] + image_placeholder = default_image_placeholder + if '' in conversation[0]['content']: + conversation[0]['content'] = conversation[0]['content'].replace('', image_placeholder) + else: + conversation[0]['content'] = image_placeholder + '\n' + conversation[0]['content'] + + input_dict = conversation_to_ids(conversation, tokenizer) + input_dict['pixel_values'] = images + return input_dict + + + +def slice_image( + image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False +): + original_size = image.size + original_width, original_height = original_size + log_ratio = math.log(original_width / original_height) + ratio = original_width * original_height / (scale_resolution * scale_resolution) + multiple = min(math.ceil(ratio), max_slice_nums) + + source_image = None + best_grid = None + patches = [] + + if multiple <= 1 or never_split: + # dont need to slice, upsample + best_size = find_best_resize( + original_size, scale_resolution, patch_size, allow_upscale=True + ) + source_image = image.resize(best_size, Image.Resampling.BICUBIC) + else: + candidate_split_grids_nums = [] + for i in [multiple - 1, multiple, multiple + 1]: + if i == 1 or i > max_slice_nums: + continue + candidate_split_grids_nums.append(i) + + # source image, down-sampling and ensure divided by patch_size + best_resize = find_best_resize(original_size, scale_resolution, patch_size) + source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) + candidate_grids = [] + + # find best grid + for split_grids_nums in candidate_split_grids_nums: + m = 1 + while m <= split_grids_nums: + if split_grids_nums % m == 0: + candidate_grids.append([m, split_grids_nums // m]) + m += 1 + + best_grid = [1, 1] + min_error = float("inf") + for grid in candidate_grids: + error = abs(log_ratio - math.log(grid[0] / grid[1])) + if error < min_error: + best_grid = grid + min_error = error + + refine_size = get_refine_size( + original_size, best_grid, scale_resolution, patch_size, allow_upscale=True + ) + + refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) + patches = split_to_patches(refine_image, best_grid) + + return source_image, patches, best_grid + + +def ensure_divide(length, patch_size): + return max(round(length / patch_size) * patch_size, patch_size) + + +def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False): + width, height = original_size + if (width * height > scale_resolution * scale_resolution) or allow_upscale: + r = width / height + height = int(scale_resolution / math.sqrt(r)) + width = int(height * r) + best_width = ensure_divide(width, patch_size) + best_height = ensure_divide(height, patch_size) + return (best_width, best_height) + + +def get_refine_size( + original_size, grid, scale_resolution, patch_size, allow_upscale=False +): + width, height = original_size + grid_x, grid_y = grid + + refine_width = ensure_divide(width, grid_x) + refine_height = ensure_divide(height, grid_y) + + grid_width = refine_width / grid_x + grid_height = refine_height / grid_y + + best_grid_size = find_best_resize( + (grid_width, grid_height), + scale_resolution, + patch_size, + allow_upscale=allow_upscale, + ) + + refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) + + return refine_size + + +def split_to_patches(image, grid): + patches = [] + width, height = image.size + grid_x = int(width / grid[0]) + grid_y = int(height / grid[1]) + + for i in range(0, height, grid_y): + images = [] + for j in range(0, width, grid_x): + box = (j, i, j + grid_x, i + grid_y) + patch = image.crop(box) + images.append(patch) + patches.append(images) + + return patches + + +def get_grid_placeholder(tokenizer, grid, query_num): + image_placeholder = ( + tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end + ) + + cols = grid[0] + rows = grid[1] + slices = [] + for i in range(rows): + lines = [] + for j in range(cols): + lines.append(image_placeholder) + slices.append("".join(lines)) + slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end + return slice_placeholder + diff --git a/finetune/ds_config_zero2.json b/finetune/ds_config_zero2.json new file mode 100644 index 0000000..d0cbb97 --- /dev/null +++ b/finetune/ds_config_zero2.json @@ -0,0 +1,54 @@ +{ + "fp16": { + "enabled": false, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + "bf16": { + "enabled": true + }, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/finetune/ds_config_zero3.json b/finetune/ds_config_zero3.json new file mode 100644 index 0000000..5c3cd9c --- /dev/null +++ b/finetune/ds_config_zero3.json @@ -0,0 +1,61 @@ + +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "offload_param": { + "device": "none", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} + diff --git a/finetune/finetune.py b/finetune/finetune.py new file mode 100644 index 0000000..9f2f7ec --- /dev/null +++ b/finetune/finetune.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright @2024 AI, ZHIHU Inc. (zhihu.com) +# +# @author: chenqianyu +# @date: 2024/5/03 +# +import os +import glob +import json +import logging +from dataclasses import dataclass, field +from typing import Dict, Optional, List +import torch +from torch.utils.data import Dataset +import transformers +from trainer import CPMTrainer +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from deepspeed import zero + +from dataset import data_collator, SupervisedDataset + + +from PIL import Image +from transformers import AutoModel, AutoTokenizer +from accelerate.utils import DistributedType + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-V-2") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + eval_data_path: str = field( + default=None, metadata={"help": "Path to the evaluation data."} + ) + lazy_preprocess: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=2048, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args, transform, data_collator=None, slice_config=None, +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + dataset_cls = SupervisedDataset + + rank0_print("Loading data...") + + train_json = json.load(open(data_args.data_path, "r")) + train_dataset = dataset_cls(train_json, transform, tokenizer, slice_config=slice_config) + + if data_args.eval_data_path: + eval_json = json.load(open(data_args.eval_data_path, "r")) + eval_dataset = dataset_cls(eval_json, transform, tokenizer, slice_config=slice_config) + else: + eval_dataset = None + + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator) + + +local_rank = 0 + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + + ( + model_args, + data_args, + training_args, + ) = parser.parse_args_into_dataclasses() + + if getattr(training_args, 'deepspeed', None): + training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED + + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + + local_rank = training_args.local_rank + + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + + model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, torch_dtype=compute_dtype, device_map=device_map) + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + + #Load data + data_module = make_supervised_data_module( + tokenizer=tokenizer, data_args=data_args, transform=model.transform, data_collator=data_collator, slice_config=model.config.__dict__, + ) + + trainer = CPMTrainer( + model=model, + tokenizer=tokenizer, + args=training_args, + **data_module, + ) + + trainer.train() + trainer.save_state() + + +if __name__ == "__main__": + train() + diff --git a/finetune/finetune_ds.sh b/finetune/finetune_ds.sh new file mode 100644 index 0000000..3654db1 --- /dev/null +++ b/finetune/finetune_ds.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +GPUS_PER_NODE=8 +NNODES=1 +NODE_RANK=0 +MASTER_ADDR=localhost +MASTER_PORT=6001 + +MODEL="path/to/minicpmv2" +# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. +# See the section for finetuning in README for more information. +DATA="path/to/trainging_data" +EVAL_DATA="path/to/test_data" + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" +torchrun $DISTRIBUTED_ARGS finetune.py \ + --model_name_or_path $MODEL \ + --data_path $DATA \ + --eval_data_path $EVAL_DATA \ + --remove_unused_columns false \ + --label_names "labels" \ + --prediction_loss_only false \ + --bf16 true \ + --bf16_full_eval true \ + --do_train \ + --do_eval \ + --max_steps 80000 \ + --eval_steps 200 \ + --output_dir output/output_minicpmv2 \ + --logging_dir output/output_minicpmv2 \ + --logging_strategy "steps" \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "steps" \ + --save_strategy "steps" \ + --save_steps 1000 \ + --save_total_limit 10 \ + --learning_rate 5e-7 \ + --weight_decay 0.1 \ + --adam_beta2 0.95 \ + --warmup_ratio 0.01 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --gradient_checkpointing True \ + --deepspeed ds_config_zero2.json \ + --report_to "tensorboard" # wandb diff --git a/finetune/readme.md b/finetune/readme.md new file mode 100644 index 0000000..136aa84 --- /dev/null +++ b/finetune/readme.md @@ -0,0 +1,69 @@ +# Minicpm-V2 Finetuning + +
+ +[English](README.md) + +
+ +We offer the official scripts for easy finetuning of the pretrained minicpm-v2 model on downstream tasks. Our finetune scripts use DeepSpeed by default. + +### Data preparation + +To prepare your finetuning data, you should (1) formulate each sample as a dictionary consisting of an id, an image path list with an image (optional, not required for pure-text example), and a list of conversations, and (2) save data samples in JSON files. + +For the vision-language example with image, you are required to define placeholder(s) to define the position to insert the image embeddings. + +
+ + vision-language example (vl_finetune_data.json) with 1 samples. + + +``` + [ + { + "id": "0", + "image": 'path/to/image_0.jpg', + "conversations": [ + { + 'role': 'user', + 'content': '\nHow many desserts are on the white plate?' + }, + { + 'role': 'assistant', + 'content': 'There are three desserts on the white plate.' + }, + { + 'role': 'user', + 'content': 'What type of desserts are they?' + }, + { + 'role': 'assistant', + 'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.' + }, + { + 'role': 'user', + 'content': 'What is the setting of the image?'}, + { + 'role': 'assistant', + 'content': 'The image is set on a table top with a plate containing the three desserts.' + }, + ] + }, + ] +``` + +
+ +### Full-parameter finetuning + +Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. To launch your training, run the following script: + +``` +sh finetune_ds.sh +``` +#### Customizing Hyperparameters +To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset. +### LoRA finetuning + +**This part is still unfinished, and we will complete it as soon as possible.** diff --git a/finetune/trainer.py b/finetune/trainer.py new file mode 100644 index 0000000..e8794dc --- /dev/null +++ b/finetune/trainer.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright @2024 AI, ZHIHU Inc. (zhihu.com) +# +# @author: chenqianyu +# @date: 2024/5/03 +# +import torch +import torch.nn as nn +from typing import Tuple, Union, Optional, List, Dict, Any +from transformers import Trainer +from transformers.trainer_pt_utils import nested_detach +from transformers.utils import is_sagemaker_mp_enabled +class CPMTrainer(Trainer): + def compute_loss( + self, + model, + inputs, + return_outputs=False + ): + if "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + + vllm_embedding, vision_hidden_states = self.model.get_vllm_embedding(inputs) + + outputs = self.model.llm( + inputs_embeds=vllm_embedding, + use_cache=False, + ) + + if labels is not None: + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = outputs.logits.view(-1, self.model.config.vocab_size).contiguous() + labels = labels.view(-1).long().contiguous() + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return (loss, outputs) if return_outputs else loss + + def prediction_step( + self, + model: nn.Module, + inputs:Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) + + diff --git a/omnilmm/finetune b/omnilmm/finetune deleted file mode 100644 index 8b13789..0000000 --- a/omnilmm/finetune +++ /dev/null @@ -1 +0,0 @@ - From 9f345c40206bb343288c0fb29498adf9ba2fd590 Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Tue, 7 May 2024 22:29:44 +0800 Subject: [PATCH 3/4] add finetuning script --- {finetune => omnilmm/finetune}/__init__.py | 0 {finetune => omnilmm/finetune}/dataset.py | 0 {finetune => omnilmm/finetune}/ds_config_zero2.json | 0 {finetune => omnilmm/finetune}/ds_config_zero3.json | 0 {finetune => omnilmm/finetune}/finetune.py | 0 {finetune => omnilmm/finetune}/finetune_ds.sh | 0 {finetune => omnilmm/finetune}/readme.md | 0 {finetune => omnilmm/finetune}/trainer.py | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename {finetune => omnilmm/finetune}/__init__.py (100%) rename {finetune => omnilmm/finetune}/dataset.py (100%) rename {finetune => omnilmm/finetune}/ds_config_zero2.json (100%) rename {finetune => omnilmm/finetune}/ds_config_zero3.json (100%) rename {finetune => omnilmm/finetune}/finetune.py (100%) rename {finetune => omnilmm/finetune}/finetune_ds.sh (100%) rename {finetune => omnilmm/finetune}/readme.md (100%) rename {finetune => omnilmm/finetune}/trainer.py (100%) diff --git a/finetune/__init__.py b/omnilmm/finetune/__init__.py similarity index 100% rename from finetune/__init__.py rename to omnilmm/finetune/__init__.py diff --git a/finetune/dataset.py b/omnilmm/finetune/dataset.py similarity index 100% rename from finetune/dataset.py rename to omnilmm/finetune/dataset.py diff --git a/finetune/ds_config_zero2.json b/omnilmm/finetune/ds_config_zero2.json similarity index 100% rename from finetune/ds_config_zero2.json rename to omnilmm/finetune/ds_config_zero2.json diff --git a/finetune/ds_config_zero3.json b/omnilmm/finetune/ds_config_zero3.json similarity index 100% rename from finetune/ds_config_zero3.json rename to omnilmm/finetune/ds_config_zero3.json diff --git a/finetune/finetune.py b/omnilmm/finetune/finetune.py similarity index 100% rename from finetune/finetune.py rename to omnilmm/finetune/finetune.py diff --git a/finetune/finetune_ds.sh b/omnilmm/finetune/finetune_ds.sh similarity index 100% rename from finetune/finetune_ds.sh rename to omnilmm/finetune/finetune_ds.sh diff --git a/finetune/readme.md b/omnilmm/finetune/readme.md similarity index 100% rename from finetune/readme.md rename to omnilmm/finetune/readme.md diff --git a/finetune/trainer.py b/omnilmm/finetune/trainer.py similarity index 100% rename from finetune/trainer.py rename to omnilmm/finetune/trainer.py From f6cbd4fb25bbf61551d9ca9f7c4eb2af6428b907 Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Wed, 8 May 2024 09:51:34 +0800 Subject: [PATCH 4/4] update finetuning code --- finetune/__init__.py | 0 {omnilmm/finetune => finetune}/dataset.py | 10 ---------- {omnilmm/finetune => finetune}/ds_config_zero2.json | 0 {omnilmm/finetune => finetune}/ds_config_zero3.json | 0 {omnilmm/finetune => finetune}/finetune.py | 8 -------- {omnilmm/finetune => finetune}/finetune_ds.sh | 0 {omnilmm/finetune => finetune}/readme.md | 3 --- {omnilmm/finetune => finetune}/trainer.py | 8 -------- omnilmm/finetune/__init__.py | 8 -------- 9 files changed, 37 deletions(-) create mode 100644 finetune/__init__.py rename {omnilmm/finetune => finetune}/dataset.py (98%) rename {omnilmm/finetune => finetune}/ds_config_zero2.json (100%) rename {omnilmm/finetune => finetune}/ds_config_zero3.json (100%) rename {omnilmm/finetune => finetune}/finetune.py (95%) rename {omnilmm/finetune => finetune}/finetune_ds.sh (100%) rename {omnilmm/finetune => finetune}/readme.md (96%) rename {omnilmm/finetune => finetune}/trainer.py (97%) delete mode 100644 omnilmm/finetune/__init__.py diff --git a/finetune/__init__.py b/finetune/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/omnilmm/finetune/dataset.py b/finetune/dataset.py similarity index 98% rename from omnilmm/finetune/dataset.py rename to finetune/dataset.py index 19592eb..b93035a 100644 --- a/omnilmm/finetune/dataset.py +++ b/finetune/dataset.py @@ -1,13 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI, ZHIHU Inc. (zhihu.com) -# -# @author: wangchongyi -# @author: chenqianyu -# @date: 2024/5/06 -# - import os import math import json diff --git a/omnilmm/finetune/ds_config_zero2.json b/finetune/ds_config_zero2.json similarity index 100% rename from omnilmm/finetune/ds_config_zero2.json rename to finetune/ds_config_zero2.json diff --git a/omnilmm/finetune/ds_config_zero3.json b/finetune/ds_config_zero3.json similarity index 100% rename from omnilmm/finetune/ds_config_zero3.json rename to finetune/ds_config_zero3.json diff --git a/omnilmm/finetune/finetune.py b/finetune/finetune.py similarity index 95% rename from omnilmm/finetune/finetune.py rename to finetune/finetune.py index 9f2f7ec..6a751e3 100644 --- a/omnilmm/finetune/finetune.py +++ b/finetune/finetune.py @@ -1,11 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI, ZHIHU Inc. (zhihu.com) -# -# @author: chenqianyu -# @date: 2024/5/03 -# import os import glob import json diff --git a/omnilmm/finetune/finetune_ds.sh b/finetune/finetune_ds.sh similarity index 100% rename from omnilmm/finetune/finetune_ds.sh rename to finetune/finetune_ds.sh diff --git a/omnilmm/finetune/readme.md b/finetune/readme.md similarity index 96% rename from omnilmm/finetune/readme.md rename to finetune/readme.md index 136aa84..26c3c47 100644 --- a/omnilmm/finetune/readme.md +++ b/finetune/readme.md @@ -64,6 +64,3 @@ sh finetune_ds.sh ``` #### Customizing Hyperparameters To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset. -### LoRA finetuning - -**This part is still unfinished, and we will complete it as soon as possible.** diff --git a/omnilmm/finetune/trainer.py b/finetune/trainer.py similarity index 97% rename from omnilmm/finetune/trainer.py rename to finetune/trainer.py index e8794dc..53ffaa6 100644 --- a/omnilmm/finetune/trainer.py +++ b/finetune/trainer.py @@ -1,11 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI, ZHIHU Inc. (zhihu.com) -# -# @author: chenqianyu -# @date: 2024/5/03 -# import torch import torch.nn as nn from typing import Tuple, Union, Optional, List, Dict, Any diff --git a/omnilmm/finetune/__init__.py b/omnilmm/finetune/__init__.py deleted file mode 100644 index 9b79fbc..0000000 --- a/omnilmm/finetune/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI, ZHIHU Inc. (zhihu.com) -# -# @author: chenqianyu -# @date: 2024/5/02 -#