export onnx

This commit is contained in:
禾息
2024-09-03 11:06:24 +08:00
parent 18599be8d5
commit fadb22086f
5 changed files with 318 additions and 164 deletions

View File

@@ -14,6 +14,8 @@
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
import onnxruntime as ort
import numpy as np
class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
@@ -29,6 +31,8 @@ class ConditionalCFM(BASECFM):
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
self.estimator_context = None # for tensorrt
self.session = None # for onnx
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
@@ -101,28 +105,47 @@ class ConditionalCFM(BASECFM):
if self.estimator is not None:
return self.estimator.forward(x, mask, mu, t, spks, cond)
else:
assert self.training is False, 'tensorrt cannot be used in training'
bs = x.shape[0]
hs = x.shape[1]
seq_len = x.shape[2]
# assert bs == 1 and hs == 80
ret = torch.empty_like(x)
self.estimator_context.set_input_shape("x", x.shape)
self.estimator_context.set_input_shape("mask", mask.shape)
self.estimator_context.set_input_shape("mu", mu.shape)
self.estimator_context.set_input_shape("t", t.shape)
self.estimator_context.set_input_shape("spks", spks.shape)
self.estimator_context.set_input_shape("cond", cond.shape)
bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out']
for i in range(len(bindings)):
self.estimator_context.set_tensor_address(names[i], bindings[i])
# elif self.estimator_context is not None:
# assert self.training is False, 'tensorrt cannot be used in training'
# bs = x.shape[0]
# hs = x.shape[1]
# seq_len = x.shape[2]
# # assert bs == 1 and hs == 80
# ret = torch.empty_like(x)
# self.estimator_context.set_input_shape("x", x.shape)
# self.estimator_context.set_input_shape("mask", mask.shape)
# self.estimator_context.set_input_shape("mu", mu.shape)
# self.estimator_context.set_input_shape("t", t.shape)
# self.estimator_context.set_input_shape("spks", spks.shape)
# self.estimator_context.set_input_shape("cond", cond.shape)
# # Create a list of bindings
# bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())]
# # Execute the inference
# self.estimator_context.execute_v2(bindings=bindings)
# return ret
else:
x_np = x.cpu().numpy()
mask_np = mask.cpu().numpy()
mu_np = mu.cpu().numpy()
t_np = t.cpu().numpy()
spks_np = spks.cpu().numpy()
cond_np = cond.cpu().numpy()
ort_inputs = {
'x': x_np,
'mask': mask_np,
'mu': mu_np,
't': t_np,
'spks': spks_np,
'cond': cond_np
}
output = self.session.run(None, ort_inputs)[0]
return torch.tensor(output, dtype=x.dtype, device=x.device)
handle = torch.cuda.current_stream().cuda_stream
self.estimator_context.execute_async_v3(stream_handle=handle)
return ret
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss