mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
fix bug
This commit is contained in:
@@ -158,9 +158,12 @@ class CausalAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
key_cache = attn.to_k(encoder_hidden_states)
|
key_cache = attn.to_k(encoder_hidden_states)
|
||||||
value_cache = attn.to_v(encoder_hidden_states)
|
value_cache = attn.to_v(encoder_hidden_states)
|
||||||
# NOTE always concat cache for interface compatibility
|
# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
|
||||||
key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
|
if cache.size(0) != 0:
|
||||||
value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
|
key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
|
||||||
|
value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
|
||||||
|
else:
|
||||||
|
key, value = key_cache, value_cache
|
||||||
cache = torch.stack([key_cache, value_cache], dim=3)
|
cache = torch.stack([key_cache, value_cache], dim=3)
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
inner_dim = key.shape[-1]
|
||||||
|
|||||||
Reference in New Issue
Block a user