mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
revert trt TODO
This commit is contained in:
@@ -77,10 +77,10 @@ class ConditionalCFM(BASECFM):
|
||||
sol = []
|
||||
|
||||
for step in range(1, len(t_span)):
|
||||
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
|
||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
if self.inference_cfg_rate > 0:
|
||||
cfg_dphi_dt = self.forward_estimator(
|
||||
cfg_dphi_dt = self.estimator(
|
||||
x, mask,
|
||||
torch.zeros_like(mu), t,
|
||||
torch.zeros_like(spks) if spks is not None else None,
|
||||
@@ -96,14 +96,6 @@ class ConditionalCFM(BASECFM):
|
||||
|
||||
return sol[-1]
|
||||
|
||||
# TODO
|
||||
def forward_estimator(self):
|
||||
if isinstance(self.estimator, trt):
|
||||
assert self.training is False, 'tensorrt cannot be used in training'
|
||||
return xxx
|
||||
else:
|
||||
return self.estimator.forward
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user