mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add cosyvoice code
This commit is contained in:
472
cosyvoice/transformer/encoder.py
Normal file
472
cosyvoice/transformer/encoder.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
||||
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
||||
# 2024 Alibaba Inc (Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Encoder definition."""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as ckpt
|
||||
|
||||
from cosyvoice.transformer.convolution import ConvolutionModule
|
||||
from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
|
||||
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
||||
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from cosyvoice.utils.class_utils import (
|
||||
COSYVOICE_EMB_CLASSES,
|
||||
COSYVOICE_SUBSAMPLE_CLASSES,
|
||||
COSYVOICE_ATTENTION_CLASSES,
|
||||
COSYVOICE_ACTIVATION_CLASSES,
|
||||
)
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||
|
||||
|
||||
class BaseEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
pos_enc_layer_type: str = "abs_pos",
|
||||
normalize_before: bool = True,
|
||||
static_chunk_size: int = 0,
|
||||
use_dynamic_chunk: bool = False,
|
||||
global_cmvn: torch.nn.Module = None,
|
||||
use_dynamic_left_chunk: bool = False,
|
||||
gradient_checkpointing: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_size (int): input dim
|
||||
output_size (int): dimension of attention
|
||||
attention_heads (int): the number of heads of multi head attention
|
||||
linear_units (int): the hidden units number of position-wise feed
|
||||
forward
|
||||
num_blocks (int): the number of decoder blocks
|
||||
dropout_rate (float): dropout rate
|
||||
attention_dropout_rate (float): dropout rate in attention
|
||||
positional_dropout_rate (float): dropout rate after adding
|
||||
positional encoding
|
||||
input_layer (str): input layer type.
|
||||
optional [linear, conv2d, conv2d6, conv2d8]
|
||||
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
||||
normalize_before (bool):
|
||||
True: use layer_norm before each sub-block of a layer.
|
||||
False: use layer_norm after each sub-block of a layer.
|
||||
static_chunk_size (int): chunk size for static chunk training and
|
||||
decoding
|
||||
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
||||
training or not, You can only use fixed chunk(chunk_size > 0)
|
||||
or dyanmic chunk size(use_dynamic_chunk = True)
|
||||
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
||||
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
||||
dynamic chunk training
|
||||
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
||||
gradient_checkpointing: rerunning a forward-pass segment for each
|
||||
checkpointed segment during backward.
|
||||
"""
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
self.global_cmvn = global_cmvn
|
||||
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
||||
positional_dropout_rate),
|
||||
)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
||||
self.static_chunk_size = static_chunk_size
|
||||
self.use_dynamic_chunk = use_dynamic_chunk
|
||||
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
xs_lens: torch.Tensor,
|
||||
decoding_chunk_size: int = 0,
|
||||
num_decoding_left_chunks: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs: padded input tensor (B, T, D)
|
||||
xs_lens: input length (B)
|
||||
decoding_chunk_size: decoding chunk size for dynamic chunk
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
Returns:
|
||||
encoder output tensor xs, and subsampled masks
|
||||
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
||||
masks: torch.Tensor batch padding mask after subsample
|
||||
(B, 1, T' ~= T/subsample_rate)
|
||||
NOTE(xcsong):
|
||||
We pass the `__call__` method of the modules instead of `forward` to the
|
||||
checkpointing API because `__call__` attaches all the hooks of the module.
|
||||
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||
"""
|
||||
T = xs.size(1)
|
||||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||||
if self.global_cmvn is not None:
|
||||
xs = self.global_cmvn(xs)
|
||||
xs, pos_emb, masks = self.embed(xs, masks)
|
||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||
chunk_masks = add_optional_chunk_mask(xs, masks,
|
||||
self.use_dynamic_chunk,
|
||||
self.use_dynamic_left_chunk,
|
||||
decoding_chunk_size,
|
||||
self.static_chunk_size,
|
||||
num_decoding_left_chunks)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
|
||||
mask_pad)
|
||||
else:
|
||||
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
# Here we assume the mask is not changed in encoder layers, so just
|
||||
# return the masks before encoder layers, and the masks will be used
|
||||
# for cross attention with decoder later
|
||||
return xs, masks
|
||||
|
||||
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
mask_pad: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.encoders:
|
||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||
return xs
|
||||
|
||||
@torch.jit.ignore(drop=True)
|
||||
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
||||
chunk_masks: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
mask_pad: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.encoders:
|
||||
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
|
||||
chunk_masks, pos_emb,
|
||||
mask_pad)
|
||||
return xs
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
offset: int,
|
||||
required_cache_size: int,
|
||||
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
""" Forward just one chunk
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
|
||||
where `time == (chunk_size - 1) * subsample_rate + \
|
||||
subsample.right_context + 1`
|
||||
offset (int): current offset in encoder output time stamp
|
||||
required_cache_size (int): cache size required for next chunk
|
||||
compuation
|
||||
>=0: actual cache size
|
||||
<0: means all history cache is required
|
||||
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
||||
transformer/conformer attention, with shape
|
||||
(elayers, head, cache_t1, d_k * 2), where
|
||||
`head * d_k == hidden-dim` and
|
||||
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
||||
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
||||
(elayers, b=1, hidden-dim, cache_t2), where
|
||||
`cache_t2 == cnn.lorder - 1`
|
||||
|
||||
Returns:
|
||||
torch.Tensor: output of current input xs,
|
||||
with shape (b=1, chunk_size, hidden-dim).
|
||||
torch.Tensor: new attention cache required for next chunk, with
|
||||
dynamic shape (elayers, head, ?, d_k * 2)
|
||||
depending on required_cache_size.
|
||||
torch.Tensor: new conformer cnn cache required for next chunk, with
|
||||
same shape as the original cnn_cache.
|
||||
|
||||
"""
|
||||
assert xs.size(0) == 1
|
||||
# tmp_masks is just for interface compatibility
|
||||
tmp_masks = torch.ones(1,
|
||||
xs.size(1),
|
||||
device=xs.device,
|
||||
dtype=torch.bool)
|
||||
tmp_masks = tmp_masks.unsqueeze(1)
|
||||
if self.global_cmvn is not None:
|
||||
xs = self.global_cmvn(xs)
|
||||
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
|
||||
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
|
||||
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
|
||||
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
|
||||
chunk_size = xs.size(1)
|
||||
attention_key_size = cache_t1 + chunk_size
|
||||
pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
|
||||
size=attention_key_size)
|
||||
if required_cache_size < 0:
|
||||
next_cache_start = 0
|
||||
elif required_cache_size == 0:
|
||||
next_cache_start = attention_key_size
|
||||
else:
|
||||
next_cache_start = max(attention_key_size - required_cache_size, 0)
|
||||
r_att_cache = []
|
||||
r_cnn_cache = []
|
||||
for i, layer in enumerate(self.encoders):
|
||||
# NOTE(xcsong): Before layer.forward
|
||||
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
|
||||
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
|
||||
xs, _, new_att_cache, new_cnn_cache = layer(
|
||||
xs,
|
||||
att_mask,
|
||||
pos_emb,
|
||||
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
||||
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
|
||||
# NOTE(xcsong): After layer.forward
|
||||
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
|
||||
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
|
||||
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
|
||||
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
|
||||
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
|
||||
# ? may be larger than cache_t1, it depends on required_cache_size
|
||||
r_att_cache = torch.cat(r_att_cache, dim=0)
|
||||
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
|
||||
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
|
||||
|
||||
return (xs, r_att_cache, r_cnn_cache)
|
||||
|
||||
def forward_chunk_by_chunk(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
decoding_chunk_size: int,
|
||||
num_decoding_left_chunks: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
""" Forward input chunk by chunk with chunk_size like a streaming
|
||||
fashion
|
||||
|
||||
Here we should pay special attention to computation cache in the
|
||||
streaming style forward chunk by chunk. Three things should be taken
|
||||
into account for computation in the current network:
|
||||
1. transformer/conformer encoder layers output cache
|
||||
2. convolution in conformer
|
||||
3. convolution in subsampling
|
||||
|
||||
However, we don't implement subsampling cache for:
|
||||
1. We can control subsampling module to output the right result by
|
||||
overlapping input instead of cache left context, even though it
|
||||
wastes some computation, but subsampling only takes a very
|
||||
small fraction of computation in the whole model.
|
||||
2. Typically, there are several covolution layers with subsampling
|
||||
in subsampling module, it is tricky and complicated to do cache
|
||||
with different convolution layers with different subsampling
|
||||
rate.
|
||||
3. Currently, nn.Sequential is used to stack all the convolution
|
||||
layers in subsampling, we need to rewrite it to make it work
|
||||
with cache, which is not prefered.
|
||||
Args:
|
||||
xs (torch.Tensor): (1, max_len, dim)
|
||||
chunk_size (int): decoding chunk size
|
||||
"""
|
||||
assert decoding_chunk_size > 0
|
||||
# The model is trained by static or dynamic chunk
|
||||
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
|
||||
subsampling = self.embed.subsampling_rate
|
||||
context = self.embed.right_context + 1 # Add current frame
|
||||
stride = subsampling * decoding_chunk_size
|
||||
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
||||
num_frames = xs.size(1)
|
||||
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
||||
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
||||
outputs = []
|
||||
offset = 0
|
||||
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
||||
|
||||
# Feed forward overlap input step by step
|
||||
for cur in range(0, num_frames - context + 1, stride):
|
||||
end = min(cur + decoding_window, num_frames)
|
||||
chunk_xs = xs[:, cur:end, :]
|
||||
(y, att_cache,
|
||||
cnn_cache) = self.forward_chunk(chunk_xs, offset,
|
||||
required_cache_size, att_cache,
|
||||
cnn_cache)
|
||||
outputs.append(y)
|
||||
offset += y.size(1)
|
||||
ys = torch.cat(outputs, 1)
|
||||
masks = torch.ones((1, 1, ys.size(1)),
|
||||
device=ys.device,
|
||||
dtype=torch.bool)
|
||||
return ys, masks
|
||||
|
||||
|
||||
class TransformerEncoder(BaseEncoder):
|
||||
"""Transformer encoder module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
pos_enc_layer_type: str = "abs_pos",
|
||||
normalize_before: bool = True,
|
||||
static_chunk_size: int = 0,
|
||||
use_dynamic_chunk: bool = False,
|
||||
global_cmvn: torch.nn.Module = None,
|
||||
use_dynamic_left_chunk: bool = False,
|
||||
key_bias: bool = True,
|
||||
selfattention_layer_type: str = "selfattn",
|
||||
activation_type: str = "relu",
|
||||
gradient_checkpointing: bool = False,
|
||||
):
|
||||
""" Construct TransformerEncoder
|
||||
|
||||
See Encoder for the meaning of each parameter.
|
||||
"""
|
||||
super().__init__(input_size, output_size, attention_heads,
|
||||
linear_units, num_blocks, dropout_rate,
|
||||
positional_dropout_rate, attention_dropout_rate,
|
||||
input_layer, pos_enc_layer_type, normalize_before,
|
||||
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
||||
use_dynamic_left_chunk, gradient_checkpointing)
|
||||
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
||||
self.encoders = torch.nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
output_size,
|
||||
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
key_bias),
|
||||
PositionwiseFeedForward(output_size, linear_units,
|
||||
dropout_rate, activation),
|
||||
dropout_rate, normalize_before) for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
|
||||
class ConformerEncoder(BaseEncoder):
|
||||
"""Conformer encoder module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
pos_enc_layer_type: str = "rel_pos",
|
||||
normalize_before: bool = True,
|
||||
static_chunk_size: int = 0,
|
||||
use_dynamic_chunk: bool = False,
|
||||
global_cmvn: torch.nn.Module = None,
|
||||
use_dynamic_left_chunk: bool = False,
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
macaron_style: bool = True,
|
||||
selfattention_layer_type: str = "rel_selfattn",
|
||||
activation_type: str = "swish",
|
||||
use_cnn_module: bool = True,
|
||||
cnn_module_kernel: int = 15,
|
||||
causal: bool = False,
|
||||
cnn_module_norm: str = "batch_norm",
|
||||
key_bias: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
):
|
||||
"""Construct ConformerEncoder
|
||||
|
||||
Args:
|
||||
input_size to use_dynamic_chunk, see in BaseEncoder
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
||||
conv1d layer.
|
||||
macaron_style (bool): Whether to use macaron style for
|
||||
positionwise layer.
|
||||
selfattention_layer_type (str): Encoder attention layer type,
|
||||
the parameter has no effect now, it's just for configure
|
||||
compatibility.
|
||||
activation_type (str): Encoder activation function type.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
causal (bool): whether to use causal convolution or not.
|
||||
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
||||
"""
|
||||
super().__init__(input_size, output_size, attention_heads,
|
||||
linear_units, num_blocks, dropout_rate,
|
||||
positional_dropout_rate, attention_dropout_rate,
|
||||
input_layer, pos_enc_layer_type, normalize_before,
|
||||
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
||||
use_dynamic_left_chunk, gradient_checkpointing)
|
||||
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
||||
|
||||
# self-attention module definition
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
key_bias,
|
||||
)
|
||||
# feed-forward module definition
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
# convolution module definition
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
||||
cnn_module_norm, causal)
|
||||
|
||||
self.encoders = torch.nn.ModuleList([
|
||||
ConformerEncoderLayer(
|
||||
output_size,
|
||||
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
||||
*encoder_selfattn_layer_args),
|
||||
PositionwiseFeedForward(*positionwise_layer_args),
|
||||
PositionwiseFeedForward(
|
||||
*positionwise_layer_args) if macaron_style else None,
|
||||
ConvolutionModule(
|
||||
*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
) for _ in range(num_blocks)
|
||||
])
|
||||
Reference in New Issue
Block a user