update finetuen for multi images sft (#462)

This commit is contained in:
qianyu chen
2024-08-15 11:24:50 +08:00
committed by GitHub
parent 825abf10e2
commit cd64150b51
7 changed files with 170 additions and 68 deletions

View File

@@ -3,6 +3,8 @@ import json
import logging
import math
import os
import re
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@@ -12,6 +14,9 @@ from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
import logging
logger = logging.getLogger(__name__)
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
@@ -28,6 +33,7 @@ class SupervisedDataset(Dataset):
patch_size=14,
query_nums=64,
batch_vision=False,
max_length=2048,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
@@ -38,35 +44,46 @@ class SupervisedDataset(Dataset):
self.patch_size = patch_size
self.query_nums=query_nums
self.batch_vision = batch_vision
self.max_length = max_length
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = Image.open(self.raw_data[i]["image"]).convert("RGB")
ret = preprocess(
image,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
)
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
try:
if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
ret = preprocess(
images_dict,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
max_length=self.max_length
)
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
except:
logger.error(f"data fetch error")
return self.__getitem__(random.randint(0, len(self)))
return ret
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
@@ -105,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
}
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=2048):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
@@ -126,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
if input_ids.shape[-1] > max_length:
ids =ids[:max_length]
context = context[:max_length]
logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
if torch.all(context):
logger.error("No tokens available to compute loss.")
raise Exception("No tokens available to compute loss.")
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
@@ -151,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")
logger.error("image start token != image end tokens")
raise Exception("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
@@ -281,10 +307,9 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
return input_ids, context, raw_msg
def preprocess(
image,
conversation,
images_dict,
conversations,
tokenizer,
transform,
query_nums=64,
@@ -292,13 +317,14 @@ def preprocess(
llm_type=None,
patch_size=14,
batch_vision=False,
max_length=2048,
):
"""
single image preprocess, the image will be placed at the top of the conversation
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
"""
conversation = copy.deepcopy(conversation)
assert len(conversation) > 1, "conversation length must large than 2"
assert conversation[0]["role"] == "user", "the first role must be user"
conversations = copy.deepcopy(conversations)
assert len(conversations) > 1, "conversations length must large than 2"
assert conversations[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
@@ -313,40 +339,69 @@ def preprocess(
if llm_type=='qwen2':
new_schema = True
use_image_id = True
if slice_config:
images = []
image_id_cnt = 0
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
image_placeholder_dict = {}
images = []
image_id_cnt = 0
for img_name, image in images_dict.items():
if slice_config:
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums, new_schema = new_schema)
image_placeholder_dict[img_name] = image_placeholder
else:
images.append(image)
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums, new_schema = new_schema)
images = [transform(i) for i in images]
else:
image_placeholder = default_image_placeholder
image_placeholder_dict[img_name] = image_placeholder
images = [transform(i) for i in images]
if len(images_dict) == 1 and "<image>" in images_dict:
if "<image>" in conversations[0]["content"]:
conversations[0]["content"] = conversations[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversations[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
else:
images = [transform(image)]
image_placeholder = default_image_placeholder
if "<image>" in conversation[0]["content"]:
conversation[0]["content"] = conversation[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversation[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
pattern = r'<image_\d+>'
new_conversations = []
for conversation in conversations:
content = conversation['content']
parts = re.split(f'({pattern})', content)
for i, part in enumerate(parts):
if not part.strip():
continue
if re.match(pattern, part):
if part in image_placeholder_dict:
parts[i] = image_placeholder_dict[part]
else:
raise Exception(f"not found {part} in image dict")
conversation['content'] = '\n'.join(parts)
new_conversations.append(conversation)
conversations = new_conversations
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
if batch_vision:
tgt_sizes = []

View File

@@ -108,6 +108,7 @@ def make_supervised_data_module(
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
)
if data_args.eval_data_path:
@@ -121,6 +122,7 @@ def make_supervised_data_module(
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
)
else:
eval_dataset = None
@@ -276,6 +278,7 @@ def train():
max_length=training_args.model_max_length,
)
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
trainer = CPMTrainer(
model=model,
tokenizer=tokenizer,

View File

@@ -13,7 +13,7 @@ MODEL="openbmb/MiniCPM-V-2_6"
DATA="path/to/trainging_data"
EVAL_DATA="path/to/test_data"
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
DISTRIBUTED_ARGS="
@@ -39,7 +39,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--do_eval \
--tune_vision true \
--tune_llm true \
--model_max_length 2048 \
--model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \

View File

@@ -14,6 +14,9 @@ EVAL_DATA="path/to/test_data"
LLM_TYPE="qwen2"
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
@@ -39,7 +42,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--tune_llm false \
--use_lora true \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \
--model_max_length 2048 \
--model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \

View File

@@ -5,13 +5,15 @@ We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V-
### Data preparation
To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files.
To prepare your fine-tuning data, you should formulate each sample as a dictionary consisting of an id, an image path (or list of images), and a list of conversations. Then, save the data samples in JSON files.
For the vision-language example with image, you are required to provide **\<image\>** to define the position to insert the image embeddings. If you don't provide \<image\>, the image will be placed at the front of the conversation.
For vision-language tasks, you must provide placeholders like **\<image\>** or **\<image_XX\>** to define where to insert the image embeddings within the conversation. If no placeholder is provided, the image will be placed at the front of the conversation by default.
#### Single Image Example
If your input consists of a single image, you can use a single placeholder **\<image\>** to indicate where the image should be inserted in the conversation.
<details>
<summary>
<b>vision-language example (vl_finetune_data.json) with 1 samples.</b>
<b>Single image example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
@@ -50,6 +52,44 @@ For the vision-language example with image, you are required to provide **\<imag
</details>
#### Multiple Images Example
For inputs containing multiple images, utilize a dictionary where each key represents a unique placeholder (e.g., **\<image_00\>**, **\<image_01\**) with the corresponding image path as its value. These placeholders can then be used within the conversation to seamlessly insert images at specific positions.
Additionally, to optimize resource management, especially when dealing with large batches of images during training or inference, consider reducing `max_slice_nums`. For example, in version 2.6, a single image is represented by 64 tokens. When `slice=9`, an image with a maximum resolution of 1344x1344 will consume nearly 64*(9+1) tokens. To minimize the number of tokens used per image, you can set `slice=1`, resulting in a single image being represented by 64 tokens.
If the total token count exceeds `max_length`, truncation will be applied. For multi-image supervised fine-tuning (SFT), it's recommended to set `MODEL_MAX_LENGTH=4096` in your script for better performance.
<details>
<summary>
<b>Multiple images example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"image": {
"<image_00>": "path/to/image_0.jpg",
"<image_01>": "path/to/image_1.jpg",
"<image_02>": "path/to/image_2.jpg",
"<image_03>": "path/to/image_3.jpg"
},
"conversations": [
{
"role": "user",
"content": "How to create such text-only videos using CapCut?\n<image_00>\n<image_01>\n<image_02>\n<image_03>\n"
},
{
"role": "assistant",
"content": "To create a text-only video as shown in the images, follow these steps in CapCut..."
}
]
}
]
```
</details>
### Full-parameter finetuning