mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 02:09:20 +08:00
Refactor dataset.py for improved data handling
Enhanced data fetching and processing logic in dataset.py to handle thinking states and improve error handling.
This commit is contained in:
@@ -69,6 +69,7 @@ class SupervisedDataset(Dataset):
|
|||||||
batch_vision=self.batch_vision,
|
batch_vision=self.batch_vision,
|
||||||
max_length=self.max_length
|
max_length=self.max_length
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = dict(
|
ret = dict(
|
||||||
input_ids=ret["input_ids"],
|
input_ids=ret["input_ids"],
|
||||||
position_ids=ret["position_ids"],
|
position_ids=ret["position_ids"],
|
||||||
@@ -283,20 +284,30 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
|
|||||||
chat.append({"role":prefix, "content":message})
|
chat.append({"role":prefix, "content":message})
|
||||||
raw_msg += prefix + message
|
raw_msg += prefix + message
|
||||||
assert set([i['role'] for i in chat]) & set(['assistant'])
|
assert set([i['role'] for i in chat]) & set(['assistant'])
|
||||||
|
if '<think>' in chat[-1]['content'] and '</think>' in chat[-1]['content']:
|
||||||
|
enable_thinking = True
|
||||||
|
else:
|
||||||
|
enable_thinking = False
|
||||||
|
|
||||||
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
|
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False, enable_thinking=enable_thinking)
|
||||||
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
|
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False, enable_thinking=enable_thinking)
|
||||||
input_ids = np.array(input_ids)
|
input_ids = np.array(input_ids)
|
||||||
|
if "<think>\n\n</think>\n\n" in ret:
|
||||||
|
offset = 4
|
||||||
|
else:
|
||||||
|
offset = 0
|
||||||
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
|
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
|
||||||
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
|
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
|
||||||
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
|
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
|
||||||
|
|
||||||
context = np.ones_like(input_ids, dtype=np.int8)
|
context = np.ones_like(input_ids, dtype=np.int8)
|
||||||
|
|
||||||
for assistant_idx in assistant_idxs:
|
for i, assistant_idx in enumerate(assistant_idxs):
|
||||||
if assistant_idx-1 in set(start_idxs):
|
if assistant_idx-1 in set(start_idxs):
|
||||||
st = assistant_idx + 1
|
if i == len(assistant_idxs) -1:
|
||||||
|
st = assistant_idx + 2 + offset
|
||||||
|
else:
|
||||||
|
st = assistant_idx + 2
|
||||||
for end_idx in end_idxs:
|
for end_idx in end_idxs:
|
||||||
if end_idx > st:
|
if end_idx > st:
|
||||||
context[st: end_idx + 1] = 0
|
context[st: end_idx + 1] = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user