add flow decoder cache

This commit is contained in:
lyuxiang.lx
2025-01-23 16:48:13 +08:00
parent 190840b8dc
commit 1c062ab381
21 changed files with 1601 additions and 214 deletions

View File

@@ -287,8 +287,16 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
Returns:
torch.Tensor: Corresponding encoding
"""
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
]
# How to subscript a Union type:
# https://github.com/pytorch/pytorch/issues/69434
if isinstance(offset, int):
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
]
elif isinstance(offset, torch.Tensor):
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
]
return pos_emb