From 1d881df8b205670eec7a270ca9c681ba045a159b Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 29 Aug 2024 19:10:08 +0800 Subject: [PATCH] fix vocoder speech overlap --- cosyvoice/cli/model.py | 145 ++++++++++++++++++--------------- cosyvoice/hifigan/generator.py | 12 ++- cosyvoice/utils/common.py | 10 ++- 3 files changed, 93 insertions(+), 74 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 664effe..1184f0d 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -31,18 +31,25 @@ class CosyVoiceModel: self.flow = flow self.hift = hift self.token_min_hop_len = 100 - self.token_max_hop_len = 400 + self.token_max_hop_len = 200 self.token_overlap_len = 20 - self.speech_overlap_len = 34 * 256 - self.window = np.hamming(2 * self.speech_overlap_len) + # mel fade in out + self.mel_overlap_len = 34 + self.mel_window = np.hamming(2 * self.mel_overlap_len) + # hift cache + self.mel_cache_len = 20 + self.source_cache_len = int(self.mel_cache_len * 256) + # rtf and decoding related self.stream_scale_factor = 1 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.lock = threading.Lock() # dict used to store session related variable - self.tts_speech_token = {} - self.llm_end = {} + self.tts_speech_token_dict = {} + self.llm_end_dict = {} + self.mel_overlap_dict = {} + self.hift_cache_dict = {} def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) @@ -64,102 +71,108 @@ class CosyVoiceModel: self.flow.decoder.estimator = xxx self.flow.decoder.session = xxx - def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid): + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: for i in self.llm.inference(text=text.to(self.device), - text_len=text_len.to(self.device), + text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), prompt_text=prompt_text.to(self.device), - prompt_text_len=prompt_text_len.to(self.device), + prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device), - prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device), + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), embedding=llm_embedding.to(self.device).half(), sampling=25, max_token_text_ratio=30, min_token_text_ratio=3): - self.tts_speech_token[this_uuid].append(i) - self.llm_end[this_uuid] = True + self.tts_speech_token_dict[uuid].append(i) + self.llm_end_dict[uuid] = True - def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding): + def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False): with self.flow_hift_context: tts_mel = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.size(1)], dtype=torch.int32).to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), prompt_token=prompt_token.to(self.device), - prompt_token_len=prompt_token_len.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=prompt_feat_len.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), embedding=embedding.to(self.device)) - tts_speech = self.hift.inference(mel=tts_mel).cpu() + # mel overlap fade in out + if self.mel_overlap_dict[uuid] is not None: + tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) + # append hift cache + if self.hift_cache_dict[uuid] is not None: + hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) + # keep overlap mel and hift cache + if finalize is False: + self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] + tts_mel = tts_mel[:, :, :-self.mel_overlap_len] + tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) + self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]} + tts_speech = tts_speech[:, :-self.source_cache_len] + else: + tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) return tts_speech - def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192), - prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32), - llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), - flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), - prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False): + def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), + prompt_text=torch.zeros(1, 0, dtype=torch.int32), + llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), + flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), + prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) with self.lock: - self.tts_speech_token[this_uuid], self.llm_end[this_uuid] = [], False - p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device), - llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device), this_uuid)) + self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None + p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p.start() + p.join() if stream is True: - cache_speech, cache_token, token_hop_len = None, None, self.token_min_hop_len + token_hop_len = self.token_min_hop_len while True: time.sleep(0.1) - if len(self.tts_speech_token[this_uuid]) >= token_hop_len + self.token_overlap_len: - this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid][:token_hop_len + self.token_overlap_len], dim=1) + if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: + this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1) with self.flow_hift_context: this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token.to(self.device), - prompt_token_len=flow_prompt_speech_token_len.to(self.device), - prompt_feat=prompt_speech_feat.to(self.device), - prompt_feat_len=prompt_speech_feat_len.to(self.device), - embedding=flow_embedding.to(self.device)) - # fade in/out if necessary - if cache_speech is not None: - this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window) - yield {'tts_speech': this_tts_speech[:, :-self.speech_overlap_len]} - cache_speech = this_tts_speech[:, -self.speech_overlap_len:] - cache_token = self.tts_speech_token[this_uuid][:token_hop_len] + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=False) + yield {'tts_speech': this_tts_speech.cpu()} with self.lock: - self.tts_speech_token[this_uuid] = self.tts_speech_token[this_uuid][token_hop_len:] + self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:] # increase token_hop_len for better speech quality token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) - if self.llm_end[this_uuid] is True and len(self.tts_speech_token[this_uuid]) < token_hop_len + self.token_overlap_len: + if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: break - p.join() + # p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None - this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1) - if this_tts_speech_token.shape[1] < self.token_min_hop_len + self.token_overlap_len and cache_token is not None: - cache_token_len = self.token_min_hop_len + self.token_overlap_len - this_tts_speech_token.shape[1] - this_tts_speech_token = torch.concat([torch.concat(cache_token[-cache_token_len:], dim=1), this_tts_speech_token], dim=1) - else: - cache_token_len = 0 + this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1) with self.flow_hift_context: this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token.to(self.device), - prompt_token_len=flow_prompt_speech_token_len.to(self.device), - prompt_feat=prompt_speech_feat.to(self.device), - prompt_feat_len=prompt_speech_feat_len.to(self.device), - embedding=flow_embedding.to(self.device)) - this_tts_speech = this_tts_speech[:, int(cache_token_len / this_tts_speech_token.shape[1] * this_tts_speech.shape[1]):] - if cache_speech is not None: - this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window) - yield {'tts_speech': this_tts_speech} + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True) + yield {'tts_speech': this_tts_speech.cpu()} else: # deal with all tokens - p.join() - this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1) + # p.join() + this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1) with self.flow_hift_context: this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token.to(self.device), - prompt_token_len=flow_prompt_speech_token_len.to(self.device), - prompt_feat=prompt_speech_feat.to(self.device), - prompt_feat_len=prompt_speech_feat_len.to(self.device), - embedding=flow_embedding.to(self.device)) - yield {'tts_speech': this_tts_speech} + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True) + yield {'tts_speech': this_tts_speech.cpu()} with self.lock: - self.tts_speech_token.pop(this_uuid) - self.llm_end.pop(this_uuid) + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) + self.mel_overlap_dict.pop(this_uuid) + self.hift_cache_dict.pop(this_uuid) torch.cuda.synchronize() diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index a45419b..fd61834 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module): inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) return inverse_transform - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: f0 = self.f0_predictor(x) s = self._f02source(f0) + # use cache_source to avoid glitch + if cache_source.shape[2] == 0: + s[:, :, :cache_source.shape[2]] = cache_source + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) @@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module): x = self._istft(magnitude, phase) x = torch.clamp(x, -self.audio_limit, self.audio_limit) - return x + return x, s def remove_weight_norm(self): print('Removing weight norm...') @@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module): l.remove_weight_norm() @torch.inference_mode() - def inference(self, mel: torch.Tensor) -> torch.Tensor: - return self.forward(x=mel) + def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + return self.forward(x=mel, cache_source=cache_source) diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 51be904..07e1f92 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -131,7 +131,9 @@ def random_sampling(weighted_scores, decoded_tokens, sampling): top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) return top_ids -def fade_in_out(fade_in_speech, fade_out_speech, window): - speech_overlap_len = int(window.shape[0] / 2) - fade_in_speech[:, :speech_overlap_len] = fade_in_speech[:, :speech_overlap_len] * window[:speech_overlap_len] + fade_out_speech[:, -speech_overlap_len:] * window[speech_overlap_len:] - return fade_in_speech +def fade_in_out(fade_in_mel, fade_out_mel, window): + device = fade_in_mel.device + fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() + mel_overlap_len = int(window.shape[0] / 2) + fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:] + return fade_in_mel.to(device)