From 37e48dd3184d19707372e8a8775e0f8a3a3ef0ac Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Tue, 15 Apr 2025 16:15:20 +0800 Subject: [PATCH] optimize vc code --- cosyvoice/cli/cosyvoice.py | 2 +- cosyvoice/cli/model.py | 77 ++++++++------------------------------ 2 files changed, 17 insertions(+), 62 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 93db014..efebe4d 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -128,7 +128,7 @@ class CosyVoice: def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate) start_time = time.time() - for model_output in self.model.vc(**model_input, stream=stream, speed=speed): + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): speech_len = model_output['tts_speech'].shape[1] / self.sample_rate logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 677f486..a14cfbe 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -122,6 +122,10 @@ class CosyVoiceModel: self.tts_speech_token_dict[uuid].append(i) self.llm_end_dict[uuid] = True + def vc_job(self, source_speech_token, uuid): + self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist() + self.llm_end_dict[uuid] = True + def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): with torch.cuda.amp.autocast(self.fp16): tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device), @@ -162,11 +166,11 @@ class CosyVoiceModel: tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) return tts_speech - def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), + def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192), prompt_text=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), - prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): + prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) with self.lock: @@ -174,7 +178,10 @@ class CosyVoiceModel: self.hift_cache_dict[this_uuid] = None self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) - p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) + if source_speech_token.shape[1] == 0: + p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) + else: + p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid)) p.start() if stream is True: token_hop_len = self.token_min_hop_len @@ -226,61 +233,6 @@ class CosyVoiceModel: self.flow_cache_dict.pop(this_uuid) torch.cuda.empty_cache() - def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs): - # this_uuid is used to track variables related to this inference thread - this_uuid = str(uuid.uuid1()) - with self.lock: - self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True - self.hift_cache_dict[this_uuid] = None - self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) - self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) - if stream is True: - token_hop_len = self.token_min_hop_len - while True: - if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ - .unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - finalize=False) - yield {'tts_speech': this_tts_speech.cpu()} - with self.lock: - self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:] - # increase token_hop_len for better speech quality - token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: - break - # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - finalize=True) - yield {'tts_speech': this_tts_speech.cpu()} - else: - # deal with all tokens - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - finalize=True, - speed=speed) - yield {'tts_speech': this_tts_speech.cpu()} - with self.lock: - self.tts_speech_token_dict.pop(this_uuid) - self.llm_end_dict.pop(this_uuid) - self.mel_overlap_dict.pop(this_uuid) - self.hift_cache_dict.pop(this_uuid) - self.flow_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() - class CosyVoice2Model(CosyVoiceModel): @@ -386,18 +338,21 @@ class CosyVoice2Model(CosyVoiceModel): tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) return tts_speech - def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), + def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192), prompt_text=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), - prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): + prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None self.flow_cache_dict[this_uuid] = self.init_flow_cache() - p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) + if source_speech_token.shape[1] == 0: + p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) + else: + p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid)) p.start() if stream is True: assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"