diff --git a/finetune/dataset.py b/finetune/dataset.py index 5012ec6..3d46f16 100644 --- a/finetune/dataset.py +++ b/finetune/dataset.py @@ -69,6 +69,7 @@ class SupervisedDataset(Dataset): batch_vision=self.batch_vision, max_length=self.max_length ) + ret = dict( input_ids=ret["input_ids"], position_ids=ret["position_ids"], @@ -80,7 +81,7 @@ class SupervisedDataset(Dataset): ) except: logger.error(f"data fetch error") - return self.__getitem__(random.randint(0, len(self))) + return self.__getitem__(random.randint(0, len(self))) return ret @@ -283,20 +284,30 @@ def conversation_to_ids_qwen2(conversation, tokenizer): chat.append({"role":prefix, "content":message}) raw_msg += prefix + message assert set([i['role'] for i in chat]) & set(['assistant']) + if '' in chat[-1]['content'] and '' in chat[-1]['content']: + enable_thinking = True + else: + enable_thinking = False - ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) - input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False) + ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False, enable_thinking=enable_thinking) + input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False, enable_thinking=enable_thinking) input_ids = np.array(input_ids) - + if "\n\n\n\n" in ret: + offset = 4 + else: + offset = 0 start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0] assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0] end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0] context = np.ones_like(input_ids, dtype=np.int8) - for assistant_idx in assistant_idxs: + for i, assistant_idx in enumerate(assistant_idxs): if assistant_idx-1 in set(start_idxs): - st = assistant_idx + 1 + if i == len(assistant_idxs) -1: + st = assistant_idx + 2 + offset + else: + st = assistant_idx + 2 for end_idx in end_idxs: if end_idx > st: context[st: end_idx + 1] = 0