update finetuen for multi images sft (#462)

This commit is contained in:
qianyu chen
2024-08-15 11:24:50 +08:00
committed by GitHub
parent 825abf10e2
commit cd64150b51
7 changed files with 170 additions and 68 deletions

View File

@@ -29,6 +29,7 @@ Join our <a href="docs/wechat.md" target="_blank"> 💬 WeChat</a>
#### 📌 Pinned
* [2024.08.15] We now also support multi-image SFT. For more details, please refer to the [finetune/README.md](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune) file.
* [2024.08.14] MiniCPM-V 2.6 now also supports [fine-tuning](https://github.com/modelscope/ms-swift/issues/1613) with the SWIFT framework!
* [2024.08.10] 🚀🚀🚀 MiniCPM-Llama3-V 2.5 is now fully supported by [official](https://github.com/ggerganov/llama.cpp) llama.cpp! GGUF models of various sizes are available [here](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf). Please note that MiniCPM-V 2.6 still needs to use [our fork](https://github.com/OpenBMB/llama.cpp/blob/minicpmv-main/examples/llava/README-minicpmv2.6.md).
* [2024.08.06] 🔥🔥🔥 We open-source MiniCPM-V 2.6, which outperforms GPT-4V on single image, multi-image and video understanding. It advances popular features of MiniCPM-Llama3-V 2.5, and can support real-time video understanding on iPad. Try it now!

View File

@@ -32,7 +32,7 @@
## 更新日志 <!-- omit in toc -->
#### 📌 置顶
* [2024.08.15] MiniCPM-V 2.6 现在支持多图像 SFT。有关更多详细信息请参阅[finetune/README.md](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune)
* [2024.08.14] MiniCPM-V 2.6 现在可以通过 SWIFT 框架 [微调](https://github.com/modelscope/ms-swift/issues/1613) 了!
* [2024.08.10] 🚀🚀🚀 llama.cpp [官方仓库](https://github.com/ggerganov/llama.cpp)正式支持 MiniCPM-Llama3-V 2.5 啦!点击[这里](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf/tree/main)查看各种大小的 GGUF 版本。但还请使用者注意 MiniCPM-V 2.6 仍然需要**拉取我们最新的 fork 来使用**[llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpmv-main/examples/llava/README-minicpmv2.6.md) 。我们将继续积极推进将这些功能合并到 llama.cpp 官方仓库
* [2024.08.06] 🔥🔥🔥 我们开源了 MiniCPM-V 2.6,该模型在单图、多图和视频理解方面取得了优于 GPT-4V 的表现。我们还进一步提升了 MiniCPM-Llama3-V 2.5 的多项亮点能力,并首次支持了 iPad 上的实时视频理解。欢迎试用!

View File

@@ -3,6 +3,8 @@ import json
import logging
import math
import os
import re
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@@ -12,6 +14,9 @@ from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
import logging
logger = logging.getLogger(__name__)
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
@@ -28,6 +33,7 @@ class SupervisedDataset(Dataset):
patch_size=14,
query_nums=64,
batch_vision=False,
max_length=2048,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
@@ -38,35 +44,46 @@ class SupervisedDataset(Dataset):
self.patch_size = patch_size
self.query_nums=query_nums
self.batch_vision = batch_vision
self.max_length = max_length
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = Image.open(self.raw_data[i]["image"]).convert("RGB")
ret = preprocess(
image,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
)
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
try:
if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
ret = preprocess(
images_dict,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
max_length=self.max_length
)
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
except:
logger.error(f"data fetch error")
return self.__getitem__(random.randint(0, len(self)))
return ret
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
@@ -105,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
}
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=2048):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
@@ -126,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
if input_ids.shape[-1] > max_length:
ids =ids[:max_length]
context = context[:max_length]
logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
if torch.all(context):
logger.error("No tokens available to compute loss.")
raise Exception("No tokens available to compute loss.")
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
@@ -151,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")
logger.error("image start token != image end tokens")
raise Exception("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
@@ -281,10 +307,9 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
return input_ids, context, raw_msg
def preprocess(
image,
conversation,
images_dict,
conversations,
tokenizer,
transform,
query_nums=64,
@@ -292,13 +317,14 @@ def preprocess(
llm_type=None,
patch_size=14,
batch_vision=False,
max_length=2048,
):
"""
single image preprocess, the image will be placed at the top of the conversation
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
"""
conversation = copy.deepcopy(conversation)
assert len(conversation) > 1, "conversation length must large than 2"
assert conversation[0]["role"] == "user", "the first role must be user"
conversations = copy.deepcopy(conversations)
assert len(conversations) > 1, "conversations length must large than 2"
assert conversations[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
@@ -313,40 +339,69 @@ def preprocess(
if llm_type=='qwen2':
new_schema = True
use_image_id = True
if slice_config:
images = []
image_id_cnt = 0
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
image_placeholder_dict = {}
images = []
image_id_cnt = 0
for img_name, image in images_dict.items():
if slice_config:
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums, new_schema = new_schema)
image_placeholder_dict[img_name] = image_placeholder
else:
images.append(image)
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums, new_schema = new_schema)
images = [transform(i) for i in images]
else:
image_placeholder = default_image_placeholder
image_placeholder_dict[img_name] = image_placeholder
images = [transform(i) for i in images]
if len(images_dict) == 1 and "<image>" in images_dict:
if "<image>" in conversations[0]["content"]:
conversations[0]["content"] = conversations[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversations[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
else:
images = [transform(image)]
image_placeholder = default_image_placeholder
if "<image>" in conversation[0]["content"]:
conversation[0]["content"] = conversation[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversation[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
pattern = r'<image_\d+>'
new_conversations = []
for conversation in conversations:
content = conversation['content']
parts = re.split(f'({pattern})', content)
for i, part in enumerate(parts):
if not part.strip():
continue
if re.match(pattern, part):
if part in image_placeholder_dict:
parts[i] = image_placeholder_dict[part]
else:
raise Exception(f"not found {part} in image dict")
conversation['content'] = '\n'.join(parts)
new_conversations.append(conversation)
conversations = new_conversations
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
if batch_vision:
tgt_sizes = []

View File

@@ -108,6 +108,7 @@ def make_supervised_data_module(
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
)
if data_args.eval_data_path:
@@ -121,6 +122,7 @@ def make_supervised_data_module(
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
)
else:
eval_dataset = None
@@ -276,6 +278,7 @@ def train():
max_length=training_args.model_max_length,
)
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
trainer = CPMTrainer(
model=model,
tokenizer=tokenizer,

View File

@@ -13,7 +13,7 @@ MODEL="openbmb/MiniCPM-V-2_6"
DATA="path/to/trainging_data"
EVAL_DATA="path/to/test_data"
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
DISTRIBUTED_ARGS="
@@ -39,7 +39,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--do_eval \
--tune_vision true \
--tune_llm true \
--model_max_length 2048 \
--model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \

View File

@@ -14,6 +14,9 @@ EVAL_DATA="path/to/test_data"
LLM_TYPE="qwen2"
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
@@ -39,7 +42,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--tune_llm false \
--use_lora true \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \
--model_max_length 2048 \
--model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \

View File

@@ -5,13 +5,15 @@ We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V-
### Data preparation
To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files.
To prepare your fine-tuning data, you should formulate each sample as a dictionary consisting of an id, an image path (or list of images), and a list of conversations. Then, save the data samples in JSON files.
For the vision-language example with image, you are required to provide **\<image\>** to define the position to insert the image embeddings. If you don't provide \<image\>, the image will be placed at the front of the conversation.
For vision-language tasks, you must provide placeholders like **\<image\>** or **\<image_XX\>** to define where to insert the image embeddings within the conversation. If no placeholder is provided, the image will be placed at the front of the conversation by default.
#### Single Image Example
If your input consists of a single image, you can use a single placeholder **\<image\>** to indicate where the image should be inserted in the conversation.
<details>
<summary>
<b>vision-language example (vl_finetune_data.json) with 1 samples.</b>
<b>Single image example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
@@ -50,6 +52,44 @@ For the vision-language example with image, you are required to provide **\<imag
</details>
#### Multiple Images Example
For inputs containing multiple images, utilize a dictionary where each key represents a unique placeholder (e.g., **\<image_00\>**, **\<image_01\**) with the corresponding image path as its value. These placeholders can then be used within the conversation to seamlessly insert images at specific positions.
Additionally, to optimize resource management, especially when dealing with large batches of images during training or inference, consider reducing `max_slice_nums`. For example, in version 2.6, a single image is represented by 64 tokens. When `slice=9`, an image with a maximum resolution of 1344x1344 will consume nearly 64*(9+1) tokens. To minimize the number of tokens used per image, you can set `slice=1`, resulting in a single image being represented by 64 tokens.
If the total token count exceeds `max_length`, truncation will be applied. For multi-image supervised fine-tuning (SFT), it's recommended to set `MODEL_MAX_LENGTH=4096` in your script for better performance.
<details>
<summary>
<b>Multiple images example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"image": {
"<image_00>": "path/to/image_0.jpg",
"<image_01>": "path/to/image_1.jpg",
"<image_02>": "path/to/image_2.jpg",
"<image_03>": "path/to/image_3.jpg"
},
"conversations": [
{
"role": "user",
"content": "How to create such text-only videos using CapCut?\n<image_00>\n<image_01>\n<image_02>\n<image_03>\n"
},
{
"role": "assistant",
"content": "To create a text-only video as shown in the images, follow these steps in CapCut..."
}
]
}
]
```
</details>
### Full-parameter finetuning