diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 48babf3..be82f0c 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -67,6 +67,8 @@ class CosyVoice: def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0): prompt_text = self.frontend.text_normalize(prompt_text, split=False) for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): + if len(i) < 0.5 * len(prompt_text): + logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text)) model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) start_time = time.time() logging.info('synthesis text {}'.format(i)) diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 00e4af0..cf9c231 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -202,6 +202,9 @@ class TransformerLM(torch.nn.Module): att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool)) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + # force continue decode first token + if i == 0: + logp[:, self.speech_token_size] = -float('inf') top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() if top_ids == self.speech_token_size: break