use static_chunk_size in flow training

This commit is contained in:
lyuxiang.lx
2025-04-07 22:34:45 +08:00
parent 7902d1c17f
commit d9ffd592f6

View File

@@ -286,12 +286,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk if streaming is True else False,
self.use_dynamic_left_chunk if streaming is True else False,
decoding_chunk_size if streaming is True else 0,
self.static_chunk_size if streaming is True else 0,
num_decoding_left_chunks if streaming is True else -1)
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
# lookahead + conformer encoder
xs, _ = self.pre_lookahead_layer(xs)
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
@@ -304,12 +299,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
xs, pos_emb, masks = self.up_embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk if streaming is True else False,
self.use_dynamic_left_chunk if streaming is True else False,
decoding_chunk_size if streaming is True else 0,
self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
num_decoding_left_chunks if streaming is True else -1)
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before: