mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
Merge pull request #715 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
This commit is contained in:
@@ -18,7 +18,7 @@ from hyperpyyaml import load_hyperpyyaml
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
||||
from cosyvoice.cli.model import CosyVoiceModel
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ class CosyVoice:
|
||||
'{}/spk2info.pt'.format(model_dir),
|
||||
instruct,
|
||||
configs['allowed_special'])
|
||||
self.sample_rate = configs['sample_rate']
|
||||
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
|
||||
load_jit = False
|
||||
fp16 = False
|
||||
@@ -64,7 +65,7 @@ class CosyVoice:
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
@@ -74,11 +75,11 @@ class CosyVoice:
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
||||
if len(i) < 0.5 * len(prompt_text):
|
||||
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
@@ -87,11 +88,11 @@ class CosyVoice:
|
||||
if self.frontend.instruct is True:
|
||||
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
@@ -105,16 +106,52 @@ class CosyVoice:
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
||||
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
|
||||
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
||||
start_time = time.time()
|
||||
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
|
||||
instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
if not os.path.exists(model_dir):
|
||||
model_dir = snapshot_download(model_dir)
|
||||
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'Qwen2-0.5B-CosyVoice-BlankEN')})
|
||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||
configs['feat_extractor'],
|
||||
'{}/campplus.onnx'.format(model_dir),
|
||||
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
||||
'{}/spk2info.pt'.format(model_dir),
|
||||
instruct,
|
||||
configs['allowed_special'])
|
||||
self.sample_rate = configs['sample_rate']
|
||||
if torch.cuda.is_available() is False and load_jit is True:
|
||||
load_jit = False
|
||||
logging.warning('cpu do not support jit, force set to False')
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
|
||||
if load_trt is True and load_onnx is True:
|
||||
load_onnx = False
|
||||
logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
|
||||
if load_onnx:
|
||||
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
|
||||
del configs
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
import json
|
||||
import onnxruntime
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -66,9 +67,7 @@ class CosyVoiceFrontEnd:
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
||||
'failed to initialize ttsfrd resource'
|
||||
self.frd.set_lang_type('pinyin')
|
||||
self.frd.enable_pinyin_mix(True)
|
||||
self.frd.set_breakmodel_index(1)
|
||||
self.frd.set_lang_type('pinyinvg')
|
||||
else:
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
@@ -112,26 +111,28 @@ class CosyVoiceFrontEnd:
|
||||
text = text.strip()
|
||||
if contains_chinese(text):
|
||||
if self.use_ttsfrd:
|
||||
text = self.frd.get_frd_extra_info(text, 'input')
|
||||
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
||||
text = ''.join(texts)
|
||||
else:
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = text.replace("\n", "")
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
text = text.replace(".", "。")
|
||||
text = text.replace(" - ", ",")
|
||||
text = remove_bracket(text)
|
||||
text = re.sub(r'[,,、]+$', '。', text)
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
text = text.replace("\n", "")
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
text = text.replace(".", "。")
|
||||
text = text.replace(" - ", ",")
|
||||
text = remove_bracket(text)
|
||||
text = re.sub(r'[,,、]+$', '。', text)
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
else:
|
||||
if self.use_ttsfrd:
|
||||
text = self.frd.get_frd_extra_info(text, 'input')
|
||||
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
||||
text = ''.join(texts)
|
||||
else:
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
if split is False:
|
||||
return text
|
||||
return texts
|
||||
@@ -142,11 +143,11 @@ class CosyVoiceFrontEnd:
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||
return model_input
|
||||
|
||||
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
||||
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
||||
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
||||
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
||||
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
||||
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
||||
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
||||
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
||||
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
||||
@@ -157,8 +158,8 @@ class CosyVoiceFrontEnd:
|
||||
'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||
return model_input
|
||||
|
||||
def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
|
||||
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
|
||||
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
||||
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
||||
# in cross lingual mode, we remove prompt in llm
|
||||
del model_input['prompt_text']
|
||||
del model_input['prompt_text_len']
|
||||
@@ -175,10 +176,10 @@ class CosyVoiceFrontEnd:
|
||||
model_input['prompt_text_len'] = instruct_text_token_len
|
||||
return model_input
|
||||
|
||||
def frontend_vc(self, source_speech_16k, prompt_speech_16k):
|
||||
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
||||
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
||||
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
||||
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
||||
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
||||
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
||||
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
||||
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
||||
|
||||
@@ -57,15 +57,15 @@ class CosyVoiceModel:
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||
self.llm.to(self.device).eval()
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=False)
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
||||
@@ -255,3 +255,168 @@ class CosyVoiceModel:
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
|
||||
|
||||
class CosyVoice2Model:
|
||||
|
||||
def __init__(self,
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.token_hop_len = 2 * self.flow.input_frame_rate
|
||||
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
||||
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
||||
self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
|
||||
# hift cache
|
||||
self.mel_cache_len = 8
|
||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||
# speech fade in out
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
# rtf and decoding related
|
||||
self.stream_scale_factor = 1
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||
self.llm.to(self.device).eval()
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
self.flow.decoder.fp16 = False
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_onnx(self, flow_decoder_estimator_model):
|
||||
import onnxruntime
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||
del self.flow.decoder.estimator
|
||||
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model):
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
||||
self.flow.decoder.fp16 = True
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
with self.llm_context:
|
||||
for i in self.llm.inference(text=text.to(self.device),
|
||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
finalize=finalize)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||
# append hift cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||
else:
|
||||
hift_cache_source = torch.zeros(1, 1, 0)
|
||||
# keep overlap mel and hift cache
|
||||
if finalize is False:
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
return tts_speech
|
||||
|
||||
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
||||
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_offset = 0
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]) \
|
||||
.unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
token_offset=token_offset,
|
||||
finalize=False)
|
||||
token_offset += self.token_hop_len
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
token_offset=token_offset,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
token_offset=0,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
|
||||
Reference in New Issue
Block a user