revert trt TODO

This commit is contained in:
lyuxiang.lx
2024-08-29 23:35:19 +08:00
parent 1d881df8b2
commit 1ab3186799
3 changed files with 5 additions and 22 deletions

View File

@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
class CosyVoice:
def __init__(self, model_dir, load_jit=True, load_trt=True):
def __init__(self, model_dir, load_jit=True):
instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
if not os.path.exists(model_dir):
@@ -42,9 +42,6 @@ class CosyVoice:
if load_jit:
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
'{}/llm.llm.fp16.zip'.format(model_dir))
if load_trt:
# TODO
self.model.load_trt()
del configs
def list_avaliable_spks(self):

View File

@@ -66,11 +66,6 @@ class CosyVoiceModel:
llm_llm = torch.jit.load(llm_llm_model)
self.llm.llm = llm_llm
def load_trt(self):
# TODO 你需要的TRT推理的准备
self.flow.decoder.estimator = xxx
self.flow.decoder.session = xxx
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context:
for i in self.llm.inference(text=text.to(self.device),
@@ -126,7 +121,6 @@ class CosyVoiceModel:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
p.join()
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
@@ -147,7 +141,7 @@ class CosyVoiceModel:
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
# p.join()
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
@@ -160,7 +154,7 @@ class CosyVoiceModel:
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
# p.join()
p.join()
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,

View File

@@ -77,10 +77,10 @@ class ConditionalCFM(BASECFM):
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.forward_estimator(
cfg_dphi_dt = self.estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
@@ -96,14 +96,6 @@ class ConditionalCFM(BASECFM):
return sol[-1]
# TODO
def forward_estimator(self):
if isinstance(self.estimator, trt):
assert self.training is False, 'tensorrt cannot be used in training'
return xxx
else:
return self.estimator.forward
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss