mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix sequence logic
This commit is contained in:
@@ -311,13 +311,15 @@ class Qwen2LM(TransformerLM):
|
|||||||
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
||||||
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
|
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
|
||||||
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
|
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
|
||||||
|
else:
|
||||||
|
instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
|
||||||
|
instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
|
||||||
for i in range(len(text_token)):
|
for i in range(len(text_token)):
|
||||||
# bistream sequence
|
# bistream sequence
|
||||||
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
||||||
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
|
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
|
||||||
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
|
||||||
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
|
this_lm_input.append(instruct_token_emb[i])
|
||||||
this_lm_input.append(instruct_token_emb[i])
|
|
||||||
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
||||||
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
||||||
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
||||||
|
|||||||
Reference in New Issue
Block a user