mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Update to MiniCPM-V 2.6
This commit is contained in:
@@ -105,7 +105,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
|
||||
}
|
||||
|
||||
|
||||
def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
|
||||
"""
|
||||
for single image multi-turn conversation
|
||||
conversation: [{'role': 'user', 'content': 'Describe this image'},
|
||||
@@ -115,6 +115,10 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
input_ids, context, raw_msg = conversation_to_ids_llama3(
|
||||
conversation, tokenizer
|
||||
)
|
||||
elif llm_type == "qwen2":
|
||||
input_ids, context, raw_msg = conversation_to_ids_qwen2(
|
||||
conversation, tokenizer
|
||||
)
|
||||
else:
|
||||
input_ids, context, raw_msg = conversation_to_ids_minicpm(
|
||||
conversation, tokenizer
|
||||
@@ -125,6 +129,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
|
||||
# build target
|
||||
target = torch.full_like(ids, -100, dtype=torch.int32)
|
||||
|
||||
for i in range(1, len(ids)):
|
||||
if context[i] == 0:
|
||||
target[i - 1] = ids[i]
|
||||
@@ -133,14 +138,21 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
target[i - 1] = tokenizer.eot_id
|
||||
else:
|
||||
target[i - 1] = tokenizer.eos_id
|
||||
|
||||
|
||||
# build image bound
|
||||
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
||||
if new_schema:
|
||||
start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id)
|
||||
end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id)
|
||||
image_start_tokens = torch.where(start_cond)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(end_cond)[0]
|
||||
else:
|
||||
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
||||
if len(image_start_tokens) != len(image_end_tokens):
|
||||
print("image start token != image end tokens")
|
||||
|
||||
|
||||
if len(image_start_tokens) > 0:
|
||||
image_bound = torch.hstack(
|
||||
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
|
||||
@@ -230,6 +242,46 @@ def conversation_to_ids_llama3(conversation, tokenizer):
|
||||
return input_ids, context, raw_msg
|
||||
|
||||
|
||||
def conversation_to_ids_qwen2(conversation, tokenizer):
|
||||
raw_msg = ""
|
||||
chat = []
|
||||
context = []
|
||||
for idx, msg in enumerate(conversation):
|
||||
role = msg["role"]
|
||||
message = msg["content"]
|
||||
assert role in ["user", "assistant"]
|
||||
if role == "user":
|
||||
prefix = "user"
|
||||
else:
|
||||
prefix = "assistant"
|
||||
chat.append({"role":prefix, "content":message})
|
||||
raw_msg += prefix + message
|
||||
assert set([i['role'] for i in chat]) & set(['assistant'])
|
||||
|
||||
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
|
||||
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
|
||||
input_ids = np.array(input_ids)
|
||||
|
||||
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
|
||||
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
|
||||
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
|
||||
|
||||
context = np.ones_like(input_ids, dtype=np.int8)
|
||||
|
||||
for assistant_idx in assistant_idxs:
|
||||
if assistant_idx-1 in set(start_idxs):
|
||||
st = assistant_idx + 1
|
||||
for end_idx in end_idxs:
|
||||
if end_idx > st:
|
||||
context[st: end_idx + 1] = 0
|
||||
break
|
||||
|
||||
input_ids = np.hstack(input_ids)
|
||||
context = np.hstack(context)
|
||||
return input_ids, context, raw_msg
|
||||
|
||||
|
||||
|
||||
def preprocess(
|
||||
image,
|
||||
conversation,
|
||||
@@ -256,8 +308,14 @@ def preprocess(
|
||||
default_image_placeholder = (
|
||||
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
|
||||
)
|
||||
new_schema = False
|
||||
use_image_id = False
|
||||
if llm_type=='qwen2':
|
||||
new_schema = True
|
||||
use_image_id = True
|
||||
if slice_config:
|
||||
images = []
|
||||
image_id_cnt = 0
|
||||
source_image, patches, best_grid = slice_image(
|
||||
image,
|
||||
slice_config["max_slice_nums"],
|
||||
@@ -270,9 +328,11 @@ def preprocess(
|
||||
for i in range(len(patches)):
|
||||
for j in range(len(patches[0])):
|
||||
images.append(patches[i][j])
|
||||
|
||||
if use_image_id:
|
||||
image_placeholder = f'{tokenizer.im_id_start}{idx}{tokenizer.im_id_end}' + image_placeholder
|
||||
image_id_cnt += 1
|
||||
image_placeholder += get_grid_placeholder(
|
||||
tokenizer, best_grid, query_nums)
|
||||
tokenizer, best_grid, query_nums, new_schema = new_schema)
|
||||
images = [transform(i) for i in images]
|
||||
else:
|
||||
images = [transform(image)]
|
||||
@@ -286,7 +346,7 @@ def preprocess(
|
||||
image_placeholder + "\n" + conversation[0]["content"]
|
||||
)
|
||||
|
||||
input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
|
||||
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
|
||||
|
||||
if batch_vision:
|
||||
tgt_sizes = []
|
||||
@@ -424,7 +484,7 @@ def split_to_patches(image, grid):
|
||||
return patches
|
||||
|
||||
|
||||
def get_grid_placeholder(tokenizer, grid, query_num):
|
||||
def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
|
||||
image_placeholder = (
|
||||
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
|
||||
)
|
||||
@@ -437,7 +497,10 @@ def get_grid_placeholder(tokenizer, grid, query_num):
|
||||
for j in range(cols):
|
||||
lines.append(image_placeholder)
|
||||
slices.append("".join(lines))
|
||||
slice_placeholder = tokenizer.slice_start + \
|
||||
if new_schema:
|
||||
slice_placeholder = '\n'.join(slices)
|
||||
else:
|
||||
slice_placeholder = tokenizer.slice_start + \
|
||||
"\n".join(slices) + tokenizer.slice_end
|
||||
return slice_placeholder
|
||||
|
||||
@@ -455,4 +518,4 @@ def reshape_by_patch(image_tensor, patch_size):
|
||||
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
|
||||
patches = patches.permute(0, 1, 3, 2).reshape(
|
||||
image_tensor.size(0), patch_size, -1)
|
||||
return patches
|
||||
return patches
|
||||
Reference in New Issue
Block a user