mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add flow decoder tensorrt infer
This commit is contained in:
@@ -30,6 +30,9 @@ class ConditionalCFM(BASECFM):
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
|
||||
self.estimator_context = None
|
||||
self.estimator_engine = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
"""Forward diffusion
|
||||
@@ -50,7 +53,7 @@ class ConditionalCFM(BASECFM):
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
z = torch.randn_like(mu) * temperature
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
||||
@@ -71,6 +74,7 @@ class ConditionalCFM(BASECFM):
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
t = t.unsqueeze(dim=0)
|
||||
|
||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||
# Or in future might add like a return_all_steps flag
|
||||
@@ -96,13 +100,30 @@ class ConditionalCFM(BASECFM):
|
||||
|
||||
return sol[-1]
|
||||
|
||||
# TODO
|
||||
def forward_estimator(self):
|
||||
if isinstance(self.estimator, trt):
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
if self.estimator_context is not None:
|
||||
assert self.training is False, 'tensorrt cannot be used in training'
|
||||
return xxx
|
||||
bs = x.shape[0]
|
||||
hs = x.shape[1]
|
||||
seq_len = x.shape[2]
|
||||
# assert bs == 1 and hs == 80
|
||||
ret = torch.empty_like(x)
|
||||
self.estimator_context.set_input_shape("x", x.shape)
|
||||
self.estimator_context.set_input_shape("mask", mask.shape)
|
||||
self.estimator_context.set_input_shape("mu", mu.shape)
|
||||
self.estimator_context.set_input_shape("t", t.shape)
|
||||
self.estimator_context.set_input_shape("spks", spks.shape)
|
||||
self.estimator_context.set_input_shape("cond", cond.shape)
|
||||
bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
|
||||
|
||||
for i in range(len(bindings)):
|
||||
self.estimator_context.set_tensor_address(self.estimator_engine.get_tensor_name(i), bindings[i])
|
||||
|
||||
handle = torch.cuda.current_stream().cuda_stream
|
||||
self.estimator_context.execute_async_v3(stream_handle=handle)
|
||||
return ret
|
||||
else:
|
||||
return self.estimator.forward
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Reference in New Issue
Block a user