fix cosyvoice3 training

This commit is contained in:
lyuxiang.lx
2025-12-29 10:02:59 +00:00
parent 8524c81acd
commit 4d7295a9a7
5 changed files with 25 additions and 23 deletions

View File

@@ -301,18 +301,23 @@ class Qwen2LM(TransformerLM):
self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {}
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
lm_target, lm_input = [], []
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
# NOTE add instruct_token in CosyVoice3
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
for i in range(len(text_token)):
# bistream sequence
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
this_lm_target, this_lm_input = [], []
this_lm_target.append(IGNORE_ID)
this_lm_input.append(sos_emb.squeeze(dim=0))
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
this_lm_input.append(instruct_token_emb[i])
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -333,8 +338,8 @@ class Qwen2LM(TransformerLM):
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
# unistream sequence
else:
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
lm_target.append(this_lm_target)
lm_input.append(this_lm_input)
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
@@ -681,6 +686,7 @@ class CosyVoice3LM(Qwen2LM):
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
# 3. sos and task_id
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
@@ -691,7 +697,7 @@ class CosyVoice3LM(Qwen2LM):
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
speech_token, speech_token_emb, speech_token_len)
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
lm_target = lm_target.to(device)
# 4. run lm forward