diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 85ad6fb..cba0c57 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -127,7 +127,7 @@ class TransformerLM(torch.nn.Module): embedding = self.spk_embed_affine_layer(embedding) embedding = embedding.unsqueeze(1) - # 3. eos and task_id + # 3. sos and task_id sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) @@ -300,7 +300,7 @@ class Qwen2LM(TransformerLM): self.stop_token_ids = [speech_token_size + i for i in range(3)] self.vllm_output_queue = {} - def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len): + def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len): lm_target, lm_input = [], [] text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) @@ -311,7 +311,7 @@ class Qwen2LM(TransformerLM): if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]: this_lm_target, this_lm_input = [], [] this_lm_target.append(IGNORE_ID) - this_lm_input.append(self.llm_embedding.weight[self.sos].reshape(1, -1)) + this_lm_input.append(sos_emb.squeeze(dim=0)) for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()): this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist() this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist() @@ -327,14 +327,13 @@ class Qwen2LM(TransformerLM): this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist() this_lm_target.append(self.eos_token) this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:]) - this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1)) + this_lm_input.append(task_id_emb.squeeze(dim=0)) this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:]) this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0) # unistream sequence else: this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token]) - this_lm_input = torch.concat([self.llm_embedding.weight[self.sos].reshape(1, -1), text_token_emb[i], - self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0) + this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0) lm_target.append(this_lm_target) lm_input.append(this_lm_input) lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) @@ -362,11 +361,15 @@ class Qwen2LM(TransformerLM): # 1. encode text_token text_token_emb = self.llm.model.model.embed_tokens(text_token) + # 3. sos and task_id + sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + # 2. encode speech_token speech_token_emb = self.speech_embedding(speech_token) # 3. prepare llm_input/target - lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len) + lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len) lm_target = lm_target.to(device) # 4. run lm forward @@ -391,6 +394,10 @@ class Qwen2LM(TransformerLM): # 1. encode text_token text_token_emb = self.llm.model.model.embed_tokens(text_token) + # 3. sos and task_id + sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + # 2. encode speech_token speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True) @@ -400,8 +407,8 @@ class Qwen2LM(TransformerLM): speech_token_combined_emb = self.speech_embedding(speech_token_combined) # 3. prepare llm_input/target - lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2), - speech_token_combined, speech_token_combined_emb, speech_token_combined_len) + lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2), + task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len) lm_target = lm_target.to(device) # 4. run lm forward @@ -650,6 +657,43 @@ class CosyVoice3LM(Qwen2LM): self.stop_token_ids = [speech_token_size + i for i in range(4)] self.vllm_output_queue = {} + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ + Args: + text: (B, L, D) + text_lengths: (B,) + audio: (B, T, N) or (B, T) + audio_lengths: (B,) + """ + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + + # 1. encode text_token + text_token_emb = self.llm.model.model.embed_tokens(text_token) + + # 3. sos and task_id + sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1) + task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1) + + # 2. encode speech_token + speech_token_emb = self.speech_embedding(speech_token) + + # 3. prepare llm_input/target + lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len) + lm_target = lm_target.to(device) + + # 4. run lm forward + lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) + logits = self.llm_decoder(lm_output) + loss = self.criterion_ce(logits, lm_target.to(device)) + acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID) + return {'loss': loss, 'acc': acc} @torch.inference_mode() def inference(