diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 910fa74..47f8336 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -67,9 +67,9 @@ class CosyVoice: spks = list(self.frontend.spk2info.keys()) return spks - def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id): + def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id): assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id' - model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '') + model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '') del model_input['text'] del model_input['text_len'] self.frontend.spk2info[zero_shot_spk_id] = model_input @@ -89,12 +89,12 @@ class CosyVoice: yield model_output start_time = time.time() - def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True): + def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True): prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend) for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text): logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text)) - model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id) + model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id) start_time = time.time() logging.info('synthesis text {}'.format(i)) for model_output in self.model.tts(**model_input, stream=stream, speed=speed): @@ -103,9 +103,9 @@ class CosyVoice: yield model_output start_time = time.time() - def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True): + def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): - model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id) + model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id) start_time = time.time() logging.info('synthesis text {}'.format(i)) for model_output in self.model.tts(**model_input, stream=stream, speed=speed): @@ -129,8 +129,8 @@ class CosyVoice: yield model_output start_time = time.time() - def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): - model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate) + def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0): + model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate) start_time = time.time() for model_output in self.model.tts(**model_input, stream=stream, speed=speed): speech_len = model_output['tts_speech'].shape[1] / self.sample_rate @@ -181,10 +181,10 @@ class CosyVoice2(CosyVoice): def inference_instruct(self, *args, **kwargs): raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!') - def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True): + def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True): assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!' for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): - model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id) + model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id) start_time = time.time() logging.info('synthesis text {}'.format(i)) for model_output in self.model.tts(**model_input, stream=stream, speed=speed): diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index f98b0d6..d614883 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -32,7 +32,7 @@ except ImportError: from wetext import Normalizer as ZhNormalizer from wetext import Normalizer as EnNormalizer use_ttsfrd = False -from cosyvoice.utils.file_utils import logging +from cosyvoice.utils.file_utils import logging, load_wav from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation @@ -89,7 +89,8 @@ class CosyVoiceFrontEnd: for i in range(text_token.shape[1]): yield text_token[:, i: i + 1] - def _extract_speech_token(self, speech): + def _extract_speech_token(self, prompt_wav): + speech = load_wav(prompt_wav, 16000) assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' feat = whisper.log_mel_spectrogram(speech, n_mels=128) speech_token = self.speech_tokenizer_session.run(None, @@ -101,7 +102,8 @@ class CosyVoiceFrontEnd: speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) return speech_token, speech_token_len - def _extract_spk_embedding(self, speech): + def _extract_spk_embedding(self, prompt_wav): + speech = load_wav(prompt_wav, 16000) feat = kaldi.fbank(speech, num_mel_bins=80, dither=0, @@ -112,7 +114,8 @@ class CosyVoiceFrontEnd: embedding = torch.tensor([embedding]).to(self.device) return embedding - def _extract_speech_feat(self, speech): + def _extract_speech_feat(self, prompt_wav): + speech = load_wav(prompt_wav, 24000) speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device) speech_feat = speech_feat.unsqueeze(dim=0) speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device) @@ -154,19 +157,18 @@ class CosyVoiceFrontEnd: model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding} return model_input - def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id): + def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id): tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) if zero_shot_spk_id == '': prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text) - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) - speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample) - speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) + speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav) + speech_token, speech_token_len = self._extract_speech_token(prompt_wav) if resample_rate == 24000: # cosyvoice2, force speech_feat % speech_token = 2 token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1]) speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len - embedding = self._extract_spk_embedding(prompt_speech_16k) + embedding = self._extract_spk_embedding(prompt_wav) model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len, 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len, 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, @@ -178,8 +180,8 @@ class CosyVoiceFrontEnd: model_input['text_len'] = tts_text_token_len return model_input - def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id): - model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id) + def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id): + model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id) # in cross lingual mode, we remove prompt in llm del model_input['prompt_text'] del model_input['prompt_text_len'] @@ -196,17 +198,16 @@ class CosyVoiceFrontEnd: model_input['prompt_text_len'] = instruct_text_token_len return model_input - def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id): - model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id) + def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id): + model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_wav, resample_rate, zero_shot_spk_id) del model_input['llm_prompt_speech_token'] del model_input['llm_prompt_speech_token_len'] return model_input - def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate): - prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k) - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) - prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample) - embedding = self._extract_spk_embedding(prompt_speech_16k) + def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate): + prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav) + prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav) + embedding = self._extract_spk_embedding(prompt_wav) source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k) model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len, 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len, diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index a92f8e7..374a90e 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -41,11 +41,11 @@ def read_json_lists(list_file): return results -def load_wav(wav, target_sr): +def load_wav(wav, target_sr, min_sr=16000): speech, sample_rate = torchaudio.load(wav, backend='soundfile') speech = speech.mean(dim=0, keepdim=True) if sample_rate != target_sr: - assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) + assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) return speech