mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add flow decoder cache
This commit is contained in:
@@ -34,7 +34,7 @@ class ConditionalCFM(BASECFM):
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
@@ -54,19 +54,19 @@ class ConditionalCFM(BASECFM):
|
||||
"""
|
||||
|
||||
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
||||
cache_size = flow_cache.shape[2]
|
||||
cache_size = cache.shape[2]
|
||||
# fix prompt and overlap part mu and z
|
||||
if cache_size != 0:
|
||||
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
||||
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
||||
z[:, :, :cache_size] = cache[:, :, :, 0]
|
||||
mu[:, :, :cache_size] = cache[:, :, :, 1]
|
||||
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
||||
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
||||
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
||||
cache = torch.stack([z_cache, mu_cache], dim=-1)
|
||||
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
"""
|
||||
@@ -123,7 +123,7 @@ class ConditionalCFM(BASECFM):
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
return self.estimator(x, mask, mu, t, spks, cond)
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
@@ -181,6 +181,9 @@ class ConditionalCFM(BASECFM):
|
||||
|
||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
||||
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
||||
if loss.isnan():
|
||||
print(123)
|
||||
pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
||||
return loss, y
|
||||
|
||||
|
||||
@@ -190,7 +193,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
@@ -209,9 +212,105 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
||||
offset = cache.pop('offset')
|
||||
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
|
||||
z = z[:, :, offset:]
|
||||
offset += mu.size(2)
|
||||
# fix prompt and overlap part mu and z
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
||||
mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
|
||||
cache['offset'] = offset
|
||||
return mel, cache
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
t = t.unsqueeze(dim=0)
|
||||
|
||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
|
||||
# estimator cache for each step
|
||||
down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
|
||||
mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x.device)
|
||||
up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
|
||||
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
for step in range(1, len(t_span)):
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
x_in[:] = x
|
||||
mask_in[:] = mask
|
||||
mu_in[0] = mu
|
||||
t_in[:] = t.unsqueeze(0)
|
||||
spks_in[0] = spks
|
||||
cond_in[0] = cond
|
||||
cache_step = {k: v[step - 1] for k, v in cache.items()}
|
||||
dphi_dt, cache_step = self.forward_estimator(
|
||||
x_in, mask_in,
|
||||
mu_in, t_in,
|
||||
spks_in,
|
||||
cond_in,
|
||||
cache_step
|
||||
)
|
||||
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
|
||||
down_blocks_kv_cache_new[step - 1] = cache_step[1]
|
||||
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
|
||||
mid_blocks_kv_cache_new[step - 1] = cache_step[3]
|
||||
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
|
||||
up_blocks_kv_cache_new[step - 1] = cache_step[5]
|
||||
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4)
|
||||
cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4)
|
||||
cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4)
|
||||
return sol[-1].float(), cache
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
|
||||
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
return x, cache
|
||||
|
||||
Reference in New Issue
Block a user