update tempo change

This commit is contained in:
lyuxiang.lx
2024-09-18 16:13:40 +08:00
parent c901a12789
commit e19e80fcd8
4 changed files with 23 additions and 31 deletions

View File

@@ -53,43 +53,43 @@ class CosyVoice:
spks = list(self.frontend.spk2info.keys())
return spks
def inference_sft(self, tts_text, spk_id, stream=False):
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
model_input = self.frontend.frontend_sft(i, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
if self.frontend.instruct is True:
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
if self.frontend.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
@@ -97,7 +97,7 @@ class CosyVoice:
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output

View File

@@ -15,6 +15,7 @@ import torch
import numpy as np
import threading
import time
from torch.nn import functional as F
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
@@ -91,7 +92,7 @@ class CosyVoiceModel:
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
tts_mel = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
@@ -116,6 +117,9 @@ class CosyVoiceModel:
self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
return tts_speech
@@ -123,7 +127,7 @@ class CosyVoiceModel:
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
@@ -169,7 +173,8 @@ class CosyVoiceModel:
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)