This commit is contained in:
lyuxiang.lx
2025-08-19 18:53:18 +08:00
parent e3c2400abb
commit da41f6175b
5 changed files with 986 additions and 13 deletions

View File

@@ -64,17 +64,18 @@ class Upsample1D(nn.Module):
class PreLookaheadLayer(nn.Module):
def __init__(self, channels: int, pre_lookahead_len: int = 1):
def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d(
channels, channels,
in_channels, channels,
kernel_size=pre_lookahead_len + 1,
stride=1, padding=0,
)
self.conv2 = nn.Conv1d(
channels, channels,
channels, in_channels,
kernel_size=3, stride=1, padding=0,
)
@@ -199,7 +200,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,