Update to MiniCPM-V 2.6

This commit is contained in:
yiranyyu
2024-08-06 12:26:49 +08:00
parent 1cb882d473
commit b1a15299e6
28 changed files with 3692 additions and 191 deletions

View File

@@ -105,7 +105,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
}
def conversation_to_ids(conversation, tokenizer, llm_type=None):
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
@@ -115,6 +115,10 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
input_ids, context, raw_msg = conversation_to_ids_llama3(
conversation, tokenizer
)
elif llm_type == "qwen2":
input_ids, context, raw_msg = conversation_to_ids_qwen2(
conversation, tokenizer
)
else:
input_ids, context, raw_msg = conversation_to_ids_minicpm(
conversation, tokenizer
@@ -125,6 +129,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
for i in range(1, len(ids)):
if context[i] == 0:
target[i - 1] = ids[i]
@@ -133,14 +138,21 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
target[i - 1] = tokenizer.eot_id
else:
target[i - 1] = tokenizer.eos_id
# build image bound
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if new_schema:
start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id)
end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
else:
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
@@ -230,6 +242,46 @@ def conversation_to_ids_llama3(conversation, tokenizer):
return input_ids, context, raw_msg
def conversation_to_ids_qwen2(conversation, tokenizer):
raw_msg = ""
chat = []
context = []
for idx, msg in enumerate(conversation):
role = msg["role"]
message = msg["content"]
assert role in ["user", "assistant"]
if role == "user":
prefix = "user"
else:
prefix = "assistant"
chat.append({"role":prefix, "content":message})
raw_msg += prefix + message
assert set([i['role'] for i in chat]) & set(['assistant'])
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
input_ids = np.array(input_ids)
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
context = np.ones_like(input_ids, dtype=np.int8)
for assistant_idx in assistant_idxs:
if assistant_idx-1 in set(start_idxs):
st = assistant_idx + 1
for end_idx in end_idxs:
if end_idx > st:
context[st: end_idx + 1] = 0
break
input_ids = np.hstack(input_ids)
context = np.hstack(context)
return input_ids, context, raw_msg
def preprocess(
image,
conversation,
@@ -256,8 +308,14 @@ def preprocess(
default_image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
)
new_schema = False
use_image_id = False
if llm_type=='qwen2':
new_schema = True
use_image_id = True
if slice_config:
images = []
image_id_cnt = 0
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
@@ -270,9 +328,11 @@ def preprocess(
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{idx}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums)
tokenizer, best_grid, query_nums, new_schema = new_schema)
images = [transform(i) for i in images]
else:
images = [transform(image)]
@@ -286,7 +346,7 @@ def preprocess(
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
if batch_vision:
tgt_sizes = []
@@ -424,7 +484,7 @@ def split_to_patches(image, grid):
return patches
def get_grid_placeholder(tokenizer, grid, query_num):
def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
)
@@ -437,7 +497,10 @@ def get_grid_placeholder(tokenizer, grid, query_num):
for j in range(cols):
lines.append(image_placeholder)
slices.append("".join(lines))
slice_placeholder = tokenizer.slice_start + \
if new_schema:
slice_placeholder = '\n'.join(slices)
else:
slice_placeholder = tokenizer.slice_start + \
"\n".join(slices) + tokenizer.slice_end
return slice_placeholder
@@ -455,4 +518,4 @@ def reshape_by_patch(image_tensor, patch_size):
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
patches = patches.permute(0, 1, 3, 2).reshape(
image_tensor.size(0), patch_size, -1)
return patches
return patches