remove flow_cache

This commit is contained in:
lyuxiang.lx
2025-05-23 12:50:47 +08:00
parent 88f467a8ac
commit 68100c267a
14 changed files with 365 additions and 955 deletions

View File

@@ -126,21 +126,26 @@ class ConditionalCFM(BASECFM):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator(x, mask, mu, t, spks, cond)
else:
with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()]) is True
estimator, trt_engine = self.estimator.acquire_estimator()
estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2)))
estimator.set_input_shape('t', (2,))
estimator.set_input_shape('spks', (2, 80))
estimator.set_input_shape('cond', (2, 80, x.size(2)))
data_ptrs = [x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()]
for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator)
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
@@ -191,7 +196,7 @@ class CausalConditionalCFM(ConditionalCFM):
self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
@@ -210,136 +215,9 @@ class CausalConditionalCFM(ConditionalCFM):
shape: (batch_size, n_feats, mel_timesteps)
"""
offset = cache.pop('offset')
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
z = z[:, :, offset:]
offset += mu.size(2)
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
# fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
cache['offset'] = offset
return mel, cache
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
cache_step = {k: v[step - 1] for k, v in cache.items()}
dphi_dt, cache_step = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in,
cache_step
)
# NOTE if smaller than flow_cache_size, means last chunk, no need to cache
if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1].float(), cache
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
if isinstance(self.estimator, torch.nn.Module):
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
else:
estimator, trt_engine = self.estimator.acquire_estimator()
estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2)))
estimator.set_input_shape('t', (2,))
estimator.set_input_shape('spks', (2, 80))
estimator.set_input_shape('cond', (2, 80, x.size(2)))
estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
data_ptrs = [x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
x.data_ptr(),
cache['down_blocks_conv_cache'].data_ptr(),
down_blocks_kv_cache_out.data_ptr(),
cache['mid_blocks_conv_cache'].data_ptr(),
mid_blocks_kv_cache_out.data_ptr(),
cache['up_blocks_conv_cache'].data_ptr(),
up_blocks_kv_cache_out.data_ptr(),
cache['final_blocks_conv_cache'].data_ptr()]
for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator)
cache = (cache['down_blocks_conv_cache'],
down_blocks_kv_cache_out,
cache['mid_blocks_conv_cache'],
mid_blocks_kv_cache_out,
cache['up_blocks_conv_cache'],
up_blocks_kv_cache_out,
cache['final_blocks_conv_cache'])
return x, cache
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None