add trt script TODO

This commit is contained in:
lyuxiang.lx
2024-08-29 10:44:04 +08:00
parent 8b097f7625
commit f1e374a9bb
5 changed files with 31 additions and 7 deletions

View File

@@ -77,10 +77,10 @@ class ConditionalCFM(BASECFM):
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.estimator(
cfg_dphi_dt = self.forward_estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
@@ -96,6 +96,14 @@ class ConditionalCFM(BASECFM):
return sol[-1]
# TODO
def forward_estimator(self):
if isinstance(self.estimator, trt):
assert self.training is False, 'tensorrt cannot be used in training'
return xxx
else:
return self.estimator.forward
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss