diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py index 19716b9..fcb1594 100644 --- a/cosyvoice/bin/export_onnx.py +++ b/cosyvoice/bin/export_onnx.py @@ -170,8 +170,8 @@ def main(): estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), sess_options=option, providers=providers) - for _ in tqdm(range(10)): - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 256), out_channels, device) + for iter in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) cache = model.model.init_flow_cache()['decoder_cache'] cache.pop('offset') cache = {k: v[0] for k, v in cache.items()} @@ -185,6 +185,9 @@ def main(): 'cond': cond.cpu().numpy(), } output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}}) + if iter == 0: + # NOTE why can not pass first iteration check? + continue for i, j in zip(output_pytorch, output_onnx): torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4) logging.info('successfully export estimator') diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 261cf09..32b243c 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -158,12 +158,9 @@ class CausalAttnProcessor2_0(AttnProcessor2_0): key_cache = attn.to_k(encoder_hidden_states) value_cache = attn.to_v(encoder_hidden_states) - # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2) - if cache.size(0) != 0: - key = torch.concat([cache[:, :, :, 0], key_cache], dim=1) - value = torch.concat([cache[:, :, :, 1], value_cache], dim=1) - else: - key, value = key_cache, value_cache + # NOTE always concat cache for interface compatibility + key = torch.concat([cache[:, :, :, 0], key_cache], dim=1) + value = torch.concat([cache[:, :, :, 1], value_cache], dim=1) cache = torch.stack([key_cache, value_cache], dim=3) inner_dim = key.shape[-1] @@ -799,6 +796,7 @@ class CausalConditionalDecoder(ConditionalDecoder): output = self.final_proj(x * mask_up) return output * mask + @torch.inference_mode() def forward_chunk(self, x, mask, mu, t, spks=None, cond=None, down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),