From cd64150b5122f8ee8c677d481c97918485129b52 Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:24:50 +0800 Subject: [PATCH] update finetuen for multi images sft (#462) --- README.md | 1 + README_zh.md | 2 +- finetune/dataset.py | 177 +++++++++++++++++++++++++------------- finetune/finetune.py | 3 + finetune/finetune_ds.sh | 4 +- finetune/finetune_lora.sh | 5 +- finetune/readme.md | 46 +++++++++- 7 files changed, 170 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index d031945..de9795e 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Join our 💬 WeChat #### 📌 Pinned +* [2024.08.15] We now also support multi-image SFT. For more details, please refer to the [finetune/README.md](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune) file. * [2024.08.14] MiniCPM-V 2.6 now also supports [fine-tuning](https://github.com/modelscope/ms-swift/issues/1613) with the SWIFT framework! * [2024.08.10] 🚀🚀🚀 MiniCPM-Llama3-V 2.5 is now fully supported by [official](https://github.com/ggerganov/llama.cpp) llama.cpp! GGUF models of various sizes are available [here](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf). Please note that MiniCPM-V 2.6 still needs to use [our fork](https://github.com/OpenBMB/llama.cpp/blob/minicpmv-main/examples/llava/README-minicpmv2.6.md). * [2024.08.06] 🔥🔥🔥 We open-source MiniCPM-V 2.6, which outperforms GPT-4V on single image, multi-image and video understanding. It advances popular features of MiniCPM-Llama3-V 2.5, and can support real-time video understanding on iPad. Try it now! diff --git a/README_zh.md b/README_zh.md index 499ab7e..1d59211 100644 --- a/README_zh.md +++ b/README_zh.md @@ -32,7 +32,7 @@ ## 更新日志 #### 📌 置顶 - +* [2024.08.15] MiniCPM-V 2.6 现在支持多图像 SFT。有关更多详细信息,请参阅[finetune/README.md](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune) * [2024.08.14] MiniCPM-V 2.6 现在可以通过 SWIFT 框架 [微调](https://github.com/modelscope/ms-swift/issues/1613) 了! * [2024.08.10] 🚀🚀🚀 llama.cpp [官方仓库](https://github.com/ggerganov/llama.cpp)正式支持 MiniCPM-Llama3-V 2.5 啦!点击[这里](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf/tree/main)查看各种大小的 GGUF 版本。但还请使用者注意 MiniCPM-V 2.6 仍然需要**拉取我们最新的 fork 来使用**:[llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpmv-main/examples/llava/README-minicpmv2.6.md) 。我们将继续积极推进将这些功能合并到 llama.cpp 官方仓库 * [2024.08.06] 🔥🔥🔥 我们开源了 MiniCPM-V 2.6,该模型在单图、多图和视频理解方面取得了优于 GPT-4V 的表现。我们还进一步提升了 MiniCPM-Llama3-V 2.5 的多项亮点能力,并首次支持了 iPad 上的实时视频理解。欢迎试用! diff --git a/finetune/dataset.py b/finetune/dataset.py index 09dadcc..dbf6bbc 100644 --- a/finetune/dataset.py +++ b/finetune/dataset.py @@ -3,6 +3,8 @@ import json import logging import math import os +import re +import random from dataclasses import dataclass, field from typing import Dict, List, Optional @@ -12,6 +14,9 @@ from PIL import Image from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from transformers import AutoProcessor, AutoTokenizer +import logging + +logger = logging.getLogger(__name__) llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" @@ -28,6 +33,7 @@ class SupervisedDataset(Dataset): patch_size=14, query_nums=64, batch_vision=False, + max_length=2048, ): super(SupervisedDataset, self).__init__() self.raw_data = raw_data @@ -38,35 +44,46 @@ class SupervisedDataset(Dataset): self.patch_size = patch_size self.query_nums=query_nums self.batch_vision = batch_vision + self.max_length = max_length def __len__(self): return len(self.raw_data) def __getitem__(self, i) -> Dict[str, torch.Tensor]: - image = Image.open(self.raw_data[i]["image"]).convert("RGB") - ret = preprocess( - image, - self.raw_data[i]["conversations"], - self.tokenizer, - self.transform, - query_nums=self.query_nums, - slice_config=self.slice_config, - llm_type=self.llm_type, - patch_size=self.patch_size, - batch_vision=self.batch_vision, - ) - ret = dict( - input_ids=ret["input_ids"], - position_ids=ret["position_ids"], - labels=ret["target"], - attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), - pixel_values=ret["pixel_values"], - tgt_sizes=ret["tgt_sizes"], - image_bound=ret["image_bound"], - ) - + try: + if isinstance(self.raw_data[i]["image"], str): + images_dict = { "" : Image.open(self.raw_data[i]["image"]).convert("RGB") } + elif isinstance(self.raw_data[i]["image"], Dict): + ### for multi-images input, the template for every image is , such as , + images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()} + + ret = preprocess( + images_dict, + self.raw_data[i]["conversations"], + self.tokenizer, + self.transform, + query_nums=self.query_nums, + slice_config=self.slice_config, + llm_type=self.llm_type, + patch_size=self.patch_size, + batch_vision=self.batch_vision, + max_length=self.max_length + ) + ret = dict( + input_ids=ret["input_ids"], + position_ids=ret["position_ids"], + labels=ret["target"], + attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), + pixel_values=ret["pixel_values"], + tgt_sizes=ret["tgt_sizes"], + image_bound=ret["image_bound"], + ) + except: + logger.error(f"data fetch error") + return self.__getitem__(random.randint(0, len(self))) return ret + def data_collator(examples, padding_value=0, max_length=2048): def trim_and_pad(seq, batch_first, padding_value): return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value) @@ -105,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048): } -def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False): +def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=2048): """ for single image multi-turn conversation conversation: [{'role': 'user', 'content': 'Describe this image'}, @@ -126,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32)) context = torch.from_numpy(np.hstack(context, dtype=np.int8)) + if input_ids.shape[-1] > max_length: + ids =ids[:max_length] + context = context[:max_length] + logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated") + + if torch.all(context): + logger.error("No tokens available to compute loss.") + raise Exception("No tokens available to compute loss.") # build target target = torch.full_like(ids, -100, dtype=torch.int32) @@ -151,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False image_start_tokens += 1 image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0] if len(image_start_tokens) != len(image_end_tokens): - print("image start token != image end tokens") + logger.error("image start token != image end tokens") + raise Exception("image start token != image end tokens") if len(image_start_tokens) > 0: image_bound = torch.hstack( @@ -281,10 +307,9 @@ def conversation_to_ids_qwen2(conversation, tokenizer): return input_ids, context, raw_msg - def preprocess( - image, - conversation, + images_dict, + conversations, tokenizer, transform, query_nums=64, @@ -292,13 +317,14 @@ def preprocess( llm_type=None, patch_size=14, batch_vision=False, + max_length=2048, ): """ - single image preprocess, the image will be placed at the top of the conversation + single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation """ - conversation = copy.deepcopy(conversation) - assert len(conversation) > 1, "conversation length must large than 2" - assert conversation[0]["role"] == "user", "the first role must be user" + conversations = copy.deepcopy(conversations) + assert len(conversations) > 1, "conversations length must large than 2" + assert conversations[0]["role"] == "user", "the first role must be user" if slice_config is not None: assert isinstance(slice_config, Dict) @@ -313,40 +339,69 @@ def preprocess( if llm_type=='qwen2': new_schema = True use_image_id = True - if slice_config: - images = [] - image_id_cnt = 0 - source_image, patches, best_grid = slice_image( - image, - slice_config["max_slice_nums"], - slice_config["scale_resolution"], - slice_config["patch_size"], - ) - images.append(source_image) - image_placeholder = default_image_placeholder - if len(patches) > 0: - for i in range(len(patches)): - for j in range(len(patches[0])): - images.append(patches[i][j]) + image_placeholder_dict = {} + images = [] + image_id_cnt = 0 + for img_name, image in images_dict.items(): + if slice_config: + source_image, patches, best_grid = slice_image( + image, + slice_config["max_slice_nums"], + slice_config["scale_resolution"], + slice_config["patch_size"], + ) + images.append(source_image) + image_placeholder = default_image_placeholder + if len(patches) > 0: + for i in range(len(patches)): + for j in range(len(patches[0])): + images.append(patches[i][j]) + if use_image_id: + image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder + image_id_cnt += 1 + image_placeholder += get_grid_placeholder( + tokenizer, best_grid, query_nums, new_schema = new_schema) + image_placeholder_dict[img_name] = image_placeholder + else: + images.append(image) if use_image_id: image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder image_id_cnt += 1 - image_placeholder += get_grid_placeholder( - tokenizer, best_grid, query_nums, new_schema = new_schema) - images = [transform(i) for i in images] + else: + image_placeholder = default_image_placeholder + image_placeholder_dict[img_name] = image_placeholder + + images = [transform(i) for i in images] + + if len(images_dict) == 1 and "" in images_dict: + if "" in conversations[0]["content"]: + conversations[0]["content"] = conversations[0]["content"].replace( + "", image_placeholder + ) + else: + conversations[0]["content"] = ( + image_placeholder + "\n" + conversation[0]["content"] + ) + input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) else: - images = [transform(image)] - image_placeholder = default_image_placeholder - if "" in conversation[0]["content"]: - conversation[0]["content"] = conversation[0]["content"].replace( - "", image_placeholder - ) - else: - conversation[0]["content"] = ( - image_placeholder + "\n" + conversation[0]["content"] - ) - - input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema) + pattern = r'' + new_conversations = [] + for conversation in conversations: + content = conversation['content'] + parts = re.split(f'({pattern})', content) + for i, part in enumerate(parts): + if not part.strip(): + continue + if re.match(pattern, part): + if part in image_placeholder_dict: + parts[i] = image_placeholder_dict[part] + else: + raise Exception(f"not found {part} in image dict") + conversation['content'] = '\n'.join(parts) + new_conversations.append(conversation) + conversations = new_conversations + + input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) if batch_vision: tgt_sizes = [] diff --git a/finetune/finetune.py b/finetune/finetune.py index 04cf2eb..0c596d7 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -108,6 +108,7 @@ def make_supervised_data_module( patch_size=patch_size, query_nums=query_nums, batch_vision=batch_vision, + max_length=max_length, ) if data_args.eval_data_path: @@ -121,6 +122,7 @@ def make_supervised_data_module( patch_size=patch_size, query_nums=query_nums, batch_vision=batch_vision, + max_length=max_length, ) else: eval_dataset = None @@ -276,6 +278,7 @@ def train(): max_length=training_args.model_max_length, ) + training_args.gradient_checkpointing_kwargs={"use_reentrant":False} trainer = CPMTrainer( model=model, tokenizer=tokenizer, diff --git a/finetune/finetune_ds.sh b/finetune/finetune_ds.sh index 92fd577..c049471 100644 --- a/finetune/finetune_ds.sh +++ b/finetune/finetune_ds.sh @@ -13,7 +13,7 @@ MODEL="openbmb/MiniCPM-V-2_6" DATA="path/to/trainging_data" EVAL_DATA="path/to/test_data" LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3" - +MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096 DISTRIBUTED_ARGS=" @@ -39,7 +39,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --do_eval \ --tune_vision true \ --tune_llm true \ - --model_max_length 2048 \ + --model_max_length $MODEL_MAX_Length \ --max_slice_nums 9 \ --max_steps 10000 \ --eval_steps 1000 \ diff --git a/finetune/finetune_lora.sh b/finetune/finetune_lora.sh index 2c12525..19437b1 100644 --- a/finetune/finetune_lora.sh +++ b/finetune/finetune_lora.sh @@ -14,6 +14,9 @@ EVAL_DATA="path/to/test_data" LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm #if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3 + +MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096 + DISTRIBUTED_ARGS=" --nproc_per_node $GPUS_PER_NODE \ --nnodes $NNODES \ @@ -39,7 +42,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --tune_llm false \ --use_lora true \ --lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \ - --model_max_length 2048 \ + --model_max_length $MODEL_MAX_Length \ --max_slice_nums 9 \ --max_steps 10000 \ --eval_steps 1000 \ diff --git a/finetune/readme.md b/finetune/readme.md index 700609f..9ba8c96 100644 --- a/finetune/readme.md +++ b/finetune/readme.md @@ -5,13 +5,15 @@ We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V- ### Data preparation -To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files. +To prepare your fine-tuning data, you should formulate each sample as a dictionary consisting of an id, an image path (or list of images), and a list of conversations. Then, save the data samples in JSON files. -For the vision-language example with image, you are required to provide **\** to define the position to insert the image embeddings. If you don't provide \, the image will be placed at the front of the conversation. +For vision-language tasks, you must provide placeholders like **\** or **\** to define where to insert the image embeddings within the conversation. If no placeholder is provided, the image will be placed at the front of the conversation by default. +#### Single Image Example +If your input consists of a single image, you can use a single placeholder **\** to indicate where the image should be inserted in the conversation.
- vision-language example (vl_finetune_data.json) with 1 samples. + Single image example (vl_finetune_data.json) with 1 samples. ``` @@ -50,6 +52,44 @@ For the vision-language example with image, you are required to provide **\ +#### Multiple Images Example +For inputs containing multiple images, utilize a dictionary where each key represents a unique placeholder (e.g., **\**, **\ + + Multiple images example (vl_finetune_data.json) with 1 samples. + + +``` + [ + { + "id": "0", + "image": { + "": "path/to/image_0.jpg", + "": "path/to/image_1.jpg", + "": "path/to/image_2.jpg", + "": "path/to/image_3.jpg" + }, + "conversations": [ + { + "role": "user", + "content": "How to create such text-only videos using CapCut?\n\n\n\n\n" + }, + { + "role": "assistant", + "content": "To create a text-only video as shown in the images, follow these steps in CapCut..." + } + ] + } + ] +``` +
### Full-parameter finetuning