Merge pull request #455 from boji123/bj_dev_stream_fix_promptcache

[debug] support flow cache, for sharper tts_mel output (handle prompt bug)
This commit is contained in:
Xiang Lyu
2024-10-16 14:12:40 +08:00
committed by GitHub
3 changed files with 35 additions and 9 deletions

View File

@@ -52,6 +52,7 @@ class CosyVoiceModel:
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.flow_cache_dict = {}
self.mel_overlap_dict = {}
self.hift_cache_dict = {}
@@ -100,13 +101,17 @@ class CosyVoiceModel:
self.llm_end_dict[uuid] = True
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
tts_mel = self.flow.inference(token=token.to(self.device),
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device))
embedding=embedding.to(self.device),
required_cache_size=self.mel_overlap_len,
flow_cache=self.flow_cache_dict[uuid])
self.flow_cache_dict[uuid] = flow_cache
# mel overlap fade in out
if self.mel_overlap_dict[uuid] is not None:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
@@ -145,6 +150,7 @@ class CosyVoiceModel:
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.flow_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()