From 54e9384fb11886eab713aa0e90e033fa2a095f85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E8=81=AA?= Date: Wed, 26 Feb 2025 20:25:14 +0800 Subject: [PATCH] update export_codec_vllm --- cosyvoice/cli/model.py | 4 ++++ cosyvoice/llm/llm.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 5374e7a..b5ea3af 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -331,6 +331,7 @@ class CosyVoice2Model(CosyVoiceModel): pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to dtype = torch.bfloat16 + # lm_head new_lm_head = nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True) with torch.no_grad(): new_lm_head.weight[:vocab_size] = self.llm.llm_decoder.weight @@ -339,6 +340,8 @@ class CosyVoice2Model(CosyVoiceModel): new_lm_head.bias[vocab_size:] = 0 self.llm.llm.model.lm_head = new_lm_head new_codec_embed = nn.Linear(in_features=feature_size, out_features=pad_vocab_size) + # embed_tokens + embed_tokens = self.llm.llm.model.model.embed_tokens with torch.no_grad(): new_codec_embed.weight[:vocab_size] = self.llm.speech_embedding.weight new_codec_embed.weight[vocab_size:] = 0 @@ -356,6 +359,7 @@ class CosyVoice2Model(CosyVoiceModel): self.llm.llm.model.save_pretrained(model_path) self.llm.llm.model.config.vocab_size = tmp_vocab_size self.llm.llm.model.config.tie_word_embeddings = tmp_tie_embedding + self.llm.llm.model.set_input_embeddings(embed_tokens) def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): tts_mel, _ = self.flow.inference(token=token.to(self.device), diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 1b12acf..331881f 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -353,12 +353,14 @@ class Qwen2LM(TransformerLM): if str(request_output.request_id) != str(request_id): continue if not request_output.finished: - print(f"Partial request output: {request_output}") + # print(f"Partial request output: {request_output}") out_token = list(request_output.outputs[0].token_ids)[-1] yield out_token out_token_ids.append(out_token) else: break + if not vllm_codec_engine.has_unfinished_requests(): + break @torch.inference_mode() def inference_bistream(