This commit is contained in:
lyuxiang.lx
2024-09-05 16:15:34 +08:00
parent eeebc45313
commit 90433f5373
35 changed files with 189 additions and 122 deletions

View File

@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
"""Construct an PositionalEncoding object."""
super(EspnetRelPositionalEncoding, self).__init__()
self.d_model = d_model
@@ -289,6 +289,6 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
"""
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
]
return pos_emb