Update to MiniCPM-Llama3-V 2.5

This commit is contained in:
yiranyyu
2024-05-20 12:44:33 +08:00
parent c0e39dbfe2
commit 2c75097411
27 changed files with 1944 additions and 985 deletions

View File

@@ -1,90 +1,115 @@
import os
import math
import json
import copy
import json
import logging
import math
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from typing import Dict, Optional, List
from PIL import Image
from dataclasses import dataclass, field
from transformers import AutoTokenizer, AutoProcessor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, transform, tokenizer, slice_config):
def __init__(
self,
raw_data,
transform,
tokenizer,
slice_config,
llm_type="minicpm",
patch_size=14,
query_nums=64,
batch_vision=False,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
self.tokenizer = tokenizer
self.transform = transform
self.slice_config = slice_config
self.llm_type = llm_type
self.patch_size = patch_size
self.query_nums=query_nums
self.batch_vision = batch_vision
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = Image.open(self.raw_data[i]["image"]).convert("RGB")
ret = preprocess(image, self.raw_data[i]["conversations"], self.tokenizer, self.transform, slice_config=self.slice_config)
ret = preprocess(
image,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
)
ret = dict(
input_ids=ret["input_ids"],
labels=ret["target"],
attention_mask=ret["input_ids"].ne(self.tokenizer.pad_token_id),
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
return ret
def data_collator(examples, padding_value=0):
input_ids = pad_sequence([example["input_ids"] for example in examples], batch_first=True, padding_value=padding_value)
targets = pad_sequence([example["labels"] for example in examples], batch_first=True, padding_value=padding_value)
attention_mask = pad_sequence([example["attention_mask"] for example in examples], batch_first=True, padding_value=padding_value)
input_ids = pad_sequence(
[example["input_ids"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
targets = pad_sequence(
[example["labels"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
attention_mask = pad_sequence(
[example["attention_mask"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
pixel_values = [example["pixel_values"] for example in examples]
image_bound = [example["image_bound"] for example in examples]
return {"input_ids": input_ids, "labels":targets, "attention_mask": attention_mask, "image_bound": image_bound, "pixel_values": pixel_values}
tgt_sizes = [example["tgt_sizes"] for example in examples]
return {
"input_ids": input_ids,
"labels": targets,
"attention_mask": attention_mask,
"image_bound": image_bound,
"tgt_sizes": tgt_sizes,
"pixel_values": pixel_values,
}
def conversation_to_ids(conversation, tokenizer):
def conversation_to_ids(conversation, tokenizer, llm_type=None):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
{'role': 'assistant', 'content': 'This is a cat.'}]
"""
raw_msg = ''
input_ids = []
context = []
for idx, msg in enumerate(conversation):
role = msg['role']
message = msg['content']
assert role in ['user', 'assistant']
if role == 'user':
prefix = '<用户>'
else:
prefix = '<AI>'
# append eos
if idx == len(conversation) - 1:
message = message + tokenizer.eos_token
prefix_ids = tokenizer.encode(prefix)[1:] # remove bos
message_ids = tokenizer.encode(message)[1:]
if llm_type == "llama3":
input_ids, context, raw_msg = conversation_to_ids_llama3(
conversation, tokenizer
)
else:
input_ids, context, raw_msg = conversation_to_ids_minicpm(
conversation, tokenizer
)
input_ids.append(prefix_ids)
input_ids.append(message_ids)
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
if role == 'assistant':
context.append(np.zeros((len(message_ids),), dtype=np.int8))
else:
context.append(np.ones((len(message_ids),), dtype=np.int8))
raw_msg += (prefix + message)
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
@@ -94,45 +119,137 @@ def conversation_to_ids(conversation, tokenizer):
if context[i] == 0:
target[i - 1] = ids[i]
if context[i] == 1 and context[i - 1] == 0:
target[i - 1] = tokenizer.eos_id
if hasattr(tokenizer, "eot_id"):
target[i - 1] = tokenizer.eot_id
else:
target[i - 1] = tokenizer.eos_id
# build image bound
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print('image start token != image end tokens')
if len(image_start_tokens)>0:
image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
print("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
)
else:
image_bound = []
return {
'input_ids': ids,
'target': target,
'image_bound': image_bound,
'raw_msg': raw_msg,
"input_ids": ids,
"target": target,
"image_bound": image_bound,
"raw_msg": raw_msg,
}
def preprocess(image, conversation, tokenizer, transform, query_nums=64, slice_config=None):
def conversation_to_ids_minicpm(conversation, tokenizer):
raw_msg = ""
input_ids = []
context = []
for idx, msg in enumerate(conversation):
role = msg["role"]
message = msg["content"]
assert role in ["user", "assistant"]
if role == "user":
prefix = "<用户>"
else:
prefix = "<AI>"
# append eos
if idx == len(conversation) - 1:
message = message + tokenizer.eos_token
prefix_ids = tokenizer.encode(prefix)[1:] # remove bos
message_ids = tokenizer.encode(message)[1:]
input_ids.append(prefix_ids)
input_ids.append(message_ids)
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
if role == "assistant":
context.append(np.zeros((len(message_ids),), dtype=np.int8))
else:
context.append(np.ones((len(message_ids),), dtype=np.int8))
raw_msg += prefix + message
return input_ids, context, raw_msg
def conversation_to_ids_llama3(conversation, tokenizer):
raw_msg = ""
input_ids = []
context = []
raw_msg = tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=False
)
input_ids = tokenizer.apply_chat_template(
conversation, tokenize=True, add_generation_prompt=False
)
input_ids = np.array(input_ids)
start_header_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|start_header_id|>")
)[0]
assistant_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("assistant")
)[0]
end_header_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|end_header_id|>")
)[0]
eot_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|eot_id|>"))[0]
context = np.ones_like(input_ids, dtype=np.int8)
for assistant_idx in assistant_idxs:
if assistant_idx in set((start_header_idxs + end_header_idxs) / 2):
st = assistant_idx + 3 # assistant<|end_header_id|>\n\n
for eot_idx in eot_idxs:
if eot_idx > st:
context[st: eot_idx + 1] = 0
break
input_ids = np.hstack(input_ids)
context = np.hstack(context)
return input_ids, context, raw_msg
def preprocess(
image,
conversation,
tokenizer,
transform,
query_nums=64,
slice_config=None,
llm_type=None,
patch_size=14,
batch_vision=False,
):
"""
single image preprocess, the image will be placed at the top of the conversation
"""
conversation = copy.deepcopy(conversation)
assert len(conversation) > 1, "conversation length must large than 2"
assert conversation[0]['role'] == 'user', "the first role must be user"
assert conversation[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
assert 'patch_size' in slice_config
assert 'max_slice_nums' in slice_config
assert 'scale_resolution' in slice_config
default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
assert "patch_size" in slice_config
assert "max_slice_nums" in slice_config
assert "scale_resolution" in slice_config
default_image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
)
if slice_config:
images = []
source_image, patches, best_grid = slice_image(
image, slice_config['max_slice_nums'], slice_config['scale_resolution'], slice_config['patch_size']
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
@@ -142,30 +259,51 @@ def preprocess(image, conversation, tokenizer, transform, query_nums=64, slice_c
images.append(patches[i][j])
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums
)
tokenizer, best_grid, query_nums)
images = [transform(i) for i in images]
else:
images = [transform(image)]
image_placeholder = default_image_placeholder
if '<image>' in conversation[0]['content']:
conversation[0]['content'] = conversation[0]['content'].replace('<image>', image_placeholder)
if "<image>" in conversation[0]["content"]:
conversation[0]["content"] = conversation[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversation[0]['content'] = image_placeholder + '\n' + conversation[0]['content']
conversation[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
if batch_vision:
tgt_sizes = []
reshape_images = []
for image in images:
H, W = image.shape[1:]
reshape_image = reshape_by_patch(image, patch_size)
reshape_images.append(reshape_image)
tgt_sizes.append([H // patch_size, W // patch_size])
if tgt_sizes:
tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32)
input_dict["pixel_values"] = reshape_images
input_dict["tgt_sizes"] = tgt_sizes
else:
input_dict["pixel_values"] = images
input_dict["tgt_sizes"] = []
input_dict = conversation_to_ids(conversation, tokenizer)
input_dict['pixel_values'] = images
return input_dict
def slice_image(
image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
):
original_size = image.size
original_width, original_height = original_size
log_ratio = math.log(original_width / original_height)
ratio = original_width * original_height / (scale_resolution * scale_resolution)
ratio = original_width * original_height / \
(scale_resolution * scale_resolution)
multiple = min(math.ceil(ratio), max_slice_nums)
source_image = None
@@ -186,7 +324,8 @@ def slice_image(
candidate_split_grids_nums.append(i)
# source image, down-sampling and ensure divided by patch_size
best_resize = find_best_resize(original_size, scale_resolution, patch_size)
best_resize = find_best_resize(
original_size, scale_resolution, patch_size)
source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
candidate_grids = []
@@ -285,6 +424,22 @@ def get_grid_placeholder(tokenizer, grid, query_num):
for j in range(cols):
lines.append(image_placeholder)
slices.append("".join(lines))
slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end
slice_placeholder = tokenizer.slice_start + \
"\n".join(slices) + tokenizer.slice_end
return slice_placeholder
def reshape_by_patch(image_tensor, patch_size):
"""
:param image_tensor: shape [3, H, W]
:param patch_size:
:return: [3, patch_size, HW/patch_size]
"""
patches = torch.nn.functional.unfold(
image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size)
)
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
patches = patches.permute(0, 1, 3, 2).reshape(
image_tensor.size(0), patch_size, -1)
return patches