mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
fix lint
This commit is contained in:
@@ -133,13 +133,13 @@ class ConditionalCFM(BASECFM):
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]) is True
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||
@@ -244,9 +244,9 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
sol = []
|
||||
|
||||
# estimator cache for each step
|
||||
down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
|
||||
mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x.device)
|
||||
up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
|
||||
down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
|
||||
mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x)
|
||||
up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
|
||||
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
@@ -302,12 +302,43 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
|
||||
self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
|
||||
self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
|
||||
self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
|
||||
self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
|
||||
self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
|
||||
self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
||||
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
|
||||
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
||||
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
|
||||
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
|
||||
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
|
||||
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
|
||||
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
|
||||
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
|
||||
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
|
||||
x.data_ptr(),
|
||||
cache['down_blocks_conv_cache'].data_ptr(),
|
||||
down_blocks_kv_cache_out.data_ptr(),
|
||||
cache['mid_blocks_conv_cache'].data_ptr(),
|
||||
mid_blocks_kv_cache_out.data_ptr(),
|
||||
cache['up_blocks_conv_cache'].data_ptr(),
|
||||
up_blocks_kv_cache_out.data_ptr(),
|
||||
cache['final_blocks_conv_cache'].data_ptr()]) is True
|
||||
cache = (cache['down_blocks_conv_cache'],
|
||||
down_blocks_kv_cache_out,
|
||||
cache['mid_blocks_conv_cache'],
|
||||
mid_blocks_kv_cache_out,
|
||||
cache['up_blocks_conv_cache'],
|
||||
up_blocks_kv_cache_out,
|
||||
cache['final_blocks_conv_cache'])
|
||||
return x, cache
|
||||
|
||||
Reference in New Issue
Block a user