From dc96e4c98430ce7ac4068df8e54d0f115dcdbfd2 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 21 Aug 2025 21:03:58 +0800 Subject: [PATCH] update --- cosyvoice/cli/cosyvoice.py | 2 +- cosyvoice/llm/llm.py | 78 +++++++++++++++++----------------- cosyvoice/utils/class_utils.py | 8 ++-- 3 files changed, 45 insertions(+), 43 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 7731863..910fa74 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -207,7 +207,7 @@ class CosyVoice3(CosyVoice): raise ValueError('{} not found!'.format(hyper_yaml_path)) with open(hyper_yaml_path, 'r') as f: configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) - assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir) + assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir) self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], configs['feat_extractor'], '{}/campplus.onnx'.format(model_dir), diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index cc2f2da..85ad6fb 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -56,8 +56,9 @@ class TransformerLM(torch.nn.Module): ) # 2. build speech token language model related modules - self.sos_eos = 0 + self.sos = 0 self.task_id = 1 + self.eos_token = self.speech_token_size self.llm_embedding = torch.nn.Embedding(2, llm_input_size) self.llm = llm self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1) @@ -85,10 +86,10 @@ class TransformerLM(torch.nn.Module): encoder_out = self.text_encoder_affine_layer(encoder_out) return encoder_out, encoder_out_lens - def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len): + def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len): text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) - lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) + lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))] lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) @@ -127,14 +128,14 @@ class TransformerLM(torch.nn.Module): embedding = embedding.unsqueeze(1) # 3. eos and task_id - sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) # 4. encode speech_token speech_token = self.speech_embedding(speech_token) # 5. unpad and pad - lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, + lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len) # 6. run lm forward @@ -193,13 +194,13 @@ class TransformerLM(torch.nn.Module): embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) # 3. concat llm_input - sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) if prompt_speech_token_len != 0: prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) - lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) + lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) # 4. cal min/max_length min_len = int((text_len - prompt_text_len) * min_token_text_ratio) @@ -215,11 +216,8 @@ class TransformerLM(torch.nn.Module): att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool)) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) - # force continue decode first token - if i == 0: - logp[:, self.speech_token_size] = -float('inf') top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() - if top_ids == self.speech_token_size: + if top_ids == self.eos_token: break # in stream mode, yield token one by one yield top_ids @@ -276,9 +274,10 @@ class Qwen2LM(TransformerLM): self.llm_output_size = llm_output_size self.speech_token_size = speech_token_size # 2. build speech token language model related modules - self.sos_eos = 0 + self.sos = 0 self.task_id = 1 - self.fill_token = 2 + self.eos_token = speech_token_size + self.fill_token = speech_token_size + 2 self.llm_embedding = torch.nn.Embedding(2, llm_input_size) self.llm = llm @@ -312,7 +311,7 @@ class Qwen2LM(TransformerLM): if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]: this_lm_target, this_lm_input = [], [] this_lm_target.append(IGNORE_ID) - this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1)) + this_lm_input.append(self.llm_embedding.weight[self.sos].reshape(1, -1)) for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()): this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist() this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist() @@ -320,21 +319,21 @@ class Qwen2LM(TransformerLM): assert len(this_speech_token) == self.mix_ratio[1] this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1) this_lm_target += this_speech_token - this_lm_target.append(self.speech_token_size + 2) + this_lm_target.append(self.fill_token) this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]]) this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]]) else: this_lm_target += [-1] * len(this_text_token) this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist() - this_lm_target.append(self.speech_token_size) + this_lm_target.append(self.eos_token) this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:]) this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1)) this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:]) this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0) # unistream sequence else: - this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size]) - this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i], + this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token]) + this_lm_input = torch.concat([self.llm_embedding.weight[self.sos].reshape(1, -1), text_token_emb[i], self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0) lm_target.append(this_lm_target) lm_input.append(this_lm_input) @@ -445,13 +444,13 @@ class Qwen2LM(TransformerLM): text = self.llm.model.model.embed_tokens(text) # 3. concat llm_input - sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) if prompt_speech_token_len != 0: prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) - lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) + lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) # 4. cal min/max_length min_len = int((text_len - prompt_text_len) * min_token_text_ratio) @@ -501,10 +500,8 @@ class Qwen2LM(TransformerLM): cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() - if top_ids == self.speech_token_size: + if top_ids in self.stop_token_ids: break - if top_ids > self.speech_token_size: - continue # in stream mode, yield token one by one yield top_ids out_tokens.append(top_ids) @@ -526,13 +523,13 @@ class Qwen2LM(TransformerLM): device = prompt_text.device # 1. prepare input - sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) if prompt_speech_token_len != 0: prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) - lm_input = torch.concat([sos_eos_emb], dim=1) + lm_input = torch.concat([sos_emb], dim=1) # 2. iterate text out_tokens = [] @@ -554,12 +551,12 @@ class Qwen2LM(TransformerLM): break # no prompt_speech_token_emb remain, can decode some speech token if prompt_speech_token_emb.size(1) == 0: - if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1): + if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1): logging.info('get fill token, need to append more text token') if text_cache.size(1) >= self.mix_ratio[0]: lm_input_text = text_cache[:, :self.mix_ratio[0]] logging.info('append {} text token'.format(lm_input_text.size(1))) - if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2: + if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token: lm_input = lm_input_text else: lm_input = torch.concat([lm_input, lm_input_text], dim=1) @@ -574,16 +571,16 @@ class Qwen2LM(TransformerLM): cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) if next_fill_index != -1 and len(out_tokens) == next_fill_index: - top_ids = self.speech_token_size + 2 + top_ids = self.fill_token next_fill_index += (self.mix_ratio[1] + 1) else: top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item() - if top_ids == self.speech_token_size + 2: + if top_ids == self.fill_token: next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1 logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index)) out_tokens.append(top_ids) if top_ids >= self.speech_token_size: - if top_ids == self.speech_token_size + 2: + if top_ids == self.fill_token: break else: raise ValueError('should not get token {}'.format(top_ids)) @@ -602,7 +599,7 @@ class Qwen2LM(TransformerLM): top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item() out_tokens.append(top_ids) if top_ids >= self.speech_token_size: - if top_ids == self.speech_token_size: + if top_ids == self.eos_token: break else: raise ValueError('should not get token {}'.format(top_ids)) @@ -628,10 +625,10 @@ class CosyVoice3LM(Qwen2LM): self.llm_output_size = llm_output_size self.speech_token_size = speech_token_size # 2. build speech token language model related modules - self.sos = 0 - self.eos = 1 - self.task_id = 2 - self.fill_token = 3 + self.sos = speech_token_size + 0 + self.eos_token = speech_token_size + 1 + self.task_id = speech_token_size + 2 + self.fill_token = speech_token_size + 3 self.llm = llm self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False) @@ -649,6 +646,11 @@ class CosyVoice3LM(Qwen2LM): self.sampling = sampling self.mix_ratio = mix_ratio + # 5. vllm related + self.stop_token_ids = [speech_token_size + i for i in range(4)] + self.vllm_output_queue = {} + + @torch.inference_mode() def inference( self, @@ -670,13 +672,13 @@ class CosyVoice3LM(Qwen2LM): text = self.llm.model.model.embed_tokens(text) # 3. concat llm_input - sos_eos_emb = self.speech_embedding.weight[self.speech_token_size + self.sos].reshape(1, 1, -1) - task_id_emb = self.speech_embedding.weight[self.speech_token_size + self.task_id].reshape(1, 1, -1) + sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1) + task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1) if prompt_speech_token_len != 0: prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) - lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) + lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) # 4. cal min/max_length min_len = int((text_len - prompt_text_len) * min_token_text_ratio) diff --git a/cosyvoice/utils/class_utils.py b/cosyvoice/utils/class_utils.py index c52fec4..aab8326 100644 --- a/cosyvoice/utils/class_utils.py +++ b/cosyvoice/utils/class_utils.py @@ -32,10 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention, RelPositionMultiHeadedAttention) from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling -from cosyvoice.llm.llm import TransformerLM, Qwen2LM +from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator -from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model +from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model COSYVOICE_ACTIVATION_CLASSES = { @@ -80,6 +80,6 @@ def get_model_type(configs): return CosyVoiceModel if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): return CosyVoice2Model - if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator): - return CosyVoice2Model + if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator): + return CosyVoice3Model raise TypeError('No valid model type found!')