From e15222b17cb0c48c810a0b014405c1a42f3a080c Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Tue, 30 Dec 2025 09:24:52 +0000 Subject: [PATCH] refine code --- cosyvoice/cli/model.py | 44 ++++++++++++++++++------------------------ 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 694104b..b589bcd 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -103,35 +103,29 @@ class CosyVoiceModel: with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False): if isinstance(text, Generator): assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!' - for i in self.llm.inference_bistream(text=text, + token_generator = self.llm.inference_bistream(text=text, + prompt_text=prompt_text.to(self.device), + prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + embedding=llm_embedding.to(self.device)) + else: + token_generator = self.llm.inference(text=text.to(self.device), + text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), prompt_text=prompt_text.to(self.device), prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), - embedding=llm_embedding.to(self.device)): - if i in self.silent_tokens: - cur_silent_token_num += 1 - if cur_silent_token_num > max_silent_token_num: - continue - else: - cur_silent_token_num = 0 - self.tts_speech_token_dict[uuid].append(i) - else: - for i in self.llm.inference(text=text.to(self.device), - text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), - prompt_text=prompt_text.to(self.device), - prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), - prompt_speech_token=llm_prompt_speech_token.to(self.device), - prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), - embedding=llm_embedding.to(self.device), - uuid=uuid): - if i in self.silent_tokens: - cur_silent_token_num += 1 - if cur_silent_token_num > max_silent_token_num: - continue - else: - cur_silent_token_num = 0 - self.tts_speech_token_dict[uuid].append(i) + embedding=llm_embedding.to(self.device), + uuid=uuid) + for i in token_generator: + if i in self.silent_tokens: + cur_silent_token_num += 1 + if cur_silent_token_num > max_silent_token_num: + continue + else: + cur_silent_token_num = 0 + self.tts_speech_token_dict[uuid].append(i) self.llm_end_dict[uuid] = True def vc_job(self, source_speech_token, uuid):