diff --git a/finetune/dataset.py b/finetune/dataset.py
index 5012ec6..3d46f16 100644
--- a/finetune/dataset.py
+++ b/finetune/dataset.py
@@ -69,6 +69,7 @@ class SupervisedDataset(Dataset):
batch_vision=self.batch_vision,
max_length=self.max_length
)
+
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
@@ -80,7 +81,7 @@ class SupervisedDataset(Dataset):
)
except:
logger.error(f"data fetch error")
- return self.__getitem__(random.randint(0, len(self)))
+ return self.__getitem__(random.randint(0, len(self)))
return ret
@@ -283,20 +284,30 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
chat.append({"role":prefix, "content":message})
raw_msg += prefix + message
assert set([i['role'] for i in chat]) & set(['assistant'])
+ if '' in chat[-1]['content'] and '' in chat[-1]['content']:
+ enable_thinking = True
+ else:
+ enable_thinking = False
- ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
- input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
+ ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False, enable_thinking=enable_thinking)
+ input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False, enable_thinking=enable_thinking)
input_ids = np.array(input_ids)
-
+ if "\n\n\n\n" in ret:
+ offset = 4
+ else:
+ offset = 0
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
context = np.ones_like(input_ids, dtype=np.int8)
- for assistant_idx in assistant_idxs:
+ for i, assistant_idx in enumerate(assistant_idxs):
if assistant_idx-1 in set(start_idxs):
- st = assistant_idx + 1
+ if i == len(assistant_idxs) -1:
+ st = assistant_idx + 2 + offset
+ else:
+ st = assistant_idx + 2
for end_idx in end_idxs:
if end_idx > st:
context[st: end_idx + 1] = 0