diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 32b243c..4a89fb1 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -158,9 +158,12 @@ class CausalAttnProcessor2_0(AttnProcessor2_0): key_cache = attn.to_k(encoder_hidden_states) value_cache = attn.to_v(encoder_hidden_states) - # NOTE always concat cache for interface compatibility - key = torch.concat([cache[:, :, :, 0], key_cache], dim=1) - value = torch.concat([cache[:, :, :, 1], value_cache], dim=1) + # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2) + if cache.size(0) != 0: + key = torch.concat([cache[:, :, :, 0], key_cache], dim=1) + value = torch.concat([cache[:, :, :, 1], value_cache], dim=1) + else: + key, value = key_cache, value_cache cache = torch.stack([key_cache, value_cache], dim=3) inner_dim = key.shape[-1]