add onnx export

This commit is contained in:
lyuxiang.lx
2024-09-04 18:15:33 +08:00
parent d8197de4cc
commit 2ce724045b
6 changed files with 105 additions and 280 deletions

View File

@@ -31,8 +31,6 @@ class ConditionalCFM(BASECFM):
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
self.estimator_context = None # for tensorrt
self.session = None # for onnx
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
@@ -82,10 +80,10 @@ class ConditionalCFM(BASECFM):
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.estimator(
cfg_dphi_dt = self.forward_estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
@@ -102,51 +100,20 @@ class ConditionalCFM(BASECFM):
return sol[-1]
def forward_estimator(self, x, mask, mu, t, spks, cond):
if self.estimator is not None:
if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond)
# elif self.estimator_context is not None:
# assert self.training is False, 'tensorrt cannot be used in training'
# bs = x.shape[0]
# hs = x.shape[1]
# seq_len = x.shape[2]
# # assert bs == 1 and hs == 80
# ret = torch.empty_like(x)
# self.estimator_context.set_input_shape("x", x.shape)
# self.estimator_context.set_input_shape("mask", mask.shape)
# self.estimator_context.set_input_shape("mu", mu.shape)
# self.estimator_context.set_input_shape("t", t.shape)
# self.estimator_context.set_input_shape("spks", spks.shape)
# self.estimator_context.set_input_shape("cond", cond.shape)
# # Create a list of bindings
# bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())]
# # Execute the inference
# self.estimator_context.execute_v2(bindings=bindings)
# return ret
else:
x_np = x.cpu().numpy()
mask_np = mask.cpu().numpy()
mu_np = mu.cpu().numpy()
t_np = t.cpu().numpy()
spks_np = spks.cpu().numpy()
cond_np = cond.cpu().numpy()
ort_inputs = {
'x': x_np,
'mask': mask_np,
'mu': mu_np,
't': t_np,
'spks': spks_np,
'cond': cond_np
'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(),
't': t.cpu().numpy(),
'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy()
}
output = self.session.run(None, ort_inputs)[0]
output = self.estimator.run(None, ort_inputs)[0]
return torch.tensor(output, dtype=x.dtype, device=x.device)
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss