mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add flow unified training
This commit is contained in:
@@ -255,6 +255,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
||||
xs_lens: torch.Tensor,
|
||||
decoding_chunk_size: int = 0,
|
||||
num_decoding_left_chunks: int = -1,
|
||||
streaming: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
@@ -286,11 +287,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
||||
xs, pos_emb, masks = self.embed(xs, masks)
|
||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||
chunk_masks = add_optional_chunk_mask(xs, masks,
|
||||
self.use_dynamic_chunk,
|
||||
self.use_dynamic_left_chunk,
|
||||
decoding_chunk_size,
|
||||
self.static_chunk_size,
|
||||
num_decoding_left_chunks)
|
||||
self.use_dynamic_chunk if streaming is True else False,
|
||||
self.use_dynamic_left_chunk if streaming is True else False,
|
||||
decoding_chunk_size if streaming is True else 0,
|
||||
self.static_chunk_size if streaming is True else 0,
|
||||
num_decoding_left_chunks if streaming is True else -1)
|
||||
# lookahead + conformer encoder
|
||||
xs, _ = self.pre_lookahead_layer(xs)
|
||||
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||
@@ -304,11 +305,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
||||
xs, pos_emb, masks = self.up_embed(xs, masks)
|
||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||
chunk_masks = add_optional_chunk_mask(xs, masks,
|
||||
self.use_dynamic_chunk,
|
||||
self.use_dynamic_left_chunk,
|
||||
decoding_chunk_size,
|
||||
self.static_chunk_size * self.up_layer.stride,
|
||||
num_decoding_left_chunks)
|
||||
self.use_dynamic_chunk if streaming is True else False,
|
||||
self.use_dynamic_left_chunk if streaming is True else False,
|
||||
decoding_chunk_size if streaming is True else 0,
|
||||
self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
|
||||
num_decoding_left_chunks if streaming is True else -1)
|
||||
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||
|
||||
if self.normalize_before:
|
||||
|
||||
Reference in New Issue
Block a user