fix sequence logic

This commit is contained in:
lyuxiang.lx
2026-01-07 07:06:04 +00:00
parent 652132ebaa
commit f97d50d559

View File

@@ -311,11 +311,13 @@ class Qwen2LM(TransformerLM):
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None: if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True) instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True) instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
else:
instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
for i in range(len(text_token)): for i in range(len(text_token)):
# bistream sequence # bistream sequence
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]: if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)] this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
this_lm_target += [IGNORE_ID] * instruct_token_len[i] this_lm_target += [IGNORE_ID] * instruct_token_len[i]
this_lm_input.append(instruct_token_emb[i]) this_lm_input.append(instruct_token_emb[i])
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()): for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):