add flow decoder cache

This commit is contained in:
lyuxiang.lx
2025-01-23 16:48:13 +08:00
parent 190840b8dc
commit 1c062ab381
21 changed files with 1601 additions and 214 deletions

View File

@@ -34,7 +34,7 @@ class ConditionalCFM(BASECFM):
self.lock = threading.Lock()
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
"""Forward diffusion
Args:
@@ -54,19 +54,19 @@ class ConditionalCFM(BASECFM):
"""
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
cache_size = flow_cache.shape[2]
cache_size = cache.shape[2]
# fix prompt and overlap part mu and z
if cache_size != 0:
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
z[:, :, :cache_size] = cache[:, :, :, 0]
mu[:, :, :cache_size] = cache[:, :, :, 1]
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
cache = torch.stack([z_cache, mu_cache], dim=-1)
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
@@ -123,7 +123,7 @@ class ConditionalCFM(BASECFM):
def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond)
return self.estimator(x, mask, mu, t, spks, cond)
else:
with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
@@ -181,6 +181,9 @@ class ConditionalCFM(BASECFM):
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
if loss.isnan():
print(123)
pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
return loss, y
@@ -190,7 +193,7 @@ class CausalConditionalCFM(ConditionalCFM):
self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
"""Forward diffusion
Args:
@@ -209,9 +212,105 @@ class CausalConditionalCFM(ConditionalCFM):
shape: (batch_size, n_feats, mel_timesteps)
"""
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
offset = cache.pop('offset')
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
z = z[:, :, offset:]
offset += mu.size(2)
# fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
cache['offset'] = offset
return mel, cache
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
# estimator cache for each step
down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x.device)
up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
cache_step = {k: v[step - 1] for k, v in cache.items()}
dphi_dt, cache_step = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in,
cache_step
)
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
down_blocks_kv_cache_new[step - 1] = cache_step[1]
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
mid_blocks_kv_cache_new[step - 1] = cache_step[3]
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
up_blocks_kv_cache_new[step - 1] = cache_step[5]
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4)
cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4)
cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4)
return sol[-1].float(), cache
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
if isinstance(self.estimator, torch.nn.Module):
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
else:
with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
return x, cache