fix cache bug

This commit is contained in:
lyuxiang.lx
2025-01-24 11:07:26 +08:00
parent 1c062ab381
commit aea75207dd
2 changed files with 20 additions and 22 deletions

View File

@@ -396,6 +396,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
encoders_kv_cache_list = []
for index, layer in enumerate(self.encoders):
xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
encoders_kv_cache_list.append(encoders_kv_cache_new)
encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
# upsample
@@ -426,4 +427,4 @@ class UpsampleConformerEncoder(torch.nn.Module):
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache_new, upsample_offset, upsample_conv_cache, upsample_kv_cache_new)
return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)