diff --git a/README.md b/README.md
index d031945..de9795e 100644
--- a/README.md
+++ b/README.md
@@ -29,6 +29,7 @@ Join our 💬 WeChat
#### 📌 Pinned
+* [2024.08.15] We now also support multi-image SFT. For more details, please refer to the [finetune/README.md](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune) file.
* [2024.08.14] MiniCPM-V 2.6 now also supports [fine-tuning](https://github.com/modelscope/ms-swift/issues/1613) with the SWIFT framework!
* [2024.08.10] 🚀🚀🚀 MiniCPM-Llama3-V 2.5 is now fully supported by [official](https://github.com/ggerganov/llama.cpp) llama.cpp! GGUF models of various sizes are available [here](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf). Please note that MiniCPM-V 2.6 still needs to use [our fork](https://github.com/OpenBMB/llama.cpp/blob/minicpmv-main/examples/llava/README-minicpmv2.6.md).
* [2024.08.06] 🔥🔥🔥 We open-source MiniCPM-V 2.6, which outperforms GPT-4V on single image, multi-image and video understanding. It advances popular features of MiniCPM-Llama3-V 2.5, and can support real-time video understanding on iPad. Try it now!
diff --git a/README_zh.md b/README_zh.md
index 499ab7e..1d59211 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -32,7 +32,7 @@
## 更新日志
#### 📌 置顶
-
+* [2024.08.15] MiniCPM-V 2.6 现在支持多图像 SFT。有关更多详细信息,请参阅[finetune/README.md](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune)
* [2024.08.14] MiniCPM-V 2.6 现在可以通过 SWIFT 框架 [微调](https://github.com/modelscope/ms-swift/issues/1613) 了!
* [2024.08.10] 🚀🚀🚀 llama.cpp [官方仓库](https://github.com/ggerganov/llama.cpp)正式支持 MiniCPM-Llama3-V 2.5 啦!点击[这里](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf/tree/main)查看各种大小的 GGUF 版本。但还请使用者注意 MiniCPM-V 2.6 仍然需要**拉取我们最新的 fork 来使用**:[llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpmv-main/examples/llava/README-minicpmv2.6.md) 。我们将继续积极推进将这些功能合并到 llama.cpp 官方仓库
* [2024.08.06] 🔥🔥🔥 我们开源了 MiniCPM-V 2.6,该模型在单图、多图和视频理解方面取得了优于 GPT-4V 的表现。我们还进一步提升了 MiniCPM-Llama3-V 2.5 的多项亮点能力,并首次支持了 iPad 上的实时视频理解。欢迎试用!
diff --git a/finetune/dataset.py b/finetune/dataset.py
index 09dadcc..dbf6bbc 100644
--- a/finetune/dataset.py
+++ b/finetune/dataset.py
@@ -3,6 +3,8 @@ import json
import logging
import math
import os
+import re
+import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@@ -12,6 +14,9 @@ from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
+import logging
+
+logger = logging.getLogger(__name__)
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
@@ -28,6 +33,7 @@ class SupervisedDataset(Dataset):
patch_size=14,
query_nums=64,
batch_vision=False,
+ max_length=2048,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
@@ -38,35 +44,46 @@ class SupervisedDataset(Dataset):
self.patch_size = patch_size
self.query_nums=query_nums
self.batch_vision = batch_vision
+ self.max_length = max_length
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
- image = Image.open(self.raw_data[i]["image"]).convert("RGB")
- ret = preprocess(
- image,
- self.raw_data[i]["conversations"],
- self.tokenizer,
- self.transform,
- query_nums=self.query_nums,
- slice_config=self.slice_config,
- llm_type=self.llm_type,
- patch_size=self.patch_size,
- batch_vision=self.batch_vision,
- )
- ret = dict(
- input_ids=ret["input_ids"],
- position_ids=ret["position_ids"],
- labels=ret["target"],
- attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
- pixel_values=ret["pixel_values"],
- tgt_sizes=ret["tgt_sizes"],
- image_bound=ret["image_bound"],
- )
-
+ try:
+ if isinstance(self.raw_data[i]["image"], str):
+ images_dict = { "" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
+ elif isinstance(self.raw_data[i]["image"], Dict):
+ ### for multi-images input, the template for every image is , such as ,
+ images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
+
+ ret = preprocess(
+ images_dict,
+ self.raw_data[i]["conversations"],
+ self.tokenizer,
+ self.transform,
+ query_nums=self.query_nums,
+ slice_config=self.slice_config,
+ llm_type=self.llm_type,
+ patch_size=self.patch_size,
+ batch_vision=self.batch_vision,
+ max_length=self.max_length
+ )
+ ret = dict(
+ input_ids=ret["input_ids"],
+ position_ids=ret["position_ids"],
+ labels=ret["target"],
+ attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
+ pixel_values=ret["pixel_values"],
+ tgt_sizes=ret["tgt_sizes"],
+ image_bound=ret["image_bound"],
+ )
+ except:
+ logger.error(f"data fetch error")
+ return self.__getitem__(random.randint(0, len(self)))
return ret
+
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
@@ -105,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
}
-def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
+def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=2048):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
@@ -126,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
+ if input_ids.shape[-1] > max_length:
+ ids =ids[:max_length]
+ context = context[:max_length]
+ logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
+
+ if torch.all(context):
+ logger.error("No tokens available to compute loss.")
+ raise Exception("No tokens available to compute loss.")
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
@@ -151,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
- print("image start token != image end tokens")
+ logger.error("image start token != image end tokens")
+ raise Exception("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
@@ -281,10 +307,9 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
return input_ids, context, raw_msg
-
def preprocess(
- image,
- conversation,
+ images_dict,
+ conversations,
tokenizer,
transform,
query_nums=64,
@@ -292,13 +317,14 @@ def preprocess(
llm_type=None,
patch_size=14,
batch_vision=False,
+ max_length=2048,
):
"""
- single image preprocess, the image will be placed at the top of the conversation
+ single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
"""
- conversation = copy.deepcopy(conversation)
- assert len(conversation) > 1, "conversation length must large than 2"
- assert conversation[0]["role"] == "user", "the first role must be user"
+ conversations = copy.deepcopy(conversations)
+ assert len(conversations) > 1, "conversations length must large than 2"
+ assert conversations[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
@@ -313,40 +339,69 @@ def preprocess(
if llm_type=='qwen2':
new_schema = True
use_image_id = True
- if slice_config:
- images = []
- image_id_cnt = 0
- source_image, patches, best_grid = slice_image(
- image,
- slice_config["max_slice_nums"],
- slice_config["scale_resolution"],
- slice_config["patch_size"],
- )
- images.append(source_image)
- image_placeholder = default_image_placeholder
- if len(patches) > 0:
- for i in range(len(patches)):
- for j in range(len(patches[0])):
- images.append(patches[i][j])
+ image_placeholder_dict = {}
+ images = []
+ image_id_cnt = 0
+ for img_name, image in images_dict.items():
+ if slice_config:
+ source_image, patches, best_grid = slice_image(
+ image,
+ slice_config["max_slice_nums"],
+ slice_config["scale_resolution"],
+ slice_config["patch_size"],
+ )
+ images.append(source_image)
+ image_placeholder = default_image_placeholder
+ if len(patches) > 0:
+ for i in range(len(patches)):
+ for j in range(len(patches[0])):
+ images.append(patches[i][j])
+ if use_image_id:
+ image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
+ image_id_cnt += 1
+ image_placeholder += get_grid_placeholder(
+ tokenizer, best_grid, query_nums, new_schema = new_schema)
+ image_placeholder_dict[img_name] = image_placeholder
+ else:
+ images.append(image)
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
- image_placeholder += get_grid_placeholder(
- tokenizer, best_grid, query_nums, new_schema = new_schema)
- images = [transform(i) for i in images]
+ else:
+ image_placeholder = default_image_placeholder
+ image_placeholder_dict[img_name] = image_placeholder
+
+ images = [transform(i) for i in images]
+
+ if len(images_dict) == 1 and "" in images_dict:
+ if "" in conversations[0]["content"]:
+ conversations[0]["content"] = conversations[0]["content"].replace(
+ "", image_placeholder
+ )
+ else:
+ conversations[0]["content"] = (
+ image_placeholder + "\n" + conversation[0]["content"]
+ )
+ input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
else:
- images = [transform(image)]
- image_placeholder = default_image_placeholder
- if "" in conversation[0]["content"]:
- conversation[0]["content"] = conversation[0]["content"].replace(
- "", image_placeholder
- )
- else:
- conversation[0]["content"] = (
- image_placeholder + "\n" + conversation[0]["content"]
- )
-
- input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
+ pattern = r''
+ new_conversations = []
+ for conversation in conversations:
+ content = conversation['content']
+ parts = re.split(f'({pattern})', content)
+ for i, part in enumerate(parts):
+ if not part.strip():
+ continue
+ if re.match(pattern, part):
+ if part in image_placeholder_dict:
+ parts[i] = image_placeholder_dict[part]
+ else:
+ raise Exception(f"not found {part} in image dict")
+ conversation['content'] = '\n'.join(parts)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
if batch_vision:
tgt_sizes = []
diff --git a/finetune/finetune.py b/finetune/finetune.py
index 04cf2eb..0c596d7 100644
--- a/finetune/finetune.py
+++ b/finetune/finetune.py
@@ -108,6 +108,7 @@ def make_supervised_data_module(
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
+ max_length=max_length,
)
if data_args.eval_data_path:
@@ -121,6 +122,7 @@ def make_supervised_data_module(
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
+ max_length=max_length,
)
else:
eval_dataset = None
@@ -276,6 +278,7 @@ def train():
max_length=training_args.model_max_length,
)
+ training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
trainer = CPMTrainer(
model=model,
tokenizer=tokenizer,
diff --git a/finetune/finetune_ds.sh b/finetune/finetune_ds.sh
index 92fd577..c049471 100644
--- a/finetune/finetune_ds.sh
+++ b/finetune/finetune_ds.sh
@@ -13,7 +13,7 @@ MODEL="openbmb/MiniCPM-V-2_6"
DATA="path/to/trainging_data"
EVAL_DATA="path/to/test_data"
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
-
+MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
DISTRIBUTED_ARGS="
@@ -39,7 +39,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--do_eval \
--tune_vision true \
--tune_llm true \
- --model_max_length 2048 \
+ --model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \
diff --git a/finetune/finetune_lora.sh b/finetune/finetune_lora.sh
index 2c12525..19437b1 100644
--- a/finetune/finetune_lora.sh
+++ b/finetune/finetune_lora.sh
@@ -14,6 +14,9 @@ EVAL_DATA="path/to/test_data"
LLM_TYPE="qwen2"
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
+
+MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
+
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
@@ -39,7 +42,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--tune_llm false \
--use_lora true \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \
- --model_max_length 2048 \
+ --model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \
diff --git a/finetune/readme.md b/finetune/readme.md
index 700609f..9ba8c96 100644
--- a/finetune/readme.md
+++ b/finetune/readme.md
@@ -5,13 +5,15 @@ We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V-
### Data preparation
-To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files.
+To prepare your fine-tuning data, you should formulate each sample as a dictionary consisting of an id, an image path (or list of images), and a list of conversations. Then, save the data samples in JSON files.
-For the vision-language example with image, you are required to provide **\** to define the position to insert the image embeddings. If you don't provide \, the image will be placed at the front of the conversation.
+For vision-language tasks, you must provide placeholders like **\** or **\** to define where to insert the image embeddings within the conversation. If no placeholder is provided, the image will be placed at the front of the conversation by default.
+#### Single Image Example
+If your input consists of a single image, you can use a single placeholder **\** to indicate where the image should be inserted in the conversation.
- vision-language example (vl_finetune_data.json) with 1 samples.
+ Single image example (vl_finetune_data.json) with 1 samples.
```
@@ -50,6 +52,44 @@ For the vision-language example with image, you are required to provide **\
+#### Multiple Images Example
+For inputs containing multiple images, utilize a dictionary where each key represents a unique placeholder (e.g., **\**, **\
+
+ Multiple images example (vl_finetune_data.json) with 1 samples.
+
+
+```
+ [
+ {
+ "id": "0",
+ "image": {
+ "": "path/to/image_0.jpg",
+ "": "path/to/image_1.jpg",
+ "": "path/to/image_2.jpg",
+ "": "path/to/image_3.jpg"
+ },
+ "conversations": [
+ {
+ "role": "user",
+ "content": "How to create such text-only videos using CapCut?\n\n\n\n\n"
+ },
+ {
+ "role": "assistant",
+ "content": "To create a text-only video as shown in the images, follow these steps in CapCut..."
+ }
+ ]
+ }
+ ]
+```
+
### Full-parameter finetuning