add flow decoder tensorrt infer

This commit is contained in:
zhoubofan.zbf
2024-08-29 23:35:07 +08:00
parent 1d881df8b2
commit 5f21aef786
5 changed files with 149 additions and 19 deletions

View File

@@ -159,7 +159,7 @@ class ConditionalDecoder(nn.Module):
_type_: _description_
"""
t = self.time_embeddings(t)
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]

View File

@@ -30,6 +30,9 @@ class ConditionalCFM(BASECFM):
# Just change the architecture of the estimator here
self.estimator = estimator
self.estimator_context = None
self.estimator_engine = None
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
@@ -50,7 +53,7 @@ class ConditionalCFM(BASECFM):
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
@@ -71,6 +74,7 @@ class ConditionalCFM(BASECFM):
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
@@ -96,13 +100,30 @@ class ConditionalCFM(BASECFM):
return sol[-1]
# TODO
def forward_estimator(self):
if isinstance(self.estimator, trt):
def forward_estimator(self, x, mask, mu, t, spks, cond):
if self.estimator_context is not None:
assert self.training is False, 'tensorrt cannot be used in training'
return xxx
bs = x.shape[0]
hs = x.shape[1]
seq_len = x.shape[2]
# assert bs == 1 and hs == 80
ret = torch.empty_like(x)
self.estimator_context.set_input_shape("x", x.shape)
self.estimator_context.set_input_shape("mask", mask.shape)
self.estimator_context.set_input_shape("mu", mu.shape)
self.estimator_context.set_input_shape("t", t.shape)
self.estimator_context.set_input_shape("spks", spks.shape)
self.estimator_context.set_input_shape("cond", cond.shape)
bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
for i in range(len(bindings)):
self.estimator_context.set_tensor_address(self.estimator_engine.get_tensor_name(i), bindings[i])
handle = torch.cuda.current_stream().cuda_stream
self.estimator_context.execute_async_v3(stream_handle=handle)
return ret
else:
return self.estimator.forward
return self.estimator.forward(x, mask, mu, t, spks, cond)
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss