use spk_embedding when sft

This commit is contained in:
lyuxiang.lx
2024-07-10 17:49:32 +08:00
parent a723ea375e
commit 0fd15bb12b
5 changed files with 8 additions and 2 deletions

View File

@@ -97,7 +97,7 @@ class TransformerLM(torch.nn.Module):
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
embedding = batch['utt_embedding'].to(device)
embedding = batch['embedding'].to(device)
# 1. prepare llm_target
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]