Refactor dataset.py for improved data handling

Enhanced data fetching and processing logic in dataset.py to handle thinking states and improve error handling.
This commit is contained in:
qianyu chen
2025-09-12 15:55:43 +08:00
committed by GitHub
parent 7233ef5473
commit af22b8f2ed

View File

@@ -69,6 +69,7 @@ class SupervisedDataset(Dataset):
batch_vision=self.batch_vision, batch_vision=self.batch_vision,
max_length=self.max_length max_length=self.max_length
) )
ret = dict( ret = dict(
input_ids=ret["input_ids"], input_ids=ret["input_ids"],
position_ids=ret["position_ids"], position_ids=ret["position_ids"],
@@ -283,20 +284,30 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
chat.append({"role":prefix, "content":message}) chat.append({"role":prefix, "content":message})
raw_msg += prefix + message raw_msg += prefix + message
assert set([i['role'] for i in chat]) & set(['assistant']) assert set([i['role'] for i in chat]) & set(['assistant'])
if '<think>' in chat[-1]['content'] and '</think>' in chat[-1]['content']:
enable_thinking = True
else:
enable_thinking = False
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False, enable_thinking=enable_thinking)
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False) input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False, enable_thinking=enable_thinking)
input_ids = np.array(input_ids) input_ids = np.array(input_ids)
if "<think>\n\n</think>\n\n" in ret:
offset = 4
else:
offset = 0
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0] start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0] assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0] end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
context = np.ones_like(input_ids, dtype=np.int8) context = np.ones_like(input_ids, dtype=np.int8)
for assistant_idx in assistant_idxs: for i, assistant_idx in enumerate(assistant_idxs):
if assistant_idx-1 in set(start_idxs): if assistant_idx-1 in set(start_idxs):
st = assistant_idx + 1 if i == len(assistant_idxs) -1:
st = assistant_idx + 2 + offset
else:
st = assistant_idx + 2
for end_idx in end_idxs: for end_idx in end_idxs:
if end_idx > st: if end_idx > st:
context[st: end_idx + 1] = 0 context[st: end_idx + 1] = 0