update lora finetune inference bug (#224)

This commit is contained in:
qianyu chen
2024-06-07 18:00:22 +08:00
committed by GitHub
parent 31eaa26ee1
commit 9bd93a281c
6 changed files with 61 additions and 17 deletions

View File

@@ -13,6 +13,7 @@ from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
@@ -194,10 +195,10 @@ def conversation_to_ids_llama3(conversation, tokenizer):
input_ids = []
context = []
raw_msg = tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=False
conversation, tokenize=False, add_generation_prompt=False, chat_template=llama3_chat_template,
)
input_ids = tokenizer.apply_chat_template(
conversation, tokenize=True, add_generation_prompt=False
conversation, tokenize=True, add_generation_prompt=False, chat_template=llama3_chat_template,
)
input_ids = np.array(input_ids)