Refactor CosyVoice inference methods to streamline CUDA stream management

- Removed the queue-based stream pool and integrated direct CUDA stream usage for improved performance.
- Simplified inference methods by eliminating unnecessary synchronization and stream management code.
- Enhanced logging for better tracking of synthesis operations and performance metrics.
- Updated the model class to support CUDA stream context management, ensuring efficient resource utilization during inference.
This commit is contained in:
禾息
2025-04-16 14:15:14 +08:00
parent fd9b7d45e2
commit 7f4c9a2c64
2 changed files with 226 additions and 237 deletions

View File

@@ -22,7 +22,7 @@ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
import queue
class CosyVoice:
@@ -60,11 +60,6 @@ class CosyVoice:
self.fp16, self.estimator_count)
del configs
thread_count = 10
self.stream_pool = queue.Queue(maxsize=thread_count)
for _ in range(thread_count):
self.stream_pool.put(torch.cuda.Stream(self.device))
def list_available_spks(self):
spks = list(self.frontend.spk2info.keys())
@@ -74,104 +69,80 @@ class CosyVoice:
self.frontend.add_spk_info(spk_id, spk_info)
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
cuda_stream = self.stream_pool.get()
with torch.cuda.stream(cuda_stream):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_sft(i, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
cuda_stream.synchronize()
self.stream_pool.put(cuda_stream)
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
cuda_stream = self.stream_pool.get()
with torch.cuda.stream(cuda_stream):
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
cuda_stream.synchronize()
self.stream_pool.put(cuda_stream)
def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
"""使用预定义的说话人执行 zero_shot 推理"""
cuda_stream = self.stream_pool.get()
with torch.cuda.stream(cuda_stream):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id)
start_time = time.time()
last_time = start_time
chunk_index = 0
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech index:{}, len {:.2f}, rtf {:.3f}, cost {:.3f}s, all cost time {:.3f}s'.format(
chunk_index, speech_len, (time.time()-last_time)/speech_len, time.time()-last_time, time.time()-start_time))
yield model_output
last_time = time.time()
chunk_index += 1
cuda_stream.synchronize()
self.stream_pool.put(cuda_stream)
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
cuda_stream = self.stream_pool.get()
with torch.cuda.stream(cuda_stream):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
cuda_stream.synchronize()
self.stream_pool.put(cuda_stream)
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
cuda_stream = self.stream_pool.get()
with torch.cuda.stream(cuda_stream):
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
if self.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
cuda_stream.synchronize()
self.stream_pool.put(cuda_stream)
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
cuda_stream = self.stream_pool.get()
with torch.cuda.stream(cuda_stream):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_sft(i, spk_id)
start_time = time.time()
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
cuda_stream.synchronize()
self.stream_pool.put(cuda_stream)
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
"""使用预定义的说话人执行 zero_shot 推理"""
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id)
start_time = time.time()
last_time = start_time
chunk_index = 0
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech index:{}, len {:.2f}, rtf {:.3f}, cost {:.3f}s, all cost time {:.3f}s'.format(
chunk_index, speech_len, (time.time()-last_time)/speech_len, time.time()-last_time, time.time()-start_time))
yield model_output
last_time = time.time()
chunk_index += 1
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
if self.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
start_time = time.time()
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
class CosyVoice2(CosyVoice):
@@ -215,42 +186,29 @@ class CosyVoice2(CosyVoice):
self.fp16, self.estimator_count)
del configs
thread_count = 10
self.stream_pool = queue.Queue(maxsize=thread_count)
for _ in range(thread_count):
self.stream_pool.put(torch.cuda.Stream(self.device))
def inference_instruct(self, *args, **kwargs):
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
cuda_stream = self.stream_pool.get()
with torch.cuda.stream(cuda_stream):
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
cuda_stream.synchronize()
self.stream_pool.put(cuda_stream)
def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True):
cuda_stream = self.stream_pool.get()
with torch.cuda.stream(cuda_stream):
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
cuda_stream.synchronize()
self.stream_pool.put(cuda_stream)

View File

@@ -23,6 +23,7 @@ import uuid
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt
from cosyvoice.flow.flow_matching import EstimatorWrapper
import queue
class CosyVoiceModel:
@@ -66,6 +67,12 @@ class CosyVoiceModel:
self.flow_cache_dict = {}
self.hift_cache_dict = {}
self.stream_context_pool = queue.Queue()
for _ in range(10):
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.is_cuda_available = torch.cuda.is_available()
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
self.llm.to(self.device).eval()
@@ -166,63 +173,70 @@ class CosyVoiceModel:
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
.unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=False)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better speech quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
p.join()
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
stream_context = self.stream_context_pool.get()
with stream_context:
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
.unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=False)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better speech quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
p.join()
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
self.synchronize_stream()
self.stream_context_pool.put(stream_context)
torch.cuda.empty_cache()
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
@@ -278,6 +292,10 @@ class CosyVoiceModel:
self.hift_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
def synchronize_stream(self):
if self.is_cuda_available:
torch.cuda.current_stream().synchronize()
class CosyVoice2Model(CosyVoiceModel):
@@ -314,6 +332,12 @@ class CosyVoice2Model(CosyVoiceModel):
self.llm_end_dict = {}
self.hift_cache_dict = {}
self.stream_context_pool = queue.Queue()
for _ in range(10):
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.is_cuda_available = torch.cuda.is_available()
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
@@ -359,57 +383,64 @@ class CosyVoice2Model(CosyVoiceModel):
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
if stream is True:
token_offset = 0
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=token_offset,
finalize=False)
token_offset += self.token_hop_len
yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=token_offset,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
p.join()
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=0,
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
torch.cuda.empty_cache()
self.synchronize_stream()
stream_context = self.stream_context_pool.get()
with torch.cuda.stream(stream_context):
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
if stream is True:
token_offset = 0
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=token_offset,
finalize=False)
token_offset += self.token_hop_len
yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=token_offset,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
p.join()
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=0,
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.synchronize_stream()
self.stream_context_pool.put(stream_context)
torch.cuda.empty_cache()
class VllmCosyVoice2Model(CosyVoice2Model):