add instruct

This commit is contained in:
lyuxiang.lx
2025-12-11 09:43:25 +00:00
parent 3298d6f3e3
commit ebef63066f
5 changed files with 36 additions and 3 deletions

View File

@@ -242,6 +242,10 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if 'instruct' in sample:
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
else:
sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
yield sample
@@ -390,6 +394,9 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
@@ -403,6 +410,8 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
"instruct_token": instruct_token,
"instruct_token_len": instruct_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}

View File

@@ -674,6 +674,9 @@ class CosyVoice3LM(Qwen2LM):
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# NOTE should append instruct_token to sequence, not implemented yet
instruct_token = batch['instruct_token'].to(device)
instruct_token_len = batch['instruct_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)