update for multi images sft

This commit is contained in:
qianyu chen
2024-08-13 13:57:48 +08:00
committed by GitHub
parent 61e942ec7c
commit 7842ec1228
5 changed files with 143 additions and 65 deletions

View File

@@ -3,6 +3,8 @@ import json
import logging
import math
import os
import re
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@@ -43,30 +45,41 @@ class SupervisedDataset(Dataset):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = Image.open(self.raw_data[i]["image"]).convert("RGB")
ret = preprocess(
image,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
)
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
# try:
if 1:
if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
ret = preprocess(
images_dict,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
)
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
# except:
# print(f"data fetch error")
# return self.__getitem__(random.randint(0, len(self)))
return ret
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
@@ -283,8 +296,8 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
def preprocess(
image,
conversation,
images_dict,
conversations,
tokenizer,
transform,
query_nums=64,
@@ -294,11 +307,11 @@ def preprocess(
batch_vision=False,
):
"""
single image preprocess, the image will be placed at the top of the conversation
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
"""
conversation = copy.deepcopy(conversation)
assert len(conversation) > 1, "conversation length must large than 2"
assert conversation[0]["role"] == "user", "the first role must be user"
conversations = copy.deepcopy(conversations)
assert len(conversations) > 1, "conversations length must large than 2"
assert conversations[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
@@ -313,41 +326,67 @@ def preprocess(
if llm_type=='qwen2':
new_schema = True
use_image_id = True
if slice_config:
images = []
image_id_cnt = 0
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
image_placeholder_dict = {}
images = []
image_id_cnt = 0
for img_name, image in images_dict.items():
if slice_config:
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums, new_schema = new_schema)
image_placeholder_dict[img_name] = image_placeholder
else:
images.append(image)
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums, new_schema = new_schema)
images = [transform(i) for i in images]
else:
image_placeholder = default_image_placeholder
image_placeholder_dict[img_name] = image_placeholder
images = [transform(i) for i in images]
if len(images_dict) == 1 and "<image>" in images_dict:
if "<image>" in conversations[0]["content"]:
conversations[0]["content"] = conversations[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversations[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
else:
images = [transform(image)]
image_placeholder = default_image_placeholder
if "<image>" in conversation[0]["content"]:
conversation[0]["content"] = conversation[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversation[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
pattern = r'<image_\d+>'
new_conversations = []
for conversation in conversations:
content = conversation['content']
parts = re.split(f'({pattern})', content)
for i, part in enumerate(parts):
if re.match(pattern, part):
if part in image_placeholder_dict:
parts[i] = image_placeholder_dict[part]
else:
print(f'Unreplaced image tag: {part}')
conversation['content'] = ''.join(parts)
new_conversations.append(conversation)
conversations = new_conversations
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
if batch_vision:
tgt_sizes = []
reshape_images = []