This commit is contained in:
lyuxiang.lx
2024-09-05 16:15:34 +08:00
parent eeebc45313
commit 90433f5373
35 changed files with 189 additions and 122 deletions

View File

@@ -18,7 +18,7 @@ import time
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
import numpy as np
class CosyVoiceModel:
@@ -80,27 +80,27 @@ class CosyVoiceModel:
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context:
for i in self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device).half(),
sampling=25,
max_token_text_ratio=30,
min_token_text_ratio=3):
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device).half(),
sampling=25,
max_token_text_ratio=30,
min_token_text_ratio=3):
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
with self.flow_hift_context:
tts_mel = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device))
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device))
# mel overlap fade in out
if self.mel_overlap_dict[uuid] is not None:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
@@ -129,7 +129,8 @@ class CosyVoiceModel:
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
if stream is True:
@@ -140,12 +141,12 @@ class CosyVoiceModel:
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=False)
yield {'tts_speech': this_tts_speech.cpu()}
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=False)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better speech quality
@@ -157,11 +158,11 @@ class CosyVoiceModel:
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
@@ -169,11 +170,11 @@ class CosyVoiceModel:
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)