mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
update for multi images sft
This commit is contained in:
@@ -3,6 +3,8 @@ import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -43,30 +45,41 @@ class SupervisedDataset(Dataset):
|
||||
return len(self.raw_data)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
image = Image.open(self.raw_data[i]["image"]).convert("RGB")
|
||||
ret = preprocess(
|
||||
image,
|
||||
self.raw_data[i]["conversations"],
|
||||
self.tokenizer,
|
||||
self.transform,
|
||||
query_nums=self.query_nums,
|
||||
slice_config=self.slice_config,
|
||||
llm_type=self.llm_type,
|
||||
patch_size=self.patch_size,
|
||||
batch_vision=self.batch_vision,
|
||||
)
|
||||
ret = dict(
|
||||
input_ids=ret["input_ids"],
|
||||
position_ids=ret["position_ids"],
|
||||
labels=ret["target"],
|
||||
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
|
||||
pixel_values=ret["pixel_values"],
|
||||
tgt_sizes=ret["tgt_sizes"],
|
||||
image_bound=ret["image_bound"],
|
||||
)
|
||||
|
||||
# try:
|
||||
if 1:
|
||||
if isinstance(self.raw_data[i]["image"], str):
|
||||
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
||||
elif isinstance(self.raw_data[i]["image"], Dict):
|
||||
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
|
||||
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
|
||||
|
||||
ret = preprocess(
|
||||
images_dict,
|
||||
self.raw_data[i]["conversations"],
|
||||
self.tokenizer,
|
||||
self.transform,
|
||||
query_nums=self.query_nums,
|
||||
slice_config=self.slice_config,
|
||||
llm_type=self.llm_type,
|
||||
patch_size=self.patch_size,
|
||||
batch_vision=self.batch_vision,
|
||||
)
|
||||
ret = dict(
|
||||
input_ids=ret["input_ids"],
|
||||
position_ids=ret["position_ids"],
|
||||
labels=ret["target"],
|
||||
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
|
||||
pixel_values=ret["pixel_values"],
|
||||
tgt_sizes=ret["tgt_sizes"],
|
||||
image_bound=ret["image_bound"],
|
||||
)
|
||||
# except:
|
||||
# print(f"data fetch error")
|
||||
# return self.__getitem__(random.randint(0, len(self)))
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
def data_collator(examples, padding_value=0, max_length=2048):
|
||||
def trim_and_pad(seq, batch_first, padding_value):
|
||||
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
|
||||
@@ -283,8 +296,8 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
|
||||
|
||||
|
||||
def preprocess(
|
||||
image,
|
||||
conversation,
|
||||
images_dict,
|
||||
conversations,
|
||||
tokenizer,
|
||||
transform,
|
||||
query_nums=64,
|
||||
@@ -294,11 +307,11 @@ def preprocess(
|
||||
batch_vision=False,
|
||||
):
|
||||
"""
|
||||
single image preprocess, the image will be placed at the top of the conversation
|
||||
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
||||
"""
|
||||
conversation = copy.deepcopy(conversation)
|
||||
assert len(conversation) > 1, "conversation length must large than 2"
|
||||
assert conversation[0]["role"] == "user", "the first role must be user"
|
||||
conversations = copy.deepcopy(conversations)
|
||||
assert len(conversations) > 1, "conversations length must large than 2"
|
||||
assert conversations[0]["role"] == "user", "the first role must be user"
|
||||
|
||||
if slice_config is not None:
|
||||
assert isinstance(slice_config, Dict)
|
||||
@@ -313,41 +326,67 @@ def preprocess(
|
||||
if llm_type=='qwen2':
|
||||
new_schema = True
|
||||
use_image_id = True
|
||||
if slice_config:
|
||||
images = []
|
||||
image_id_cnt = 0
|
||||
source_image, patches, best_grid = slice_image(
|
||||
image,
|
||||
slice_config["max_slice_nums"],
|
||||
slice_config["scale_resolution"],
|
||||
slice_config["patch_size"],
|
||||
)
|
||||
images.append(source_image)
|
||||
image_placeholder = default_image_placeholder
|
||||
if len(patches) > 0:
|
||||
for i in range(len(patches)):
|
||||
for j in range(len(patches[0])):
|
||||
images.append(patches[i][j])
|
||||
image_placeholder_dict = {}
|
||||
images = []
|
||||
image_id_cnt = 0
|
||||
for img_name, image in images_dict.items():
|
||||
if slice_config:
|
||||
source_image, patches, best_grid = slice_image(
|
||||
image,
|
||||
slice_config["max_slice_nums"],
|
||||
slice_config["scale_resolution"],
|
||||
slice_config["patch_size"],
|
||||
)
|
||||
images.append(source_image)
|
||||
image_placeholder = default_image_placeholder
|
||||
if len(patches) > 0:
|
||||
for i in range(len(patches)):
|
||||
for j in range(len(patches[0])):
|
||||
images.append(patches[i][j])
|
||||
if use_image_id:
|
||||
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
|
||||
image_id_cnt += 1
|
||||
image_placeholder += get_grid_placeholder(
|
||||
tokenizer, best_grid, query_nums, new_schema = new_schema)
|
||||
image_placeholder_dict[img_name] = image_placeholder
|
||||
else:
|
||||
images.append(image)
|
||||
if use_image_id:
|
||||
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
|
||||
image_id_cnt += 1
|
||||
image_placeholder += get_grid_placeholder(
|
||||
tokenizer, best_grid, query_nums, new_schema = new_schema)
|
||||
images = [transform(i) for i in images]
|
||||
else:
|
||||
image_placeholder = default_image_placeholder
|
||||
image_placeholder_dict[img_name] = image_placeholder
|
||||
|
||||
images = [transform(i) for i in images]
|
||||
|
||||
if len(images_dict) == 1 and "<image>" in images_dict:
|
||||
if "<image>" in conversations[0]["content"]:
|
||||
conversations[0]["content"] = conversations[0]["content"].replace(
|
||||
"<image>", image_placeholder
|
||||
)
|
||||
else:
|
||||
conversations[0]["content"] = (
|
||||
image_placeholder + "\n" + conversation[0]["content"]
|
||||
)
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
|
||||
else:
|
||||
images = [transform(image)]
|
||||
image_placeholder = default_image_placeholder
|
||||
if "<image>" in conversation[0]["content"]:
|
||||
conversation[0]["content"] = conversation[0]["content"].replace(
|
||||
"<image>", image_placeholder
|
||||
)
|
||||
else:
|
||||
conversation[0]["content"] = (
|
||||
image_placeholder + "\n" + conversation[0]["content"]
|
||||
)
|
||||
|
||||
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
|
||||
|
||||
pattern = r'<image_\d+>'
|
||||
new_conversations = []
|
||||
for conversation in conversations:
|
||||
content = conversation['content']
|
||||
parts = re.split(f'({pattern})', content)
|
||||
for i, part in enumerate(parts):
|
||||
if re.match(pattern, part):
|
||||
if part in image_placeholder_dict:
|
||||
parts[i] = image_placeholder_dict[part]
|
||||
else:
|
||||
print(f'Unreplaced image tag: {part}')
|
||||
conversation['content'] = ''.join(parts)
|
||||
new_conversations.append(conversation)
|
||||
conversations = new_conversations
|
||||
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
|
||||
if batch_vision:
|
||||
tgt_sizes = []
|
||||
reshape_images = []
|
||||
|
||||
@@ -205,6 +205,7 @@ def train():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
tokenizer.model_max_length = training_args.model_max_length
|
||||
|
||||
if not training_args.tune_vision:
|
||||
model.vpm.requires_grad_(False)
|
||||
@@ -276,6 +277,7 @@ def train():
|
||||
max_length=training_args.model_max_length,
|
||||
)
|
||||
|
||||
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
|
||||
trainer = CPMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@@ -13,7 +13,7 @@ MODEL="openbmb/MiniCPM-V-2_6"
|
||||
DATA="path/to/trainging_data"
|
||||
EVAL_DATA="path/to/test_data"
|
||||
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
|
||||
|
||||
MODEL_MAX_Length=4096 # if use openbmb/MiniCPM-V-2 or openbmb/MiniCPM-Llama3-V-2_5, please set MODEL_MAX_Length=2048
|
||||
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
@@ -39,7 +39,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--do_eval \
|
||||
--tune_vision true \
|
||||
--tune_llm true \
|
||||
--model_max_length 2048 \
|
||||
--model_max_length $MODEL_MAX_Length \
|
||||
--max_slice_nums 9 \
|
||||
--max_steps 10000 \
|
||||
--eval_steps 1000 \
|
||||
|
||||
@@ -14,6 +14,9 @@ EVAL_DATA="path/to/test_data"
|
||||
LLM_TYPE="qwen2"
|
||||
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
|
||||
|
||||
MODEL_MAX_Length=4096 # if use openbmb/MiniCPM-V-2 or openbmb/MiniCPM-Llama3-V-2_5, please set MODEL_MAX_Length=2048
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $GPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
@@ -39,7 +42,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--tune_llm false \
|
||||
--use_lora true \
|
||||
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \
|
||||
--model_max_length 2048 \
|
||||
--model_max_length $MODEL_MAX_Length \
|
||||
--max_slice_nums 9 \
|
||||
--max_steps 10000 \
|
||||
--eval_steps 1000 \
|
||||
|
||||
@@ -5,13 +5,15 @@ We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V-
|
||||
|
||||
### Data preparation
|
||||
|
||||
To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files.
|
||||
To prepare your fine-tuning data, you should formulate each sample as a dictionary consisting of an id, an image path (or list of images), and a list of conversations. Then, save the data samples in JSON files.
|
||||
|
||||
For the vision-language example with image, you are required to provide **\<image\>** to define the position to insert the image embeddings. If you don't provide \<image\>, the image will be placed at the front of the conversation.
|
||||
For vision-language tasks, you must provide placeholders like **\<image\>** or **\<image_XX\>** to define where to insert the image embeddings within the conversation. If no placeholder is provided, the image will be placed at the front of the conversation by default.
|
||||
|
||||
#### Single Image Example
|
||||
If your input consists of a single image, you can use a single placeholder **\<image\>** to indicate where the image should be inserted in the conversation.
|
||||
<details>
|
||||
<summary>
|
||||
<b>vision-language example (vl_finetune_data.json) with 1 samples.</b>
|
||||
<b>Single image example (vl_finetune_data.json) with 1 samples.</b>
|
||||
</summary>
|
||||
|
||||
```
|
||||
@@ -50,6 +52,38 @@ For the vision-language example with image, you are required to provide **\<imag
|
||||
|
||||
</details>
|
||||
|
||||
#### Multiple Images Example
|
||||
For inputs with multiple images, you should use a dictionary where each key represents a unique placeholder (e.g., **\<image_00\>**, **\<image_01\>**), and the corresponding value is the image path. You can then use these placeholders in the conversation to insert the images at specific positions.
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>Multiple images example (vl_finetune_data.json) with 1 samples.</b>
|
||||
</summary>
|
||||
|
||||
```
|
||||
[
|
||||
{
|
||||
"id": "0",
|
||||
"image": {
|
||||
"<image_00>": "path/to/image_0.jpg",
|
||||
"<image_01>": "path/to/image_1.jpg",
|
||||
"<image_02>": "path/to/image_2.jpg",
|
||||
"<image_03>": "path/to/image_3.jpg"
|
||||
},
|
||||
"conversations": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How to create such text-only videos using CapCut?\n<image_00>\n<image_01>\n<image_02>\n<image_03>\n"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "To create a text-only video as shown in the images, follow these steps in CapCut..."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
</details>
|
||||
|
||||
### Full-parameter finetuning
|
||||
|
||||
|
||||
Reference in New Issue
Block a user