From 02f941d34885bdb08c4cbcbb4bb8e2cecad3d430 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 24 Jul 2024 19:18:09 +0800 Subject: [PATCH] update model inference --- cosyvoice/cli/cosyvoice.py | 8 +-- cosyvoice/cli/model.py | 96 ++++++++++++++++++++---------- cosyvoice/flow/length_regulator.py | 2 +- cosyvoice/llm/llm.py | 4 +- webui.py | 39 +++++------- 5 files changed, 85 insertions(+), 64 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 54c08c0..e2601eb 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -46,9 +46,9 @@ class CosyVoice: return spks def inference_sft(self, tts_text, spk_id, stream=False): - start_time = time.time() for i in self.frontend.text_normalize(tts_text, split=True): model_input = self.frontend.frontend_sft(i, spk_id) + start_time = time.time() for model_output in self.model.inference(**model_input, stream=stream): speech_len = model_output['tts_speech'].shape[1] / 22050 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) @@ -56,10 +56,10 @@ class CosyVoice: start_time = time.time() def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False): - start_time = time.time() prompt_text = self.frontend.text_normalize(prompt_text, split=False) for i in self.frontend.text_normalize(tts_text, split=True): model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) + start_time = time.time() for model_output in self.model.inference(**model_input, stream=stream): speech_len = model_output['tts_speech'].shape[1] / 22050 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) @@ -69,9 +69,9 @@ class CosyVoice: def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False): if self.frontend.instruct is True: raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir)) - start_time = time.time() for i in self.frontend.text_normalize(tts_text, split=True): model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k) + start_time = time.time() for model_output in self.model.inference(**model_input, stream=stream): speech_len = model_output['tts_speech'].shape[1] / 22050 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) @@ -81,10 +81,10 @@ class CosyVoice: def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False): if self.frontend.instruct is False: raise ValueError('{} do not support instruct inference'.format(self.model_dir)) - start_time = time.time() instruct_text = self.frontend.text_normalize(instruct_text, split=False) for i in self.frontend.text_normalize(tts_text, split=True): model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) + start_time = time.time() for model_output in self.model.inference(**model_input, stream=stream): speech_len = model_output['tts_speech'].shape[1] / 22050 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 7a0fdaf..7fb61ed 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -13,6 +13,9 @@ # limitations under the License. import torch import numpy as np +import threading +import time +from contextlib import nullcontext class CosyVoiceModel: @@ -25,10 +28,13 @@ class CosyVoiceModel: self.llm = llm self.flow = flow self.hift = hift - self.stream_win_len = 60 - self.stream_hop_len = 50 - self.overlap = 4395 # 10 token equals 4395 sample point + self.stream_win_len = 60 * 4 + self.stream_hop_len = 50 * 4 + self.overlap = 4395 * 4 # 10 token equals 4395 sample point self.window = np.hamming(2 * self.overlap) + self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.lock = threading.Lock() def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) @@ -38,13 +44,8 @@ class CosyVoiceModel: self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) self.hift.to(self.device).eval() - def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192), - prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32), - llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), - flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), - prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False): - if stream is True: - tts_speech_token, cache_speech = [], None + def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding): + with self.llm_context: for i in self.llm.inference(text=text.to(self.device), text_len=text_len.to(self.device), prompt_text=prompt_text.to(self.device), @@ -56,10 +57,56 @@ class CosyVoiceModel: sampling=25, max_token_text_ratio=30, min_token_text_ratio=3, - stream=stream): - tts_speech_token.append(i) - if len(tts_speech_token) == self.stream_win_len: - this_tts_speech_token = torch.concat(tts_speech_token, dim=1) + stream=True): + self.tts_speech_token.append(i) + self.llm_end = True + + def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding): + with self.flow_hift_context: + tts_mel = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.size(1)], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=prompt_token_len.to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=prompt_feat_len.to(self.device), + embedding=embedding.to(self.device)) + tts_speech = self.hift.inference(mel=tts_mel).cpu() + return tts_speech + + def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192), + prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32), + llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), + flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), + prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False): + if stream is True: + self.tts_speech_token, self.llm_end, cache_speech = [], False, None + p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device), + llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device))) + p.start() + while True: + time.sleep(0.1) + if len(self.tts_speech_token) >= self.stream_win_len: + this_tts_speech_token = torch.concat(self.tts_speech_token[:self.stream_win_len], dim=1) + with self.flow_hift_context: + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token.to(self.device), + prompt_token_len=flow_prompt_speech_token_len.to(self.device), + prompt_feat=prompt_speech_feat.to(self.device), + prompt_feat_len=prompt_speech_feat_len.to(self.device), + embedding=flow_embedding.to(self.device)) + # fade in/out if necessary + if cache_speech is not None: + this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:] + yield {'tts_speech': this_tts_speech[:, :-self.overlap]} + cache_speech = this_tts_speech[:, -self.overlap:] + with self.lock: + self.tts_speech_token = self.tts_speech_token[self.stream_hop_len:] + if self.llm_end is True: + break + # deal with remain tokens + if cache_speech is None or len(self.tts_speech_token) > self.stream_win_len - self.stream_hop_len: + this_tts_speech_token = torch.concat(self.tts_speech_token, dim=1) + with self.flow_hift_context: this_tts_mel = self.flow.inference(token=this_tts_speech_token, token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device), prompt_token=flow_prompt_speech_token.to(self.device), @@ -68,29 +115,14 @@ class CosyVoiceModel: prompt_feat_len=prompt_speech_feat_len.to(self.device), embedding=flow_embedding.to(self.device)) this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu() - # fade in/out if necessary - if cache_speech is not None: - this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:] - yield {'tts_speech': this_tts_speech[:, :-self.overlap]} - cache_speech = this_tts_speech[:, -self.overlap:] - tts_speech_token = tts_speech_token[-(self.stream_win_len - self.stream_hop_len):] - # deal with remain tokens - if cache_speech is None or len(tts_speech_token) > self.stream_win_len - self.stream_hop_len: - this_tts_speech_token = torch.concat(tts_speech_token, dim=1) - this_tts_mel = self.flow.inference(token=this_tts_speech_token, - token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device), - prompt_token=flow_prompt_speech_token.to(self.device), - prompt_token_len=flow_prompt_speech_token_len.to(self.device), - prompt_feat=prompt_speech_feat.to(self.device), - prompt_feat_len=prompt_speech_feat_len.to(self.device), - embedding=flow_embedding.to(self.device)) - this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu() if cache_speech is not None: this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:] yield {'tts_speech': this_tts_speech} else: - assert len(tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len) + assert len(self.tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len) yield {'tts_speech': cache_speech} + p.join() + torch.cuda.synchronize() else: tts_speech_token = [] for i in self.llm.inference(text=text.to(self.device), diff --git a/cosyvoice/flow/length_regulator.py b/cosyvoice/flow/length_regulator.py index 622f29a..5d4348e 100755 --- a/cosyvoice/flow/length_regulator.py +++ b/cosyvoice/flow/length_regulator.py @@ -43,7 +43,7 @@ class InterpolateRegulator(nn.Module): def forward(self, x, ylens=None): # x in (B, T, D) mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) - x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') + x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') out = self.model(x).transpose(1, 2).contiguous() olens = ylens return out * mask, olens diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 38b7b4c..704a49e 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -174,7 +174,7 @@ class TransformerLM(torch.nn.Module): embedding = self.spk_embed_affine_layer(embedding) embedding = embedding.unsqueeze(dim=1) else: - embedding = torch.zeros(1, 0, self.llm_input_size).to(device) + embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) # 3. concat llm_input sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) @@ -182,7 +182,7 @@ class TransformerLM(torch.nn.Module): if prompt_speech_token_len != 0: prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: - prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device) + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) # 4. cal min/max_length diff --git a/webui.py b/webui.py index ce90e26..be74f04 100644 --- a/webui.py +++ b/webui.py @@ -24,14 +24,8 @@ import torchaudio import random import librosa -import logging -logging.getLogger('matplotlib').setLevel(logging.WARNING) - from cosyvoice.cli.cosyvoice import CosyVoice -from cosyvoice.utils.file_utils import load_wav, speed_change - -logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(levelname)s %(message)s') +from cosyvoice.utils.file_utils import load_wav, speed_change, logging def generate_seed(): seed = random.randint(1, 100000000) @@ -63,10 +57,11 @@ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成 '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮', '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮', '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'} +stream_mode_list = [('否', False), ('是', True)] def change_instruction(mode_checkbox_group): return instruct_dict[mode_checkbox_group] -def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor): +def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor): if prompt_wav_upload is not None: prompt_wav = prompt_wav_upload elif prompt_wav_record is not None: @@ -117,32 +112,25 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro if mode_checkbox_group == '预训练音色': logging.info('get sft inference request') set_all_random_seed(seed) - output = cosyvoice.inference_sft(tts_text, sft_dropdown) + for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream): + yield (target_sr, i['tts_speech'].numpy().flatten()) elif mode_checkbox_group == '3s极速复刻': logging.info('get zero_shot inference request') prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) set_all_random_seed(seed) - output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k) + for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream): + yield (target_sr, i['tts_speech'].numpy().flatten()) elif mode_checkbox_group == '跨语种复刻': logging.info('get cross_lingual inference request') prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) set_all_random_seed(seed) - output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k) + for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream): + yield (target_sr, i['tts_speech'].numpy().flatten()) else: logging.info('get instruct inference request') set_all_random_seed(seed) - output = cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text) - - if speed_factor != 1.0: - try: - audio_data, sample_rate = speed_change(output["tts_speech"], target_sr, str(speed_factor)) - audio_data = audio_data.numpy().flatten() - except Exception as e: - print(f"Failed to change speed of audio: \n{e}") - else: - audio_data = output['tts_speech'].numpy().flatten() - - return (target_sr, audio_data) + for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream): + yield (target_sr, i['tts_speech'].numpy().flatten()) def main(): with gr.Blocks() as demo: @@ -155,6 +143,7 @@ def main(): mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0]) instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5) sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25) + stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1]) with gr.Column(scale=0.25): seed_button = gr.Button(value="\U0001F3B2") seed = gr.Number(value=0, label="随机推理种子") @@ -167,11 +156,11 @@ def main(): generate_button = gr.Button("生成音频") - audio_output = gr.Audio(label="合成音频") + audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True) seed_button.click(generate_seed, inputs=[], outputs=seed) generate_button.click(generate_audio, - inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor], + inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor], outputs=[audio_output]) mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text]) demo.queue(max_size=4, default_concurrency_limit=2)