This commit is contained in:
lyuxiang.lx
2025-02-06 16:07:13 +08:00
parent 24f796a2b1
commit 2a3e033ee1
17 changed files with 187 additions and 135 deletions

View File

@@ -133,13 +133,13 @@ class ConditionalCFM(BASECFM):
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()]) is True
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
@@ -244,9 +244,9 @@ class CausalConditionalCFM(ConditionalCFM):
sol = []
# estimator cache for each step
down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x.device)
up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x)
up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
@@ -302,12 +302,43 @@ class CausalConditionalCFM(ConditionalCFM):
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
x.data_ptr(),
cache['down_blocks_conv_cache'].data_ptr(),
down_blocks_kv_cache_out.data_ptr(),
cache['mid_blocks_conv_cache'].data_ptr(),
mid_blocks_kv_cache_out.data_ptr(),
cache['up_blocks_conv_cache'].data_ptr(),
up_blocks_kv_cache_out.data_ptr(),
cache['final_blocks_conv_cache'].data_ptr()]) is True
cache = (cache['down_blocks_conv_cache'],
down_blocks_kv_cache_out,
cache['mid_blocks_conv_cache'],
mid_blocks_kv_cache_out,
cache['up_blocks_conv_cache'],
up_blocks_kv_cache_out,
cache['final_blocks_conv_cache'])
return x, cache