mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
force set use_flow_cache
This commit is contained in:
@@ -401,10 +401,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
|
||||
if self.use_flow_cache is True:
|
||||
flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
|
||||
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
@@ -412,6 +408,10 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
|
||||
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
|
||||
flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
|
||||
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
|
||||
@@ -442,6 +442,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
|
||||
Reference in New Issue
Block a user