mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add onnx export
This commit is contained in:
@@ -31,8 +31,6 @@ class ConditionalCFM(BASECFM):
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
self.estimator_context = None # for tensorrt
|
||||
self.session = None # for onnx
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
@@ -82,10 +80,10 @@ class ConditionalCFM(BASECFM):
|
||||
sol = []
|
||||
|
||||
for step in range(1, len(t_span)):
|
||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
||||
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
if self.inference_cfg_rate > 0:
|
||||
cfg_dphi_dt = self.estimator(
|
||||
cfg_dphi_dt = self.forward_estimator(
|
||||
x, mask,
|
||||
torch.zeros_like(mu), t,
|
||||
torch.zeros_like(spks) if spks is not None else None,
|
||||
@@ -102,51 +100,20 @@ class ConditionalCFM(BASECFM):
|
||||
return sol[-1]
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
|
||||
if self.estimator is not None:
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
# elif self.estimator_context is not None:
|
||||
# assert self.training is False, 'tensorrt cannot be used in training'
|
||||
# bs = x.shape[0]
|
||||
# hs = x.shape[1]
|
||||
# seq_len = x.shape[2]
|
||||
# # assert bs == 1 and hs == 80
|
||||
# ret = torch.empty_like(x)
|
||||
# self.estimator_context.set_input_shape("x", x.shape)
|
||||
# self.estimator_context.set_input_shape("mask", mask.shape)
|
||||
# self.estimator_context.set_input_shape("mu", mu.shape)
|
||||
# self.estimator_context.set_input_shape("t", t.shape)
|
||||
# self.estimator_context.set_input_shape("spks", spks.shape)
|
||||
# self.estimator_context.set_input_shape("cond", cond.shape)
|
||||
|
||||
# # Create a list of bindings
|
||||
# bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())]
|
||||
|
||||
# # Execute the inference
|
||||
# self.estimator_context.execute_v2(bindings=bindings)
|
||||
# return ret
|
||||
else:
|
||||
x_np = x.cpu().numpy()
|
||||
mask_np = mask.cpu().numpy()
|
||||
mu_np = mu.cpu().numpy()
|
||||
t_np = t.cpu().numpy()
|
||||
spks_np = spks.cpu().numpy()
|
||||
cond_np = cond.cpu().numpy()
|
||||
|
||||
ort_inputs = {
|
||||
'x': x_np,
|
||||
'mask': mask_np,
|
||||
'mu': mu_np,
|
||||
't': t_np,
|
||||
'spks': spks_np,
|
||||
'cond': cond_np
|
||||
'x': x.cpu().numpy(),
|
||||
'mask': mask.cpu().numpy(),
|
||||
'mu': mu.cpu().numpy(),
|
||||
't': t.cpu().numpy(),
|
||||
'spks': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy()
|
||||
}
|
||||
|
||||
output = self.session.run(None, ort_inputs)[0]
|
||||
|
||||
output = self.estimator.run(None, ort_inputs)[0]
|
||||
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user