support online onnx to trt conversion

This commit is contained in:
huzetao.hzt
2025-01-07 17:20:06 +08:00
parent 5d12ced727
commit b6a1116d15
4 changed files with 146 additions and 26 deletions

View File

@@ -120,24 +120,7 @@ class ConditionalCFM(BASECFM):
return sol[-1].float()
def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond)
else:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
return x
return self.estimator.forward(x, mask, mu, t, spks, cond)
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss