This commit is contained in:
lyuxiang.lx
2025-08-22 14:42:34 +08:00
parent 6b5eef62cc
commit f76f5abcc1

View File

@@ -127,7 +127,7 @@ class TransformerLM(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding) embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(1) embedding = embedding.unsqueeze(1)
# 3. eos and task_id # 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1) sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
@@ -300,7 +300,7 @@ class Qwen2LM(TransformerLM):
self.stop_token_ids = [speech_token_size + i for i in range(3)] self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {} self.vllm_output_queue = {}
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len): def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
lm_target, lm_input = [], [] lm_target, lm_input = [], []
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
@@ -311,7 +311,7 @@ class Qwen2LM(TransformerLM):
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]: if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
this_lm_target, this_lm_input = [], [] this_lm_target, this_lm_input = [], []
this_lm_target.append(IGNORE_ID) this_lm_target.append(IGNORE_ID)
this_lm_input.append(self.llm_embedding.weight[self.sos].reshape(1, -1)) this_lm_input.append(sos_emb.squeeze(dim=0))
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()): for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist() this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist() this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -327,14 +327,13 @@ class Qwen2LM(TransformerLM):
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist() this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
this_lm_target.append(self.eos_token) this_lm_target.append(self.eos_token)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:]) this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1)) this_lm_input.append(task_id_emb.squeeze(dim=0))
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:]) this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0) this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
# unistream sequence # unistream sequence
else: else:
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token]) this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos].reshape(1, -1), text_token_emb[i], this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
lm_target.append(this_lm_target) lm_target.append(this_lm_target)
lm_input.append(this_lm_input) lm_input.append(this_lm_input)
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
@@ -362,11 +361,15 @@ class Qwen2LM(TransformerLM):
# 1. encode text_token # 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token) text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token # 2. encode speech_token
speech_token_emb = self.speech_embedding(speech_token) speech_token_emb = self.speech_embedding(speech_token)
# 3. prepare llm_input/target # 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len) lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len)
lm_target = lm_target.to(device) lm_target = lm_target.to(device)
# 4. run lm forward # 4. run lm forward
@@ -391,6 +394,10 @@ class Qwen2LM(TransformerLM):
# 1. encode text_token # 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token) text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token # 2. encode speech_token
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True) reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
@@ -400,8 +407,8 @@ class Qwen2LM(TransformerLM):
speech_token_combined_emb = self.speech_embedding(speech_token_combined) speech_token_combined_emb = self.speech_embedding(speech_token_combined)
# 3. prepare llm_input/target # 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2), lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
speech_token_combined, speech_token_combined_emb, speech_token_combined_len) task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
lm_target = lm_target.to(device) lm_target = lm_target.to(device)
# 4. run lm forward # 4. run lm forward
@@ -650,6 +657,43 @@ class CosyVoice3LM(Qwen2LM):
self.stop_token_ids = [speech_token_size + i for i in range(4)] self.stop_token_ids = [speech_token_size + i for i in range(4)]
self.vllm_output_queue = {} self.vllm_output_queue = {}
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 3. sos and task_id
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token_emb = self.speech_embedding(speech_token)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len)
lm_target = lm_target.to(device)
# 4. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target.to(device))
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
@torch.inference_mode() @torch.inference_mode()
def inference( def inference(