mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
export onnx
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||
@@ -29,6 +31,8 @@ class ConditionalCFM(BASECFM):
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
self.estimator_context = None # for tensorrt
|
||||
self.session = None # for onnx
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
@@ -101,28 +105,47 @@ class ConditionalCFM(BASECFM):
|
||||
|
||||
if self.estimator is not None:
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
else:
|
||||
assert self.training is False, 'tensorrt cannot be used in training'
|
||||
bs = x.shape[0]
|
||||
hs = x.shape[1]
|
||||
seq_len = x.shape[2]
|
||||
# assert bs == 1 and hs == 80
|
||||
ret = torch.empty_like(x)
|
||||
self.estimator_context.set_input_shape("x", x.shape)
|
||||
self.estimator_context.set_input_shape("mask", mask.shape)
|
||||
self.estimator_context.set_input_shape("mu", mu.shape)
|
||||
self.estimator_context.set_input_shape("t", t.shape)
|
||||
self.estimator_context.set_input_shape("spks", spks.shape)
|
||||
self.estimator_context.set_input_shape("cond", cond.shape)
|
||||
bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
|
||||
names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out']
|
||||
|
||||
for i in range(len(bindings)):
|
||||
self.estimator_context.set_tensor_address(names[i], bindings[i])
|
||||
# elif self.estimator_context is not None:
|
||||
# assert self.training is False, 'tensorrt cannot be used in training'
|
||||
# bs = x.shape[0]
|
||||
# hs = x.shape[1]
|
||||
# seq_len = x.shape[2]
|
||||
# # assert bs == 1 and hs == 80
|
||||
# ret = torch.empty_like(x)
|
||||
# self.estimator_context.set_input_shape("x", x.shape)
|
||||
# self.estimator_context.set_input_shape("mask", mask.shape)
|
||||
# self.estimator_context.set_input_shape("mu", mu.shape)
|
||||
# self.estimator_context.set_input_shape("t", t.shape)
|
||||
# self.estimator_context.set_input_shape("spks", spks.shape)
|
||||
# self.estimator_context.set_input_shape("cond", cond.shape)
|
||||
|
||||
# # Create a list of bindings
|
||||
# bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())]
|
||||
|
||||
# # Execute the inference
|
||||
# self.estimator_context.execute_v2(bindings=bindings)
|
||||
# return ret
|
||||
else:
|
||||
x_np = x.cpu().numpy()
|
||||
mask_np = mask.cpu().numpy()
|
||||
mu_np = mu.cpu().numpy()
|
||||
t_np = t.cpu().numpy()
|
||||
spks_np = spks.cpu().numpy()
|
||||
cond_np = cond.cpu().numpy()
|
||||
|
||||
ort_inputs = {
|
||||
'x': x_np,
|
||||
'mask': mask_np,
|
||||
'mu': mu_np,
|
||||
't': t_np,
|
||||
'spks': spks_np,
|
||||
'cond': cond_np
|
||||
}
|
||||
|
||||
output = self.session.run(None, ort_inputs)[0]
|
||||
|
||||
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
||||
|
||||
handle = torch.cuda.current_stream().cuda_stream
|
||||
self.estimator_context.execute_async_v3(stream_handle=handle)
|
||||
return ret
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Reference in New Issue
Block a user