From 1ab31867998fb9b8f456520870e9794dee03efea Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 29 Aug 2024 23:35:19 +0800 Subject: [PATCH] revert trt TODO --- cosyvoice/cli/cosyvoice.py | 5 +---- cosyvoice/cli/model.py | 10 ++-------- cosyvoice/flow/flow_matching.py | 12 ++---------- 3 files changed, 5 insertions(+), 22 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index d5fbd4e..49fe15f 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging class CosyVoice: - def __init__(self, model_dir, load_jit=True, load_trt=True): + def __init__(self, model_dir, load_jit=True): instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): @@ -42,9 +42,6 @@ class CosyVoice: if load_jit: self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), '{}/llm.llm.fp16.zip'.format(model_dir)) - if load_trt: - # TODO - self.model.load_trt() del configs def list_avaliable_spks(self): diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 1184f0d..99ccbe5 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -66,11 +66,6 @@ class CosyVoiceModel: llm_llm = torch.jit.load(llm_llm_model) self.llm.llm = llm_llm - def load_trt(self): - # TODO 你需要的TRT推理的准备 - self.flow.decoder.estimator = xxx - self.flow.decoder.session = xxx - def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: for i in self.llm.inference(text=text.to(self.device), @@ -126,7 +121,6 @@ class CosyVoiceModel: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p.start() - p.join() if stream is True: token_hop_len = self.token_min_hop_len while True: @@ -147,7 +141,7 @@ class CosyVoiceModel: token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: break - # p.join() + p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1) with self.flow_hift_context: @@ -160,7 +154,7 @@ class CosyVoiceModel: yield {'tts_speech': this_tts_speech.cpu()} else: # deal with all tokens - # p.join() + p.join() this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1) with self.flow_hift_context: this_tts_speech = self.token2wav(token=this_tts_speech_token, diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index bcbaeb5..f82eaae 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -77,10 +77,10 @@ class ConditionalCFM(BASECFM): sol = [] for step in range(1, len(t_span)): - dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond) + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) # Classifier-Free Guidance inference introduced in VoiceBox if self.inference_cfg_rate > 0: - cfg_dphi_dt = self.forward_estimator( + cfg_dphi_dt = self.estimator( x, mask, torch.zeros_like(mu), t, torch.zeros_like(spks) if spks is not None else None, @@ -96,14 +96,6 @@ class ConditionalCFM(BASECFM): return sol[-1] - # TODO - def forward_estimator(self): - if isinstance(self.estimator, trt): - assert self.training is False, 'tensorrt cannot be used in training' - return xxx - else: - return self.estimator.forward - def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss