mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
Refactor dataset.py for improved data handling
Enhanced data fetching and processing logic in dataset.py to handle thinking states and improve error handling.
This commit is contained in:
@@ -69,6 +69,7 @@ class SupervisedDataset(Dataset):
|
||||
batch_vision=self.batch_vision,
|
||||
max_length=self.max_length
|
||||
)
|
||||
|
||||
ret = dict(
|
||||
input_ids=ret["input_ids"],
|
||||
position_ids=ret["position_ids"],
|
||||
@@ -80,7 +81,7 @@ class SupervisedDataset(Dataset):
|
||||
)
|
||||
except:
|
||||
logger.error(f"data fetch error")
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
return ret
|
||||
|
||||
|
||||
@@ -283,20 +284,30 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
|
||||
chat.append({"role":prefix, "content":message})
|
||||
raw_msg += prefix + message
|
||||
assert set([i['role'] for i in chat]) & set(['assistant'])
|
||||
if '<think>' in chat[-1]['content'] and '</think>' in chat[-1]['content']:
|
||||
enable_thinking = True
|
||||
else:
|
||||
enable_thinking = False
|
||||
|
||||
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
|
||||
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
|
||||
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False, enable_thinking=enable_thinking)
|
||||
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False, enable_thinking=enable_thinking)
|
||||
input_ids = np.array(input_ids)
|
||||
|
||||
if "<think>\n\n</think>\n\n" in ret:
|
||||
offset = 4
|
||||
else:
|
||||
offset = 0
|
||||
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
|
||||
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
|
||||
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
|
||||
|
||||
context = np.ones_like(input_ids, dtype=np.int8)
|
||||
|
||||
for assistant_idx in assistant_idxs:
|
||||
for i, assistant_idx in enumerate(assistant_idxs):
|
||||
if assistant_idx-1 in set(start_idxs):
|
||||
st = assistant_idx + 1
|
||||
if i == len(assistant_idxs) -1:
|
||||
st = assistant_idx + 2 + offset
|
||||
else:
|
||||
st = assistant_idx + 2
|
||||
for end_idx in end_idxs:
|
||||
if end_idx > st:
|
||||
context[st: end_idx + 1] = 0
|
||||
|
||||
Reference in New Issue
Block a user