update for multi images sft

This commit is contained in:
qianyu chen
2024-08-13 13:57:48 +08:00
committed by GitHub
parent 61e942ec7c
commit 7842ec1228
5 changed files with 143 additions and 65 deletions

View File

@@ -3,6 +3,8 @@ import json
import logging import logging
import math import math
import os import os
import re
import random
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional from typing import Dict, List, Optional
@@ -43,30 +45,41 @@ class SupervisedDataset(Dataset):
return len(self.raw_data) return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]: def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = Image.open(self.raw_data[i]["image"]).convert("RGB") # try:
ret = preprocess( if 1:
image, if isinstance(self.raw_data[i]["image"], str):
self.raw_data[i]["conversations"], images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
self.tokenizer, elif isinstance(self.raw_data[i]["image"], Dict):
self.transform, ### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
query_nums=self.query_nums, images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
slice_config=self.slice_config,
llm_type=self.llm_type, ret = preprocess(
patch_size=self.patch_size, images_dict,
batch_vision=self.batch_vision, self.raw_data[i]["conversations"],
) self.tokenizer,
ret = dict( self.transform,
input_ids=ret["input_ids"], query_nums=self.query_nums,
position_ids=ret["position_ids"], slice_config=self.slice_config,
labels=ret["target"], llm_type=self.llm_type,
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), patch_size=self.patch_size,
pixel_values=ret["pixel_values"], batch_vision=self.batch_vision,
tgt_sizes=ret["tgt_sizes"], )
image_bound=ret["image_bound"], ret = dict(
) input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
# except:
# print(f"data fetch error")
# return self.__getitem__(random.randint(0, len(self)))
return ret return ret
def data_collator(examples, padding_value=0, max_length=2048): def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value): def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value) return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
@@ -283,8 +296,8 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
def preprocess( def preprocess(
image, images_dict,
conversation, conversations,
tokenizer, tokenizer,
transform, transform,
query_nums=64, query_nums=64,
@@ -294,11 +307,11 @@ def preprocess(
batch_vision=False, batch_vision=False,
): ):
""" """
single image preprocess, the image will be placed at the top of the conversation single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
""" """
conversation = copy.deepcopy(conversation) conversations = copy.deepcopy(conversations)
assert len(conversation) > 1, "conversation length must large than 2" assert len(conversations) > 1, "conversations length must large than 2"
assert conversation[0]["role"] == "user", "the first role must be user" assert conversations[0]["role"] == "user", "the first role must be user"
if slice_config is not None: if slice_config is not None:
assert isinstance(slice_config, Dict) assert isinstance(slice_config, Dict)
@@ -313,41 +326,67 @@ def preprocess(
if llm_type=='qwen2': if llm_type=='qwen2':
new_schema = True new_schema = True
use_image_id = True use_image_id = True
if slice_config: image_placeholder_dict = {}
images = [] images = []
image_id_cnt = 0 image_id_cnt = 0
source_image, patches, best_grid = slice_image( for img_name, image in images_dict.items():
image, if slice_config:
slice_config["max_slice_nums"], source_image, patches, best_grid = slice_image(
slice_config["scale_resolution"], image,
slice_config["patch_size"], slice_config["max_slice_nums"],
) slice_config["scale_resolution"],
images.append(source_image) slice_config["patch_size"],
image_placeholder = default_image_placeholder )
if len(patches) > 0: images.append(source_image)
for i in range(len(patches)): image_placeholder = default_image_placeholder
for j in range(len(patches[0])): if len(patches) > 0:
images.append(patches[i][j]) for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums, new_schema = new_schema)
image_placeholder_dict[img_name] = image_placeholder
else:
images.append(image)
if use_image_id: if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1 image_id_cnt += 1
image_placeholder += get_grid_placeholder( else:
tokenizer, best_grid, query_nums, new_schema = new_schema) image_placeholder = default_image_placeholder
images = [transform(i) for i in images] image_placeholder_dict[img_name] = image_placeholder
images = [transform(i) for i in images]
if len(images_dict) == 1 and "<image>" in images_dict:
if "<image>" in conversations[0]["content"]:
conversations[0]["content"] = conversations[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversations[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
else: else:
images = [transform(image)] pattern = r'<image_\d+>'
image_placeholder = default_image_placeholder new_conversations = []
if "<image>" in conversation[0]["content"]: for conversation in conversations:
conversation[0]["content"] = conversation[0]["content"].replace( content = conversation['content']
"<image>", image_placeholder parts = re.split(f'({pattern})', content)
) for i, part in enumerate(parts):
else: if re.match(pattern, part):
conversation[0]["content"] = ( if part in image_placeholder_dict:
image_placeholder + "\n" + conversation[0]["content"] parts[i] = image_placeholder_dict[part]
) else:
print(f'Unreplaced image tag: {part}')
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema) conversation['content'] = ''.join(parts)
new_conversations.append(conversation)
conversations = new_conversations
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
if batch_vision: if batch_vision:
tgt_sizes = [] tgt_sizes = []
reshape_images = [] reshape_images = []

View File

@@ -205,6 +205,7 @@ def train():
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True model_args.model_name_or_path, trust_remote_code=True
) )
tokenizer.model_max_length = training_args.model_max_length
if not training_args.tune_vision: if not training_args.tune_vision:
model.vpm.requires_grad_(False) model.vpm.requires_grad_(False)
@@ -276,6 +277,7 @@ def train():
max_length=training_args.model_max_length, max_length=training_args.model_max_length,
) )
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
trainer = CPMTrainer( trainer = CPMTrainer(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,

View File

@@ -13,7 +13,7 @@ MODEL="openbmb/MiniCPM-V-2_6"
DATA="path/to/trainging_data" DATA="path/to/trainging_data"
EVAL_DATA="path/to/test_data" EVAL_DATA="path/to/test_data"
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3" LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
MODEL_MAX_Length=4096 # if use openbmb/MiniCPM-V-2 or openbmb/MiniCPM-Llama3-V-2_5, please set MODEL_MAX_Length=2048
DISTRIBUTED_ARGS=" DISTRIBUTED_ARGS="
@@ -39,7 +39,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--do_eval \ --do_eval \
--tune_vision true \ --tune_vision true \
--tune_llm true \ --tune_llm true \
--model_max_length 2048 \ --model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \ --max_slice_nums 9 \
--max_steps 10000 \ --max_steps 10000 \
--eval_steps 1000 \ --eval_steps 1000 \

View File

@@ -14,6 +14,9 @@ EVAL_DATA="path/to/test_data"
LLM_TYPE="qwen2" LLM_TYPE="qwen2"
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3 #if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
MODEL_MAX_Length=4096 # if use openbmb/MiniCPM-V-2 or openbmb/MiniCPM-Llama3-V-2_5, please set MODEL_MAX_Length=2048
DISTRIBUTED_ARGS=" DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \ --nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \ --nnodes $NNODES \
@@ -39,7 +42,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--tune_llm false \ --tune_llm false \
--use_lora true \ --use_lora true \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \ --lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \
--model_max_length 2048 \ --model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \ --max_slice_nums 9 \
--max_steps 10000 \ --max_steps 10000 \
--eval_steps 1000 \ --eval_steps 1000 \

View File

@@ -5,13 +5,15 @@ We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V-
### Data preparation ### Data preparation
To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files. To prepare your fine-tuning data, you should formulate each sample as a dictionary consisting of an id, an image path (or list of images), and a list of conversations. Then, save the data samples in JSON files.
For the vision-language example with image, you are required to provide **\<image\>** to define the position to insert the image embeddings. If you don't provide \<image\>, the image will be placed at the front of the conversation. For vision-language tasks, you must provide placeholders like **\<image\>** or **\<image_XX\>** to define where to insert the image embeddings within the conversation. If no placeholder is provided, the image will be placed at the front of the conversation by default.
#### Single Image Example
If your input consists of a single image, you can use a single placeholder **\<image\>** to indicate where the image should be inserted in the conversation.
<details> <details>
<summary> <summary>
<b>vision-language example (vl_finetune_data.json) with 1 samples.</b> <b>Single image example (vl_finetune_data.json) with 1 samples.</b>
</summary> </summary>
``` ```
@@ -50,6 +52,38 @@ For the vision-language example with image, you are required to provide **\<imag
</details> </details>
#### Multiple Images Example
For inputs with multiple images, you should use a dictionary where each key represents a unique placeholder (e.g., **\<image_00\>**, **\<image_01\>**), and the corresponding value is the image path. You can then use these placeholders in the conversation to insert the images at specific positions.
<details>
<summary>
<b>Multiple images example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"image": {
"<image_00>": "path/to/image_0.jpg",
"<image_01>": "path/to/image_1.jpg",
"<image_02>": "path/to/image_2.jpg",
"<image_03>": "path/to/image_3.jpg"
},
"conversations": [
{
"role": "user",
"content": "How to create such text-only videos using CapCut?\n<image_00>\n<image_01>\n<image_02>\n<image_03>\n"
},
{
"role": "assistant",
"content": "To create a text-only video as shown in the images, follow these steps in CapCut..."
}
]
}
]
```
</details>
### Full-parameter finetuning ### Full-parameter finetuning