This commit is contained in:
lyuxiang.lx
2025-02-06 16:07:13 +08:00
parent 24f796a2b1
commit 2a3e033ee1
17 changed files with 187 additions and 135 deletions

View File

@@ -56,7 +56,7 @@ class Upsample1D(nn.Module):
# In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
if conv_cache.size(2) == 0:
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
@@ -287,11 +287,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk if streaming is True else False,
self.use_dynamic_left_chunk if streaming is True else False,
decoding_chunk_size if streaming is True else 0,
self.static_chunk_size if streaming is True else 0,
num_decoding_left_chunks if streaming is True else -1)
self.use_dynamic_chunk if streaming is True else False,
self.use_dynamic_left_chunk if streaming is True else False,
decoding_chunk_size if streaming is True else 0,
self.static_chunk_size if streaming is True else 0,
num_decoding_left_chunks if streaming is True else -1)
# lookahead + conformer encoder
xs, _ = self.pre_lookahead_layer(xs)
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
@@ -305,11 +305,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
xs, pos_emb, masks = self.up_embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk if streaming is True else False,
self.use_dynamic_left_chunk if streaming is True else False,
decoding_chunk_size if streaming is True else 0,
self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
num_decoding_left_chunks if streaming is True else -1)
self.use_dynamic_chunk if streaming is True else False,
self.use_dynamic_left_chunk if streaming is True else False,
decoding_chunk_size if streaming is True else 0,
self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
num_decoding_left_chunks if streaming is True else -1)
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before: