update model inference

This commit is contained in:
lyuxiang.lx
2024-07-24 19:18:09 +08:00
parent a13411c561
commit 02f941d348
5 changed files with 85 additions and 64 deletions

View File

@@ -46,9 +46,9 @@ class CosyVoice:
return spks
def inference_sft(self, tts_text, spk_id, stream=False):
start_time = time.time()
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_sft(i, spk_id)
start_time = time.time()
for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -56,10 +56,10 @@ class CosyVoice:
start_time = time.time()
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
start_time = time.time()
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
start_time = time.time()
for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -69,9 +69,9 @@ class CosyVoice:
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
if self.frontend.instruct is True:
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
start_time = time.time()
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
start_time = time.time()
for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -81,10 +81,10 @@ class CosyVoice:
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
if self.frontend.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
start_time = time.time()
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time()
for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))