fix vocoder train

This commit is contained in:
lyuxiang.lx
2025-03-07 16:39:13 +08:00
parent fcc054f64e
commit a69b7e275d
12 changed files with 108 additions and 17 deletions

View File

@@ -326,7 +326,8 @@ class Qwen2LM(TransformerLM):
# unistream sequence
else:
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i], self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
lm_target.append(this_lm_target)
lm_input.append(this_lm_input)
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)