diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index ea8c448..54c08c0 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import torch +import time from hyperpyyaml import load_hyperpyyaml from modelscope import snapshot_download from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.model import CosyVoiceModel +from cosyvoice.utils.file_utils import logging class CosyVoice: @@ -44,40 +45,48 @@ class CosyVoice: spks = list(self.frontend.spk2info.keys()) return spks - def inference_sft(self, tts_text, spk_id): - tts_speeches = [] + def inference_sft(self, tts_text, spk_id, stream=False): + start_time = time.time() for i in self.frontend.text_normalize(tts_text, split=True): model_input = self.frontend.frontend_sft(i, spk_id) - model_output = self.model.inference(**model_input) - tts_speeches.append(model_output['tts_speech']) - return {'tts_speech': torch.concat(tts_speeches, dim=1)} + for model_output in self.model.inference(**model_input, stream=stream): + speech_len = model_output['tts_speech'].shape[1] / 22050 + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() - def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k): + def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False): + start_time = time.time() prompt_text = self.frontend.text_normalize(prompt_text, split=False) - tts_speeches = [] for i in self.frontend.text_normalize(tts_text, split=True): model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) - model_output = self.model.inference(**model_input) - tts_speeches.append(model_output['tts_speech']) - return {'tts_speech': torch.concat(tts_speeches, dim=1)} + for model_output in self.model.inference(**model_input, stream=stream): + speech_len = model_output['tts_speech'].shape[1] / 22050 + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() - def inference_cross_lingual(self, tts_text, prompt_speech_16k): + def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False): if self.frontend.instruct is True: raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir)) - tts_speeches = [] + start_time = time.time() for i in self.frontend.text_normalize(tts_text, split=True): model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k) - model_output = self.model.inference(**model_input) - tts_speeches.append(model_output['tts_speech']) - return {'tts_speech': torch.concat(tts_speeches, dim=1)} + for model_output in self.model.inference(**model_input, stream=stream): + speech_len = model_output['tts_speech'].shape[1] / 22050 + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() - def inference_instruct(self, tts_text, spk_id, instruct_text): + def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False): if self.frontend.instruct is False: raise ValueError('{} do not support instruct inference'.format(self.model_dir)) + start_time = time.time() instruct_text = self.frontend.text_normalize(instruct_text, split=False) - tts_speeches = [] for i in self.frontend.text_normalize(tts_text, split=True): model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) - model_output = self.model.inference(**model_input) - tts_speeches.append(model_output['tts_speech']) - return {'tts_speech': torch.concat(tts_speeches, dim=1)} + for model_output in self.model.inference(**model_input, stream=stream): + speech_len = model_output['tts_speech'].shape[1] / 22050 + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index f4625e3..7a0fdaf 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +import numpy as np + class CosyVoiceModel: @@ -23,6 +25,10 @@ class CosyVoiceModel: self.llm = llm self.flow = flow self.hift = hift + self.stream_win_len = 60 + self.stream_hop_len = 50 + self.overlap = 4395 # 10 token equals 4395 sample point + self.window = np.hamming(2 * self.overlap) def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) @@ -36,25 +42,79 @@ class CosyVoiceModel: prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32), llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), - prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)): - tts_speech_token = self.llm.inference(text=text.to(self.device), - text_len=text_len.to(self.device), - prompt_text=prompt_text.to(self.device), - prompt_text_len=prompt_text_len.to(self.device), - prompt_speech_token=llm_prompt_speech_token.to(self.device), - prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device), - embedding=llm_embedding.to(self.device), - beam_size=1, - sampling=25, - max_token_text_ratio=30, - min_token_text_ratio=3) - tts_mel = self.flow.inference(token=tts_speech_token, - token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device), - prompt_token=flow_prompt_speech_token.to(self.device), - prompt_token_len=flow_prompt_speech_token_len.to(self.device), - prompt_feat=prompt_speech_feat.to(self.device), - prompt_feat_len=prompt_speech_feat_len.to(self.device), - embedding=flow_embedding.to(self.device)) - tts_speech = self.hift.inference(mel=tts_mel).cpu() - torch.cuda.empty_cache() - return {'tts_speech': tts_speech} + prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False): + if stream is True: + tts_speech_token, cache_speech = [], None + for i in self.llm.inference(text=text.to(self.device), + text_len=text_len.to(self.device), + prompt_text=prompt_text.to(self.device), + prompt_text_len=prompt_text_len.to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device), + embedding=llm_embedding.to(self.device), + beam_size=1, + sampling=25, + max_token_text_ratio=30, + min_token_text_ratio=3, + stream=stream): + tts_speech_token.append(i) + if len(tts_speech_token) == self.stream_win_len: + this_tts_speech_token = torch.concat(tts_speech_token, dim=1) + this_tts_mel = self.flow.inference(token=this_tts_speech_token, + token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device), + prompt_token=flow_prompt_speech_token.to(self.device), + prompt_token_len=flow_prompt_speech_token_len.to(self.device), + prompt_feat=prompt_speech_feat.to(self.device), + prompt_feat_len=prompt_speech_feat_len.to(self.device), + embedding=flow_embedding.to(self.device)) + this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu() + # fade in/out if necessary + if cache_speech is not None: + this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:] + yield {'tts_speech': this_tts_speech[:, :-self.overlap]} + cache_speech = this_tts_speech[:, -self.overlap:] + tts_speech_token = tts_speech_token[-(self.stream_win_len - self.stream_hop_len):] + # deal with remain tokens + if cache_speech is None or len(tts_speech_token) > self.stream_win_len - self.stream_hop_len: + this_tts_speech_token = torch.concat(tts_speech_token, dim=1) + this_tts_mel = self.flow.inference(token=this_tts_speech_token, + token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device), + prompt_token=flow_prompt_speech_token.to(self.device), + prompt_token_len=flow_prompt_speech_token_len.to(self.device), + prompt_feat=prompt_speech_feat.to(self.device), + prompt_feat_len=prompt_speech_feat_len.to(self.device), + embedding=flow_embedding.to(self.device)) + this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu() + if cache_speech is not None: + this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:] + yield {'tts_speech': this_tts_speech} + else: + assert len(tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len) + yield {'tts_speech': cache_speech} + else: + tts_speech_token = [] + for i in self.llm.inference(text=text.to(self.device), + text_len=text_len.to(self.device), + prompt_text=prompt_text.to(self.device), + prompt_text_len=prompt_text_len.to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device), + embedding=llm_embedding.to(self.device), + beam_size=1, + sampling=25, + max_token_text_ratio=30, + min_token_text_ratio=3, + stream=stream): + tts_speech_token.append(i) + assert len(tts_speech_token) == 1, 'tts_speech_token len should be 1 when stream is {}'.format(stream) + tts_speech_token = torch.concat(tts_speech_token, dim=1) + tts_mel = self.flow.inference(token=tts_speech_token, + token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device), + prompt_token=flow_prompt_speech_token.to(self.device), + prompt_token_len=flow_prompt_speech_token_len.to(self.device), + prompt_feat=prompt_speech_feat.to(self.device), + prompt_feat_len=prompt_speech_feat_len.to(self.device), + embedding=flow_embedding.to(self.device)) + tts_speech = self.hift.inference(mel=tts_mel).cpu() + torch.cuda.empty_cache() + yield {'tts_speech': tts_speech} diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 3b418c5..38b7b4c 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -158,6 +158,7 @@ class TransformerLM(torch.nn.Module): sampling: int = 25, max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, + stream: bool = False, ) -> torch.Tensor: device = text.device text = torch.concat([prompt_text, text], dim=1) @@ -199,8 +200,13 @@ class TransformerLM(torch.nn.Module): top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item() if top_ids == self.speech_token_size: break + # in stream mode, yield token one by one + if stream is True: + yield torch.tensor([[top_ids]], dtype=torch.int64, device=device) out_tokens.append(top_ids) offset += lm_input.size(1) lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) - return torch.tensor([out_tokens], dtype=torch.int64, device=device) + # in non-stream mode, yield all token + if stream is False: + yield torch.tensor([out_tokens], dtype=torch.int64, device=device) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index d4179e1..40e7b20 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -15,6 +15,10 @@ import json import torchaudio +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') def read_lists(list_file):