mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
fix vocoder train
This commit is contained in:
@@ -326,7 +326,8 @@ class Qwen2LM(TransformerLM):
|
||||
# unistream sequence
|
||||
else:
|
||||
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
|
||||
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i], self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
|
||||
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
|
||||
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
|
||||
lm_target.append(this_lm_target)
|
||||
lm_input.append(this_lm_input)
|
||||
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user