mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/models/__init__.py
Normal file
0
funasr_local/models/__init__.py
Normal file
187
funasr_local/models/ctc.py
Normal file
187
funasr_local/models/ctc.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
class CTC(torch.nn.Module):
|
||||
"""CTC module.
|
||||
|
||||
Args:
|
||||
odim: dimension of outputs
|
||||
encoder_output_size: number of encoder projection units
|
||||
dropout_rate: dropout rate (0.0 ~ 1.0)
|
||||
ctc_type: builtin or warpctc
|
||||
reduce: reduce the CTC loss into a scalar
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
odim: int,
|
||||
encoder_output_size: int,
|
||||
dropout_rate: float = 0.0,
|
||||
ctc_type: str = "builtin",
|
||||
reduce: bool = True,
|
||||
ignore_nan_grad: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dropout_rate = dropout_rate
|
||||
self.ctc_lo = torch.nn.Linear(eprojs, odim)
|
||||
self.ctc_type = ctc_type
|
||||
self.ignore_nan_grad = ignore_nan_grad
|
||||
|
||||
if self.ctc_type == "builtin":
|
||||
self.ctc_loss = torch.nn.CTCLoss(reduction="none")
|
||||
elif self.ctc_type == "warpctc":
|
||||
import warpctc_pytorch as warp_ctc
|
||||
|
||||
if ignore_nan_grad:
|
||||
logging.warning("ignore_nan_grad option is not supported for warp_ctc")
|
||||
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
|
||||
|
||||
elif self.ctc_type == "gtnctc":
|
||||
from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
|
||||
|
||||
self.ctc_loss = GTNCTCLossFunction.apply
|
||||
else:
|
||||
raise ValueError(
|
||||
f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'
|
||||
)
|
||||
|
||||
self.reduce = reduce
|
||||
|
||||
def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor:
|
||||
if self.ctc_type == "builtin":
|
||||
th_pred = th_pred.log_softmax(2)
|
||||
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
|
||||
|
||||
if loss.requires_grad and self.ignore_nan_grad:
|
||||
# ctc_grad: (L, B, O)
|
||||
ctc_grad = loss.grad_fn(torch.ones_like(loss))
|
||||
ctc_grad = ctc_grad.sum([0, 2])
|
||||
indices = torch.isfinite(ctc_grad)
|
||||
size = indices.long().sum()
|
||||
if size == 0:
|
||||
# Return as is
|
||||
logging.warning(
|
||||
"All samples in this mini-batch got nan grad."
|
||||
" Returning nan value instead of CTC loss"
|
||||
)
|
||||
elif size != th_pred.size(1):
|
||||
logging.warning(
|
||||
f"{th_pred.size(1) - size}/{th_pred.size(1)}"
|
||||
" samples got nan grad."
|
||||
" These were ignored for CTC loss."
|
||||
)
|
||||
|
||||
# Create mask for target
|
||||
target_mask = torch.full(
|
||||
[th_target.size(0)],
|
||||
1,
|
||||
dtype=torch.bool,
|
||||
device=th_target.device,
|
||||
)
|
||||
s = 0
|
||||
for ind, le in enumerate(th_olen):
|
||||
if not indices[ind]:
|
||||
target_mask[s : s + le] = 0
|
||||
s += le
|
||||
|
||||
# Calc loss again using maksed data
|
||||
loss = self.ctc_loss(
|
||||
th_pred[:, indices, :],
|
||||
th_target[target_mask],
|
||||
th_ilen[indices],
|
||||
th_olen[indices],
|
||||
)
|
||||
else:
|
||||
size = th_pred.size(1)
|
||||
|
||||
if self.reduce:
|
||||
# Batch-size average
|
||||
loss = loss.sum() / size
|
||||
else:
|
||||
loss = loss / size
|
||||
return loss
|
||||
|
||||
elif self.ctc_type == "warpctc":
|
||||
# warpctc only supports float32
|
||||
th_pred = th_pred.to(dtype=torch.float32)
|
||||
|
||||
th_target = th_target.cpu().int()
|
||||
th_ilen = th_ilen.cpu().int()
|
||||
th_olen = th_olen.cpu().int()
|
||||
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
|
||||
if self.reduce:
|
||||
# NOTE: sum() is needed to keep consistency since warpctc
|
||||
# return as tensor w/ shape (1,)
|
||||
# but builtin return as tensor w/o shape (scalar).
|
||||
loss = loss.sum()
|
||||
return loss
|
||||
|
||||
elif self.ctc_type == "gtnctc":
|
||||
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
|
||||
return self.ctc_loss(log_probs, th_target, th_ilen, 0, "none")
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, hs_pad, hlens, ys_pad, ys_lens):
|
||||
"""Calculate CTC loss.
|
||||
|
||||
Args:
|
||||
hs_pad: batch of padded hidden state sequences (B, Tmax, D)
|
||||
hlens: batch of lengths of hidden state sequences (B)
|
||||
ys_pad: batch of padded character id sequence tensor (B, Lmax)
|
||||
ys_lens: batch of lengths of character sequence (B)
|
||||
"""
|
||||
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
|
||||
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
|
||||
|
||||
if self.ctc_type == "gtnctc":
|
||||
# gtn expects list form for ys
|
||||
ys_true = [y[y != -1] for y in ys_pad] # parse padded ys
|
||||
else:
|
||||
# ys_hat: (B, L, D) -> (L, B, D)
|
||||
ys_hat = ys_hat.transpose(0, 1)
|
||||
# (B, L) -> (BxL,)
|
||||
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
|
||||
|
||||
loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
|
||||
device=hs_pad.device, dtype=hs_pad.dtype
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def softmax(self, hs_pad):
|
||||
"""softmax of frame activations
|
||||
|
||||
Args:
|
||||
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
|
||||
"""
|
||||
return F.softmax(self.ctc_lo(hs_pad), dim=2)
|
||||
|
||||
def log_softmax(self, hs_pad):
|
||||
"""log_softmax of frame activations
|
||||
|
||||
Args:
|
||||
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
|
||||
"""
|
||||
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
|
||||
|
||||
def argmax(self, hs_pad):
|
||||
"""argmax of frame activations
|
||||
|
||||
Args:
|
||||
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: argmax applied 2d tensor (B, Tmax)
|
||||
"""
|
||||
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
|
||||
160
funasr_local/models/data2vec.py
Normal file
160
funasr_local/models/data2vec.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.layers.abs_normalize import AbsNormalize
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr_local.torch_utils.device_funcs import force_gatherable
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class Data2VecPretrainModel(AbsESPnetModel):
|
||||
"""Data2Vec Pretrain model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.encoder = encoder
|
||||
self.num_updates = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape)
|
||||
|
||||
self.encoder.set_num_updates(self.num_updates)
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out = self.encode(speech, speech_lengths)
|
||||
|
||||
losses = encoder_out["losses"]
|
||||
loss = sum(losses.values())
|
||||
sample_size = encoder_out["sample_size"]
|
||||
loss = loss.sum() / sample_size
|
||||
|
||||
target_var = float(encoder_out["target_var"])
|
||||
pred_var = float(encoder_out["pred_var"])
|
||||
ema_decay = float(encoder_out["ema_decay"])
|
||||
|
||||
stats = dict(
|
||||
loss=torch.clone(loss.detach()),
|
||||
target_var=target_var,
|
||||
pred_var=pred_var,
|
||||
ema_decay=ema_decay,
|
||||
)
|
||||
|
||||
loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
):
|
||||
"""Frontend + Encoder.
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
|
||||
speech_lengths = None
|
||||
encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
|
||||
|
||||
return encoder_out
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
self.num_updates = num_updates
|
||||
|
||||
def get_num_updates(self):
|
||||
return self.num_updates
|
||||
@@ -0,0 +1,334 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.nets_utils import to_device
|
||||
from funasr_local.modules.rnn.attentions import initial_att
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
def build_attention_list(
|
||||
eprojs: int,
|
||||
dunits: int,
|
||||
atype: str = "location",
|
||||
num_att: int = 1,
|
||||
num_encs: int = 1,
|
||||
aheads: int = 4,
|
||||
adim: int = 320,
|
||||
awin: int = 5,
|
||||
aconv_chans: int = 10,
|
||||
aconv_filts: int = 100,
|
||||
han_mode: bool = False,
|
||||
han_type=None,
|
||||
han_heads: int = 4,
|
||||
han_dim: int = 320,
|
||||
han_conv_chans: int = -1,
|
||||
han_conv_filts: int = 100,
|
||||
han_win: int = 5,
|
||||
):
|
||||
|
||||
att_list = torch.nn.ModuleList()
|
||||
if num_encs == 1:
|
||||
for i in range(num_att):
|
||||
att = initial_att(
|
||||
atype,
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads,
|
||||
adim,
|
||||
awin,
|
||||
aconv_chans,
|
||||
aconv_filts,
|
||||
)
|
||||
att_list.append(att)
|
||||
elif num_encs > 1: # no multi-speaker mode
|
||||
if han_mode:
|
||||
att = initial_att(
|
||||
han_type,
|
||||
eprojs,
|
||||
dunits,
|
||||
han_heads,
|
||||
han_dim,
|
||||
han_win,
|
||||
han_conv_chans,
|
||||
han_conv_filts,
|
||||
han_mode=True,
|
||||
)
|
||||
return att
|
||||
else:
|
||||
att_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
att = initial_att(
|
||||
atype[idx],
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads[idx],
|
||||
adim[idx],
|
||||
awin[idx],
|
||||
aconv_chans[idx],
|
||||
aconv_filts[idx],
|
||||
)
|
||||
att_list.append(att)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Number of encoders needs to be more than one. {}".format(num_encs)
|
||||
)
|
||||
return att_list
|
||||
|
||||
|
||||
class RNNDecoder(AbsDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
hidden_size: int = 320,
|
||||
sampling_probability: float = 0.0,
|
||||
dropout: float = 0.0,
|
||||
context_residual: bool = False,
|
||||
replace_sos: bool = False,
|
||||
num_encs: int = 1,
|
||||
att_conf: dict = get_default_kwargs(build_attention_list),
|
||||
):
|
||||
# FIXME(kamo): The parts of num_spk should be refactored more more more
|
||||
assert check_argument_types()
|
||||
if rnn_type not in {"lstm", "gru"}:
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dtype = rnn_type
|
||||
self.dunits = hidden_size
|
||||
self.dlayers = num_layers
|
||||
self.context_residual = context_residual
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.odim = vocab_size
|
||||
self.sampling_probability = sampling_probability
|
||||
self.dropout = dropout
|
||||
self.num_encs = num_encs
|
||||
|
||||
# for multilingual translation
|
||||
self.replace_sos = replace_sos
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
|
||||
self.dropout_emb = torch.nn.Dropout(p=dropout)
|
||||
|
||||
self.decoder = torch.nn.ModuleList()
|
||||
self.dropout_dec = torch.nn.ModuleList()
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
for _ in range(1, self.dlayers):
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
# NOTE: dropout is applied only for the vertical connections
|
||||
# see https://arxiv.org/pdf/1409.2329.pdf
|
||||
|
||||
if context_residual:
|
||||
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
|
||||
else:
|
||||
self.output = torch.nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
self.att_list = build_attention_list(
|
||||
eprojs=eprojs, dunits=hidden_size, **att_conf
|
||||
)
|
||||
|
||||
def zero_state(self, hs_pad):
|
||||
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
|
||||
|
||||
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
|
||||
if self.dtype == "lstm":
|
||||
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i], c_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]),
|
||||
(z_prev[i], c_prev[i]),
|
||||
)
|
||||
else:
|
||||
z_list[0] = self.decoder[0](ey, z_prev[0])
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
|
||||
)
|
||||
return z_list, c_list
|
||||
|
||||
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
hs_pad = [hs_pad]
|
||||
hlens = [hlens]
|
||||
|
||||
# attention index for the attention module
|
||||
# in SPA (speaker parallel attention),
|
||||
# att_idx is used to select attention module. In other cases, it is 0.
|
||||
att_idx = min(strm_idx, len(self.att_list) - 1)
|
||||
|
||||
# hlens should be list of list of integer
|
||||
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
|
||||
|
||||
# get dim, length info
|
||||
olength = ys_in_pad.size(1)
|
||||
|
||||
# initialization
|
||||
c_list = [self.zero_state(hs_pad[0])]
|
||||
z_list = [self.zero_state(hs_pad[0])]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(hs_pad[0]))
|
||||
z_list.append(self.zero_state(hs_pad[0]))
|
||||
z_all = []
|
||||
if self.num_encs == 1:
|
||||
att_w = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
att_w_list = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
|
||||
# pre-computation of embedding
|
||||
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
|
||||
|
||||
# loop for an output sequence
|
||||
for i in range(olength):
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
|
||||
)
|
||||
else:
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w_list[idx] = self.att_list[idx](
|
||||
hs_pad[idx],
|
||||
hlens[idx],
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[idx],
|
||||
)
|
||||
hs_pad_han = torch.stack(att_c_list, dim=1)
|
||||
hlens_han = [self.num_encs] * len(ys_in_pad)
|
||||
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
|
||||
hs_pad_han,
|
||||
hlens_han,
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[self.num_encs],
|
||||
)
|
||||
if i > 0 and random.random() < self.sampling_probability:
|
||||
z_out = self.output(z_all[-1])
|
||||
z_out = np.argmax(z_out.detach().cpu(), axis=1)
|
||||
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
|
||||
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
|
||||
else:
|
||||
# utt x (zdim + hdim)
|
||||
ey = torch.cat((eys[:, i, :], att_c), dim=1)
|
||||
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
|
||||
if self.context_residual:
|
||||
z_all.append(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
) # utt x (zdim + hdim)
|
||||
else:
|
||||
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
|
||||
|
||||
z_all = torch.stack(z_all, dim=1)
|
||||
z_all = self.output(z_all)
|
||||
z_all.masked_fill_(
|
||||
make_pad_mask(ys_in_lens, z_all, 1),
|
||||
0,
|
||||
)
|
||||
return z_all, ys_in_lens
|
||||
|
||||
def init_state(self, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
c_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
z_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
z_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
# TODO(karita): support strm_index for `asr_mix`
|
||||
strm_index = 0
|
||||
att_idx = min(strm_index, len(self.att_list) - 1)
|
||||
if self.num_encs == 1:
|
||||
a = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
a = [None] * (self.num_encs + 1) # atts + han
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
return dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=a,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
)
|
||||
|
||||
def score(self, yseq, state, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
att_idx, z_list, c_list = state["workspace"]
|
||||
vy = yseq[-1].unsqueeze(0)
|
||||
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
x[0].unsqueeze(0),
|
||||
[x[0].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"],
|
||||
)
|
||||
else:
|
||||
att_w = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w[idx] = self.att_list[idx](
|
||||
x[idx].unsqueeze(0),
|
||||
[x[idx].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][idx],
|
||||
)
|
||||
h_han = torch.stack(att_c_list, dim=1)
|
||||
att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
|
||||
h_han,
|
||||
[self.num_encs],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][self.num_encs],
|
||||
)
|
||||
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
|
||||
z_list, c_list = self.rnn_forward(
|
||||
ey, z_list, c_list, state["z_prev"], state["c_prev"]
|
||||
)
|
||||
if self.context_residual:
|
||||
logits = self.output(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
)
|
||||
else:
|
||||
logits = self.output(self.dropout_dec[-1](z_list[-1]))
|
||||
logp = F.log_softmax(logits, dim=1).squeeze(0)
|
||||
return (
|
||||
logp,
|
||||
dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=att_w,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,258 @@
|
||||
"""RNN decoder definition for Transducer models."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.beam_search.beam_search_transducer import Hypothesis
|
||||
from funasr_local.models.specaug.specaug import SpecAug
|
||||
|
||||
class RNNTDecoder(torch.nn.Module):
|
||||
"""RNN decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size.
|
||||
embed_size: Embedding size.
|
||||
hidden_size: Hidden size..
|
||||
rnn_type: Decoder layers type.
|
||||
num_layers: Number of decoder layers.
|
||||
dropout_rate: Dropout rate for decoder layers.
|
||||
embed_dropout_rate: Dropout rate for embedding layer.
|
||||
embed_pad: Embedding padding symbol ID.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embed_size: int = 256,
|
||||
hidden_size: int = 256,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
dropout_rate: float = 0.0,
|
||||
embed_dropout_rate: float = 0.0,
|
||||
embed_pad: int = 0,
|
||||
) -> None:
|
||||
"""Construct a RNNDecoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
if rnn_type not in ("lstm", "gru"):
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
|
||||
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
|
||||
|
||||
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
|
||||
|
||||
self.rnn = torch.nn.ModuleList(
|
||||
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
|
||||
)
|
||||
|
||||
for _ in range(1, num_layers):
|
||||
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
|
||||
|
||||
self.dropout_rnn = torch.nn.ModuleList(
|
||||
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.dlayers = num_layers
|
||||
self.dtype = rnn_type
|
||||
|
||||
self.output_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.device = next(self.parameters()).device
|
||||
self.score_cache = {}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
labels: torch.Tensor,
|
||||
label_lens: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
states: Decoder hidden states.
|
||||
((N, B, D_dec), (N, B, D_dec) or None) or None
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, U, D_dec)
|
||||
|
||||
"""
|
||||
if states is None:
|
||||
states = self.init_state(labels.size(0))
|
||||
|
||||
dec_embed = self.dropout_embed(self.embed(labels))
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
return dec_out
|
||||
|
||||
def rnn_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
x: RNN input sequences. (B, D_emb)
|
||||
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
x: RNN output sequences. (B, D_dec)
|
||||
(h_next, c_next): Decoder hidden states.
|
||||
(N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_prev, c_prev = state
|
||||
h_next, c_next = self.init_state(x.size(0))
|
||||
|
||||
for layer in range(self.dlayers):
|
||||
if self.dtype == "lstm":
|
||||
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
|
||||
layer
|
||||
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
|
||||
else:
|
||||
x, h_next[layer : layer + 1] = self.rnn[layer](
|
||||
x, hx=h_prev[layer : layer + 1]
|
||||
)
|
||||
|
||||
x = self.dropout_rnn[layer](x)
|
||||
|
||||
return x, (h_next, c_next)
|
||||
|
||||
def score(
|
||||
self,
|
||||
label: torch.Tensor,
|
||||
label_sequence: List[int],
|
||||
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypothesis.
|
||||
|
||||
Args:
|
||||
label: Previous label. (1, 1)
|
||||
label_sequence: Current label sequence.
|
||||
dec_state: Previous decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequence. (1, D_dec)
|
||||
dec_state: Decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
str_labels = "_".join(map(str, label_sequence))
|
||||
|
||||
if str_labels in self.score_cache:
|
||||
dec_out, dec_state = self.score_cache[str_labels]
|
||||
else:
|
||||
dec_embed = self.embed(label)
|
||||
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
|
||||
|
||||
self.score_cache[str_labels] = (dec_out, dec_state)
|
||||
|
||||
return dec_out[0], dec_state
|
||||
|
||||
def batch_score(
|
||||
self,
|
||||
hyps: List[Hypothesis],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypotheses.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, D_dec)
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
|
||||
dec_embed = self.embed(labels)
|
||||
|
||||
states = self.create_batch_states([h.dec_state for h in hyps])
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
|
||||
return dec_out.squeeze(1), states
|
||||
|
||||
def set_device(self, device: torch.device) -> None:
|
||||
"""Set GPU device to use.
|
||||
|
||||
Args:
|
||||
device: Device ID.
|
||||
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
def init_state(
|
||||
self, batch_size: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
|
||||
"""Initialize decoder states.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
|
||||
Returns:
|
||||
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.dtype == "lstm":
|
||||
c_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
return (h_n, c_n)
|
||||
|
||||
return (h_n, None)
|
||||
|
||||
def select_state(
|
||||
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Get specified ID state from decoder hidden states.
|
||||
|
||||
Args:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
idx: State ID to extract.
|
||||
|
||||
Returns:
|
||||
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
states[0][:, idx : idx + 1, :],
|
||||
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
|
||||
)
|
||||
|
||||
def create_batch_states(
|
||||
self,
|
||||
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Create decoder hidden states.
|
||||
|
||||
Args:
|
||||
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
|
||||
|
||||
Returns:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
torch.cat([s[0] for s in new_states], dim=1),
|
||||
torch.cat([s[1] for s in new_states], dim=1)
|
||||
if self.dtype == "lstm"
|
||||
else None,
|
||||
)
|
||||
0
funasr_local/models/decoder/__init__.py
Normal file
0
funasr_local/models/decoder/__init__.py
Normal file
19
funasr_local/models/decoder/abs_decoder.py
Normal file
19
funasr_local/models/decoder/abs_decoder.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.scorers.scorer_interface import ScorerInterface
|
||||
|
||||
|
||||
class AbsDecoder(torch.nn.Module, ScorerInterface, ABC):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
776
funasr_local/models/decoder/contextual_decoder.py
Normal file
776
funasr_local/models/decoder/contextual_decoder.py
Normal file
@@ -0,0 +1,776 @@
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from funasr_local.modules.streaming_utils import utils as myutils
|
||||
from funasr_local.models.decoder.transformer_decoder import BaseTransformerDecoder
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
|
||||
from funasr_local.modules.embedding import PositionalEncoding
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
||||
from funasr_local.modules.repeat import repeat
|
||||
from funasr_local.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
|
||||
|
||||
|
||||
class ContextualDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(ContextualDecoderLayer, self).__init__()
|
||||
self.size = size
|
||||
self.self_attn = self_attn
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
if self_attn is not None:
|
||||
self.norm2 = LayerNorm(size)
|
||||
if src_attn is not None:
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
|
||||
# tgt = self.dropout(tgt)
|
||||
if isinstance(tgt, Tuple):
|
||||
tgt, _ = tgt
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
tgt = self.feed_forward(tgt)
|
||||
|
||||
x = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm2(tgt)
|
||||
if self.training:
|
||||
cache = None
|
||||
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
|
||||
x = residual + self.dropout(x)
|
||||
x_self_attn = x
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
x = self.src_attn(x, memory, memory_mask)
|
||||
x_src_attn = x
|
||||
|
||||
x = residual + self.dropout(x)
|
||||
return x, tgt_mask, x_self_attn, x_src_attn
|
||||
|
||||
|
||||
class ContextualBiasDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
src_attn,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(ContextualBiasDecoder, self).__init__()
|
||||
self.size = size
|
||||
self.src_attn = src_attn
|
||||
if src_attn is not None:
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
|
||||
x = tgt
|
||||
if self.src_attn is not None:
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
x = self.dropout(self.src_attn(x, memory, memory_mask))
|
||||
return x, tgt_mask, memory, memory_mask, cache
|
||||
|
||||
|
||||
class ContextualParaformerDecoder(ParaformerSANMDecoder):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
att_layer_num: int = 6,
|
||||
kernel_size: int = 21,
|
||||
sanm_shfit: int = 0,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
if input_layer == 'none':
|
||||
self.embed = None
|
||||
if input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
# pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(vocab_size, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
if use_output_layer:
|
||||
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.output_layer = None
|
||||
|
||||
self.att_layer_num = att_layer_num
|
||||
self.num_blocks = num_blocks
|
||||
if sanm_shfit is None:
|
||||
sanm_shfit = (kernel_size - 1) // 2
|
||||
self.decoders = repeat(
|
||||
att_layer_num - 1,
|
||||
lambda lnum: DecoderLayerSANM(
|
||||
attention_dim,
|
||||
MultiHeadedAttentionSANMDecoder(
|
||||
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
|
||||
),
|
||||
MultiHeadedAttentionCrossAtt(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.bias_decoder = ContextualBiasDecoder(
|
||||
size=attention_dim,
|
||||
src_attn=MultiHeadedAttentionCrossAtt(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
dropout_rate=dropout_rate,
|
||||
normalize_before=True,
|
||||
)
|
||||
self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
|
||||
self.last_decoder = ContextualDecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttentionSANMDecoder(
|
||||
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
|
||||
),
|
||||
MultiHeadedAttentionCrossAtt(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
if num_blocks - att_layer_num <= 0:
|
||||
self.decoders2 = None
|
||||
else:
|
||||
self.decoders2 = repeat(
|
||||
num_blocks - att_layer_num,
|
||||
lambda lnum: DecoderLayerSANM(
|
||||
attention_dim,
|
||||
MultiHeadedAttentionSANMDecoder(
|
||||
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
|
||||
),
|
||||
None,
|
||||
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.decoders3 = repeat(
|
||||
1,
|
||||
lambda lnum: DecoderLayerSANM(
|
||||
attention_dim,
|
||||
None,
|
||||
None,
|
||||
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
contextual_info: torch.Tensor,
|
||||
return_hidden: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
|
||||
|
||||
x = tgt
|
||||
x, tgt_mask, memory, memory_mask, _ = self.decoders(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
_, _, x_self_attn, x_src_attn = self.last_decoder(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
|
||||
# contextual paraformer related
|
||||
contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
|
||||
contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
|
||||
cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
|
||||
|
||||
if self.bias_output is not None:
|
||||
x = torch.cat([x_src_attn, cx], dim=2)
|
||||
x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
|
||||
x = x_self_attn + self.dropout(x)
|
||||
|
||||
if self.decoders2 is not None:
|
||||
x, tgt_mask, memory, memory_mask, _ = self.decoders2(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
|
||||
x, tgt_mask, memory, memory_mask, _ = self.decoders3(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
olens = tgt_mask.sum(1)
|
||||
if self.output_layer is not None and return_hidden is False:
|
||||
x = self.output_layer(x)
|
||||
return x, olens
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
|
||||
## decoder
|
||||
# ffn
|
||||
"{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
|
||||
# fsmn
|
||||
"{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
|
||||
tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
|
||||
tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
|
||||
tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 2, 0),
|
||||
}, # (256,1,31),(1,31,256,1)
|
||||
# src att
|
||||
"{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# dnn
|
||||
"{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
|
||||
# embed_concat_ffn
|
||||
"{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
|
||||
# out norm
|
||||
"{}.after_norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.after_norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
|
||||
# in embed
|
||||
"{}.embed.0.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/w_embs".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (4235,256),(4235,256)
|
||||
|
||||
# out layer
|
||||
"{}.output_layer.weight".format(tensor_name_prefix_torch):
|
||||
{"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
|
||||
"squeeze": [None, None],
|
||||
"transpose": [(1, 0), None],
|
||||
}, # (4235,256),(256,4235)
|
||||
"{}.output_layer.bias".format(tensor_name_prefix_torch):
|
||||
{"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
|
||||
"seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
|
||||
"squeeze": [None, None],
|
||||
"transpose": [None, None],
|
||||
}, # (4235,),(4235,)
|
||||
|
||||
## clas decoder
|
||||
# src att
|
||||
"{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# dnn
|
||||
"{}.bias_output.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
|
||||
}
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
var_dict_torch_update = dict()
|
||||
decoder_layeridx_sets = set()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
names = name.split('.')
|
||||
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
||||
if names[1] == "decoders":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
elif names[1] == "last_decoder":
|
||||
layeridx = 15
|
||||
name_q = name.replace("last_decoder", "decoders.layeridx")
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
|
||||
elif names[1] == "decoders2":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
name_q = name_q.replace("decoders2", "decoders")
|
||||
layeridx_bias = len(decoder_layeridx_sets)
|
||||
|
||||
layeridx += layeridx_bias
|
||||
if "decoders." in name:
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "decoders3":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
if "decoders." in name:
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
elif names[1] == "bias_decoder":
|
||||
name_q = name
|
||||
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
|
||||
elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
|
||||
name_tf = map_dict[name]["name"]
|
||||
if isinstance(name_tf, list):
|
||||
idx_list = 0
|
||||
if name_tf[idx_list] in var_dict_tf.keys():
|
||||
pass
|
||||
else:
|
||||
idx_list = 1
|
||||
data_tf = var_dict_tf[name_tf[idx_list]]
|
||||
if map_dict[name]["squeeze"][idx_list] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
|
||||
if map_dict[name]["transpose"][idx_list] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
|
||||
name_tf[idx_list],
|
||||
var_dict_tf[name_tf[
|
||||
idx_list]].shape))
|
||||
|
||||
else:
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "after_norm":
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "embed_concat_ffn":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
if "decoders." in name:
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
334
funasr_local/models/decoder/rnn_decoder.py
Normal file
334
funasr_local/models/decoder/rnn_decoder.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.nets_utils import to_device
|
||||
from funasr_local.modules.rnn.attentions import initial_att
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
def build_attention_list(
|
||||
eprojs: int,
|
||||
dunits: int,
|
||||
atype: str = "location",
|
||||
num_att: int = 1,
|
||||
num_encs: int = 1,
|
||||
aheads: int = 4,
|
||||
adim: int = 320,
|
||||
awin: int = 5,
|
||||
aconv_chans: int = 10,
|
||||
aconv_filts: int = 100,
|
||||
han_mode: bool = False,
|
||||
han_type=None,
|
||||
han_heads: int = 4,
|
||||
han_dim: int = 320,
|
||||
han_conv_chans: int = -1,
|
||||
han_conv_filts: int = 100,
|
||||
han_win: int = 5,
|
||||
):
|
||||
|
||||
att_list = torch.nn.ModuleList()
|
||||
if num_encs == 1:
|
||||
for i in range(num_att):
|
||||
att = initial_att(
|
||||
atype,
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads,
|
||||
adim,
|
||||
awin,
|
||||
aconv_chans,
|
||||
aconv_filts,
|
||||
)
|
||||
att_list.append(att)
|
||||
elif num_encs > 1: # no multi-speaker mode
|
||||
if han_mode:
|
||||
att = initial_att(
|
||||
han_type,
|
||||
eprojs,
|
||||
dunits,
|
||||
han_heads,
|
||||
han_dim,
|
||||
han_win,
|
||||
han_conv_chans,
|
||||
han_conv_filts,
|
||||
han_mode=True,
|
||||
)
|
||||
return att
|
||||
else:
|
||||
att_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
att = initial_att(
|
||||
atype[idx],
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads[idx],
|
||||
adim[idx],
|
||||
awin[idx],
|
||||
aconv_chans[idx],
|
||||
aconv_filts[idx],
|
||||
)
|
||||
att_list.append(att)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Number of encoders needs to be more than one. {}".format(num_encs)
|
||||
)
|
||||
return att_list
|
||||
|
||||
|
||||
class RNNDecoder(AbsDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
hidden_size: int = 320,
|
||||
sampling_probability: float = 0.0,
|
||||
dropout: float = 0.0,
|
||||
context_residual: bool = False,
|
||||
replace_sos: bool = False,
|
||||
num_encs: int = 1,
|
||||
att_conf: dict = get_default_kwargs(build_attention_list),
|
||||
):
|
||||
# FIXME(kamo): The parts of num_spk should be refactored more more more
|
||||
assert check_argument_types()
|
||||
if rnn_type not in {"lstm", "gru"}:
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dtype = rnn_type
|
||||
self.dunits = hidden_size
|
||||
self.dlayers = num_layers
|
||||
self.context_residual = context_residual
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.odim = vocab_size
|
||||
self.sampling_probability = sampling_probability
|
||||
self.dropout = dropout
|
||||
self.num_encs = num_encs
|
||||
|
||||
# for multilingual translation
|
||||
self.replace_sos = replace_sos
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
|
||||
self.dropout_emb = torch.nn.Dropout(p=dropout)
|
||||
|
||||
self.decoder = torch.nn.ModuleList()
|
||||
self.dropout_dec = torch.nn.ModuleList()
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
for _ in range(1, self.dlayers):
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
# NOTE: dropout is applied only for the vertical connections
|
||||
# see https://arxiv.org/pdf/1409.2329.pdf
|
||||
|
||||
if context_residual:
|
||||
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
|
||||
else:
|
||||
self.output = torch.nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
self.att_list = build_attention_list(
|
||||
eprojs=eprojs, dunits=hidden_size, **att_conf
|
||||
)
|
||||
|
||||
def zero_state(self, hs_pad):
|
||||
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
|
||||
|
||||
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
|
||||
if self.dtype == "lstm":
|
||||
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i], c_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]),
|
||||
(z_prev[i], c_prev[i]),
|
||||
)
|
||||
else:
|
||||
z_list[0] = self.decoder[0](ey, z_prev[0])
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
|
||||
)
|
||||
return z_list, c_list
|
||||
|
||||
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
hs_pad = [hs_pad]
|
||||
hlens = [hlens]
|
||||
|
||||
# attention index for the attention module
|
||||
# in SPA (speaker parallel attention),
|
||||
# att_idx is used to select attention module. In other cases, it is 0.
|
||||
att_idx = min(strm_idx, len(self.att_list) - 1)
|
||||
|
||||
# hlens should be list of list of integer
|
||||
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
|
||||
|
||||
# get dim, length info
|
||||
olength = ys_in_pad.size(1)
|
||||
|
||||
# initialization
|
||||
c_list = [self.zero_state(hs_pad[0])]
|
||||
z_list = [self.zero_state(hs_pad[0])]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(hs_pad[0]))
|
||||
z_list.append(self.zero_state(hs_pad[0]))
|
||||
z_all = []
|
||||
if self.num_encs == 1:
|
||||
att_w = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
att_w_list = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
|
||||
# pre-computation of embedding
|
||||
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
|
||||
|
||||
# loop for an output sequence
|
||||
for i in range(olength):
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
|
||||
)
|
||||
else:
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w_list[idx] = self.att_list[idx](
|
||||
hs_pad[idx],
|
||||
hlens[idx],
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[idx],
|
||||
)
|
||||
hs_pad_han = torch.stack(att_c_list, dim=1)
|
||||
hlens_han = [self.num_encs] * len(ys_in_pad)
|
||||
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
|
||||
hs_pad_han,
|
||||
hlens_han,
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[self.num_encs],
|
||||
)
|
||||
if i > 0 and random.random() < self.sampling_probability:
|
||||
z_out = self.output(z_all[-1])
|
||||
z_out = np.argmax(z_out.detach().cpu(), axis=1)
|
||||
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
|
||||
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
|
||||
else:
|
||||
# utt x (zdim + hdim)
|
||||
ey = torch.cat((eys[:, i, :], att_c), dim=1)
|
||||
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
|
||||
if self.context_residual:
|
||||
z_all.append(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
) # utt x (zdim + hdim)
|
||||
else:
|
||||
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
|
||||
|
||||
z_all = torch.stack(z_all, dim=1)
|
||||
z_all = self.output(z_all)
|
||||
z_all.masked_fill_(
|
||||
make_pad_mask(ys_in_lens, z_all, 1),
|
||||
0,
|
||||
)
|
||||
return z_all, ys_in_lens
|
||||
|
||||
def init_state(self, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
c_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
z_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
z_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
# TODO(karita): support strm_index for `asr_mix`
|
||||
strm_index = 0
|
||||
att_idx = min(strm_index, len(self.att_list) - 1)
|
||||
if self.num_encs == 1:
|
||||
a = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
a = [None] * (self.num_encs + 1) # atts + han
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
return dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=a,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
)
|
||||
|
||||
def score(self, yseq, state, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
att_idx, z_list, c_list = state["workspace"]
|
||||
vy = yseq[-1].unsqueeze(0)
|
||||
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
x[0].unsqueeze(0),
|
||||
[x[0].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"],
|
||||
)
|
||||
else:
|
||||
att_w = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w[idx] = self.att_list[idx](
|
||||
x[idx].unsqueeze(0),
|
||||
[x[idx].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][idx],
|
||||
)
|
||||
h_han = torch.stack(att_c_list, dim=1)
|
||||
att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
|
||||
h_han,
|
||||
[self.num_encs],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][self.num_encs],
|
||||
)
|
||||
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
|
||||
z_list, c_list = self.rnn_forward(
|
||||
ey, z_list, c_list, state["z_prev"], state["c_prev"]
|
||||
)
|
||||
if self.context_residual:
|
||||
logits = self.output(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
)
|
||||
else:
|
||||
logits = self.output(self.dropout_dec[-1](z_list[-1]))
|
||||
logp = F.log_softmax(logits, dim=1).squeeze(0)
|
||||
return (
|
||||
logp,
|
||||
dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=att_w,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
),
|
||||
)
|
||||
258
funasr_local/models/decoder/rnnt_decoder.py
Normal file
258
funasr_local/models/decoder/rnnt_decoder.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""RNN decoder definition for Transducer models."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.beam_search.beam_search_transducer import Hypothesis
|
||||
from funasr_local.models.specaug.specaug import SpecAug
|
||||
|
||||
class RNNTDecoder(torch.nn.Module):
|
||||
"""RNN decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size.
|
||||
embed_size: Embedding size.
|
||||
hidden_size: Hidden size..
|
||||
rnn_type: Decoder layers type.
|
||||
num_layers: Number of decoder layers.
|
||||
dropout_rate: Dropout rate for decoder layers.
|
||||
embed_dropout_rate: Dropout rate for embedding layer.
|
||||
embed_pad: Embedding padding symbol ID.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embed_size: int = 256,
|
||||
hidden_size: int = 256,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
dropout_rate: float = 0.0,
|
||||
embed_dropout_rate: float = 0.0,
|
||||
embed_pad: int = 0,
|
||||
) -> None:
|
||||
"""Construct a RNNDecoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
if rnn_type not in ("lstm", "gru"):
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
|
||||
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
|
||||
|
||||
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
|
||||
|
||||
self.rnn = torch.nn.ModuleList(
|
||||
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
|
||||
)
|
||||
|
||||
for _ in range(1, num_layers):
|
||||
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
|
||||
|
||||
self.dropout_rnn = torch.nn.ModuleList(
|
||||
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.dlayers = num_layers
|
||||
self.dtype = rnn_type
|
||||
|
||||
self.output_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.device = next(self.parameters()).device
|
||||
self.score_cache = {}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
labels: torch.Tensor,
|
||||
label_lens: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
states: Decoder hidden states.
|
||||
((N, B, D_dec), (N, B, D_dec) or None) or None
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, U, D_dec)
|
||||
|
||||
"""
|
||||
if states is None:
|
||||
states = self.init_state(labels.size(0))
|
||||
|
||||
dec_embed = self.dropout_embed(self.embed(labels))
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
return dec_out
|
||||
|
||||
def rnn_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
x: RNN input sequences. (B, D_emb)
|
||||
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
x: RNN output sequences. (B, D_dec)
|
||||
(h_next, c_next): Decoder hidden states.
|
||||
(N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_prev, c_prev = state
|
||||
h_next, c_next = self.init_state(x.size(0))
|
||||
|
||||
for layer in range(self.dlayers):
|
||||
if self.dtype == "lstm":
|
||||
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
|
||||
layer
|
||||
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
|
||||
else:
|
||||
x, h_next[layer : layer + 1] = self.rnn[layer](
|
||||
x, hx=h_prev[layer : layer + 1]
|
||||
)
|
||||
|
||||
x = self.dropout_rnn[layer](x)
|
||||
|
||||
return x, (h_next, c_next)
|
||||
|
||||
def score(
|
||||
self,
|
||||
label: torch.Tensor,
|
||||
label_sequence: List[int],
|
||||
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypothesis.
|
||||
|
||||
Args:
|
||||
label: Previous label. (1, 1)
|
||||
label_sequence: Current label sequence.
|
||||
dec_state: Previous decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequence. (1, D_dec)
|
||||
dec_state: Decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
str_labels = "_".join(map(str, label_sequence))
|
||||
|
||||
if str_labels in self.score_cache:
|
||||
dec_out, dec_state = self.score_cache[str_labels]
|
||||
else:
|
||||
dec_embed = self.embed(label)
|
||||
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
|
||||
|
||||
self.score_cache[str_labels] = (dec_out, dec_state)
|
||||
|
||||
return dec_out[0], dec_state
|
||||
|
||||
def batch_score(
|
||||
self,
|
||||
hyps: List[Hypothesis],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypotheses.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, D_dec)
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
|
||||
dec_embed = self.embed(labels)
|
||||
|
||||
states = self.create_batch_states([h.dec_state for h in hyps])
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
|
||||
return dec_out.squeeze(1), states
|
||||
|
||||
def set_device(self, device: torch.device) -> None:
|
||||
"""Set GPU device to use.
|
||||
|
||||
Args:
|
||||
device: Device ID.
|
||||
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
def init_state(
|
||||
self, batch_size: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
|
||||
"""Initialize decoder states.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
|
||||
Returns:
|
||||
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.dtype == "lstm":
|
||||
c_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
return (h_n, c_n)
|
||||
|
||||
return (h_n, None)
|
||||
|
||||
def select_state(
|
||||
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Get specified ID state from decoder hidden states.
|
||||
|
||||
Args:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
idx: State ID to extract.
|
||||
|
||||
Returns:
|
||||
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
states[0][:, idx : idx + 1, :],
|
||||
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
|
||||
)
|
||||
|
||||
def create_batch_states(
|
||||
self,
|
||||
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Create decoder hidden states.
|
||||
|
||||
Args:
|
||||
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
|
||||
|
||||
Returns:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
torch.cat([s[0] for s in new_states], dim=1),
|
||||
torch.cat([s[1] for s in new_states], dim=1)
|
||||
if self.dtype == "lstm"
|
||||
else None,
|
||||
)
|
||||
1484
funasr_local/models/decoder/sanm_decoder.py
Normal file
1484
funasr_local/models/decoder/sanm_decoder.py
Normal file
File diff suppressed because it is too large
Load Diff
37
funasr_local/models/decoder/sv_decoder.py
Normal file
37
funasr_local/models/decoder/sv_decoder.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
|
||||
|
||||
class DenseDecoder(AbsDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
encoder_output_size,
|
||||
num_nodes_resnet1: int = 256,
|
||||
num_nodes_last_layer: int = 256,
|
||||
batchnorm_momentum: float = 0.5,
|
||||
):
|
||||
super(DenseDecoder, self).__init__()
|
||||
self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1)
|
||||
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
|
||||
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.output_dense = torch.nn.Linear(num_nodes_last_layer, vocab_size, bias=False)
|
||||
|
||||
def forward(self, features):
|
||||
embeddings = {}
|
||||
features = self.resnet1_dense(features)
|
||||
embeddings["resnet1_dense"] = features
|
||||
features = F.relu(features)
|
||||
features = self.resnet1_bn(features)
|
||||
|
||||
features = self.resnet2_dense(features)
|
||||
embeddings["resnet2_dense"] = features
|
||||
features = F.relu(features)
|
||||
features = self.resnet2_bn(features)
|
||||
|
||||
features = self.output_dense(features)
|
||||
return features, embeddings
|
||||
766
funasr_local/models/decoder/transformer_decoder.py
Normal file
766
funasr_local/models/decoder/transformer_decoder.py
Normal file
@@ -0,0 +1,766 @@
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Decoder definition."""
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.modules.attention import MultiHeadedAttention
|
||||
from funasr_local.modules.dynamic_conv import DynamicConvolution
|
||||
from funasr_local.modules.dynamic_conv2d import DynamicConvolution2D
|
||||
from funasr_local.modules.embedding import PositionalEncoding
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from funasr_local.modules.lightconv import LightweightConvolution
|
||||
from funasr_local.modules.lightconv2d import LightweightConvolution2D
|
||||
from funasr_local.modules.mask import subsequent_mask
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr_local.modules.repeat import repeat
|
||||
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""Single decoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
src_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(DecoderLayer, self).__init__()
|
||||
self.size = size
|
||||
self.self_attn = self_attn
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
|
||||
"""Compute decoded features.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
|
||||
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
|
||||
cache (List[torch.Tensor]): List of cached tensors.
|
||||
Each tensor shape should be (#batch, maxlen_out - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
||||
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
||||
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
||||
|
||||
"""
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
# compute only the last frame query keeping dim: max_time_out -> 1
|
||||
assert cache.shape == (
|
||||
tgt.shape[0],
|
||||
tgt.shape[1] - 1,
|
||||
self.size,
|
||||
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
tgt_concat = torch.cat(
|
||||
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear1(tgt_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, memory, memory_mask
|
||||
|
||||
|
||||
class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
|
||||
"""Base class of Transfomer decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
encoder_output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the number of units of position-wise feed forward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
self_attention_dropout_rate: dropout rate for attention
|
||||
input_layer: input layer type
|
||||
use_output_layer: whether to use output layer
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before: whether to use layer_norm before the first block
|
||||
concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied.
|
||||
i.e. x -> x + att(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
if input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(vocab_size, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
if use_output_layer:
|
||||
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.output_layer = None
|
||||
|
||||
# Must set by the inheritance
|
||||
self.decoders = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask & m
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
|
||||
memory.device
|
||||
)
|
||||
# Padding for Longformer
|
||||
if memory_mask.shape[-1] != memory.shape[1]:
|
||||
padlen = memory.shape[1] - memory_mask.shape[-1]
|
||||
memory_mask = torch.nn.functional.pad(
|
||||
memory_mask, (0, padlen), "constant", False
|
||||
)
|
||||
|
||||
x = self.embed(tgt)
|
||||
x, tgt_mask, memory, memory_mask = self.decoders(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, olens
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
cache: List[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward one step.
|
||||
|
||||
Args:
|
||||
tgt: input token ids, int64 (batch, maxlen_out)
|
||||
tgt_mask: input token mask, (batch, maxlen_out)
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
cache: cached output list of (batch, max_time_out-1, size)
|
||||
Returns:
|
||||
y, cache: NN output value and cache per `self.decoders`.
|
||||
y.shape` is (batch, maxlen_out, token)
|
||||
"""
|
||||
x = self.embed(tgt)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.decoders)
|
||||
new_cache = []
|
||||
for c, decoder in zip(cache, self.decoders):
|
||||
x, tgt_mask, memory, memory_mask = decoder(
|
||||
x, tgt_mask, memory, None, cache=c
|
||||
)
|
||||
new_cache.append(x)
|
||||
|
||||
if self.normalize_before:
|
||||
y = self.after_norm(x[:, -1])
|
||||
else:
|
||||
y = x[:, -1]
|
||||
if self.output_layer is not None:
|
||||
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
||||
|
||||
return y, new_cache
|
||||
|
||||
def score(self, ys, state, x):
|
||||
"""Score."""
|
||||
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
||||
logp, state = self.forward_one_step(
|
||||
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
|
||||
)
|
||||
return logp.squeeze(0), state
|
||||
|
||||
def batch_score(
|
||||
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.decoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [
|
||||
torch.stack([states[b][i] for b in range(n_batch)])
|
||||
for i in range(n_layers)
|
||||
]
|
||||
|
||||
# batch decoding
|
||||
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
|
||||
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
|
||||
|
||||
class TransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ParaformerDecoderSAN(BaseTransformerDecoder):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
embeds_id: int = -1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
self.embeds_id = embeds_id
|
||||
self.attention_dim = attention_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
|
||||
memory.device
|
||||
)
|
||||
# Padding for Longformer
|
||||
if memory_mask.shape[-1] != memory.shape[1]:
|
||||
padlen = memory.shape[1] - memory_mask.shape[-1]
|
||||
memory_mask = torch.nn.functional.pad(
|
||||
memory_mask, (0, padlen), "constant", False
|
||||
)
|
||||
|
||||
# x = self.embed(tgt)
|
||||
x = tgt
|
||||
embeds_outputs = None
|
||||
for layer_id, decoder in enumerate(self.decoders):
|
||||
x, tgt_mask, memory, memory_mask = decoder(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
if layer_id == self.embeds_id:
|
||||
embeds_outputs = x
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
if embeds_outputs is not None:
|
||||
return x, olens, embeds_outputs
|
||||
else:
|
||||
return x, olens
|
||||
|
||||
|
||||
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
LightweightConvolution(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
LightweightConvolution2D(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
DynamicConvolution(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
DynamicConvolution2D(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
458
funasr_local/models/e2e_asr.py
Normal file
458
funasr_local/models/e2e_asr.py
Normal file
@@ -0,0 +1,458 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.layers.abs_normalize import AbsNormalize
|
||||
from funasr_local.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr_local.models.ctc import CTC
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr_local.modules.add_sos_eos import add_sos_eos
|
||||
from funasr_local.modules.e2e_asr_common import ErrorCalculator
|
||||
from funasr_local.modules.nets_utils import th_accuracy
|
||||
from funasr_local.torch_utils.device_funcs import force_gatherable
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class ESPnetASRModel(AbsESPnetModel):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
ctc_weight: float = 0.5,
|
||||
interctc_weight: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
self.blank_id = 0
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.ctc_weight = ctc_weight
|
||||
self.interctc_weight = interctc_weight
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.postencoder = postencoder
|
||||
self.encoder = encoder
|
||||
|
||||
if not hasattr(self.encoder, "interctc_use_conditioning"):
|
||||
self.encoder.interctc_use_conditioning = False
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
self.encoder.conditioning_layer = torch.nn.Linear(
|
||||
vocab_size, self.encoder.output_size()
|
||||
)
|
||||
|
||||
self.error_calculator = None
|
||||
|
||||
|
||||
# we set self.decoder = None in the CTC mode since
|
||||
# self.decoder parameters were never used and PyTorch complained
|
||||
# and threw an Exception in the multi-GPU experiment.
|
||||
# thanks Jeff Farris for pointing out the issue.
|
||||
if ctc_weight == 1.0:
|
||||
self.decoder = None
|
||||
else:
|
||||
self.decoder = decoder
|
||||
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
if report_cer or report_wer:
|
||||
self.error_calculator = ErrorCalculator(
|
||||
token_list, sym_space, sym_blank, report_cer, report_wer
|
||||
)
|
||||
|
||||
if ctc_weight == 0.0:
|
||||
self.ctc = None
|
||||
else:
|
||||
self.ctc = ctc
|
||||
|
||||
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
stats = dict()
|
||||
|
||||
# 1. CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
intermediate_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
stats["acc"] = acc_att
|
||||
stats["cer"] = cer_att
|
||||
stats["wer"] = wer_att
|
||||
|
||||
# Collect total loss stats
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(
|
||||
feats, feats_lengths, ctc=self.ctc
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# Post-encoder, e.g. NLU
|
||||
if self.postencoder is not None:
|
||||
encoder_out, encoder_out_lens = self.postencoder(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def nll(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
|
||||
Args:
|
||||
encoder_out: (Batch, Length, Dim)
|
||||
encoder_out_lens: (Batch,)
|
||||
ys_pad: (Batch, Length)
|
||||
ys_pad_lens: (Batch,)
|
||||
"""
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
) # [batch, seqlen, dim]
|
||||
batch_size = decoder_out.size(0)
|
||||
decoder_num_class = decoder_out.size(2)
|
||||
# nll: negative log-likelihood
|
||||
nll = torch.nn.functional.cross_entropy(
|
||||
decoder_out.view(-1, decoder_num_class),
|
||||
ys_out_pad.view(-1),
|
||||
ignore_index=self.ignore_id,
|
||||
reduction="none",
|
||||
)
|
||||
nll = nll.view(batch_size, -1)
|
||||
nll = nll.sum(dim=1)
|
||||
assert nll.size(0) == batch_size
|
||||
return nll
|
||||
|
||||
def batchify_nll(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
batch_size: int = 100,
|
||||
):
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
encoder_out: (Batch, Length, Dim)
|
||||
encoder_out_lens: (Batch,)
|
||||
ys_pad: (Batch, Length)
|
||||
ys_pad_lens: (Batch,)
|
||||
batch_size: int, samples each batch contain when computing nll,
|
||||
you may change this to avoid OOM or increase
|
||||
GPU memory usage
|
||||
"""
|
||||
total_num = encoder_out.size(0)
|
||||
if total_num <= batch_size:
|
||||
nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
else:
|
||||
nll = []
|
||||
start_idx = 0
|
||||
while True:
|
||||
end_idx = min(start_idx + batch_size, total_num)
|
||||
batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
|
||||
batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
|
||||
batch_ys_pad = ys_pad[start_idx:end_idx, :]
|
||||
batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
|
||||
batch_nll = self.nll(
|
||||
batch_encoder_out,
|
||||
batch_encoder_out_lens,
|
||||
batch_ys_pad,
|
||||
batch_ys_pad_lens,
|
||||
)
|
||||
nll.append(batch_nll)
|
||||
start_idx = end_idx
|
||||
if start_idx == total_num:
|
||||
break
|
||||
nll = torch.cat(nll)
|
||||
assert nll.size(0) == total_num
|
||||
return nll
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
249
funasr_local/models/e2e_asr_common.py
Normal file
249
funasr_local/models/e2e_asr_common.py
Normal file
@@ -0,0 +1,249 @@
|
||||
#!/usr/bin/env python3
|
||||
# encoding: utf-8
|
||||
|
||||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Common functions for ASR."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from itertools import groupby
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
|
||||
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||
"""End detection.
|
||||
|
||||
described in Eq. (50) of S. Watanabe et al
|
||||
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||||
|
||||
:param ended_hyps:
|
||||
:param i:
|
||||
:param M:
|
||||
:param D_end:
|
||||
:return:
|
||||
"""
|
||||
if len(ended_hyps) == 0:
|
||||
return False
|
||||
count = 0
|
||||
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
|
||||
for m in six.moves.range(M):
|
||||
# get ended_hyps with their length is i - m
|
||||
hyp_length = i - m
|
||||
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
|
||||
if len(hyps_same_length) > 0:
|
||||
best_hyp_same_length = sorted(
|
||||
hyps_same_length, key=lambda x: x["score"], reverse=True
|
||||
)[0]
|
||||
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
|
||||
count += 1
|
||||
|
||||
if count == M:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# TODO(takaaki-hori): add different smoothing methods
|
||||
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||||
"""Obtain label distribution for loss smoothing.
|
||||
|
||||
:param odim:
|
||||
:param lsm_type:
|
||||
:param blank:
|
||||
:param transcript:
|
||||
:return:
|
||||
"""
|
||||
if transcript is not None:
|
||||
with open(transcript, "rb") as f:
|
||||
trans_json = json.load(f)["utts"]
|
||||
|
||||
if lsm_type == "unigram":
|
||||
assert transcript is not None, (
|
||||
"transcript is required for %s label smoothing" % lsm_type
|
||||
)
|
||||
labelcount = np.zeros(odim)
|
||||
for k, v in trans_json.items():
|
||||
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
|
||||
# to avoid an error when there is no text in an uttrance
|
||||
if len(ids) > 0:
|
||||
labelcount[ids] += 1
|
||||
labelcount[odim - 1] = len(transcript) # count <eos>
|
||||
labelcount[labelcount == 0] = 1 # flooring
|
||||
labelcount[blank] = 0 # remove counts for blank
|
||||
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||||
else:
|
||||
logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
|
||||
sys.exit()
|
||||
|
||||
return labeldist
|
||||
|
||||
|
||||
def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
|
||||
"""Return the output size of the VGG frontend.
|
||||
|
||||
:param in_channel: input channel size
|
||||
:param out_channel: output channel size
|
||||
:return: output size
|
||||
:rtype int
|
||||
"""
|
||||
idim = idim / in_channel
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
|
||||
return int(idim) * out_channel # numer of channels
|
||||
|
||||
|
||||
class ErrorCalculator(object):
|
||||
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
||||
|
||||
:param y_hats: numpy array with predicted text
|
||||
:param y_pads: numpy array with true (target) text
|
||||
:param char_list:
|
||||
:param sym_space:
|
||||
:param sym_blank:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
|
||||
):
|
||||
"""Construct an ErrorCalculator object."""
|
||||
super(ErrorCalculator, self).__init__()
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
self.char_list = char_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
self.idx_blank = self.char_list.index(self.blank)
|
||||
if self.space in self.char_list:
|
||||
self.idx_space = self.char_list.index(self.space)
|
||||
else:
|
||||
self.idx_space = None
|
||||
|
||||
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
||||
"""Calculate sentence-level WER/CER score.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:param bool is_ctc: calculate CER score for CTC
|
||||
:return: sentence-level WER score
|
||||
:rtype float
|
||||
:return: sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cer, wer = None, None
|
||||
if is_ctc:
|
||||
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
||||
elif not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
return cer, wer
|
||||
|
||||
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
||||
"""Calculate sentence-level CER score for CTC.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
cers, char_ref_lens = [], []
|
||||
for i, y in enumerate(ys_hat):
|
||||
y_hat = [x[0] for x in groupby(y)]
|
||||
y_true = ys_pad[i]
|
||||
seq_hat, seq_true = [], []
|
||||
for idx in y_hat:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_hat.append(self.char_list[int(idx)])
|
||||
|
||||
for idx in y_true:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_true.append(self.char_list[int(idx)])
|
||||
|
||||
hyp_chars = "".join(seq_hat)
|
||||
ref_chars = "".join(seq_true)
|
||||
if len(ref_chars) > 0:
|
||||
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
||||
return cer_ctc
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
||||
:return: token list of prediction
|
||||
:rtype list
|
||||
:return: token list of reference
|
||||
:rtype list
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
# NOTE: padding index (-1) in y_true is used to pad y_hat
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, "")
|
||||
seq_true_text = "".join(seq_true).replace(self.space, " ")
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
char_eds, char_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(" ", "")
|
||||
ref_chars = seq_true_text.replace(" ", "")
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level WER score
|
||||
:rtype float
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
word_eds, word_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
326
funasr_local/models/e2e_asr_mfcca.py
Normal file
326
funasr_local/models/e2e_asr_mfcca.py
Normal file
@@ -0,0 +1,326 @@
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.e2e_asr_common import ErrorCalculator
|
||||
from funasr_local.modules.nets_utils import th_accuracy
|
||||
from funasr_local.modules.add_sos_eos import add_sos_eos
|
||||
from funasr_local.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr_local.models.ctc import CTC
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr_local.layers.abs_normalize import AbsNormalize
|
||||
from funasr_local.torch_utils.device_funcs import force_gatherable
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
import pdb
|
||||
import random
|
||||
import math
|
||||
class MFCCA(AbsESPnetModel):
|
||||
"""
|
||||
Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
|
||||
MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
|
||||
https://arxiv.org/abs/2210.05265
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
rnnt_decoder: None,
|
||||
ctc_weight: float = 0.5,
|
||||
ignore_id: int = -1,
|
||||
lsm_weight: float = 0.0,
|
||||
mask_ratio: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert rnnt_decoder is None, "Not implemented"
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.ctc_weight = ctc_weight
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.mask_ratio = mask_ratio
|
||||
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.encoder = encoder
|
||||
# we set self.decoder = None in the CTC mode since
|
||||
# self.decoder parameters were never used and PyTorch complained
|
||||
# and threw an Exception in the multi-GPU experiment.
|
||||
# thanks Jeff Farris for pointing out the issue.
|
||||
if ctc_weight == 1.0:
|
||||
self.decoder = None
|
||||
else:
|
||||
self.decoder = decoder
|
||||
if ctc_weight == 0.0:
|
||||
self.ctc = None
|
||||
else:
|
||||
self.ctc = ctc
|
||||
self.rnnt_decoder = rnnt_decoder
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
if report_cer or report_wer:
|
||||
self.error_calculator = ErrorCalculator(
|
||||
token_list, sym_space, sym_blank, report_cer, report_wer
|
||||
)
|
||||
else:
|
||||
self.error_calculator = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
#pdb.set_trace()
|
||||
if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
|
||||
rate_num = random.random()
|
||||
#rate_num = 0.1
|
||||
if(rate_num<=self.mask_ratio):
|
||||
retain_channel = math.ceil(random.random() *8)
|
||||
if(retain_channel>1):
|
||||
speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
|
||||
else:
|
||||
speech = speech[:,:,torch.randperm(8)[0]]
|
||||
#pdb.set_trace()
|
||||
batch_size = speech.shape[0]
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# 2a. Attention-decoder branch
|
||||
if self.ctc_weight == 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
else:
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 2b. CTC branch
|
||||
if self.ctc_weight == 0.0:
|
||||
loss_ctc, cer_ctc = None, None
|
||||
else:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 2c. RNN-T branch
|
||||
if self.rnnt_decoder is not None:
|
||||
_ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
|
||||
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_att=loss_att.detach() if loss_att is not None else None,
|
||||
loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
|
||||
acc=acc_att,
|
||||
cer=cer_att,
|
||||
wer=wer_att,
|
||||
cer_ctc=cer_ctc,
|
||||
)
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
#pdb.set_trace()
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
if(encoder_out.dim()==4):
|
||||
assert encoder_out.size(2) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
else:
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
channel_size = 1
|
||||
return feats, feats_lengths, channel_size
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
if(encoder_out.dim()==4):
|
||||
encoder_out = encoder_out.mean(1)
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
|
||||
def _calc_rnnt_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError
|
||||
1764
funasr_local/models/e2e_asr_paraformer.py
Normal file
1764
funasr_local/models/e2e_asr_paraformer.py
Normal file
File diff suppressed because it is too large
Load Diff
1015
funasr_local/models/e2e_asr_transducer.py
Normal file
1015
funasr_local/models/e2e_asr_transducer.py
Normal file
File diff suppressed because it is too large
Load Diff
253
funasr_local/models/e2e_diar_eend_ola.py
Normal file
253
funasr_local/models/e2e_diar_eend_ola.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.models.frontend.wav_frontend import WavFrontendMel23
|
||||
from funasr_local.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
||||
from funasr_local.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
||||
from funasr_local.modules.eend_ola.utils.power import generate_mapping_dict
|
||||
from funasr_local.torch_utils.device_funcs import force_gatherable
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
pass
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
def pad_attractor(att, max_n_speakers):
|
||||
C, D = att.shape
|
||||
if C < max_n_speakers:
|
||||
att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
|
||||
return att
|
||||
|
||||
|
||||
class DiarEENDOLAModel(AbsESPnetModel):
|
||||
"""EEND-OLA diarization model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: WavFrontendMel23,
|
||||
encoder: EENDOLATransformerEncoder,
|
||||
encoder_decoder_attractor: EncoderDecoderAttractor,
|
||||
n_units: int = 256,
|
||||
max_n_speaker: int = 8,
|
||||
attractor_loss_weight: float = 1.0,
|
||||
mapping_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
self.frontend = frontend
|
||||
self.enc = encoder
|
||||
self.eda = encoder_decoder_attractor
|
||||
self.attractor_loss_weight = attractor_loss_weight
|
||||
self.max_n_speaker = max_n_speaker
|
||||
if mapping_dict is None:
|
||||
mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
|
||||
self.mapping_dict = mapping_dict
|
||||
# PostNet
|
||||
self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
|
||||
self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
|
||||
|
||||
def forward_encoder(self, xs, ilens):
|
||||
xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
|
||||
pad_shape = xs.shape
|
||||
xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
|
||||
xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
|
||||
emb = self.enc(xs, xs_mask)
|
||||
emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
|
||||
emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
|
||||
return emb
|
||||
|
||||
def forward_post_net(self, logits, ilens):
|
||||
maxlen = torch.max(ilens).to(torch.int).item()
|
||||
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
|
||||
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
|
||||
outputs, (_, _) = self.postnet(logits)
|
||||
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
|
||||
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
|
||||
outputs = [self.output_layer(output) for output in outputs]
|
||||
return outputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
stats = dict()
|
||||
|
||||
# 1. CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
intermediate_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
stats["acc"] = acc_att
|
||||
stats["cer"] = cer_att
|
||||
stats["wer"] = wer_att
|
||||
|
||||
# Collect total loss stats
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def estimate_sequential(self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
n_speakers: int = None,
|
||||
shuffle: bool = True,
|
||||
threshold: float = 0.5,
|
||||
**kwargs):
|
||||
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
|
||||
emb = self.forward_encoder(speech, speech_lengths)
|
||||
if shuffle:
|
||||
orders = [np.arange(e.shape[0]) for e in emb]
|
||||
for order in orders:
|
||||
np.random.shuffle(order)
|
||||
attractors, probs = self.eda.estimate(
|
||||
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
|
||||
else:
|
||||
attractors, probs = self.eda.estimate(emb)
|
||||
attractors_active = []
|
||||
for p, att, e in zip(probs, attractors, emb):
|
||||
if n_speakers and n_speakers >= 0:
|
||||
att = att[:n_speakers, ]
|
||||
attractors_active.append(att)
|
||||
elif threshold is not None:
|
||||
silence = torch.nonzero(p < threshold)[0]
|
||||
n_spk = silence[0] if silence.size else None
|
||||
att = att[:n_spk, ]
|
||||
attractors_active.append(att)
|
||||
else:
|
||||
NotImplementedError('n_speakers or threshold has to be given.')
|
||||
raw_n_speakers = [att.shape[0] for att in attractors_active]
|
||||
attractors = [
|
||||
pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
|
||||
for att in attractors_active]
|
||||
ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
|
||||
logits = self.forward_post_net(ys, speech_lengths)
|
||||
ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
|
||||
zip(logits, raw_n_speakers)]
|
||||
|
||||
return ys, emb, attractors, raw_n_speakers
|
||||
|
||||
def recover_y_from_powerlabel(self, logit, n_speaker):
|
||||
pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
|
||||
oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
|
||||
for i in oov_index:
|
||||
if i > 0:
|
||||
pred[i] = pred[i - 1]
|
||||
else:
|
||||
pred[i] = 0
|
||||
pred = [self.inv_mapping_func(i) for i in pred]
|
||||
decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
|
||||
decisions = torch.from_numpy(
|
||||
np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
|
||||
torch.float32)
|
||||
decisions = decisions[:, :n_speaker]
|
||||
return decisions
|
||||
|
||||
def inv_mapping_func(self, label):
|
||||
|
||||
if not isinstance(label, int):
|
||||
label = int(label)
|
||||
if label in self.mapping_dict['label2dec'].keys():
|
||||
num = self.mapping_dict['label2dec'][label]
|
||||
else:
|
||||
num = -1
|
||||
return num
|
||||
|
||||
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
pass
|
||||
494
funasr_local/models/e2e_diar_sond.py
Normal file
494
funasr_local/models/e2e_diar_sond.py
Normal file
@@ -0,0 +1,494 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from itertools import permutations
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import to_device
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr_local.layers.abs_normalize import AbsNormalize
|
||||
from funasr_local.torch_utils.device_funcs import force_gatherable
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
from funasr_local.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
|
||||
from funasr_local.utils.misc import int2vec
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class DiarSondModel(AbsESPnetModel):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
|
||||
https://arxiv.org/abs/2211.10243
|
||||
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
|
||||
https://arxiv.org/abs/2303.05397
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder: torch.nn.Module,
|
||||
speaker_encoder: Optional[torch.nn.Module],
|
||||
ci_scorer: torch.nn.Module,
|
||||
cd_scorer: Optional[torch.nn.Module],
|
||||
decoder: torch.nn.Module,
|
||||
token_list: list,
|
||||
lsm_weight: float = 0.1,
|
||||
length_normalized_loss: bool = False,
|
||||
max_spk_num: int = 16,
|
||||
label_aggregator: Optional[torch.nn.Module] = None,
|
||||
normalize_speech_speaker: bool = False,
|
||||
ignore_id: int = -1,
|
||||
speaker_discrimination_loss_weight: float = 1.0,
|
||||
inter_score_loss_weight: float = 0.0,
|
||||
inputs_type: str = "raw",
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.encoder = encoder
|
||||
self.speaker_encoder = speaker_encoder
|
||||
self.ci_scorer = ci_scorer
|
||||
self.cd_scorer = cd_scorer
|
||||
self.normalize = normalize
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.label_aggregator = label_aggregator
|
||||
self.decoder = decoder
|
||||
self.token_list = token_list
|
||||
self.max_spk_num = max_spk_num
|
||||
self.normalize_speech_speaker = normalize_speech_speaker
|
||||
self.ignore_id = ignore_id
|
||||
self.criterion_diar = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
|
||||
self.pse_embedding = self.generate_pse_embedding()
|
||||
self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
|
||||
self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
|
||||
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
|
||||
self.inter_score_loss_weight = inter_score_loss_weight
|
||||
self.forward_steps = 0
|
||||
self.inputs_type = inputs_type
|
||||
|
||||
def generate_pse_embedding(self):
|
||||
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
|
||||
for idx, pse_label in enumerate(self.token_list):
|
||||
emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float)
|
||||
embedding[idx] = emb
|
||||
return torch.from_numpy(embedding)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor = None,
|
||||
profile: torch.Tensor = None,
|
||||
profile_lengths: torch.Tensor = None,
|
||||
binary_labels: torch.Tensor = None,
|
||||
binary_labels_lengths: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, samples) or (Batch, frames, input_size)
|
||||
speech_lengths: (Batch,) default None for chunk interator,
|
||||
because the chunk-iterator does not
|
||||
have the speech_lengths returned.
|
||||
see in
|
||||
espnet2/iterators/chunk_iter_factory.py
|
||||
profile: (Batch, N_spk, dim)
|
||||
profile_lengths: (Batch,)
|
||||
binary_labels: (Batch, frames, max_spk_num)
|
||||
binary_labels_lengths: (Batch,)
|
||||
"""
|
||||
assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
|
||||
batch_size = speech.shape[0]
|
||||
self.forward_steps = self.forward_steps + 1
|
||||
if self.pse_embedding.device != speech.device:
|
||||
self.pse_embedding = self.pse_embedding.to(speech.device)
|
||||
self.power_weight = self.power_weight.to(speech.device)
|
||||
self.int_token_arr = self.int_token_arr.to(speech.device)
|
||||
|
||||
# 1. Network forward
|
||||
pred, inter_outputs = self.prediction_forward(
|
||||
speech, speech_lengths,
|
||||
profile, profile_lengths,
|
||||
return_inter_outputs=True
|
||||
)
|
||||
(speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
|
||||
|
||||
# 2. Aggregate time-domain labels to match forward outputs
|
||||
if self.label_aggregator is not None:
|
||||
binary_labels, binary_labels_lengths = self.label_aggregator(
|
||||
binary_labels, binary_labels_lengths
|
||||
)
|
||||
# 2. Calculate power-set encoding (PSE) labels
|
||||
raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
|
||||
pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
|
||||
|
||||
# If encoder uses conv* as input_layer (i.e., subsampling),
|
||||
# the sequence length of 'pred' might be slightly less than the
|
||||
# length of 'spk_labels'. Here we force them to be equal.
|
||||
length_diff_tolerance = 2
|
||||
length_diff = abs(pse_labels.shape[1] - pred.shape[1])
|
||||
if length_diff <= length_diff_tolerance:
|
||||
min_len = min(pred.shape[1], pse_labels.shape[1])
|
||||
pse_labels = pse_labels[:, :min_len]
|
||||
pred = pred[:, :min_len]
|
||||
cd_score = cd_score[:, :min_len]
|
||||
ci_score = ci_score[:, :min_len]
|
||||
|
||||
loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
|
||||
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
|
||||
loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
|
||||
label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device)
|
||||
loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
|
||||
+ self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
|
||||
|
||||
(
|
||||
correct,
|
||||
num_frames,
|
||||
speech_scored,
|
||||
speech_miss,
|
||||
speech_falarm,
|
||||
speaker_scored,
|
||||
speaker_miss,
|
||||
speaker_falarm,
|
||||
speaker_error,
|
||||
) = self.calc_diarization_error(
|
||||
pred=F.embedding(pred.argmax(dim=2) * (~label_mask), self.pse_embedding),
|
||||
label=F.embedding(pse_labels * (~label_mask), self.pse_embedding),
|
||||
length=binary_labels_lengths
|
||||
)
|
||||
|
||||
if speech_scored > 0 and num_frames > 0:
|
||||
sad_mr, sad_fr, mi, fa, cf, acc, der = (
|
||||
speech_miss / speech_scored,
|
||||
speech_falarm / speech_scored,
|
||||
speaker_miss / speaker_scored,
|
||||
speaker_falarm / speaker_scored,
|
||||
speaker_error / speaker_scored,
|
||||
correct / num_frames,
|
||||
(speaker_miss + speaker_falarm + speaker_error) / speaker_scored,
|
||||
)
|
||||
else:
|
||||
sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_diar=loss_diar.detach() if loss_diar is not None else None,
|
||||
loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
|
||||
loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
|
||||
loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
|
||||
sad_mr=sad_mr,
|
||||
sad_fr=sad_fr,
|
||||
mi=mi,
|
||||
fa=fa,
|
||||
cf=cf,
|
||||
acc=acc,
|
||||
der=der,
|
||||
forward_steps=self.forward_steps,
|
||||
)
|
||||
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def classification_loss(
|
||||
self,
|
||||
predictions: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
prediction_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
mask = make_pad_mask(prediction_lengths, maxlen=labels.shape[1])
|
||||
pad_labels = labels.masked_fill(
|
||||
mask.to(predictions.device),
|
||||
value=self.ignore_id
|
||||
)
|
||||
loss = self.criterion_diar(predictions.contiguous(), pad_labels)
|
||||
|
||||
return loss
|
||||
|
||||
def speaker_discrimination_loss(
|
||||
self,
|
||||
profile: torch.Tensor,
|
||||
profile_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float() # (B, N, 1)
|
||||
mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2)) # (B, N, N)
|
||||
mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0).to(mask))
|
||||
|
||||
eps = 1e-12
|
||||
coding_norm = torch.linalg.norm(
|
||||
profile * profile_mask + (1 - profile_mask) * eps,
|
||||
dim=2, keepdim=True
|
||||
) * profile_mask
|
||||
# profile: Batch, N, dim
|
||||
cos_theta = F.cosine_similarity(profile.unsqueeze(2), profile.unsqueeze(1), dim=-1, eps=eps) * mask
|
||||
cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps)
|
||||
loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum()
|
||||
|
||||
return loss
|
||||
|
||||
def calculate_multi_labels(self, pse_labels, pse_labels_lengths):
|
||||
mask = make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1])
|
||||
padding_labels = pse_labels.masked_fill(
|
||||
mask.to(pse_labels.device),
|
||||
value=0
|
||||
).to(pse_labels)
|
||||
multi_labels = F.embedding(padding_labels, self.pse_embedding)
|
||||
|
||||
return multi_labels
|
||||
|
||||
def internal_score_loss(
|
||||
self,
|
||||
cd_score: torch.Tensor,
|
||||
ci_score: torch.Tensor,
|
||||
pse_labels: torch.Tensor,
|
||||
pse_labels_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths)
|
||||
ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths)
|
||||
cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths)
|
||||
return ci_loss, cd_loss
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
profile: torch.Tensor = None,
|
||||
profile_lengths: torch.Tensor = None,
|
||||
binary_labels: torch.Tensor = None,
|
||||
binary_labels_lengths: torch.Tensor = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode_speaker(
|
||||
self,
|
||||
profile: torch.Tensor,
|
||||
profile_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with autocast(False):
|
||||
if profile.shape[1] < self.max_spk_num:
|
||||
profile = F.pad(profile, [0, 0, 0, self.max_spk_num-profile.shape[1], 0, 0], "constant", 0.0)
|
||||
profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float()
|
||||
profile = F.normalize(profile, dim=2)
|
||||
if self.speaker_encoder is not None:
|
||||
profile = self.speaker_encoder(profile, profile_lengths)[0]
|
||||
return profile * profile_mask, profile_lengths
|
||||
else:
|
||||
return profile, profile_lengths
|
||||
|
||||
def encode_speech(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.encoder is not None and self.inputs_type == "raw":
|
||||
speech, speech_lengths = self.encode(speech, speech_lengths)
|
||||
speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
|
||||
speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
|
||||
return speech * speech_mask, speech_lengths
|
||||
else:
|
||||
return speech, speech_lengths
|
||||
|
||||
@staticmethod
|
||||
def concate_speech_ivc(
|
||||
speech: torch.Tensor,
|
||||
ivc: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
nn, tt = ivc.shape[1], speech.shape[1]
|
||||
speech = speech.unsqueeze(dim=1) # B x 1 x T x D
|
||||
speech = speech.expand(-1, nn, -1, -1) # B x N x T x D
|
||||
ivc = ivc.unsqueeze(dim=2) # B x N x 1 x D
|
||||
ivc = ivc.expand(-1, -1, tt, -1) # B x N x T x D
|
||||
sd_in = torch.cat([speech, ivc], dim=3) # B x N x T x 2D
|
||||
return sd_in
|
||||
|
||||
def calc_similarity(
|
||||
self,
|
||||
speech_encoder_outputs: torch.Tensor,
|
||||
speaker_encoder_outputs: torch.Tensor,
|
||||
seq_len: torch.Tensor = None,
|
||||
spk_len: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1]
|
||||
d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2]
|
||||
if self.normalize_speech_speaker:
|
||||
speech_encoder_outputs = F.normalize(speech_encoder_outputs, dim=2)
|
||||
speaker_encoder_outputs = F.normalize(speaker_encoder_outputs, dim=2)
|
||||
ge_in = self.concate_speech_ivc(speech_encoder_outputs, speaker_encoder_outputs)
|
||||
ge_in = torch.reshape(ge_in, [bb * self.max_spk_num, tt, d_sph + d_spk])
|
||||
ge_len = seq_len.unsqueeze(1).expand(-1, self.max_spk_num)
|
||||
ge_len = torch.reshape(ge_len, [bb * self.max_spk_num])
|
||||
cd_simi = self.cd_scorer(ge_in, ge_len)[0]
|
||||
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
|
||||
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
|
||||
|
||||
if isinstance(self.ci_scorer, AbsEncoder):
|
||||
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
|
||||
ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
|
||||
else:
|
||||
ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
|
||||
|
||||
return ci_simi, cd_simi
|
||||
|
||||
def post_net_forward(self, simi, seq_len):
|
||||
logits = self.decoder(simi, seq_len)[0]
|
||||
|
||||
return logits
|
||||
|
||||
def prediction_forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
profile_lengths: torch.Tensor,
|
||||
return_inter_outputs: bool = False,
|
||||
) -> [torch.Tensor, Optional[list]]:
|
||||
# speech encoding
|
||||
speech, speech_lengths = self.encode_speech(speech, speech_lengths)
|
||||
# speaker encoding
|
||||
profile, profile_lengths = self.encode_speaker(profile, profile_lengths)
|
||||
# calculating similarity
|
||||
ci_simi, cd_simi = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
|
||||
similarity = torch.cat([cd_simi, ci_simi], dim=2)
|
||||
# post net forward
|
||||
logits = self.post_net_forward(similarity, speech_lengths)
|
||||
|
||||
if return_inter_outputs:
|
||||
return logits, [(speech, speech_lengths), (profile, profile_lengths), (ci_simi, cd_simi)]
|
||||
return logits
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch,)
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim)
|
||||
encoder_outputs = self.encoder(feats, feats_lengths)
|
||||
encoder_out, encoder_out_lens = encoder_outputs[:2]
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = speech.shape[0]
|
||||
speech_lengths = (
|
||||
speech_lengths
|
||||
if speech_lengths is not None
|
||||
else torch.ones(batch_size).int() * speech.shape[1]
|
||||
)
|
||||
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
@staticmethod
|
||||
def calc_diarization_error(pred, label, length):
|
||||
# Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
|
||||
|
||||
(batch_size, max_len, num_output) = label.size()
|
||||
# mask the padding part
|
||||
mask = ~make_pad_mask(length, maxlen=label.shape[1]).unsqueeze(-1).numpy()
|
||||
|
||||
# pred and label have the shape (batch_size, max_len, num_output)
|
||||
label_np = label.data.cpu().numpy().astype(int)
|
||||
pred_np = (pred.data.cpu().numpy() > 0).astype(int)
|
||||
label_np = label_np * mask
|
||||
pred_np = pred_np * mask
|
||||
length = length.data.cpu().numpy()
|
||||
|
||||
# compute speech activity detection error
|
||||
n_ref = np.sum(label_np, axis=2)
|
||||
n_sys = np.sum(pred_np, axis=2)
|
||||
speech_scored = float(np.sum(n_ref > 0))
|
||||
speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0)))
|
||||
speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0)))
|
||||
|
||||
# compute speaker diarization error
|
||||
speaker_scored = float(np.sum(n_ref))
|
||||
speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0)))
|
||||
speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0)))
|
||||
n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2)
|
||||
speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map))
|
||||
correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output)
|
||||
num_frames = np.sum(length)
|
||||
return (
|
||||
correct,
|
||||
num_frames,
|
||||
speech_scored,
|
||||
speech_miss,
|
||||
speech_falarm,
|
||||
speaker_scored,
|
||||
speaker_miss,
|
||||
speaker_falarm,
|
||||
speaker_error,
|
||||
)
|
||||
274
funasr_local/models/e2e_sv.py
Normal file
274
funasr_local/models/e2e_sv.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.layers.abs_normalize import AbsNormalize
|
||||
from funasr_local.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr_local.models.ctc import CTC
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr_local.modules.add_sos_eos import add_sos_eos
|
||||
from funasr_local.modules.e2e_asr_common import ErrorCalculator
|
||||
from funasr_local.modules.nets_utils import th_accuracy
|
||||
from funasr_local.torch_utils.device_funcs import force_gatherable
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class ESPnetSVModel(AbsESPnetModel):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
pooling_layer: torch.nn.Module,
|
||||
decoder: AbsDecoder,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
self.vocab_size = vocab_size
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.postencoder = postencoder
|
||||
self.encoder = encoder
|
||||
self.pooling_layer = pooling_layer
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
loss_transducer, cer_transducer, wer_transducer = None, None, None
|
||||
stats = dict()
|
||||
|
||||
# 1. CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
intermediate_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
if self.use_transducer_decoder:
|
||||
# 2a. Transducer decoder branch
|
||||
(
|
||||
loss_transducer,
|
||||
cer_transducer,
|
||||
wer_transducer,
|
||||
) = self._calc_transducer_loss(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
text,
|
||||
)
|
||||
|
||||
if loss_ctc is not None:
|
||||
loss = loss_transducer + (self.ctc_weight * loss_ctc)
|
||||
else:
|
||||
loss = loss_transducer
|
||||
|
||||
# Collect Transducer branch stats
|
||||
stats["loss_transducer"] = (
|
||||
loss_transducer.detach() if loss_transducer is not None else None
|
||||
)
|
||||
stats["cer_transducer"] = cer_transducer
|
||||
stats["wer_transducer"] = wer_transducer
|
||||
|
||||
else:
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
stats["acc"] = acc_att
|
||||
stats["cer"] = cer_att
|
||||
stats["wer"] = wer_att
|
||||
|
||||
# Collect total loss stats
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim) -> (Batch, Channel, Length2, Dim2)
|
||||
encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
|
||||
|
||||
# Post-encoder, e.g. NLU
|
||||
if self.postencoder is not None:
|
||||
encoder_out, encoder_out_lens = self.postencoder(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
175
funasr_local/models/e2e_tp.py
Normal file
175
funasr_local/models/e2e_tp.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.models.predictor.cif import mae_loss
|
||||
from funasr_local.modules.add_sos_eos import add_sos_eos
|
||||
from funasr_local.modules.nets_utils import make_pad_mask, pad_list
|
||||
from funasr_local.torch_utils.device_funcs import force_gatherable
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
from funasr_local.models.predictor.cif import CifPredictorV3
|
||||
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class TimestampPredictor(AbsESPnetModel):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: Optional[AbsFrontend],
|
||||
encoder: AbsEncoder,
|
||||
predictor: CifPredictorV3,
|
||||
predictor_bias: int = 0,
|
||||
token_list=None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
|
||||
self.frontend = frontend
|
||||
self.encoder = encoder
|
||||
self.encoder.interctc_use_conditioning = False
|
||||
|
||||
self.predictor = predictor
|
||||
self.predictor_bias = predictor_bias
|
||||
self.criterion_pre = mae_loss()
|
||||
self.token_list = token_list
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
batch_size = speech.shape[0]
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
speech = speech[:, :speech_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
if self.predictor_bias == 1:
|
||||
_, text = add_sos_eos(text, 1, 2, -1)
|
||||
text_lengths = text_lengths + self.predictor_bias
|
||||
_, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
|
||||
|
||||
# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
|
||||
loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
|
||||
|
||||
loss = loss_pre
|
||||
stats = dict()
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
|
||||
encoder_out_mask,
|
||||
token_num)
|
||||
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
1075
funasr_local/models/e2e_uni_asr.py
Normal file
1075
funasr_local/models/e2e_uni_asr.py
Normal file
File diff suppressed because it is too large
Load Diff
665
funasr_local/models/e2e_vad.py
Normal file
665
funasr_local/models/e2e_vad.py
Normal file
@@ -0,0 +1,665 @@
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
from funasr_local.models.encoder.fsmn_encoder import FSMN
|
||||
|
||||
|
||||
class VadStateMachine(Enum):
|
||||
kVadInStateStartPointNotDetected = 1
|
||||
kVadInStateInSpeechSegment = 2
|
||||
kVadInStateEndPointDetected = 3
|
||||
|
||||
|
||||
class FrameState(Enum):
|
||||
kFrameStateInvalid = -1
|
||||
kFrameStateSpeech = 1
|
||||
kFrameStateSil = 0
|
||||
|
||||
|
||||
# final voice/unvoice state per frame
|
||||
class AudioChangeState(Enum):
|
||||
kChangeStateSpeech2Speech = 0
|
||||
kChangeStateSpeech2Sil = 1
|
||||
kChangeStateSil2Sil = 2
|
||||
kChangeStateSil2Speech = 3
|
||||
kChangeStateNoBegin = 4
|
||||
kChangeStateInvalid = 5
|
||||
|
||||
|
||||
class VadDetectMode(Enum):
|
||||
kVadSingleUtteranceDetectMode = 0
|
||||
kVadMutipleUtteranceDetectMode = 1
|
||||
|
||||
|
||||
class VADXOptions:
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
|
||||
snr_mode: int = 0,
|
||||
max_end_silence_time: int = 800,
|
||||
max_start_silence_time: int = 3000,
|
||||
do_start_point_detection: bool = True,
|
||||
do_end_point_detection: bool = True,
|
||||
window_size_ms: int = 200,
|
||||
sil_to_speech_time_thres: int = 150,
|
||||
speech_to_sil_time_thres: int = 150,
|
||||
speech_2_noise_ratio: float = 1.0,
|
||||
do_extend: int = 1,
|
||||
lookback_time_start_point: int = 200,
|
||||
lookahead_time_end_point: int = 100,
|
||||
max_single_segment_time: int = 60000,
|
||||
nn_eval_block_size: int = 8,
|
||||
dcd_block_size: int = 4,
|
||||
snr_thres: int = -100.0,
|
||||
noise_frame_num_used_for_snr: int = 100,
|
||||
decibel_thres: int = -100.0,
|
||||
speech_noise_thres: float = 0.6,
|
||||
fe_prior_thres: float = 1e-4,
|
||||
silence_pdf_num: int = 1,
|
||||
sil_pdf_ids: List[int] = [0],
|
||||
speech_noise_thresh_low: float = -0.1,
|
||||
speech_noise_thresh_high: float = 0.3,
|
||||
output_frame_probs: bool = False,
|
||||
frame_in_ms: int = 10,
|
||||
frame_length_ms: int = 25,
|
||||
):
|
||||
self.sample_rate = sample_rate
|
||||
self.detect_mode = detect_mode
|
||||
self.snr_mode = snr_mode
|
||||
self.max_end_silence_time = max_end_silence_time
|
||||
self.max_start_silence_time = max_start_silence_time
|
||||
self.do_start_point_detection = do_start_point_detection
|
||||
self.do_end_point_detection = do_end_point_detection
|
||||
self.window_size_ms = window_size_ms
|
||||
self.sil_to_speech_time_thres = sil_to_speech_time_thres
|
||||
self.speech_to_sil_time_thres = speech_to_sil_time_thres
|
||||
self.speech_2_noise_ratio = speech_2_noise_ratio
|
||||
self.do_extend = do_extend
|
||||
self.lookback_time_start_point = lookback_time_start_point
|
||||
self.lookahead_time_end_point = lookahead_time_end_point
|
||||
self.max_single_segment_time = max_single_segment_time
|
||||
self.nn_eval_block_size = nn_eval_block_size
|
||||
self.dcd_block_size = dcd_block_size
|
||||
self.snr_thres = snr_thres
|
||||
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
|
||||
self.decibel_thres = decibel_thres
|
||||
self.speech_noise_thres = speech_noise_thres
|
||||
self.fe_prior_thres = fe_prior_thres
|
||||
self.silence_pdf_num = silence_pdf_num
|
||||
self.sil_pdf_ids = sil_pdf_ids
|
||||
self.speech_noise_thresh_low = speech_noise_thresh_low
|
||||
self.speech_noise_thresh_high = speech_noise_thresh_high
|
||||
self.output_frame_probs = output_frame_probs
|
||||
self.frame_in_ms = frame_in_ms
|
||||
self.frame_length_ms = frame_length_ms
|
||||
|
||||
|
||||
class E2EVadSpeechBufWithDoa(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self):
|
||||
self.start_ms = 0
|
||||
self.end_ms = 0
|
||||
self.buffer = []
|
||||
self.contain_seg_start_point = False
|
||||
self.contain_seg_end_point = False
|
||||
self.doa = 0
|
||||
|
||||
def Reset(self):
|
||||
self.start_ms = 0
|
||||
self.end_ms = 0
|
||||
self.buffer = []
|
||||
self.contain_seg_start_point = False
|
||||
self.contain_seg_end_point = False
|
||||
self.doa = 0
|
||||
|
||||
|
||||
class E2EVadFrameProb(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self):
|
||||
self.noise_prob = 0.0
|
||||
self.speech_prob = 0.0
|
||||
self.score = 0.0
|
||||
self.frame_id = 0
|
||||
self.frm_state = 0
|
||||
|
||||
|
||||
class WindowDetector(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
|
||||
speech_to_sil_time: int, frame_size_ms: int):
|
||||
self.window_size_ms = window_size_ms
|
||||
self.sil_to_speech_time = sil_to_speech_time
|
||||
self.speech_to_sil_time = speech_to_sil_time
|
||||
self.frame_size_ms = frame_size_ms
|
||||
|
||||
self.win_size_frame = int(window_size_ms / frame_size_ms)
|
||||
self.win_sum = 0
|
||||
self.win_state = [0] * self.win_size_frame # 初始化窗
|
||||
|
||||
self.cur_win_pos = 0
|
||||
self.pre_frame_state = FrameState.kFrameStateSil
|
||||
self.cur_frame_state = FrameState.kFrameStateSil
|
||||
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
|
||||
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
|
||||
|
||||
self.voice_last_frame_count = 0
|
||||
self.noise_last_frame_count = 0
|
||||
self.hydre_frame_count = 0
|
||||
|
||||
def Reset(self) -> None:
|
||||
self.cur_win_pos = 0
|
||||
self.win_sum = 0
|
||||
self.win_state = [0] * self.win_size_frame
|
||||
self.pre_frame_state = FrameState.kFrameStateSil
|
||||
self.cur_frame_state = FrameState.kFrameStateSil
|
||||
self.voice_last_frame_count = 0
|
||||
self.noise_last_frame_count = 0
|
||||
self.hydre_frame_count = 0
|
||||
|
||||
def GetWinSize(self) -> int:
|
||||
return int(self.win_size_frame)
|
||||
|
||||
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
|
||||
cur_frame_state = FrameState.kFrameStateSil
|
||||
if frameState == FrameState.kFrameStateSpeech:
|
||||
cur_frame_state = 1
|
||||
elif frameState == FrameState.kFrameStateSil:
|
||||
cur_frame_state = 0
|
||||
else:
|
||||
return AudioChangeState.kChangeStateInvalid
|
||||
self.win_sum -= self.win_state[self.cur_win_pos]
|
||||
self.win_sum += cur_frame_state
|
||||
self.win_state[self.cur_win_pos] = cur_frame_state
|
||||
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
|
||||
|
||||
if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
|
||||
self.pre_frame_state = FrameState.kFrameStateSpeech
|
||||
return AudioChangeState.kChangeStateSil2Speech
|
||||
|
||||
if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
|
||||
self.pre_frame_state = FrameState.kFrameStateSil
|
||||
return AudioChangeState.kChangeStateSpeech2Sil
|
||||
|
||||
if self.pre_frame_state == FrameState.kFrameStateSil:
|
||||
return AudioChangeState.kChangeStateSil2Sil
|
||||
if self.pre_frame_state == FrameState.kFrameStateSpeech:
|
||||
return AudioChangeState.kChangeStateSpeech2Speech
|
||||
return AudioChangeState.kChangeStateInvalid
|
||||
|
||||
def FrameSizeMs(self) -> int:
|
||||
return int(self.frame_size_ms)
|
||||
|
||||
|
||||
class E2EVadModel(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
|
||||
super(E2EVadModel, self).__init__()
|
||||
self.vad_opts = VADXOptions(**vad_post_args)
|
||||
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
|
||||
self.vad_opts.sil_to_speech_time_thres,
|
||||
self.vad_opts.speech_to_sil_time_thres,
|
||||
self.vad_opts.frame_in_ms)
|
||||
self.encoder = encoder
|
||||
# init variables
|
||||
self.is_final = False
|
||||
self.data_buf_start_frame = 0
|
||||
self.frm_cnt = 0
|
||||
self.latest_confirmed_speech_frame = 0
|
||||
self.lastest_confirmed_silence_frame = -1
|
||||
self.continous_silence_frame_count = 0
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
|
||||
self.confirmed_start_frame = -1
|
||||
self.confirmed_end_frame = -1
|
||||
self.number_end_time_detected = 0
|
||||
self.sil_frame = 0
|
||||
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
|
||||
self.noise_average_decibel = -100.0
|
||||
self.pre_end_silence_detected = False
|
||||
self.next_seg = True
|
||||
|
||||
self.output_data_buf = []
|
||||
self.output_data_buf_offset = 0
|
||||
self.frame_probs = []
|
||||
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
|
||||
self.speech_noise_thres = self.vad_opts.speech_noise_thres
|
||||
self.scores = None
|
||||
self.max_time_out = False
|
||||
self.decibel = []
|
||||
self.data_buf = None
|
||||
self.data_buf_all = None
|
||||
self.waveform = None
|
||||
self.ResetDetection()
|
||||
self.frontend = frontend
|
||||
|
||||
def AllResetDetection(self):
|
||||
self.is_final = False
|
||||
self.data_buf_start_frame = 0
|
||||
self.frm_cnt = 0
|
||||
self.latest_confirmed_speech_frame = 0
|
||||
self.lastest_confirmed_silence_frame = -1
|
||||
self.continous_silence_frame_count = 0
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
|
||||
self.confirmed_start_frame = -1
|
||||
self.confirmed_end_frame = -1
|
||||
self.number_end_time_detected = 0
|
||||
self.sil_frame = 0
|
||||
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
|
||||
self.noise_average_decibel = -100.0
|
||||
self.pre_end_silence_detected = False
|
||||
self.next_seg = True
|
||||
|
||||
self.output_data_buf = []
|
||||
self.output_data_buf_offset = 0
|
||||
self.frame_probs = []
|
||||
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
|
||||
self.speech_noise_thres = self.vad_opts.speech_noise_thres
|
||||
self.scores = None
|
||||
self.max_time_out = False
|
||||
self.decibel = []
|
||||
self.data_buf = None
|
||||
self.data_buf_all = None
|
||||
self.waveform = None
|
||||
self.ResetDetection()
|
||||
|
||||
def ResetDetection(self):
|
||||
self.continous_silence_frame_count = 0
|
||||
self.latest_confirmed_speech_frame = 0
|
||||
self.lastest_confirmed_silence_frame = -1
|
||||
self.confirmed_start_frame = -1
|
||||
self.confirmed_end_frame = -1
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
|
||||
self.windows_detector.Reset()
|
||||
self.sil_frame = 0
|
||||
self.frame_probs = []
|
||||
|
||||
def ComputeDecibel(self) -> None:
|
||||
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
|
||||
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
|
||||
if self.data_buf_all is None:
|
||||
self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0]
|
||||
self.data_buf = self.data_buf_all
|
||||
else:
|
||||
self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0]))
|
||||
for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
|
||||
self.decibel.append(
|
||||
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
|
||||
0.000001))
|
||||
|
||||
def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
|
||||
scores = self.encoder(feats, in_cache).to('cpu') # return B * T * D
|
||||
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
|
||||
self.vad_opts.nn_eval_block_size = scores.shape[1]
|
||||
self.frm_cnt += scores.shape[1] # count total frames
|
||||
if self.scores is None:
|
||||
self.scores = scores # the first calculation
|
||||
else:
|
||||
self.scores = torch.cat((self.scores, scores), dim=1)
|
||||
|
||||
def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
|
||||
while self.data_buf_start_frame < frame_idx:
|
||||
if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
|
||||
self.data_buf_start_frame += 1
|
||||
self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
|
||||
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
|
||||
|
||||
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
|
||||
last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
|
||||
self.PopDataBufTillFrame(start_frm)
|
||||
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
|
||||
if last_frm_is_end_point:
|
||||
extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
|
||||
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
|
||||
expected_sample_number += int(extra_sample)
|
||||
if end_point_is_sent_end:
|
||||
expected_sample_number = max(expected_sample_number, len(self.data_buf))
|
||||
if len(self.data_buf) < expected_sample_number:
|
||||
print('error in calling pop data_buf\n')
|
||||
|
||||
if len(self.output_data_buf) == 0 or first_frm_is_start_point:
|
||||
self.output_data_buf.append(E2EVadSpeechBufWithDoa())
|
||||
self.output_data_buf[-1].Reset()
|
||||
self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
|
||||
self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
|
||||
self.output_data_buf[-1].doa = 0
|
||||
cur_seg = self.output_data_buf[-1]
|
||||
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
|
||||
print('warning\n')
|
||||
out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
|
||||
data_to_pop = 0
|
||||
if end_point_is_sent_end:
|
||||
data_to_pop = expected_sample_number
|
||||
else:
|
||||
data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
|
||||
if data_to_pop > len(self.data_buf):
|
||||
print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
|
||||
data_to_pop = len(self.data_buf)
|
||||
expected_sample_number = len(self.data_buf)
|
||||
|
||||
cur_seg.doa = 0
|
||||
for sample_cpy_out in range(0, data_to_pop):
|
||||
# cur_seg.buffer[out_pos ++] = data_buf_.back();
|
||||
out_pos += 1
|
||||
for sample_cpy_out in range(data_to_pop, expected_sample_number):
|
||||
# cur_seg.buffer[out_pos++] = data_buf_.back()
|
||||
out_pos += 1
|
||||
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
|
||||
print('Something wrong with the VAD algorithm\n')
|
||||
self.data_buf_start_frame += frm_cnt
|
||||
cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
|
||||
if first_frm_is_start_point:
|
||||
cur_seg.contain_seg_start_point = True
|
||||
if last_frm_is_end_point:
|
||||
cur_seg.contain_seg_end_point = True
|
||||
|
||||
def OnSilenceDetected(self, valid_frame: int):
|
||||
self.lastest_confirmed_silence_frame = valid_frame
|
||||
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
||||
self.PopDataBufTillFrame(valid_frame)
|
||||
# silence_detected_callback_
|
||||
# pass
|
||||
|
||||
def OnVoiceDetected(self, valid_frame: int) -> None:
|
||||
self.latest_confirmed_speech_frame = valid_frame
|
||||
self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
|
||||
|
||||
def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
|
||||
if self.vad_opts.do_start_point_detection:
|
||||
pass
|
||||
if self.confirmed_start_frame != -1:
|
||||
print('not reset vad properly\n')
|
||||
else:
|
||||
self.confirmed_start_frame = start_frame
|
||||
|
||||
if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
||||
self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
|
||||
|
||||
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
|
||||
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
|
||||
self.OnVoiceDetected(t)
|
||||
if self.vad_opts.do_end_point_detection:
|
||||
pass
|
||||
if self.confirmed_end_frame != -1:
|
||||
print('not reset vad properly\n')
|
||||
else:
|
||||
self.confirmed_end_frame = end_frame
|
||||
if not fake_result:
|
||||
self.sil_frame = 0
|
||||
self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
|
||||
self.number_end_time_detected += 1
|
||||
|
||||
def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
|
||||
if is_final_frame:
|
||||
self.OnVoiceEnd(cur_frm_idx, False, True)
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
||||
|
||||
def GetLatency(self) -> int:
|
||||
return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
|
||||
|
||||
def LatencyFrmNumAtStartPoint(self) -> int:
|
||||
vad_latency = self.windows_detector.GetWinSize()
|
||||
if self.vad_opts.do_extend:
|
||||
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
|
||||
return vad_latency
|
||||
|
||||
def GetFrameState(self, t: int) -> FrameState:
|
||||
frame_state = FrameState.kFrameStateInvalid
|
||||
cur_decibel = self.decibel[t]
|
||||
cur_snr = cur_decibel - self.noise_average_decibel
|
||||
# for each frame, calc log posterior probability of each state
|
||||
if cur_decibel < self.vad_opts.decibel_thres:
|
||||
frame_state = FrameState.kFrameStateSil
|
||||
self.DetectOneFrame(frame_state, t, False)
|
||||
return frame_state
|
||||
|
||||
sum_score = 0.0
|
||||
noise_prob = 0.0
|
||||
assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
|
||||
if len(self.sil_pdf_ids) > 0:
|
||||
assert len(self.scores) == 1 # 只支持batch_size = 1的测试
|
||||
sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
|
||||
sum_score = sum(sil_pdf_scores)
|
||||
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
|
||||
total_score = 1.0
|
||||
sum_score = total_score - sum_score
|
||||
speech_prob = math.log(sum_score)
|
||||
if self.vad_opts.output_frame_probs:
|
||||
frame_prob = E2EVadFrameProb()
|
||||
frame_prob.noise_prob = noise_prob
|
||||
frame_prob.speech_prob = speech_prob
|
||||
frame_prob.score = sum_score
|
||||
frame_prob.frame_id = t
|
||||
self.frame_probs.append(frame_prob)
|
||||
if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
|
||||
if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
|
||||
frame_state = FrameState.kFrameStateSpeech
|
||||
else:
|
||||
frame_state = FrameState.kFrameStateSil
|
||||
else:
|
||||
frame_state = FrameState.kFrameStateSil
|
||||
if self.noise_average_decibel < -99.9:
|
||||
self.noise_average_decibel = cur_decibel
|
||||
else:
|
||||
self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
|
||||
self.vad_opts.noise_frame_num_used_for_snr
|
||||
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
|
||||
|
||||
return frame_state
|
||||
|
||||
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||
is_final: bool = False
|
||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||
self.waveform = waveform # compute decibel for each frame
|
||||
self.ComputeDecibel()
|
||||
self.ComputeScores(feats, in_cache)
|
||||
if not is_final:
|
||||
self.DetectCommonFrames()
|
||||
else:
|
||||
self.DetectLastFrames()
|
||||
segments = []
|
||||
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
|
||||
segment_batch = []
|
||||
if len(self.output_data_buf) > 0:
|
||||
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
||||
if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
|
||||
i].contain_seg_end_point):
|
||||
continue
|
||||
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
|
||||
segment_batch.append(segment)
|
||||
self.output_data_buf_offset += 1 # need update this parameter
|
||||
if segment_batch:
|
||||
segments.append(segment_batch)
|
||||
if is_final:
|
||||
# reset class variables and clear the dict for the next query
|
||||
self.AllResetDetection()
|
||||
return segments, in_cache
|
||||
|
||||
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||
is_final: bool = False, max_end_sil: int = 800
|
||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
|
||||
self.waveform = waveform # compute decibel for each frame
|
||||
|
||||
self.ComputeScores(feats, in_cache)
|
||||
self.ComputeDecibel()
|
||||
if not is_final:
|
||||
self.DetectCommonFrames()
|
||||
else:
|
||||
self.DetectLastFrames()
|
||||
segments = []
|
||||
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
|
||||
segment_batch = []
|
||||
if len(self.output_data_buf) > 0:
|
||||
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
||||
if not self.output_data_buf[i].contain_seg_start_point:
|
||||
continue
|
||||
if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
|
||||
continue
|
||||
start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
|
||||
if self.output_data_buf[i].contain_seg_end_point:
|
||||
end_ms = self.output_data_buf[i].end_ms
|
||||
self.next_seg = True
|
||||
self.output_data_buf_offset += 1
|
||||
else:
|
||||
end_ms = -1
|
||||
self.next_seg = False
|
||||
segment = [start_ms, end_ms]
|
||||
segment_batch.append(segment)
|
||||
if segment_batch:
|
||||
segments.append(segment_batch)
|
||||
if is_final:
|
||||
# reset class variables and clear the dict for the next query
|
||||
self.AllResetDetection()
|
||||
return segments, in_cache
|
||||
|
||||
def DetectCommonFrames(self) -> int:
|
||||
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
||||
return 0
|
||||
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
||||
frame_state = FrameState.kFrameStateInvalid
|
||||
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
|
||||
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
|
||||
|
||||
return 0
|
||||
|
||||
def DetectLastFrames(self) -> int:
|
||||
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
||||
return 0
|
||||
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
||||
frame_state = FrameState.kFrameStateInvalid
|
||||
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
|
||||
if i != 0:
|
||||
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
|
||||
else:
|
||||
self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
|
||||
|
||||
return 0
|
||||
|
||||
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
|
||||
tmp_cur_frm_state = FrameState.kFrameStateInvalid
|
||||
if cur_frm_state == FrameState.kFrameStateSpeech:
|
||||
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
|
||||
tmp_cur_frm_state = FrameState.kFrameStateSpeech
|
||||
else:
|
||||
tmp_cur_frm_state = FrameState.kFrameStateSil
|
||||
elif cur_frm_state == FrameState.kFrameStateSil:
|
||||
tmp_cur_frm_state = FrameState.kFrameStateSil
|
||||
state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
|
||||
frm_shift_in_ms = self.vad_opts.frame_in_ms
|
||||
if AudioChangeState.kChangeStateSil2Speech == state_change:
|
||||
silence_frame_count = self.continous_silence_frame_count
|
||||
self.continous_silence_frame_count = 0
|
||||
self.pre_end_silence_detected = False
|
||||
start_frame = 0
|
||||
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
||||
start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
|
||||
self.OnVoiceStart(start_frame)
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
|
||||
for t in range(start_frame + 1, cur_frm_idx + 1):
|
||||
self.OnVoiceDetected(t)
|
||||
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
||||
for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
|
||||
self.OnVoiceDetected(t)
|
||||
if cur_frm_idx - self.confirmed_start_frame + 1 > \
|
||||
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
|
||||
self.OnVoiceEnd(cur_frm_idx, False, False)
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
||||
elif not is_final_frame:
|
||||
self.OnVoiceDetected(cur_frm_idx)
|
||||
else:
|
||||
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
|
||||
else:
|
||||
pass
|
||||
elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
|
||||
self.continous_silence_frame_count = 0
|
||||
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
||||
pass
|
||||
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
||||
if cur_frm_idx - self.confirmed_start_frame + 1 > \
|
||||
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
|
||||
self.OnVoiceEnd(cur_frm_idx, False, False)
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
||||
elif not is_final_frame:
|
||||
self.OnVoiceDetected(cur_frm_idx)
|
||||
else:
|
||||
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
|
||||
else:
|
||||
pass
|
||||
elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
|
||||
self.continous_silence_frame_count = 0
|
||||
if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
||||
if cur_frm_idx - self.confirmed_start_frame + 1 > \
|
||||
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
|
||||
self.max_time_out = True
|
||||
self.OnVoiceEnd(cur_frm_idx, False, False)
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
||||
elif not is_final_frame:
|
||||
self.OnVoiceDetected(cur_frm_idx)
|
||||
else:
|
||||
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
|
||||
else:
|
||||
pass
|
||||
elif AudioChangeState.kChangeStateSil2Sil == state_change:
|
||||
self.continous_silence_frame_count += 1
|
||||
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
||||
# silence timeout, return zero length decision
|
||||
if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
|
||||
self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
|
||||
or (is_final_frame and self.number_end_time_detected == 0):
|
||||
for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
|
||||
self.OnSilenceDetected(t)
|
||||
self.OnVoiceStart(0, True)
|
||||
self.OnVoiceEnd(0, True, False);
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
||||
else:
|
||||
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
|
||||
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
|
||||
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
||||
if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
|
||||
lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
|
||||
if self.vad_opts.do_extend:
|
||||
lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
|
||||
lookback_frame -= 1
|
||||
lookback_frame = max(0, lookback_frame)
|
||||
self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
||||
elif cur_frm_idx - self.confirmed_start_frame + 1 > \
|
||||
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
|
||||
self.OnVoiceEnd(cur_frm_idx, False, False)
|
||||
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
||||
elif self.vad_opts.do_extend and not is_final_frame:
|
||||
if self.continous_silence_frame_count <= int(
|
||||
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
|
||||
self.OnVoiceDetected(cur_frm_idx)
|
||||
else:
|
||||
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
|
||||
else:
|
||||
pass
|
||||
|
||||
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
|
||||
self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
|
||||
self.ResetDetection()
|
||||
0
funasr_local/models/encoder/__init__.py
Normal file
0
funasr_local/models/encoder/__init__.py
Normal file
21
funasr_local/models/encoder/abs_encoder.py
Normal file
21
funasr_local/models/encoder/abs_encoder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsEncoder(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def output_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
1238
funasr_local/models/encoder/conformer_encoder.py
Normal file
1238
funasr_local/models/encoder/conformer_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
577
funasr_local/models/encoder/data2vec_encoder.py
Normal file
577
funasr_local/models/encoder/data2vec_encoder.py
Normal file
@@ -0,0 +1,577 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr_local.modules.data2vec.data_utils import compute_mask_indices
|
||||
from funasr_local.modules.data2vec.ema_module import EMAModule
|
||||
from funasr_local.modules.data2vec.grad_multiply import GradMultiply
|
||||
from funasr_local.modules.data2vec.wav2vec2 import (
|
||||
ConvFeatureExtractionModel,
|
||||
TransformerEncoder,
|
||||
)
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
def get_annealed_rate(start, end, curr_step, total_steps):
|
||||
r = end - start
|
||||
pct_remaining = 1 - curr_step / total_steps
|
||||
return end - r * pct_remaining
|
||||
|
||||
|
||||
class Data2VecEncoder(AbsEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
# for ConvFeatureExtractionModel
|
||||
input_size: int = None,
|
||||
extractor_mode: str = None,
|
||||
conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
|
||||
# for Transformer Encoder
|
||||
## model architecture
|
||||
layer_type: str = "transformer",
|
||||
layer_norm_first: bool = False,
|
||||
encoder_layers: int = 12,
|
||||
encoder_embed_dim: int = 768,
|
||||
encoder_ffn_embed_dim: int = 3072,
|
||||
encoder_attention_heads: int = 12,
|
||||
activation_fn: str = "gelu",
|
||||
## dropouts
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.0,
|
||||
encoder_layerdrop: float = 0.0,
|
||||
dropout_input: float = 0.0,
|
||||
dropout_features: float = 0.0,
|
||||
## grad settings
|
||||
feature_grad_mult: float = 1.0,
|
||||
## masking
|
||||
mask_prob: float = 0.65,
|
||||
mask_length: int = 10,
|
||||
mask_selection: str = "static",
|
||||
mask_other: int = 0,
|
||||
no_mask_overlap: bool = False,
|
||||
mask_min_space: int = 1,
|
||||
require_same_masks: bool = True, # if set as True, collate_fn should be clipping
|
||||
mask_dropout: float = 0.0,
|
||||
## channel masking
|
||||
mask_channel_length: int = 10,
|
||||
mask_channel_prob: float = 0.0,
|
||||
mask_channel_before: bool = False,
|
||||
mask_channel_selection: str = "static",
|
||||
mask_channel_other: int = 0,
|
||||
no_mask_channel_overlap: bool = False,
|
||||
mask_channel_min_space: int = 1,
|
||||
## positional embeddings
|
||||
conv_pos: int = 128,
|
||||
conv_pos_groups: int = 16,
|
||||
pos_conv_depth: int = 1,
|
||||
max_positions: int = 100000,
|
||||
# EMA module
|
||||
average_top_k_layers: int = 8,
|
||||
layer_norm_target_layer: bool = False,
|
||||
instance_norm_target_layer: bool = False,
|
||||
instance_norm_targets: bool = False,
|
||||
layer_norm_targets: bool = False,
|
||||
batch_norm_target_layer: bool = False,
|
||||
group_norm_target_layer: bool = False,
|
||||
ema_decay: float = 0.999,
|
||||
ema_end_decay: float = 0.9999,
|
||||
ema_anneal_end_step: int = 100000,
|
||||
ema_transformer_only: bool = True,
|
||||
ema_layers_only: bool = True,
|
||||
min_target_var: float = 0.1,
|
||||
min_pred_var: float = 0.01,
|
||||
# Loss
|
||||
loss_beta: float = 0.0,
|
||||
loss_scale: float = None,
|
||||
# FP16 optimization
|
||||
required_seq_len_multiple: int = 2,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
# ConvFeatureExtractionModel
|
||||
self.conv_feature_layers = conv_feature_layers
|
||||
feature_enc_layers = eval(conv_feature_layers)
|
||||
self.extractor_embed = feature_enc_layers[-1][0]
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=extractor_mode,
|
||||
in_d=input_size,
|
||||
)
|
||||
|
||||
# Transformer Encoder
|
||||
## model architecture
|
||||
self.layer_type = layer_type
|
||||
self.layer_norm_first = layer_norm_first
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_embed_dim = encoder_embed_dim
|
||||
self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.activation_fn = activation_fn
|
||||
## dropout
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.dropout_input = dropout_input
|
||||
self.dropout_features = dropout_features
|
||||
## grad settings
|
||||
self.feature_grad_mult = feature_grad_mult
|
||||
## masking
|
||||
self.mask_prob = mask_prob
|
||||
self.mask_length = mask_length
|
||||
self.mask_selection = mask_selection
|
||||
self.mask_other = mask_other
|
||||
self.no_mask_overlap = no_mask_overlap
|
||||
self.mask_min_space = mask_min_space
|
||||
self.require_same_masks = require_same_masks # if set as True, collate_fn should be clipping
|
||||
self.mask_dropout = mask_dropout
|
||||
## channel masking
|
||||
self.mask_channel_length = mask_channel_length
|
||||
self.mask_channel_prob = mask_channel_prob
|
||||
self.mask_channel_before = mask_channel_before
|
||||
self.mask_channel_selection = mask_channel_selection
|
||||
self.mask_channel_other = mask_channel_other
|
||||
self.no_mask_channel_overlap = no_mask_channel_overlap
|
||||
self.mask_channel_min_space = mask_channel_min_space
|
||||
## positional embeddings
|
||||
self.conv_pos = conv_pos
|
||||
self.conv_pos_groups = conv_pos_groups
|
||||
self.pos_conv_depth = pos_conv_depth
|
||||
self.max_positions = max_positions
|
||||
self.mask_emb = nn.Parameter(torch.FloatTensor(self.encoder_embed_dim).uniform_())
|
||||
self.encoder = TransformerEncoder(
|
||||
dropout=self.dropout,
|
||||
encoder_embed_dim=self.encoder_embed_dim,
|
||||
required_seq_len_multiple=required_seq_len_multiple,
|
||||
pos_conv_depth=self.pos_conv_depth,
|
||||
conv_pos=self.conv_pos,
|
||||
conv_pos_groups=self.conv_pos_groups,
|
||||
# transformer layers
|
||||
layer_type=self.layer_type,
|
||||
encoder_layers=self.encoder_layers,
|
||||
encoder_ffn_embed_dim=self.encoder_ffn_embed_dim,
|
||||
encoder_attention_heads=self.encoder_attention_heads,
|
||||
attention_dropout=self.attention_dropout,
|
||||
activation_dropout=self.activation_dropout,
|
||||
activation_fn=self.activation_fn,
|
||||
layer_norm_first=self.layer_norm_first,
|
||||
encoder_layerdrop=self.encoder_layerdrop,
|
||||
max_positions=self.max_positions,
|
||||
)
|
||||
## projections and dropouts
|
||||
self.post_extract_proj = nn.Linear(self.extractor_embed, self.encoder_embed_dim)
|
||||
self.dropout_input = nn.Dropout(self.dropout_input)
|
||||
self.dropout_features = nn.Dropout(self.dropout_features)
|
||||
self.layer_norm = torch.nn.LayerNorm(self.extractor_embed)
|
||||
self.final_proj = nn.Linear(self.encoder_embed_dim, self.encoder_embed_dim)
|
||||
|
||||
# EMA module
|
||||
self.average_top_k_layers = average_top_k_layers
|
||||
self.layer_norm_target_layer = layer_norm_target_layer
|
||||
self.instance_norm_target_layer = instance_norm_target_layer
|
||||
self.instance_norm_targets = instance_norm_targets
|
||||
self.layer_norm_targets = layer_norm_targets
|
||||
self.batch_norm_target_layer = batch_norm_target_layer
|
||||
self.group_norm_target_layer = group_norm_target_layer
|
||||
self.ema_decay = ema_decay
|
||||
self.ema_end_decay = ema_end_decay
|
||||
self.ema_anneal_end_step = ema_anneal_end_step
|
||||
self.ema_transformer_only = ema_transformer_only
|
||||
self.ema_layers_only = ema_layers_only
|
||||
self.min_target_var = min_target_var
|
||||
self.min_pred_var = min_pred_var
|
||||
self.ema = None
|
||||
|
||||
# Loss
|
||||
self.loss_beta = loss_beta
|
||||
self.loss_scale = loss_scale
|
||||
|
||||
# FP16 optimization
|
||||
self.required_seq_len_multiple = required_seq_len_multiple
|
||||
|
||||
self.num_updates = 0
|
||||
|
||||
logging.info("Data2VecEncoder settings: {}".format(self.__dict__))
|
||||
|
||||
def make_ema_teacher(self):
|
||||
skip_keys = set()
|
||||
if self.ema_layers_only:
|
||||
self.ema_transformer_only = True
|
||||
for k, _ in self.encoder.pos_conv.named_parameters():
|
||||
skip_keys.add(f"pos_conv.{k}")
|
||||
|
||||
self.ema = EMAModule(
|
||||
self.encoder if self.ema_transformer_only else self,
|
||||
ema_decay=self.ema_decay,
|
||||
ema_fp32=True,
|
||||
skip_keys=skip_keys,
|
||||
)
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
if self.ema is None and self.final_proj is not None:
|
||||
logging.info("Making EMA Teacher")
|
||||
self.make_ema_teacher()
|
||||
elif self.training and self.ema is not None:
|
||||
if self.ema_decay != self.ema_end_decay:
|
||||
if num_updates >= self.ema_anneal_end_step:
|
||||
decay = self.ema_end_decay
|
||||
else:
|
||||
decay = get_annealed_rate(
|
||||
self.ema_decay,
|
||||
self.ema_end_decay,
|
||||
num_updates,
|
||||
self.ema_anneal_end_step,
|
||||
)
|
||||
self.ema.set_decay(decay)
|
||||
if self.ema.get_decay() < 1:
|
||||
self.ema.step(self.encoder if self.ema_transformer_only else self)
|
||||
|
||||
self.num_updates = num_updates
|
||||
|
||||
def apply_mask(
|
||||
self,
|
||||
x,
|
||||
padding_mask,
|
||||
mask_indices=None,
|
||||
mask_channel_indices=None,
|
||||
):
|
||||
B, T, C = x.shape
|
||||
|
||||
if self.mask_channel_prob > 0 and self.mask_channel_before:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
if self.mask_prob > 0:
|
||||
if mask_indices is None:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=1,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space,
|
||||
require_same_masks=self.require_same_masks,
|
||||
mask_dropout=self.mask_dropout,
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||
x[mask_indices] = self.mask_emb
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0 and not self.mask_channel_before:
|
||||
if mask_channel_indices is None:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
return torch.floor((input_length - kernel_size).to(torch.float32) / stride + 1)
|
||||
|
||||
conv_cfg_list = eval(self.conv_feature_layers)
|
||||
|
||||
for i in range(len(conv_cfg_list)):
|
||||
input_lengths = _conv_out_length(
|
||||
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
|
||||
)
|
||||
|
||||
return input_lengths.to(torch.long)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad,
|
||||
ilens=None,
|
||||
mask=False,
|
||||
features_only=True,
|
||||
layer=None,
|
||||
mask_indices=None,
|
||||
mask_channel_indices=None,
|
||||
padding_count=None,
|
||||
):
|
||||
# create padding_mask by ilens
|
||||
if ilens is not None:
|
||||
padding_mask = make_pad_mask(lengths=ilens).to(xs_pad.device)
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
features = xs_pad
|
||||
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(features)
|
||||
if self.feature_grad_mult != 1.0:
|
||||
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
features = self.feature_extractor(features)
|
||||
|
||||
features = features.transpose(1, 2)
|
||||
|
||||
features = self.layer_norm(features)
|
||||
|
||||
orig_padding_mask = padding_mask
|
||||
|
||||
if padding_mask is not None:
|
||||
input_lengths = (1 - padding_mask.long()).sum(-1)
|
||||
# apply conv formula to get real output_lengths
|
||||
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
|
||||
|
||||
padding_mask = torch.zeros(
|
||||
features.shape[:2], dtype=features.dtype, device=features.device
|
||||
)
|
||||
# these two operations makes sure that all values
|
||||
# before the output lengths indices are attended to
|
||||
padding_mask[
|
||||
(
|
||||
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
||||
output_lengths - 1,
|
||||
)
|
||||
] = 1
|
||||
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
pre_encoder_features = None
|
||||
if self.ema_transformer_only:
|
||||
pre_encoder_features = features.clone()
|
||||
|
||||
features = self.dropout_input(features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(
|
||||
features,
|
||||
padding_mask,
|
||||
mask_indices=mask_indices,
|
||||
mask_channel_indices=mask_channel_indices,
|
||||
)
|
||||
else:
|
||||
x = features
|
||||
mask_indices = None
|
||||
|
||||
x, layer_results = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
if features_only:
|
||||
encoder_out_lens = (1 - padding_mask.long()).sum(1)
|
||||
return x, encoder_out_lens, None
|
||||
|
||||
result = {
|
||||
"losses": {},
|
||||
"padding_mask": padding_mask,
|
||||
"x": x,
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
self.ema.model.eval()
|
||||
|
||||
if self.ema_transformer_only:
|
||||
y, layer_results = self.ema.model.extract_features(
|
||||
pre_encoder_features,
|
||||
padding_mask=padding_mask,
|
||||
min_layer=self.encoder_layers - self.average_top_k_layers,
|
||||
)
|
||||
y = {
|
||||
"x": y,
|
||||
"padding_mask": padding_mask,
|
||||
"layer_results": layer_results,
|
||||
}
|
||||
else:
|
||||
y = self.ema.model.extract_features(
|
||||
source=xs_pad,
|
||||
padding_mask=orig_padding_mask,
|
||||
mask=False,
|
||||
)
|
||||
|
||||
target_layer_results = [l[2] for l in y["layer_results"]]
|
||||
|
||||
permuted = False
|
||||
if self.instance_norm_target_layer or self.batch_norm_target_layer:
|
||||
target_layer_results = [
|
||||
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
|
||||
]
|
||||
permuted = True
|
||||
|
||||
if self.batch_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.batch_norm(
|
||||
tl.float(), running_mean=None, running_var=None, training=True
|
||||
)
|
||||
for tl in target_layer_results
|
||||
]
|
||||
|
||||
if self.instance_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.instance_norm(tl.float()) for tl in target_layer_results
|
||||
]
|
||||
|
||||
if permuted:
|
||||
target_layer_results = [
|
||||
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
||||
]
|
||||
|
||||
if self.group_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.layer_norm(tl.float(), tl.shape[-2:])
|
||||
for tl in target_layer_results
|
||||
]
|
||||
|
||||
if self.layer_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.layer_norm(tl.float(), tl.shape[-1:])
|
||||
for tl in target_layer_results
|
||||
]
|
||||
|
||||
y = sum(target_layer_results) / len(target_layer_results)
|
||||
|
||||
if self.layer_norm_targets:
|
||||
y = F.layer_norm(y.float(), y.shape[-1:])
|
||||
|
||||
if self.instance_norm_targets:
|
||||
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
if not permuted:
|
||||
y = y.transpose(0, 1)
|
||||
|
||||
y = y[mask_indices]
|
||||
|
||||
x = x[mask_indices]
|
||||
x = self.final_proj(x)
|
||||
|
||||
sz = x.size(-1)
|
||||
|
||||
if self.loss_beta == 0:
|
||||
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
|
||||
else:
|
||||
loss = F.smooth_l1_loss(
|
||||
x.float(), y.float(), reduction="none", beta=self.loss_beta
|
||||
).sum(dim=-1)
|
||||
|
||||
if self.loss_scale is not None:
|
||||
scale = self.loss_scale
|
||||
else:
|
||||
scale = 1 / math.sqrt(sz)
|
||||
|
||||
result["losses"]["regression"] = loss.sum() * scale
|
||||
|
||||
if "sample_size" not in result:
|
||||
result["sample_size"] = loss.numel()
|
||||
|
||||
with torch.no_grad():
|
||||
result["target_var"] = self.compute_var(y)
|
||||
result["pred_var"] = self.compute_var(x.float())
|
||||
|
||||
if self.num_updates > 5000 and result["target_var"] < self.min_target_var:
|
||||
logging.error(
|
||||
f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
|
||||
)
|
||||
raise Exception(
|
||||
f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
|
||||
)
|
||||
if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
|
||||
logging.error(
|
||||
f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
|
||||
)
|
||||
raise Exception(
|
||||
f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
|
||||
)
|
||||
|
||||
if self.ema is not None:
|
||||
result["ema_decay"] = self.ema.get_decay() * 1000
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def compute_var(y):
|
||||
y = y.view(-1, y.size(-1))
|
||||
if dist.is_initialized():
|
||||
zc = torch.tensor(y.size(0)).cuda()
|
||||
zs = y.sum(dim=0)
|
||||
zss = (y ** 2).sum(dim=0)
|
||||
|
||||
dist.all_reduce(zc)
|
||||
dist.all_reduce(zs)
|
||||
dist.all_reduce(zss)
|
||||
|
||||
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
||||
return torch.sqrt(var + 1e-6).mean()
|
||||
else:
|
||||
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
||||
|
||||
def extract_features(
|
||||
self, xs_pad, ilens, mask=False, layer=None
|
||||
):
|
||||
res = self.forward(
|
||||
xs_pad,
|
||||
ilens,
|
||||
mask=mask,
|
||||
features_only=True,
|
||||
layer=layer,
|
||||
)
|
||||
return res
|
||||
|
||||
def remove_pretraining_modules(self, last_layer=None):
|
||||
self.final_proj = None
|
||||
self.ema = None
|
||||
if last_layer is not None:
|
||||
self.encoder.layers = nn.ModuleList(
|
||||
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
||||
)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.encoder_embed_dim
|
||||
686
funasr_local/models/encoder/ecapa_tdnn_encoder.py
Normal file
686
funasr_local/models/encoder/ecapa_tdnn_encoder.py
Normal file
@@ -0,0 +1,686 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class _BatchNorm1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_shape=None,
|
||||
input_size=None,
|
||||
eps=1e-05,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
combine_batch_time=False,
|
||||
skip_transpose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.combine_batch_time = combine_batch_time
|
||||
self.skip_transpose = skip_transpose
|
||||
|
||||
if input_size is None and skip_transpose:
|
||||
input_size = input_shape[1]
|
||||
elif input_size is None:
|
||||
input_size = input_shape[-1]
|
||||
|
||||
self.norm = nn.BatchNorm1d(
|
||||
input_size,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
track_running_stats=track_running_stats,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shape_or = x.shape
|
||||
if self.combine_batch_time:
|
||||
if x.ndim == 3:
|
||||
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
|
||||
else:
|
||||
x = x.reshape(
|
||||
shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
|
||||
)
|
||||
|
||||
elif not self.skip_transpose:
|
||||
x = x.transpose(-1, 1)
|
||||
|
||||
x_n = self.norm(x)
|
||||
|
||||
if self.combine_batch_time:
|
||||
x_n = x_n.reshape(shape_or)
|
||||
elif not self.skip_transpose:
|
||||
x_n = x_n.transpose(1, -1)
|
||||
|
||||
return x_n
|
||||
|
||||
|
||||
class _Conv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
input_shape=None,
|
||||
in_channels=None,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding="same",
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="reflect",
|
||||
skip_transpose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
self.unsqueeze = False
|
||||
self.skip_transpose = skip_transpose
|
||||
|
||||
if input_shape is None and in_channels is None:
|
||||
raise ValueError("Must provide one of input_shape or in_channels")
|
||||
|
||||
if in_channels is None:
|
||||
in_channels = self._check_input_shape(input_shape)
|
||||
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
self.kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
padding=0,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if not self.skip_transpose:
|
||||
x = x.transpose(1, -1)
|
||||
|
||||
if self.unsqueeze:
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
if self.padding == "same":
|
||||
x = self._manage_padding(
|
||||
x, self.kernel_size, self.dilation, self.stride
|
||||
)
|
||||
|
||||
elif self.padding == "causal":
|
||||
num_pad = (self.kernel_size - 1) * self.dilation
|
||||
x = F.pad(x, (num_pad, 0))
|
||||
|
||||
elif self.padding == "valid":
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Padding must be 'same', 'valid' or 'causal'. Got "
|
||||
+ self.padding
|
||||
)
|
||||
|
||||
wx = self.conv(x)
|
||||
|
||||
if self.unsqueeze:
|
||||
wx = wx.squeeze(1)
|
||||
|
||||
if not self.skip_transpose:
|
||||
wx = wx.transpose(1, -1)
|
||||
|
||||
return wx
|
||||
|
||||
def _manage_padding(
|
||||
self, x, kernel_size: int, dilation: int, stride: int,
|
||||
):
|
||||
# Detecting input shape
|
||||
L_in = x.shape[-1]
|
||||
|
||||
# Time padding
|
||||
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
||||
|
||||
# Applying padding
|
||||
x = F.pad(x, padding, mode=self.padding_mode)
|
||||
|
||||
return x
|
||||
|
||||
def _check_input_shape(self, shape):
|
||||
"""Checks the input shape and returns the number of input channels.
|
||||
"""
|
||||
|
||||
if len(shape) == 2:
|
||||
self.unsqueeze = True
|
||||
in_channels = 1
|
||||
elif self.skip_transpose:
|
||||
in_channels = shape[1]
|
||||
elif len(shape) == 3:
|
||||
in_channels = shape[2]
|
||||
else:
|
||||
raise ValueError(
|
||||
"conv1d expects 2d, 3d inputs. Got " + str(len(shape))
|
||||
)
|
||||
|
||||
# Kernel size must be odd
|
||||
if self.kernel_size % 2 == 0:
|
||||
raise ValueError(
|
||||
"The field kernel size must be an odd number. Got %s."
|
||||
% (self.kernel_size)
|
||||
)
|
||||
return in_channels
|
||||
|
||||
|
||||
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
||||
if stride > 1:
|
||||
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
|
||||
L_out = stride * (n_steps - 1) + kernel_size * dilation
|
||||
padding = [kernel_size // 2, kernel_size // 2]
|
||||
|
||||
else:
|
||||
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
|
||||
return padding
|
||||
|
||||
|
||||
# Skip transpose as much as possible for efficiency
|
||||
class Conv1d(_Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(skip_transpose=True, *args, **kwargs)
|
||||
|
||||
|
||||
class BatchNorm1d(_BatchNorm1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(skip_transpose=True, *args, **kwargs)
|
||||
|
||||
|
||||
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
||||
assert len(length.shape) == 1
|
||||
|
||||
if max_len is None:
|
||||
max_len = length.max().long().item() # using arange to generate mask
|
||||
mask = torch.arange(
|
||||
max_len, device=length.device, dtype=length.dtype
|
||||
).expand(len(length), max_len) < length.unsqueeze(1)
|
||||
|
||||
if dtype is None:
|
||||
dtype = length.dtype
|
||||
|
||||
if device is None:
|
||||
device = length.device
|
||||
|
||||
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
||||
return mask
|
||||
|
||||
|
||||
class TDNNBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
activation=nn.ReLU,
|
||||
groups=1,
|
||||
):
|
||||
super(TDNNBlock, self).__init__()
|
||||
self.conv = Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
self.activation = activation()
|
||||
self.norm = BatchNorm1d(input_size=out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.activation(self.conv(x)))
|
||||
|
||||
|
||||
class Res2NetBlock(torch.nn.Module):
|
||||
"""An implementation of Res2NetBlock w/ dilation.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
in_channels : int
|
||||
The number of channels expected in the input.
|
||||
out_channels : int
|
||||
The number of output channels.
|
||||
scale : int
|
||||
The scale of the Res2Net block.
|
||||
kernel_size: int
|
||||
The kernel size of the Res2Net block.
|
||||
dilation : int
|
||||
The dilation of the Res2Net block.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
||||
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
|
||||
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
||||
>>> out_tensor.shape
|
||||
torch.Size([8, 120, 64])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
|
||||
):
|
||||
super(Res2NetBlock, self).__init__()
|
||||
assert in_channels % scale == 0
|
||||
assert out_channels % scale == 0
|
||||
|
||||
in_channel = in_channels // scale
|
||||
hidden_channel = out_channels // scale
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
TDNNBlock(
|
||||
in_channel,
|
||||
hidden_channel,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
)
|
||||
for i in range(scale - 1)
|
||||
]
|
||||
)
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
y = []
|
||||
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
||||
if i == 0:
|
||||
y_i = x_i
|
||||
elif i == 1:
|
||||
y_i = self.blocks[i - 1](x_i)
|
||||
else:
|
||||
y_i = self.blocks[i - 1](x_i + y_i)
|
||||
y.append(y_i)
|
||||
y = torch.cat(y, dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
"""An implementation of squeeze-and-excitation block.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
in_channels : int
|
||||
The number of input channels.
|
||||
se_channels : int
|
||||
The number of output channels after squeeze.
|
||||
out_channels : int
|
||||
The number of output channels.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
||||
>>> se_layer = SEBlock(64, 16, 64)
|
||||
>>> lengths = torch.rand((8,))
|
||||
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
|
||||
>>> out_tensor.shape
|
||||
torch.Size([8, 120, 64])
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, se_channels, out_channels):
|
||||
super(SEBlock, self).__init__()
|
||||
|
||||
self.conv1 = Conv1d(
|
||||
in_channels=in_channels, out_channels=se_channels, kernel_size=1
|
||||
)
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
self.conv2 = Conv1d(
|
||||
in_channels=se_channels, out_channels=out_channels, kernel_size=1
|
||||
)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
L = x.shape[-1]
|
||||
if lengths is not None:
|
||||
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
||||
mask = mask.unsqueeze(1)
|
||||
total = mask.sum(dim=2, keepdim=True)
|
||||
s = (x * mask).sum(dim=2, keepdim=True) / total
|
||||
else:
|
||||
s = x.mean(dim=2, keepdim=True)
|
||||
|
||||
s = self.relu(self.conv1(s))
|
||||
s = self.sigmoid(self.conv2(s))
|
||||
|
||||
return s * x
|
||||
|
||||
|
||||
class AttentiveStatisticsPooling(nn.Module):
|
||||
"""This class implements an attentive statistic pooling layer for each channel.
|
||||
It returns the concatenated mean and std of the input tensor.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
channels: int
|
||||
The number of input channels.
|
||||
attention_channels: int
|
||||
The number of attention channels.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
||||
>>> asp_layer = AttentiveStatisticsPooling(64)
|
||||
>>> lengths = torch.rand((8,))
|
||||
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
|
||||
>>> out_tensor.shape
|
||||
torch.Size([8, 1, 128])
|
||||
"""
|
||||
|
||||
def __init__(self, channels, attention_channels=128, global_context=True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = 1e-12
|
||||
self.global_context = global_context
|
||||
if global_context:
|
||||
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
|
||||
else:
|
||||
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
|
||||
self.tanh = nn.Tanh()
|
||||
self.conv = Conv1d(
|
||||
in_channels=attention_channels, out_channels=channels, kernel_size=1
|
||||
)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
"""Calculates mean and std for a batch (input tensor).
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Tensor of shape [N, C, L].
|
||||
"""
|
||||
L = x.shape[-1]
|
||||
|
||||
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
||||
mean = (m * x).sum(dim)
|
||||
std = torch.sqrt(
|
||||
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
|
||||
)
|
||||
return mean, std
|
||||
|
||||
if lengths is None:
|
||||
lengths = torch.ones(x.shape[0], device=x.device)
|
||||
|
||||
# Make binary mask of shape [N, 1, L]
|
||||
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Expand the temporal context of the pooling layer by allowing the
|
||||
# self-attention to look at global properties of the utterance.
|
||||
if self.global_context:
|
||||
# torch.std is unstable for backward computation
|
||||
# https://github.com/pytorch/pytorch/issues/4320
|
||||
total = mask.sum(dim=2, keepdim=True).float()
|
||||
mean, std = _compute_statistics(x, mask / total)
|
||||
mean = mean.unsqueeze(2).repeat(1, 1, L)
|
||||
std = std.unsqueeze(2).repeat(1, 1, L)
|
||||
attn = torch.cat([x, mean, std], dim=1)
|
||||
else:
|
||||
attn = x
|
||||
|
||||
# Apply layers
|
||||
attn = self.conv(self.tanh(self.tdnn(attn)))
|
||||
|
||||
# Filter out zero-paddings
|
||||
attn = attn.masked_fill(mask == 0, float("-inf"))
|
||||
|
||||
attn = F.softmax(attn, dim=2)
|
||||
mean, std = _compute_statistics(x, attn)
|
||||
# Append mean and std of the batch
|
||||
pooled_stats = torch.cat((mean, std), dim=1)
|
||||
pooled_stats = pooled_stats.unsqueeze(2)
|
||||
|
||||
return pooled_stats
|
||||
|
||||
|
||||
class SERes2NetBlock(nn.Module):
|
||||
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
||||
TDNN-Res2Net-TDNN-SEBlock.
|
||||
|
||||
Arguments
|
||||
----------
|
||||
out_channels: int
|
||||
The number of output channels.
|
||||
res2net_scale: int
|
||||
The scale of the Res2Net block.
|
||||
kernel_size: int
|
||||
The kernel size of the TDNN blocks.
|
||||
dilation: int
|
||||
The dilation of the Res2Net block.
|
||||
activation : torch class
|
||||
A class for constructing the activation layers.
|
||||
groups: int
|
||||
Number of blocked connections from input channels to output channels.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> x = torch.rand(8, 120, 64).transpose(1, 2)
|
||||
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
|
||||
>>> out = conv(x).transpose(1, 2)
|
||||
>>> out.shape
|
||||
torch.Size([8, 120, 64])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
res2net_scale=8,
|
||||
se_channels=128,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=torch.nn.ReLU,
|
||||
groups=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.tdnn1 = TDNNBlock(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=activation,
|
||||
groups=groups,
|
||||
)
|
||||
self.res2net_block = Res2NetBlock(
|
||||
out_channels, out_channels, res2net_scale, kernel_size, dilation
|
||||
)
|
||||
self.tdnn2 = TDNNBlock(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=activation,
|
||||
groups=groups,
|
||||
)
|
||||
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
||||
|
||||
self.shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
residual = x
|
||||
if self.shortcut:
|
||||
residual = self.shortcut(x)
|
||||
|
||||
x = self.tdnn1(x)
|
||||
x = self.res2net_block(x)
|
||||
x = self.tdnn2(x)
|
||||
x = self.se_block(x, lengths)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class ECAPA_TDNN(torch.nn.Module):
|
||||
"""An implementation of the speaker embedding model in a paper.
|
||||
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
||||
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
|
||||
|
||||
Arguments
|
||||
---------
|
||||
activation : torch class
|
||||
A class for constructing the activation layers.
|
||||
channels : list of ints
|
||||
Output channels for TDNN/SERes2Net layer.
|
||||
kernel_sizes : list of ints
|
||||
List of kernel sizes for each layer.
|
||||
dilations : list of ints
|
||||
List of dilations for kernels in each layer.
|
||||
lin_neurons : int
|
||||
Number of neurons in linear layers.
|
||||
groups : list of ints
|
||||
List of groups for kernels in each layer.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> input_feats = torch.rand([5, 120, 80])
|
||||
>>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
|
||||
>>> outputs = compute_embedding(input_feats)
|
||||
>>> outputs.shape
|
||||
torch.Size([5, 1, 192])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
lin_neurons=192,
|
||||
activation=torch.nn.ReLU,
|
||||
channels=[512, 512, 512, 512, 1536],
|
||||
kernel_sizes=[5, 3, 3, 3, 1],
|
||||
dilations=[1, 2, 3, 4, 1],
|
||||
attention_channels=128,
|
||||
res2net_scale=8,
|
||||
se_channels=128,
|
||||
global_context=True,
|
||||
groups=[1, 1, 1, 1, 1],
|
||||
window_size=20,
|
||||
window_shift=1,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
assert len(channels) == len(kernel_sizes)
|
||||
assert len(channels) == len(dilations)
|
||||
self.channels = channels
|
||||
self.blocks = nn.ModuleList()
|
||||
self.window_size = window_size
|
||||
self.window_shift = window_shift
|
||||
|
||||
# The initial TDNN layer
|
||||
self.blocks.append(
|
||||
TDNNBlock(
|
||||
input_size,
|
||||
channels[0],
|
||||
kernel_sizes[0],
|
||||
dilations[0],
|
||||
activation,
|
||||
groups[0],
|
||||
)
|
||||
)
|
||||
|
||||
# SE-Res2Net layers
|
||||
for i in range(1, len(channels) - 1):
|
||||
self.blocks.append(
|
||||
SERes2NetBlock(
|
||||
channels[i - 1],
|
||||
channels[i],
|
||||
res2net_scale=res2net_scale,
|
||||
se_channels=se_channels,
|
||||
kernel_size=kernel_sizes[i],
|
||||
dilation=dilations[i],
|
||||
activation=activation,
|
||||
groups=groups[i],
|
||||
)
|
||||
)
|
||||
|
||||
# Multi-layer feature aggregation
|
||||
self.mfa = TDNNBlock(
|
||||
channels[-1],
|
||||
channels[-1],
|
||||
kernel_sizes[-1],
|
||||
dilations[-1],
|
||||
activation,
|
||||
groups=groups[-1],
|
||||
)
|
||||
|
||||
# Attentive Statistical Pooling
|
||||
self.asp = AttentiveStatisticsPooling(
|
||||
channels[-1],
|
||||
attention_channels=attention_channels,
|
||||
global_context=global_context,
|
||||
)
|
||||
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
|
||||
|
||||
# Final linear transformation
|
||||
self.fc = Conv1d(
|
||||
in_channels=channels[-1] * 2,
|
||||
out_channels=lin_neurons,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def windowed_pooling(self, x, lengths=None):
|
||||
# x: Batch, Channel, Time
|
||||
tt = x.shape[2]
|
||||
num_chunk = int(math.ceil(tt / self.window_shift))
|
||||
pad = self.window_size // 2
|
||||
x = F.pad(x, (pad, pad, 0, 0), "reflect")
|
||||
stat_list = []
|
||||
|
||||
for i in range(num_chunk):
|
||||
# B x C
|
||||
st, ed = i * self.window_shift, i * self.window_shift + self.window_size
|
||||
x = self.asp(x[:, :, st: ed],
|
||||
lengths=torch.clamp(lengths - i, 0, self.window_size)
|
||||
if lengths is not None else None)
|
||||
x = self.asp_bn(x)
|
||||
x = self.fc(x)
|
||||
stat_list.append(x)
|
||||
|
||||
return torch.cat(stat_list, dim=2)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
"""Returns the embedding vector.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Tensor of shape (batch, time, channel).
|
||||
lengths: torch.Tensor
|
||||
Tensor of shape (batch, )
|
||||
"""
|
||||
# Minimize transpose for efficiency
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
xl = []
|
||||
for layer in self.blocks:
|
||||
try:
|
||||
x = layer(x, lengths=lengths)
|
||||
except TypeError:
|
||||
x = layer(x)
|
||||
xl.append(x)
|
||||
|
||||
# Multi-layer feature aggregation
|
||||
x = torch.cat(xl[1:], dim=1)
|
||||
x = self.mfa(x)
|
||||
|
||||
if self.window_size is None:
|
||||
# Attentive Statistical Pooling
|
||||
x = self.asp(x, lengths=lengths)
|
||||
x = self.asp_bn(x)
|
||||
# Final linear transformation
|
||||
x = self.fc(x)
|
||||
# x = x.transpose(1, 2)
|
||||
x = x.squeeze(2) # -> B, C
|
||||
else:
|
||||
x = self.windowed_pooling(x, lengths)
|
||||
x = x.transpose(1, 2) # -> B, T, C
|
||||
return x
|
||||
270
funasr_local/models/encoder/encoder_layer_mfcca.py
Normal file
270
funasr_local/models/encoder/encoder_layer_mfcca.py
Normal file
@@ -0,0 +1,270 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder self-attention layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
|
||||
class Encoder_Conformer_Layer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
conv_module (torch.nn.Module): Convolution module instance.
|
||||
`ConvlutionModule` instance can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
feed_forward_macaron,
|
||||
conv_module,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
cca_pos=0,
|
||||
):
|
||||
"""Construct an Encoder_Conformer_Layer object."""
|
||||
super(Encoder_Conformer_Layer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.feed_forward_macaron = feed_forward_macaron
|
||||
self.conv_module = conv_module
|
||||
self.norm_ff = LayerNorm(size) # for the FNN module
|
||||
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||
if feed_forward_macaron is not None:
|
||||
self.norm_ff_macaron = LayerNorm(size)
|
||||
self.ff_scale = 0.5
|
||||
else:
|
||||
self.ff_scale = 1.0
|
||||
if self.conv_module is not None:
|
||||
self.norm_conv = LayerNorm(size) # for the CNN module
|
||||
self.norm_final = LayerNorm(size) # for the final output of the block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
self.cca_pos = cca_pos
|
||||
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, x_input, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
||||
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
||||
- w/o pos emb: Tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
if isinstance(x_input, tuple):
|
||||
x, pos_emb = x_input[0], x_input[1]
|
||||
else:
|
||||
x, pos_emb = x_input, None
|
||||
# whether to use macaron style
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if self.cca_pos<2:
|
||||
if pos_emb is not None:
|
||||
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, x_att), dim=-1)
|
||||
x = residual + self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
# convolution module
|
||||
if self.conv_module is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
x = residual + self.dropout(self.conv_module(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
|
||||
# feed forward module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
|
||||
if self.conv_module is not None:
|
||||
x = self.norm_final(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
conv_module (torch.nn.Module): Convolution module instance.
|
||||
`ConvlutionModule` instance can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn_cros_channel,
|
||||
self_attn_conformer,
|
||||
feed_forward_csa,
|
||||
feed_forward_macaron_csa,
|
||||
conv_module_csa,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
|
||||
self.encoder_cros_channel_atten = self_attn_cros_channel
|
||||
self.encoder_csa = Encoder_Conformer_Layer(
|
||||
size,
|
||||
self_attn_conformer,
|
||||
feed_forward_csa,
|
||||
feed_forward_macaron_csa,
|
||||
conv_module_csa,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
cca_pos=0)
|
||||
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
|
||||
def forward(self, x_input, mask, channel_size, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
||||
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
||||
- w/o pos emb: Tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
if isinstance(x_input, tuple):
|
||||
x, pos_emb = x_input[0], x_input[1]
|
||||
else:
|
||||
x, pos_emb = x_input, None
|
||||
residual = x
|
||||
x = self.norm_mha(x)
|
||||
t_leng = x.size(1)
|
||||
d_dim = x.size(2)
|
||||
x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
|
||||
x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
|
||||
pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
|
||||
pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
|
||||
x_pad = torch.cat([pad_before,x_new, pad_after], 1)
|
||||
x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
|
||||
x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
|
||||
x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
|
||||
x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
|
||||
x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
|
||||
x_new = x_new.reshape(-1,channel_size,d_dim)
|
||||
x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
|
||||
x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
|
||||
x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
|
||||
x = residual + self.dropout(x_att)
|
||||
if pos_emb is not None:
|
||||
x_input = (x, pos_emb)
|
||||
else:
|
||||
x_input = x
|
||||
x_input, mask = self.encoder_csa(x_input, mask)
|
||||
|
||||
|
||||
return x_input, mask , channel_size
|
||||
301
funasr_local/models/encoder/fsmn_encoder.py
Normal file
301
funasr_local/models/encoder/fsmn_encoder.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from typing import Tuple, Dict
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LinearTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(LinearTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AffineTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(AffineTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RectifiedLinear(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(RectifiedLinear, self).__init__()
|
||||
self.dim = input_dim
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.relu(input)
|
||||
return out
|
||||
|
||||
|
||||
class FSMNBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
lorder=None,
|
||||
rorder=None,
|
||||
lstride=1,
|
||||
rstride=1,
|
||||
):
|
||||
super(FSMNBlock, self).__init__()
|
||||
|
||||
self.dim = input_dim
|
||||
|
||||
if lorder is None:
|
||||
return
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
|
||||
|
||||
if self.rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
def forward(self, input: torch.Tensor, cache: torch.Tensor):
|
||||
x = torch.unsqueeze(input, 1)
|
||||
x_per = x.permute(0, 3, 2, 1) # B D T C
|
||||
|
||||
cache = cache.to(x_per.device)
|
||||
y_left = torch.cat((cache, x_per), dim=2)
|
||||
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
|
||||
y_left = self.conv_left(y_left)
|
||||
out = x_per + y_left
|
||||
|
||||
if self.conv_right is not None:
|
||||
# maybe need to check
|
||||
y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
y_right = self.conv_right(y_right)
|
||||
out += y_right
|
||||
|
||||
out_per = out.permute(0, 3, 2, 1)
|
||||
output = out_per.squeeze(1)
|
||||
|
||||
return output, cache
|
||||
|
||||
|
||||
class BasicBlock(nn.Sequential):
|
||||
def __init__(self,
|
||||
linear_dim: int,
|
||||
proj_dim: int,
|
||||
lorder: int,
|
||||
rorder: int,
|
||||
lstride: int,
|
||||
rstride: int,
|
||||
stack_layer: int
|
||||
):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
self.stack_layer = stack_layer
|
||||
self.linear = LinearTransform(linear_dim, proj_dim)
|
||||
self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
|
||||
self.affine = AffineTransform(proj_dim, linear_dim)
|
||||
self.relu = RectifiedLinear(linear_dim, linear_dim)
|
||||
|
||||
def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
|
||||
x1 = self.linear(input) # B T D
|
||||
cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
|
||||
if cache_layer_name not in in_cache:
|
||||
in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
|
||||
x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
|
||||
x3 = self.affine(x2)
|
||||
x4 = self.relu(x3)
|
||||
return x4
|
||||
|
||||
|
||||
class FsmnStack(nn.Sequential):
|
||||
def __init__(self, *args):
|
||||
super(FsmnStack, self).__init__(*args)
|
||||
|
||||
def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
|
||||
x = input
|
||||
for module in self._modules.values():
|
||||
x = module(x, in_cache)
|
||||
return x
|
||||
|
||||
|
||||
'''
|
||||
FSMN net for keyword spotting
|
||||
input_dim: input dimension
|
||||
linear_dim: fsmn input dimensionll
|
||||
proj_dim: fsmn projection dimension
|
||||
lorder: fsmn left order
|
||||
rorder: fsmn right order
|
||||
num_syn: output dimension
|
||||
fsmn_layers: no. of sequential fsmn layers
|
||||
'''
|
||||
|
||||
|
||||
class FSMN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
input_affine_dim: int,
|
||||
fsmn_layers: int,
|
||||
linear_dim: int,
|
||||
proj_dim: int,
|
||||
lorder: int,
|
||||
rorder: int,
|
||||
lstride: int,
|
||||
rstride: int,
|
||||
output_affine_dim: int,
|
||||
output_dim: int
|
||||
):
|
||||
super(FSMN, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_affine_dim = input_affine_dim
|
||||
self.fsmn_layers = fsmn_layers
|
||||
self.linear_dim = linear_dim
|
||||
self.proj_dim = proj_dim
|
||||
self.output_affine_dim = output_affine_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
|
||||
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
|
||||
self.relu = RectifiedLinear(linear_dim, linear_dim)
|
||||
self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
|
||||
range(fsmn_layers)])
|
||||
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
|
||||
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def fuse_modules(self):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
in_cache: Dict[str, torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor (B, T, D)
|
||||
in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
|
||||
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
|
||||
"""
|
||||
|
||||
x1 = self.in_linear1(input)
|
||||
x2 = self.in_linear2(x1)
|
||||
x3 = self.relu(x2)
|
||||
x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
|
||||
x5 = self.out_linear1(x4)
|
||||
x6 = self.out_linear2(x5)
|
||||
x7 = self.softmax(x6)
|
||||
|
||||
return x7
|
||||
|
||||
|
||||
'''
|
||||
one deep fsmn layer
|
||||
dimproj: projection dimension, input and output dimension of memory blocks
|
||||
dimlinear: dimension of mapping layer
|
||||
lorder: left order
|
||||
rorder: right order
|
||||
lstride: left stride
|
||||
rstride: right stride
|
||||
'''
|
||||
|
||||
|
||||
class DFSMN(nn.Module):
|
||||
|
||||
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
|
||||
super(DFSMN, self).__init__()
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.expand = AffineTransform(dimproj, dimlinear)
|
||||
self.shrink = LinearTransform(dimlinear, dimproj)
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
|
||||
|
||||
if rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
def forward(self, input):
|
||||
f1 = F.relu(self.expand(input))
|
||||
p1 = self.shrink(f1)
|
||||
|
||||
x = torch.unsqueeze(p1, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
|
||||
|
||||
if self.conv_right is not None:
|
||||
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
|
||||
else:
|
||||
out = x_per + self.conv_left(y_left)
|
||||
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
output = input + out1.squeeze(1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
'''
|
||||
build stacked dfsmn layers
|
||||
'''
|
||||
|
||||
|
||||
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
|
||||
repeats = [
|
||||
nn.Sequential(
|
||||
DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
|
||||
for i in range(fsmn_layers)
|
||||
]
|
||||
|
||||
return nn.Sequential(*repeats)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
|
||||
print(fsmn)
|
||||
|
||||
num_params = sum(p.numel() for p in fsmn.parameters())
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
x = torch.zeros(128, 200, 400) # batch-size * time * dim
|
||||
y, _ = fsmn(x) # batch-size * time * dim
|
||||
print('input shape: {}'.format(x.shape))
|
||||
print('output shape: {}'.format(y.shape))
|
||||
|
||||
print(fsmn.to_kaldi_net())
|
||||
450
funasr_local/models/encoder/mfcca_encoder.py
Normal file
450
funasr_local/models/encoder/mfcca_encoder.py
Normal file
@@ -0,0 +1,450 @@
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.models.encoder.encoder_layer_mfcca import EncoderLayer
|
||||
from funasr_local.modules.nets_utils import get_activation
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.attention import (
|
||||
MultiHeadedAttention, # noqa: H301
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
LegacyRelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from funasr_local.modules.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
RelPositionalEncoding, # noqa: H301
|
||||
LegacyRelPositionalEncoding, # noqa: H301
|
||||
)
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from funasr_local.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr_local.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr_local.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr_local.modules.repeat import repeat
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr_local.modules.subsampling import TooShortUttError
|
||||
from funasr_local.modules.subsampling import check_short_utt
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
import pdb
|
||||
import math
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, channels).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
|
||||
class MFCCAEncoder(AbsEncoder):
|
||||
"""Conformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size (int): Input dimension.
|
||||
output_size (int): Dimention of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
If True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
If False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
rel_pos_type (str): Whether to use the latest relative positional encoding or
|
||||
the legacy one. The legacy relative positional encoding will be deprecated
|
||||
in the future. More Details can be found in
|
||||
https://github.com/espnet/espnet/pull/2816.
|
||||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
encoder_attn_layer_type (str): Encoder attention layer type.
|
||||
activation_type (str): Encoder activation function type.
|
||||
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
cnn_module_kernel (int): Kernerl size of convolution module.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 3,
|
||||
macaron_style: bool = False,
|
||||
rel_pos_type: str = "legacy",
|
||||
pos_enc_layer_type: str = "rel_pos",
|
||||
selfattention_layer_type: str = "rel_selfattn",
|
||||
activation_type: str = "swish",
|
||||
use_cnn_module: bool = True,
|
||||
zero_triu: bool = False,
|
||||
cnn_module_kernel: int = 31,
|
||||
padding_idx: int = -1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if rel_pos_type == "legacy":
|
||||
if pos_enc_layer_type == "rel_pos":
|
||||
pos_enc_layer_type = "legacy_rel_pos"
|
||||
if selfattention_layer_type == "rel_selfattn":
|
||||
selfattention_layer_type = "legacy_rel_selfattn"
|
||||
elif rel_pos_type == "latest":
|
||||
assert selfattention_layer_type != "legacy_rel_selfattn"
|
||||
assert pos_enc_layer_type != "legacy_rel_pos"
|
||||
else:
|
||||
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
elif pos_enc_layer_type == "legacy_rel_pos":
|
||||
assert selfattention_layer_type == "legacy_rel_selfattn"
|
||||
pos_enc_class = LegacyRelPositionalEncoding
|
||||
logging.warning(
|
||||
"Using legacy_rel_pos and it will be deprecated in the future."
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(output_size, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "legacy_rel_selfattn":
|
||||
assert pos_enc_layer_type == "legacy_rel_pos"
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
logging.warning(
|
||||
"Using legacy_rel_selfattn and it will be deprecated in the future."
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
zero_triu,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation)
|
||||
encoder_selfattn_layer_raw = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args_raw = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
|
||||
|
||||
self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
|
||||
|
||||
self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
|
||||
|
||||
self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
channel_size: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
|
||||
ilens (torch.Tensor): Input length (#batch).
|
||||
prev_states (torch.Tensor): Not to be used now.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, L, output_size).
|
||||
torch.Tensor: Output length (#batch).
|
||||
torch.Tensor: Not to be used now.
|
||||
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
if (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
|
||||
t_leng = xs_pad.size(1)
|
||||
d_dim = xs_pad.size(2)
|
||||
xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
|
||||
#pdb.set_trace()
|
||||
if(channel_size<8):
|
||||
repeat_num = math.ceil(8/channel_size)
|
||||
xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
|
||||
xs_pad = self.conv1(xs_pad)
|
||||
xs_pad = self.conv2(xs_pad)
|
||||
xs_pad = self.conv3(xs_pad)
|
||||
xs_pad = self.conv4(xs_pad)
|
||||
xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
|
||||
mask_tmp = masks.size(1)
|
||||
masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
return xs_pad, olens, None
|
||||
def forward_hidden(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
|
||||
ilens (torch.Tensor): Input length (#batch).
|
||||
prev_states (torch.Tensor): Not to be used now.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, L, output_size).
|
||||
torch.Tensor: Output length (#batch).
|
||||
torch.Tensor: Not to be used now.
|
||||
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
if (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
num_layer = len(self.encoders)
|
||||
for idx, encoder in enumerate(self.encoders):
|
||||
xs_pad, masks = encoder(xs_pad, masks)
|
||||
if idx == num_layer // 2 - 1:
|
||||
hidden_feature = xs_pad
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
hidden_feature = hidden_feature[0]
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
self.hidden_feature = self.after_norm(hidden_feature)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
return xs_pad, olens, None
|
||||
38
funasr_local/models/encoder/opennmt_encoders/ci_scorers.py
Normal file
38
funasr_local/models/encoder/opennmt_encoders/ci_scorers.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class DotScorer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
spk_emb: torch.Tensor,
|
||||
):
|
||||
# xs_pad: B, T, D
|
||||
# spk_emb: B, N, D
|
||||
scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
|
||||
return scores
|
||||
|
||||
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
|
||||
return {}
|
||||
|
||||
|
||||
class CosScorer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
spk_emb: torch.Tensor,
|
||||
):
|
||||
# xs_pad: B, T, D
|
||||
# spk_emb: B, N, D
|
||||
scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
|
||||
return scores
|
||||
|
||||
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
|
||||
return {}
|
||||
277
funasr_local/models/encoder/opennmt_encoders/conv_encoder.py
Normal file
277
funasr_local/models/encoder/opennmt_encoders/conv_encoder.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
import math
|
||||
from funasr_local.modules.repeat import repeat
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_units,
|
||||
num_units,
|
||||
kernel_size=3,
|
||||
activation="tanh",
|
||||
stride=1,
|
||||
include_batch_norm=False,
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
left_padding = math.ceil((kernel_size - stride) / 2)
|
||||
right_padding = kernel_size - stride - left_padding
|
||||
self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
self.conv1d = nn.Conv1d(
|
||||
input_units,
|
||||
num_units,
|
||||
kernel_size,
|
||||
stride,
|
||||
)
|
||||
self.activation = self.get_activation(activation)
|
||||
if include_batch_norm:
|
||||
self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
|
||||
self.residual = residual
|
||||
self.include_batch_norm = include_batch_norm
|
||||
self.input_units = input_units
|
||||
self.num_units = num_units
|
||||
self.stride = stride
|
||||
|
||||
@staticmethod
|
||||
def get_activation(activation):
|
||||
if activation == "tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
return nn.ReLU()
|
||||
|
||||
def forward(self, xs_pad, ilens=None):
|
||||
outputs = self.conv1d(self.conv_padding(xs_pad))
|
||||
if self.residual and self.stride == 1 and self.input_units == self.num_units:
|
||||
outputs = outputs + xs_pad
|
||||
|
||||
if self.include_batch_norm:
|
||||
outputs = self.bn(outputs)
|
||||
|
||||
# add parenthesis for repeat module
|
||||
return self.activation(outputs), ilens
|
||||
|
||||
|
||||
class ConvEncoder(AbsEncoder):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Convolution encoder in OpenNMT framework
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
input_units,
|
||||
num_units,
|
||||
kernel_size=3,
|
||||
dropout_rate=0.3,
|
||||
position_encoder=None,
|
||||
activation='tanh',
|
||||
auxiliary_states=True,
|
||||
out_units=None,
|
||||
out_norm=False,
|
||||
out_residual=False,
|
||||
include_batchnorm=False,
|
||||
regularization_weight=0.0,
|
||||
stride=1,
|
||||
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = num_units
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.input_units = input_units
|
||||
self.num_units = num_units
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_rate = dropout_rate
|
||||
self.position_encoder = position_encoder
|
||||
self.out_units = out_units
|
||||
self.auxiliary_states = auxiliary_states
|
||||
self.out_norm = out_norm
|
||||
self.activation = activation
|
||||
self.out_residual = out_residual
|
||||
self.include_batch_norm = include_batchnorm
|
||||
self.regularization_weight = regularization_weight
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
if isinstance(stride, int):
|
||||
self.stride = [stride] * self.num_layers
|
||||
else:
|
||||
self.stride = stride
|
||||
self.downsample_rate = 1
|
||||
for s in self.stride:
|
||||
self.downsample_rate *= s
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.cnn_a = repeat(
|
||||
self.num_layers,
|
||||
lambda lnum: EncoderLayer(
|
||||
input_units if lnum == 0 else num_units,
|
||||
num_units,
|
||||
kernel_size,
|
||||
activation,
|
||||
self.stride[lnum],
|
||||
include_batchnorm,
|
||||
residual=True if lnum > 0 else False
|
||||
)
|
||||
)
|
||||
|
||||
if self.out_units is not None:
|
||||
left_padding = math.ceil((kernel_size - stride) / 2)
|
||||
right_padding = kernel_size - stride - left_padding
|
||||
self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
self.conv_out = nn.Conv1d(
|
||||
num_units,
|
||||
out_units,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
if self.out_norm:
|
||||
self.after_norm = LayerNorm(out_units)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_units
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
inputs = xs_pad
|
||||
if self.position_encoder is not None:
|
||||
inputs = self.position_encoder(inputs)
|
||||
|
||||
if self.dropout_rate > 0:
|
||||
inputs = self.dropout(inputs)
|
||||
|
||||
outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
|
||||
|
||||
if self.out_units is not None:
|
||||
outputs = self.conv_out(self.out_padding(outputs))
|
||||
|
||||
outputs = outputs.transpose(1, 2)
|
||||
if self.out_norm:
|
||||
outputs = self.after_norm(outputs)
|
||||
|
||||
if self.out_residual:
|
||||
outputs = outputs + inputs
|
||||
|
||||
return outputs, ilens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
"{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
|
||||
"{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
}
|
||||
if self.out_units is not None:
|
||||
# add output layer
|
||||
map_dict_local.update({
|
||||
"{}.conv_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
}, # tf: (1, 256, 256) -> torch: (256, 256, 1)
|
||||
"{}.conv_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # tf: (256,) -> torch: (256,)
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
# process special (first and last) layers
|
||||
if name in map_dict:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
# process general layers
|
||||
else:
|
||||
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
|
||||
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
335
funasr_local/models/encoder/opennmt_encoders/fsmn_encoder.py
Normal file
335
funasr_local/models/encoder/opennmt_encoders/fsmn_encoder.py
Normal file
@@ -0,0 +1,335 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
import math
|
||||
from funasr_local.modules.repeat import repeat
|
||||
from funasr_local.modules.multi_layer_conv import FsmnFeedForward
|
||||
|
||||
|
||||
class FsmnBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
fsmn_shift=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1,
|
||||
padding=0, groups=n_feat, bias=False)
|
||||
# padding
|
||||
left_padding = (kernel_size - 1) // 2
|
||||
if fsmn_shift > 0:
|
||||
left_padding = left_padding + fsmn_shift
|
||||
right_padding = kernel_size - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
|
||||
def forward(self, inputs, mask, mask_shfit_chunk=None):
|
||||
b, t, d = inputs.size()
|
||||
if mask is not None:
|
||||
mask = torch.reshape(mask, (b, -1, 1))
|
||||
if mask_shfit_chunk is not None:
|
||||
mask = mask * mask_shfit_chunk
|
||||
|
||||
inputs = inputs * mask
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
x = x + inputs
|
||||
x = self.dropout(x)
|
||||
return x * mask
|
||||
|
||||
|
||||
class EncoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
feed_forward,
|
||||
fsmn_block,
|
||||
dropout_rate=0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.ffn = feed_forward
|
||||
self.memory = fsmn_block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# xs_pad in Batch, Time, Dim
|
||||
|
||||
context = self.ffn(xs_pad)[0]
|
||||
memory = self.memory(context, mask)
|
||||
|
||||
memory = self.dropout(memory)
|
||||
if self.in_size == self.size:
|
||||
return memory + xs_pad, mask
|
||||
|
||||
return memory, mask
|
||||
|
||||
|
||||
class FsmnEncoder(AbsEncoder):
|
||||
"""Encoder using Fsmn
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_units,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
dnn_num_layers,
|
||||
num_memory_units=512,
|
||||
ffn_inner_dim=2048,
|
||||
dropout_rate=0.0,
|
||||
shift=0,
|
||||
position_encoder=None,
|
||||
sample_rate=1,
|
||||
out_units=None,
|
||||
tf2torch_tensor_name_prefix_torch="post_net",
|
||||
tf2torch_tensor_name_prefix_tf="EAND/post_net"
|
||||
):
|
||||
"""Initializes the parameters of the encoder.
|
||||
|
||||
Args:
|
||||
filter_size: the total order of memory block
|
||||
fsmn_num_layers: The number of fsmn layers.
|
||||
dnn_num_layers: The number of dnn layers
|
||||
num_units: The number of memory units.
|
||||
ffn_inner_dim: The number of units of the inner linear transformation
|
||||
in the feed forward layer.
|
||||
dropout_rate: The probability to drop units from the outputs.
|
||||
shift: left padding, to control delay
|
||||
position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
|
||||
apply on inputs or ``None``.
|
||||
"""
|
||||
super(FsmnEncoder, self).__init__()
|
||||
self.in_units = in_units
|
||||
self.filter_size = filter_size
|
||||
self.fsmn_num_layers = fsmn_num_layers
|
||||
self.dnn_num_layers = dnn_num_layers
|
||||
self.num_memory_units = num_memory_units
|
||||
self.ffn_inner_dim = ffn_inner_dim
|
||||
self.dropout_rate = dropout_rate
|
||||
self.shift = shift
|
||||
if not isinstance(shift, list):
|
||||
self.shift = [shift for _ in range(self.fsmn_num_layers)]
|
||||
self.sample_rate = sample_rate
|
||||
if not isinstance(sample_rate, list):
|
||||
self.sample_rate = [sample_rate for _ in range(self.fsmn_num_layers)]
|
||||
self.position_encoder = position_encoder
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.out_units = out_units
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
self.fsmn_layers = repeat(
|
||||
self.fsmn_num_layers,
|
||||
lambda lnum: EncoderLayer(
|
||||
in_units if lnum == 0 else num_memory_units,
|
||||
num_memory_units,
|
||||
FsmnFeedForward(
|
||||
in_units if lnum == 0 else num_memory_units,
|
||||
ffn_inner_dim,
|
||||
num_memory_units,
|
||||
1,
|
||||
dropout_rate
|
||||
),
|
||||
FsmnBlock(
|
||||
num_memory_units,
|
||||
dropout_rate,
|
||||
filter_size,
|
||||
self.shift[lnum]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.dnn_layers = repeat(
|
||||
dnn_num_layers,
|
||||
lambda lnum: FsmnFeedForward(
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
num_memory_units,
|
||||
1,
|
||||
dropout_rate,
|
||||
)
|
||||
)
|
||||
if out_units is not None:
|
||||
self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_memory_units
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs = xs_pad
|
||||
if self.position_encoder is not None:
|
||||
inputs = self.position_encoder(inputs)
|
||||
|
||||
inputs = self.dropout(inputs)
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
inputs = self.fsmn_layers(inputs, masks)[0]
|
||||
inputs = self.dnn_layers(inputs)[0]
|
||||
|
||||
if self.out_units is not None:
|
||||
inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
return inputs, ilens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
# for fsmn_layers
|
||||
"{}.fsmn_layers.layeridx.ffn.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.ffn.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.ffn.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.ffn.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.ffn.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.memory.fsmn_block.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/memory/depth_conv_w".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 2, 0),
|
||||
}, # (1, 31, 512, 1) -> (31, 512, 1) -> (512, 1, 31)
|
||||
|
||||
# for dnn_layers
|
||||
"{}.dnn_layers.layeridx.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.dnn_layers.layeridx.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.dnn_layers.layeridx.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.dnn_layers.layeridx.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.dnn_layers.layeridx.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
|
||||
}
|
||||
if self.out_units is not None:
|
||||
# add output layer
|
||||
map_dict_local.update({
|
||||
"{}.conv1d.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.conv1d.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
# process special (first and last) layers
|
||||
if name in map_dict:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
# process general layers
|
||||
else:
|
||||
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
|
||||
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
@@ -0,0 +1,480 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr_local.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
|
||||
from funasr_local.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from funasr_local.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr_local.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr_local.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr_local.modules.repeat import repeat
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr_local.modules.subsampling import TooShortUttError
|
||||
from funasr_local.modules.subsampling import check_short_utt
|
||||
from funasr_local.models.ctc import CTC
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, x, mask, cache=None, mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
return x, mask, cache, mask_att_chunk_encoder
|
||||
|
||||
|
||||
class SelfAttentionEncoder(AbsEncoder):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Self attention encoder in OpenNMT framework
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
tf2torch_tensor_name_prefix_torch: str = "encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
out_units=None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
elif input_layer == "null":
|
||||
self.embed = None
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
output_size,
|
||||
MultiHeadSelfAttention(
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
) if lnum > 0 else EncoderLayer(
|
||||
input_size,
|
||||
output_size,
|
||||
MultiHeadSelfAttention(
|
||||
attention_heads,
|
||||
input_size if input_layer == "pe" or input_layer == "null" else output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
self.out_units = out_units
|
||||
if out_units is not None:
|
||||
self.output_linear = nn.Linear(output_size, out_units)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad = xs_pad * self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
xs_pad = self.dropout(xs_pad)
|
||||
# encoder_outs = self.encoders0(xs_pad, masks)
|
||||
# xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
if self.out_units is not None:
|
||||
xs_pad = self.output_linear(xs_pad)
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
# cicd
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (768,256),(1,256,768)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (768,),(768,)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# ffn
|
||||
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# out norm
|
||||
"{}.after_norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.after_norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
}
|
||||
if self.out_units is not None:
|
||||
map_dict_local.update({
|
||||
"{}.output_linear.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
},
|
||||
"{}.output_linear.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
# process special (first and last) layers
|
||||
if name in map_dict:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
# process general layers
|
||||
else:
|
||||
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
|
||||
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
853
funasr_local/models/encoder/resnet34_encoder.py
Normal file
853
funasr_local/models/encoder/resnet34_encoder.py
Normal file
@@ -0,0 +1,853 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
from typing import Tuple, Optional
|
||||
from funasr_local.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BasicLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_filters: int, filters: int, stride: int, bn_momentum: float = 0.5):
|
||||
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.in_filters = in_filters
|
||||
self.filters = filters
|
||||
|
||||
self.bn1 = torch.nn.BatchNorm2d(in_filters, eps=1e-3, momentum=bn_momentum, affine=True)
|
||||
self.relu1 = torch.nn.ReLU()
|
||||
self.conv1 = torch.nn.Conv2d(in_filters, filters, 3, stride, bias=False)
|
||||
|
||||
self.bn2 = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
|
||||
self.relu2 = torch.nn.ReLU()
|
||||
self.conv2 = torch.nn.Conv2d(filters, filters, 3, 1, bias=False)
|
||||
|
||||
if in_filters != filters or stride > 1:
|
||||
self.conv_sc = torch.nn.Conv2d(in_filters, filters, 1, stride, bias=False)
|
||||
self.bn_sc = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
|
||||
|
||||
def proper_padding(self, x, stride):
|
||||
# align padding mode to tf.layers.conv2d with padding_mod="same"
|
||||
if stride == 1:
|
||||
return F.pad(x, (1, 1, 1, 1), "constant", 0)
|
||||
elif stride == 2:
|
||||
h, w = x.size(2), x.size(3)
|
||||
# (left, right, top, bottom)
|
||||
return F.pad(x, (w % 2, 1, h % 2, 1), "constant", 0)
|
||||
|
||||
def forward(self, xs_pad, ilens):
|
||||
identity = xs_pad
|
||||
if self.in_filters != self.filters or self.stride > 1:
|
||||
identity = self.conv_sc(identity)
|
||||
identity = self.bn_sc(identity)
|
||||
|
||||
xs_pad = self.relu1(self.bn1(xs_pad))
|
||||
xs_pad = self.proper_padding(xs_pad, self.stride)
|
||||
xs_pad = self.conv1(xs_pad)
|
||||
|
||||
xs_pad = self.relu2(self.bn2(xs_pad))
|
||||
xs_pad = self.proper_padding(xs_pad, 1)
|
||||
xs_pad = self.conv2(xs_pad)
|
||||
|
||||
if self.stride == 2:
|
||||
ilens = (ilens + 1) // self.stride
|
||||
|
||||
return xs_pad + identity, ilens
|
||||
|
||||
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self, in_filters, filters, num_layer, stride, bn_momentum=0.5):
|
||||
super().__init__()
|
||||
self.num_layer = num_layer
|
||||
|
||||
for i in range(num_layer):
|
||||
layer = BasicLayer(in_filters if i == 0 else filters, filters,
|
||||
stride if i == 0 else 1, bn_momentum)
|
||||
self.add_module("layer_{}".format(i), layer)
|
||||
|
||||
def forward(self, xs_pad, ilens):
|
||||
|
||||
for i in range(self.num_layer):
|
||||
xs_pad, ilens = self._modules["layer_{}".format(i)](xs_pad, ilens)
|
||||
|
||||
return xs_pad, ilens
|
||||
|
||||
|
||||
class ResNet34(AbsEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
):
|
||||
super(ResNet34, self).__init__()
|
||||
|
||||
self.use_head_conv = use_head_conv
|
||||
self.use_head_maxpool = use_head_maxpool
|
||||
self.num_nodes_pooling_layer = num_nodes_pooling_layer
|
||||
self.layers_in_block = layers_in_block
|
||||
self.filters_in_block = filters_in_block
|
||||
self.input_size = input_size
|
||||
|
||||
pre_filters = filters_in_block[0]
|
||||
if use_head_conv:
|
||||
self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
|
||||
self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
if use_head_maxpool:
|
||||
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
|
||||
|
||||
for i in range(len(layers_in_block)):
|
||||
if i == 0:
|
||||
in_filters = pre_filters if self.use_head_conv else 1
|
||||
else:
|
||||
in_filters = filters_in_block[i-1]
|
||||
|
||||
block = BasicBlock(in_filters,
|
||||
filters=filters_in_block[i],
|
||||
num_layer=layers_in_block[i],
|
||||
stride=1 if i == 0 else 2,
|
||||
bn_momentum=batchnorm_momentum)
|
||||
self.add_module("block_{}".format(i), block)
|
||||
|
||||
self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1)
|
||||
self.resnet0_bn = torch.nn.BatchNorm2d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.time_ds_ratio = 8
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
features = xs_pad
|
||||
assert features.size(-1) == self.input_size, \
|
||||
"Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
|
||||
features = torch.unsqueeze(features, dim=1)
|
||||
if self.use_head_conv:
|
||||
features = self.pre_conv(features)
|
||||
features = self.pre_conv_bn(features)
|
||||
features = F.relu(features)
|
||||
|
||||
if self.use_head_maxpool:
|
||||
features = self.head_maxpool(features)
|
||||
|
||||
resnet_outs, resnet_out_lens = features, ilens
|
||||
for i in range(len(self.layers_in_block)):
|
||||
block = self._modules["block_{}".format(i)]
|
||||
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
|
||||
|
||||
features = self.resnet0_dense(resnet_outs)
|
||||
features = F.relu(features)
|
||||
features = self.resnet0_bn(features)
|
||||
|
||||
return features, resnet_out_lens
|
||||
|
||||
# Note: For training, this implement is not equivalent to tf because of the kernel_regularizer in tf.layers.
|
||||
# TODO: implement kernel_regularizer in torch with munal loss addition or weigth_decay in the optimizer
|
||||
class ResNet34_SP_L2Reg(AbsEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="EAND/speech_encoder",
|
||||
tf_train_steps=720000,
|
||||
):
|
||||
super(ResNet34_SP_L2Reg, self).__init__()
|
||||
|
||||
self.use_head_conv = use_head_conv
|
||||
self.use_head_maxpool = use_head_maxpool
|
||||
self.num_nodes_pooling_layer = num_nodes_pooling_layer
|
||||
self.layers_in_block = layers_in_block
|
||||
self.filters_in_block = filters_in_block
|
||||
self.input_size = input_size
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
self.tf_train_steps = tf_train_steps
|
||||
|
||||
pre_filters = filters_in_block[0]
|
||||
if use_head_conv:
|
||||
self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
|
||||
self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
if use_head_maxpool:
|
||||
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
|
||||
|
||||
for i in range(len(layers_in_block)):
|
||||
if i == 0:
|
||||
in_filters = pre_filters if self.use_head_conv else 1
|
||||
else:
|
||||
in_filters = filters_in_block[i-1]
|
||||
|
||||
block = BasicBlock(in_filters,
|
||||
filters=filters_in_block[i],
|
||||
num_layer=layers_in_block[i],
|
||||
stride=1 if i == 0 else 2,
|
||||
bn_momentum=batchnorm_momentum)
|
||||
self.add_module("block_{}".format(i), block)
|
||||
|
||||
self.resnet0_dense = torch.nn.Conv1d(filters_in_block[-1] * input_size // 8, num_nodes_pooling_layer, 1)
|
||||
self.resnet0_bn = torch.nn.BatchNorm1d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.time_ds_ratio = 8
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
features = xs_pad
|
||||
assert features.size(-1) == self.input_size, \
|
||||
"Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
|
||||
features = torch.unsqueeze(features, dim=1)
|
||||
if self.use_head_conv:
|
||||
features = self.pre_conv(features)
|
||||
features = self.pre_conv_bn(features)
|
||||
features = F.relu(features)
|
||||
|
||||
if self.use_head_maxpool:
|
||||
features = self.head_maxpool(features)
|
||||
|
||||
resnet_outs, resnet_out_lens = features, ilens
|
||||
for i in range(len(self.layers_in_block)):
|
||||
block = self._modules["block_{}".format(i)]
|
||||
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
|
||||
|
||||
# B, C, T, F
|
||||
bb, cc, tt, ff = resnet_outs.shape
|
||||
resnet_outs = torch.reshape(resnet_outs.permute(0, 3, 1, 2), [bb, ff*cc, tt])
|
||||
features = self.resnet0_dense(resnet_outs)
|
||||
features = F.relu(features)
|
||||
features = self.resnet0_bn(features)
|
||||
|
||||
return features, resnet_out_lens
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
train_steps = self.tf_train_steps
|
||||
map_dict_local = {
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
|
||||
}
|
||||
for layer_idx in range(3):
|
||||
map_dict_local.update({
|
||||
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
|
||||
},
|
||||
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
|
||||
})
|
||||
|
||||
for block_idx in range(len(self.layers_in_block)):
|
||||
for layer_idx in range(self.layers_in_block[block_idx]):
|
||||
for i in ["1", "2", "_sc"]:
|
||||
map_dict_local.update({
|
||||
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
if name in map_dict:
|
||||
if "num_batches_tracked" not in name:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
|
||||
logging.info("torch tensor: {}, manually assigning to: {}".format(
|
||||
name, map_dict[name]
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class ResNet34Diar(ResNet34):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
embedding_node="resnet1_dense",
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
num_nodes_resnet1=256,
|
||||
num_nodes_last_layer=256,
|
||||
pooling_type="window_shift",
|
||||
pool_size=20,
|
||||
stride=1,
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
|
||||
):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
|
||||
https://arxiv.org/abs/2211.10243
|
||||
"""
|
||||
|
||||
super(ResNet34Diar, self).__init__(
|
||||
input_size,
|
||||
use_head_conv=use_head_conv,
|
||||
batchnorm_momentum=batchnorm_momentum,
|
||||
use_head_maxpool=use_head_maxpool,
|
||||
num_nodes_pooling_layer=num_nodes_pooling_layer,
|
||||
layers_in_block=layers_in_block,
|
||||
filters_in_block=filters_in_block,
|
||||
)
|
||||
|
||||
self.embedding_node = embedding_node
|
||||
self.num_nodes_resnet1 = num_nodes_resnet1
|
||||
self.num_nodes_last_layer = num_nodes_last_layer
|
||||
self.pooling_type = pooling_type
|
||||
self.pool_size = pool_size
|
||||
self.stride = stride
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
|
||||
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
|
||||
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
def output_size(self) -> int:
|
||||
if self.embedding_node.startswith("resnet1"):
|
||||
return self.num_nodes_resnet1
|
||||
elif self.embedding_node.startswith("resnet2"):
|
||||
return self.num_nodes_last_layer
|
||||
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
endpoints = OrderedDict()
|
||||
res_out, ilens = super().forward(xs_pad, ilens)
|
||||
endpoints["resnet0_bn"] = res_out
|
||||
if self.pooling_type == "frame_gsp":
|
||||
features = statistic_pooling(res_out, ilens, (3, ))
|
||||
else:
|
||||
features, ilens = windowed_statistic_pooling(res_out, ilens, (2, 3), self.pool_size, self.stride)
|
||||
features = features.transpose(1, 2)
|
||||
endpoints["pooling"] = features
|
||||
|
||||
features = self.resnet1_dense(features)
|
||||
endpoints["resnet1_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet1_relu"] = features
|
||||
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet1_bn"] = features
|
||||
|
||||
features = self.resnet2_dense(features)
|
||||
endpoints["resnet2_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet2_relu"] = features
|
||||
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet2_bn"] = features
|
||||
|
||||
return endpoints[self.embedding_node], ilens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
train_steps = 300000
|
||||
map_dict_local = {
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
|
||||
}
|
||||
for layer_idx in range(3):
|
||||
map_dict_local.update({
|
||||
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1) if layer_idx == 0 else (1, 0),
|
||||
},
|
||||
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
|
||||
})
|
||||
|
||||
for block_idx in range(len(self.layers_in_block)):
|
||||
for layer_idx in range(self.layers_in_block[block_idx]):
|
||||
for i in ["1", "2", "_sc"]:
|
||||
map_dict_local.update({
|
||||
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
if name in map_dict:
|
||||
if "num_batches_tracked" not in name:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
|
||||
logging.info("torch tensor: {}, manually assigning to: {}".format(
|
||||
name, map_dict[name]
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
embedding_node="resnet1_dense",
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
num_nodes_resnet1=256,
|
||||
num_nodes_last_layer=256,
|
||||
pooling_type="window_shift",
|
||||
pool_size=20,
|
||||
stride=1,
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
|
||||
):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
|
||||
https://arxiv.org/abs/2303.05397
|
||||
"""
|
||||
|
||||
super(ResNet34SpL2RegDiar, self).__init__(
|
||||
input_size,
|
||||
use_head_conv=use_head_conv,
|
||||
batchnorm_momentum=batchnorm_momentum,
|
||||
use_head_maxpool=use_head_maxpool,
|
||||
num_nodes_pooling_layer=num_nodes_pooling_layer,
|
||||
layers_in_block=layers_in_block,
|
||||
filters_in_block=filters_in_block,
|
||||
)
|
||||
|
||||
self.embedding_node = embedding_node
|
||||
self.num_nodes_resnet1 = num_nodes_resnet1
|
||||
self.num_nodes_last_layer = num_nodes_last_layer
|
||||
self.pooling_type = pooling_type
|
||||
self.pool_size = pool_size
|
||||
self.stride = stride
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
|
||||
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
|
||||
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
def output_size(self) -> int:
|
||||
if self.embedding_node.startswith("resnet1"):
|
||||
return self.num_nodes_resnet1
|
||||
elif self.embedding_node.startswith("resnet2"):
|
||||
return self.num_nodes_last_layer
|
||||
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
endpoints = OrderedDict()
|
||||
res_out, ilens = super().forward(xs_pad, ilens)
|
||||
endpoints["resnet0_bn"] = res_out
|
||||
if self.pooling_type == "frame_gsp":
|
||||
features = statistic_pooling(res_out, ilens, (2, ))
|
||||
else:
|
||||
features, ilens = windowed_statistic_pooling(res_out, ilens, (2, ), self.pool_size, self.stride)
|
||||
features = features.transpose(1, 2)
|
||||
endpoints["pooling"] = features
|
||||
|
||||
features = self.resnet1_dense(features)
|
||||
endpoints["resnet1_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet1_relu"] = features
|
||||
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet1_bn"] = features
|
||||
|
||||
features = self.resnet2_dense(features)
|
||||
endpoints["resnet2_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet2_relu"] = features
|
||||
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet2_bn"] = features
|
||||
|
||||
return endpoints[self.embedding_node], ilens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
train_steps = 720000
|
||||
map_dict_local = {
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
|
||||
}
|
||||
for layer_idx in range(3):
|
||||
map_dict_local.update({
|
||||
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
|
||||
},
|
||||
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
|
||||
})
|
||||
|
||||
for block_idx in range(len(self.layers_in_block)):
|
||||
for layer_idx in range(self.layers_in_block[block_idx]):
|
||||
for i in ["1", "2", "_sc"]:
|
||||
map_dict_local.update({
|
||||
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
if name in map_dict:
|
||||
if "num_batches_tracked" not in name:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
var_dict_torch_update[name] = torch.from_numpy(np.array(map_dict[name])).type(torch.int64).to("cpu")
|
||||
logging.info("torch tensor: {}, manually assigning to: {}".format(
|
||||
name, map_dict[name]
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
115
funasr_local/models/encoder/rnn_encoder.py
Normal file
115
funasr_local/models/encoder/rnn_encoder.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.rnn.encoders import RNN
|
||||
from funasr_local.modules.rnn.encoders import RNNP
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
|
||||
class RNNEncoder(AbsEncoder):
|
||||
"""RNNEncoder class.
|
||||
|
||||
Args:
|
||||
input_size: The number of expected features in the input
|
||||
output_size: The number of output features
|
||||
hidden_size: The number of hidden features
|
||||
bidirectional: If ``True`` becomes a bidirectional LSTM
|
||||
use_projection: Use projection layer or not
|
||||
num_layers: Number of recurrent layers
|
||||
dropout: dropout probability
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
rnn_type: str = "lstm",
|
||||
bidirectional: bool = True,
|
||||
use_projection: bool = True,
|
||||
num_layers: int = 4,
|
||||
hidden_size: int = 320,
|
||||
output_size: int = 320,
|
||||
dropout: float = 0.0,
|
||||
subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
self.rnn_type = rnn_type
|
||||
self.bidirectional = bidirectional
|
||||
self.use_projection = use_projection
|
||||
|
||||
if rnn_type not in {"lstm", "gru"}:
|
||||
raise ValueError(f"Not supported rnn_type={rnn_type}")
|
||||
|
||||
if subsample is None:
|
||||
subsample = np.ones(num_layers + 1, dtype=np.int)
|
||||
else:
|
||||
subsample = subsample[:num_layers]
|
||||
# Append 1 at the beginning because the second or later is used
|
||||
subsample = np.pad(
|
||||
np.array(subsample, dtype=np.int),
|
||||
[1, num_layers - len(subsample)],
|
||||
mode="constant",
|
||||
constant_values=1,
|
||||
)
|
||||
|
||||
rnn_type = ("b" if bidirectional else "") + rnn_type
|
||||
if use_projection:
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[
|
||||
RNNP(
|
||||
input_size,
|
||||
num_layers,
|
||||
hidden_size,
|
||||
output_size,
|
||||
subsample,
|
||||
dropout,
|
||||
typ=rnn_type,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[
|
||||
RNN(
|
||||
input_size,
|
||||
num_layers,
|
||||
hidden_size,
|
||||
output_size,
|
||||
dropout,
|
||||
typ=rnn_type,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if prev_states is None:
|
||||
prev_states = [None] * len(self.enc)
|
||||
assert len(prev_states) == len(self.enc)
|
||||
|
||||
current_states = []
|
||||
for module, prev_state in zip(self.enc, prev_states):
|
||||
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
|
||||
current_states.append(states)
|
||||
|
||||
if self.use_projection:
|
||||
xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0)
|
||||
else:
|
||||
xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0)
|
||||
return xs_pad, ilens, current_states
|
||||
1213
funasr_local/models/encoder/sanm_encoder.py
Normal file
1213
funasr_local/models/encoder/sanm_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
684
funasr_local/models/encoder/transformer_encoder.py
Normal file
684
funasr_local/models/encoder/transformer_encoder.py
Normal file
@@ -0,0 +1,684 @@
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Transformer encoder definition."""
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typeguard import check_argument_types
|
||||
import logging
|
||||
|
||||
from funasr_local.models.ctc import CTC
|
||||
from funasr_local.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr_local.modules.attention import MultiHeadedAttention
|
||||
from funasr_local.modules.embedding import PositionalEncoding
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from funasr_local.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr_local.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr_local.modules.repeat import repeat
|
||||
from funasr_local.modules.nets_utils import rename_state_dict
|
||||
from funasr_local.modules.dynamic_conv import DynamicConvolution
|
||||
from funasr_local.modules.dynamic_conv2d import DynamicConvolution2D
|
||||
from funasr_local.modules.lightconv import LightweightConvolution
|
||||
from funasr_local.modules.lightconv2d import LightweightConvolution2D
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr_local.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr_local.modules.subsampling import TooShortUttError
|
||||
from funasr_local.modules.subsampling import check_short_utt
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
stochastic_depth_rate (float): Proability to skip this layer.
|
||||
During training, the layer may skip residual computation and return input
|
||||
as-is with given probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
|
||||
def forward(self, x, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x_q, x, x, mask)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
class TransformerEncoder(AbsEncoder):
|
||||
"""Transformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size: input dim
|
||||
output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the number of units of position-wise feed forward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
attention_dropout_rate: dropout rate in attention
|
||||
positional_dropout_rate: dropout rate after adding positional encoding
|
||||
input_layer: input layer type
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before: whether to use layer_norm before the first block
|
||||
concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied.
|
||||
i.e. x -> x + att(x)
|
||||
positionwise_layer_type: linear of conv1d
|
||||
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, output_size, attention_dropout_rate
|
||||
),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
xs_pad, masks = encoder_layer(xs_pad, masks)
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
|
||||
def _pre_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
# https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
|
||||
rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
|
||||
# https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
|
||||
rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
|
||||
|
||||
|
||||
class TransformerEncoder_s0(torch.nn.Module):
|
||||
"""Transformer encoder module.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
attention_dim (int): Dimension of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
conv_wshare (int): The number of kernel of convolution. Only used in
|
||||
selfattention_layer_type == "lightconv*" or "dynamiconv*".
|
||||
conv_kernel_length (Union[int, str]): Kernel size str of convolution
|
||||
(e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type
|
||||
== "lightconv*" or "dynamiconv*".
|
||||
conv_usebias (bool): Whether to use bias in convolution. Only used in
|
||||
selfattention_layer_type == "lightconv*" or "dynamiconv*".
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
||||
pos_enc_class (torch.nn.Module): Positional encoding module class.
|
||||
`PositionalEncoding `or `ScaledPositionalEncoding`
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
selfattention_layer_type (str): Encoder attention layer type.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
|
||||
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
|
||||
indices start from 1.
|
||||
if not None, intermediate outputs are returned (which changes return type
|
||||
signature.)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idim,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
conv_wshare=4,
|
||||
conv_kernel_length="11",
|
||||
conv_usebias=False,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
selfattention_layer_type="selfattn",
|
||||
padding_idx=-1,
|
||||
stochastic_depth_rate=0.0,
|
||||
intermediate_layers=None,
|
||||
ctc_softmax=None,
|
||||
conditioning_layer_dim=None,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super(TransformerEncoder_s0, self).__init__()
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
self.conv_subsampling_factor = 1
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
|
||||
self.conv_subsampling_factor = 4
|
||||
elif input_layer == "conv2d-scaled-pos-enc":
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
self.conv_subsampling_factor = 4
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate)
|
||||
self.conv_subsampling_factor = 6
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate)
|
||||
self.conv_subsampling_factor = 8
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
positionwise_layer, positionwise_layer_args = self.get_positionwise_layer(
|
||||
positionwise_layer_type,
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
positionwise_conv_kernel_size,
|
||||
)
|
||||
if selfattention_layer_type in [
|
||||
"selfattn",
|
||||
"rel_selfattn",
|
||||
"legacy_rel_selfattn",
|
||||
]:
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = [
|
||||
(
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
] * num_blocks
|
||||
elif selfattention_layer_type == "lightconv":
|
||||
logging.info("encoder self-attention layer type = lightweight convolution")
|
||||
encoder_selfattn_layer = LightweightConvolution
|
||||
encoder_selfattn_layer_args = [
|
||||
(
|
||||
conv_wshare,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
int(conv_kernel_length.split("_")[lnum]),
|
||||
False,
|
||||
conv_usebias,
|
||||
)
|
||||
for lnum in range(num_blocks)
|
||||
]
|
||||
elif selfattention_layer_type == "lightconv2d":
|
||||
logging.info(
|
||||
"encoder self-attention layer "
|
||||
"type = lightweight convolution 2-dimensional"
|
||||
)
|
||||
encoder_selfattn_layer = LightweightConvolution2D
|
||||
encoder_selfattn_layer_args = [
|
||||
(
|
||||
conv_wshare,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
int(conv_kernel_length.split("_")[lnum]),
|
||||
False,
|
||||
conv_usebias,
|
||||
)
|
||||
for lnum in range(num_blocks)
|
||||
]
|
||||
elif selfattention_layer_type == "dynamicconv":
|
||||
logging.info("encoder self-attention layer type = dynamic convolution")
|
||||
encoder_selfattn_layer = DynamicConvolution
|
||||
encoder_selfattn_layer_args = [
|
||||
(
|
||||
conv_wshare,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
int(conv_kernel_length.split("_")[lnum]),
|
||||
False,
|
||||
conv_usebias,
|
||||
)
|
||||
for lnum in range(num_blocks)
|
||||
]
|
||||
elif selfattention_layer_type == "dynamicconv2d":
|
||||
logging.info(
|
||||
"encoder self-attention layer type = dynamic convolution 2-dimensional"
|
||||
)
|
||||
encoder_selfattn_layer = DynamicConvolution2D
|
||||
encoder_selfattn_layer_args = [
|
||||
(
|
||||
conv_wshare,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
int(conv_kernel_length.split("_")[lnum]),
|
||||
False,
|
||||
conv_usebias,
|
||||
)
|
||||
for lnum in range(num_blocks)
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(selfattention_layer_type)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate * float(1 + lnum) / num_blocks,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
self.intermediate_layers = intermediate_layers
|
||||
self.use_conditioning = True if ctc_softmax is not None else False
|
||||
if self.use_conditioning:
|
||||
self.ctc_softmax = ctc_softmax
|
||||
self.conditioning_layer = torch.nn.Linear(
|
||||
conditioning_layer_dim, attention_dim
|
||||
)
|
||||
|
||||
def get_positionwise_layer(
|
||||
self,
|
||||
positionwise_layer_type="linear",
|
||||
attention_dim=256,
|
||||
linear_units=2048,
|
||||
dropout_rate=0.1,
|
||||
positionwise_conv_kernel_size=1,
|
||||
):
|
||||
"""Define positionwise layer."""
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
return positionwise_layer, positionwise_layer_args
|
||||
|
||||
def forward(self, xs, masks):
|
||||
"""Encode input sequence.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
masks (torch.Tensor): Mask tensor (#batch, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, attention_dim).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
if isinstance(
|
||||
self.embed,
|
||||
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8),
|
||||
):
|
||||
xs, masks = self.embed(xs, masks)
|
||||
else:
|
||||
xs = self.embed(xs)
|
||||
|
||||
if self.intermediate_layers is None:
|
||||
xs, masks = self.encoders(xs, masks)
|
||||
else:
|
||||
intermediate_outputs = []
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
xs, masks = encoder_layer(xs, masks)
|
||||
|
||||
if (
|
||||
self.intermediate_layers is not None
|
||||
and layer_idx + 1 in self.intermediate_layers
|
||||
):
|
||||
encoder_output = xs
|
||||
# intermediate branches also require normalization.
|
||||
if self.normalize_before:
|
||||
encoder_output = self.after_norm(encoder_output)
|
||||
intermediate_outputs.append(encoder_output)
|
||||
|
||||
if self.use_conditioning:
|
||||
intermediate_result = self.ctc_softmax(encoder_output)
|
||||
xs = xs + self.conditioning_layer(intermediate_result)
|
||||
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
|
||||
if self.intermediate_layers is not None:
|
||||
return xs, masks, intermediate_outputs
|
||||
return xs, masks
|
||||
|
||||
def forward_one_step(self, xs, masks, cache=None):
|
||||
"""Encode input frame.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): Input tensor.
|
||||
masks (torch.Tensor): Mask tensor.
|
||||
cache (List[torch.Tensor]): List of cache tensors.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
torch.Tensor: Mask tensor.
|
||||
List[torch.Tensor]: List of new cache tensors.
|
||||
|
||||
"""
|
||||
if isinstance(self.embed, Conv2dSubsampling):
|
||||
xs, masks = self.embed(xs, masks)
|
||||
else:
|
||||
xs = self.embed(xs)
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.encoders))]
|
||||
new_cache = []
|
||||
for c, e in zip(cache, self.encoders):
|
||||
xs, masks = e(xs, masks, cache=c)
|
||||
new_cache.append(xs)
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
return xs, masks, new_cache
|
||||
|
||||
0
funasr_local/models/frontend/__init__.py
Normal file
0
funasr_local/models/frontend/__init__.py
Normal file
17
funasr_local/models/frontend/abs_frontend.py
Normal file
17
funasr_local/models/frontend/abs_frontend.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsFrontend(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def output_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
258
funasr_local/models/frontend/default.py
Normal file
258
funasr_local/models/frontend/default.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import copy
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.layers.log_mel import LogMel
|
||||
from funasr_local.layers.stft import Stft
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.modules.frontends.frontend import Frontend
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
class DefaultFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
htk: bool = False,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
apply_stft: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
self.hop_length = hop_length
|
||||
|
||||
if apply_stft:
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
center=center,
|
||||
window=window,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
)
|
||||
else:
|
||||
self.stft = None
|
||||
self.apply_stft = apply_stft
|
||||
|
||||
if frontend_conf is not None:
|
||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||
else:
|
||||
self.frontend = None
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
self.frontend_type = "default"
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
if self.stft is not None:
|
||||
input_stft, feats_lens = self._compute_stft(input, input_lengths)
|
||||
else:
|
||||
input_stft = ComplexTensor(input[..., 0], input[..., 1])
|
||||
feats_lens = input_lengths
|
||||
# 2. [Option] Speech enhancement
|
||||
if self.frontend is not None:
|
||||
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||
# input_stft: (Batch, Length, [Channel], Freq)
|
||||
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||
|
||||
# 3. [Multi channel case]: Select a channel
|
||||
if input_stft.dim() == 4:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
if self.training:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(input_stft.size(2))
|
||||
input_stft = input_stft[:, :, ch, :]
|
||||
else:
|
||||
# Use the first channel
|
||||
input_stft = input_stft[:, :, 0, :]
|
||||
|
||||
# 4. STFT -> Power spectrum
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||
|
||||
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
||||
# input_power: (Batch, [Channel,] Length, Freq)
|
||||
# -> input_feats: (Batch, Length, Dim)
|
||||
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||
|
||||
return input_feats, feats_lens
|
||||
|
||||
def _compute_stft(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# Change torch.Tensor to ComplexTensor
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
return input_stft, feats_lens
|
||||
|
||||
|
||||
|
||||
|
||||
class MultiChannelFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
htk: bool = False,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
apply_stft: bool = True,
|
||||
frame_length: int = None,
|
||||
frame_shift: int = None,
|
||||
lfr_m: int = None,
|
||||
lfr_n: int = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
self.hop_length = hop_length
|
||||
|
||||
if apply_stft:
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
center=center,
|
||||
window=window,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
)
|
||||
else:
|
||||
self.stft = None
|
||||
self.apply_stft = apply_stft
|
||||
|
||||
if frontend_conf is not None:
|
||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||
else:
|
||||
self.frontend = None
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
self.frontend_type = "multichannelfrontend"
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
#import pdb;pdb.set_trace()
|
||||
if self.stft is not None:
|
||||
input_stft, feats_lens = self._compute_stft(input, input_lengths)
|
||||
else:
|
||||
if isinstance(input, ComplexTensor):
|
||||
input_stft = input
|
||||
else:
|
||||
input_stft = ComplexTensor(input[..., 0], input[..., 1])
|
||||
feats_lens = input_lengths
|
||||
# 2. [Option] Speech enhancement
|
||||
if self.frontend is not None:
|
||||
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||
# input_stft: (Batch, Length, [Channel], Freq)
|
||||
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||
# 4. STFT -> Power spectrum
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||
|
||||
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
||||
# input_power: (Batch, [Channel,] Length, Freq)
|
||||
# -> input_feats: (Batch, Length, Dim)
|
||||
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||
bt = input_feats.size(0)
|
||||
if input_feats.dim() ==4:
|
||||
channel_size = input_feats.size(2)
|
||||
# batch * channel * T * D
|
||||
#pdb.set_trace()
|
||||
input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
|
||||
# input_feats = input_feats.transpose(1,2)
|
||||
# batch * channel
|
||||
feats_lens = feats_lens.repeat(1,channel_size).squeeze()
|
||||
else:
|
||||
channel_size = 1
|
||||
return input_feats, feats_lens, channel_size
|
||||
|
||||
def _compute_stft(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# Change torch.Tensor to ComplexTensor
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
return input_stft, feats_lens
|
||||
51
funasr_local/models/frontend/eend_ola_feature.py
Normal file
51
funasr_local/models/frontend/eend_ola_feature.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# This module is for computing audio features
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
|
||||
def transform(Y, dtype=np.float32):
|
||||
Y = np.abs(Y)
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
mean = np.mean(Y, axis=0)
|
||||
Y = Y - mean
|
||||
return Y.astype(dtype)
|
||||
|
||||
|
||||
def subsample(Y, T, subsampling=1):
|
||||
Y_ss = Y[::subsampling]
|
||||
T_ss = T[::subsampling]
|
||||
return Y_ss, T_ss
|
||||
|
||||
|
||||
def splice(Y, context_size=0):
|
||||
Y_pad = np.pad(
|
||||
Y,
|
||||
[(context_size, context_size), (0, 0)],
|
||||
'constant')
|
||||
Y_spliced = np.lib.stride_tricks.as_strided(
|
||||
np.ascontiguousarray(Y_pad),
|
||||
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
|
||||
(Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
|
||||
return Y_spliced
|
||||
|
||||
|
||||
def stft(
|
||||
data,
|
||||
frame_size=1024,
|
||||
frame_shift=256):
|
||||
fft_size = 1 << (frame_size - 1).bit_length()
|
||||
if len(data) % frame_shift == 0:
|
||||
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||
hop_length=frame_shift).T[:-1]
|
||||
else:
|
||||
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||
hop_length=frame_shift).T
|
||||
146
funasr_local/models/frontend/fused.py
Normal file
146
funasr_local/models/frontend/fused.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.models.frontend.default import DefaultFrontend
|
||||
from funasr_local.models.frontend.s3prl import S3prlFrontend
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class FusedFrontends(AbsFrontend):
|
||||
def __init__(
|
||||
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
|
||||
):
|
||||
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.align_method = (
|
||||
align_method # fusing method : linear_projection only for now
|
||||
)
|
||||
self.proj_dim = proj_dim # dim of the projection done on each frontend
|
||||
self.frontends = [] # list of the frontends to combine
|
||||
|
||||
for i, frontend in enumerate(frontends):
|
||||
frontend_type = frontend["frontend_type"]
|
||||
if frontend_type == "default":
|
||||
n_mels, fs, n_fft, win_length, hop_length = (
|
||||
frontend.get("n_mels", 80),
|
||||
fs,
|
||||
frontend.get("n_fft", 512),
|
||||
frontend.get("win_length"),
|
||||
frontend.get("hop_length", 128),
|
||||
)
|
||||
window, center, normalized, onesided = (
|
||||
frontend.get("window", "hann"),
|
||||
frontend.get("center", True),
|
||||
frontend.get("normalized", False),
|
||||
frontend.get("onesided", True),
|
||||
)
|
||||
fmin, fmax, htk, apply_stft = (
|
||||
frontend.get("fmin", None),
|
||||
frontend.get("fmax", None),
|
||||
frontend.get("htk", False),
|
||||
frontend.get("apply_stft", True),
|
||||
)
|
||||
|
||||
self.frontends.append(
|
||||
DefaultFrontend(
|
||||
n_mels=n_mels,
|
||||
n_fft=n_fft,
|
||||
fs=fs,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
window=window,
|
||||
center=center,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
apply_stft=apply_stft,
|
||||
)
|
||||
)
|
||||
elif frontend_type == "s3prl":
|
||||
frontend_conf, download_dir, multilayer_feature = (
|
||||
frontend.get("frontend_conf"),
|
||||
frontend.get("download_dir"),
|
||||
frontend.get("multilayer_feature"),
|
||||
)
|
||||
self.frontends.append(
|
||||
S3prlFrontend(
|
||||
fs=fs,
|
||||
frontend_conf=frontend_conf,
|
||||
download_dir=download_dir,
|
||||
multilayer_feature=multilayer_feature,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError # frontends are only default or s3prl
|
||||
|
||||
self.frontends = torch.nn.ModuleList(self.frontends)
|
||||
|
||||
self.gcd = np.gcd.reduce([frontend.hop_length for frontend in self.frontends])
|
||||
self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
|
||||
if torch.cuda.is_available():
|
||||
dev = "cuda"
|
||||
else:
|
||||
dev = "cpu"
|
||||
if self.align_method == "linear_projection":
|
||||
self.projection_layers = [
|
||||
torch.nn.Linear(
|
||||
in_features=frontend.output_size(),
|
||||
out_features=self.factors[i] * self.proj_dim,
|
||||
)
|
||||
for i, frontend in enumerate(self.frontends)
|
||||
]
|
||||
self.projection_layers = torch.nn.ModuleList(self.projection_layers)
|
||||
self.projection_layers = self.projection_layers.to(torch.device(dev))
|
||||
|
||||
def output_size(self) -> int:
|
||||
return len(self.frontends) * self.proj_dim
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# step 0 : get all frontends features
|
||||
self.feats = []
|
||||
for frontend in self.frontends:
|
||||
with torch.no_grad():
|
||||
input_feats, feats_lens = frontend.forward(input, input_lengths)
|
||||
self.feats.append([input_feats, feats_lens])
|
||||
|
||||
if (
|
||||
self.align_method == "linear_projection"
|
||||
): # TODO(Dan): to add other align methods
|
||||
|
||||
# first step : projections
|
||||
self.feats_proj = []
|
||||
for i, frontend in enumerate(self.frontends):
|
||||
input_feats = self.feats[i][0]
|
||||
self.feats_proj.append(self.projection_layers[i](input_feats))
|
||||
|
||||
# 2nd step : reshape
|
||||
self.feats_reshaped = []
|
||||
for i, frontend in enumerate(self.frontends):
|
||||
input_feats_proj = self.feats_proj[i]
|
||||
bs, nf, dim = input_feats_proj.shape
|
||||
input_feats_reshaped = torch.reshape(
|
||||
input_feats_proj, (bs, nf * self.factors[i], dim // self.factors[i])
|
||||
)
|
||||
self.feats_reshaped.append(input_feats_reshaped)
|
||||
|
||||
# 3rd step : drop the few last frames
|
||||
m = min([x.shape[1] for x in self.feats_reshaped])
|
||||
self.feats_final = [x[:, :m, :] for x in self.feats_reshaped]
|
||||
|
||||
input_feats = torch.cat(
|
||||
self.feats_final, dim=-1
|
||||
) # change the input size of the preencoder : proj_dim * n_frontends
|
||||
feats_lens = torch.ones_like(self.feats[0][1]) * (m)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return input_feats, feats_lens
|
||||
143
funasr_local/models/frontend/s3prl.py
Normal file
143
funasr_local/models/frontend/s3prl.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.modules.frontends.frontend import Frontend
|
||||
from funasr_local.modules.nets_utils import pad_list
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
def base_s3prl_setup(args):
|
||||
args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
|
||||
args.upstream_model_config = getattr(args, "upstream_model_config", None)
|
||||
args.upstream_refresh = getattr(args, "upstream_refresh", False)
|
||||
args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
|
||||
args.init_ckpt = getattr(args, "init_ckpt", None)
|
||||
args.verbose = getattr(args, "verbose", False)
|
||||
args.tile_factor = getattr(args, "tile_factor", 1)
|
||||
return args
|
||||
|
||||
|
||||
class S3prlFrontend(AbsFrontend):
|
||||
"""Speech Pretrained Representation frontend structure for ASR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
download_dir: str = None,
|
||||
multilayer_feature: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
if download_dir is not None:
|
||||
torch.hub.set_dir(download_dir)
|
||||
|
||||
self.multilayer_feature = multilayer_feature
|
||||
self.upstream, self.featurizer = self._get_upstream(frontend_conf)
|
||||
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
|
||||
self.output_dim = self.featurizer.output_dim
|
||||
self.frontend_type = "s3prl"
|
||||
self.hop_length = self.upstream.get_downsample_rates("key")
|
||||
|
||||
def _get_upstream(self, frontend_conf):
|
||||
"""Get S3PRL upstream model."""
|
||||
s3prl_args = base_s3prl_setup(
|
||||
Namespace(**frontend_conf, device="cpu"),
|
||||
)
|
||||
self.args = s3prl_args
|
||||
|
||||
s3prl_path = None
|
||||
python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
|
||||
for p in python_path_list:
|
||||
if p.endswith("s3prl"):
|
||||
s3prl_path = p
|
||||
break
|
||||
assert s3prl_path is not None
|
||||
|
||||
s3prl_upstream = torch.hub.load(
|
||||
s3prl_path,
|
||||
s3prl_args.upstream,
|
||||
ckpt=s3prl_args.upstream_ckpt,
|
||||
model_config=s3prl_args.upstream_model_config,
|
||||
refresh=s3prl_args.upstream_refresh,
|
||||
source="local",
|
||||
).to("cpu")
|
||||
|
||||
if getattr(
|
||||
s3prl_upstream, "model", None
|
||||
) is not None and s3prl_upstream.model.__class__.__name__ in [
|
||||
"Wav2Vec2Model",
|
||||
"HubertModel",
|
||||
]:
|
||||
s3prl_upstream.model.encoder.layerdrop = 0.0
|
||||
|
||||
from s3prl.upstream.interfaces import Featurizer
|
||||
|
||||
if self.multilayer_feature is None:
|
||||
feature_selection = "last_hidden_state"
|
||||
else:
|
||||
feature_selection = "hidden_states"
|
||||
s3prl_featurizer = Featurizer(
|
||||
upstream=s3prl_upstream,
|
||||
feature_selection=feature_selection,
|
||||
upstream_device="cpu",
|
||||
)
|
||||
|
||||
return s3prl_upstream, s3prl_featurizer
|
||||
|
||||
def _tile_representations(self, feature):
|
||||
"""Tile up the representations by `tile_factor`.
|
||||
|
||||
Input - sequence of representations
|
||||
shape: (batch_size, seq_len, feature_dim)
|
||||
Output - sequence of tiled representations
|
||||
shape: (batch_size, seq_len * factor, feature_dim)
|
||||
"""
|
||||
assert (
|
||||
len(feature.shape) == 3
|
||||
), "Input argument `feature` has invalid shape: {}".format(feature.shape)
|
||||
tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
|
||||
tiled_feature = tiled_feature.reshape(
|
||||
feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
|
||||
)
|
||||
return tiled_feature
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.output_dim
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
|
||||
self.upstream.eval()
|
||||
with torch.no_grad():
|
||||
feats = self.upstream(wavs)
|
||||
feats = self.featurizer(wavs, feats)
|
||||
|
||||
if self.args.tile_factor != 1:
|
||||
feats = self._tile_representations(feats)
|
||||
|
||||
input_feats = pad_list(feats, 0.0)
|
||||
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
|
||||
|
||||
# Saving CUDA Memory
|
||||
del feats
|
||||
|
||||
return input_feats, feats_lens
|
||||
|
||||
def reload_pretrained_parameters(self):
|
||||
self.upstream.load_state_dict(self.pretrained_params)
|
||||
logging.info("Pretrained S3PRL frontend model parameters reloaded!")
|
||||
503
funasr_local/models/frontend/wav_frontend.py
Normal file
503
funasr_local/models/frontend/wav_frontend.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from typeguard import check_argument_types
|
||||
|
||||
import funasr_local.models.frontend.eend_ola_feature as eend_ola_feature
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
|
||||
|
||||
def load_cmvn(cmvn_file):
|
||||
with open(cmvn_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
means_list = []
|
||||
vars_list = []
|
||||
for i in range(len(lines)):
|
||||
line_item = lines[i].split()
|
||||
if line_item[0] == '<AddShift>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
add_shift_line = line_item[3:(len(line_item) - 1)]
|
||||
means_list = list(add_shift_line)
|
||||
continue
|
||||
elif line_item[0] == '<Rescale>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
vars_list = list(rescale_line)
|
||||
continue
|
||||
means = np.array(means_list).astype(np.float32)
|
||||
vars = np.array(vars_list).astype(np.float32)
|
||||
cmvn = np.array([means, vars])
|
||||
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
|
||||
return cmvn
|
||||
|
||||
|
||||
def apply_cmvn(inputs, cmvn): # noqa
|
||||
"""
|
||||
Apply CMVN with mvn data
|
||||
"""
|
||||
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
|
||||
means = cmvn[0:1, :dim]
|
||||
vars = cmvn[1:2, :dim]
|
||||
inputs += means.to(device)
|
||||
inputs *= vars.to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
|
||||
def apply_lfr(inputs, lfr_m, lfr_n):
|
||||
LFR_inputs = []
|
||||
T = inputs.shape[0]
|
||||
T_lfr = int(np.ceil(T / lfr_n))
|
||||
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
|
||||
inputs = torch.vstack((left_padding, inputs))
|
||||
T = T + (lfr_m - 1) // 2
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
|
||||
else: # process last LFR frame
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n:]).view(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = torch.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32)
|
||||
|
||||
|
||||
class WavFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cmvn_file: str = None,
|
||||
fs: int = 16000,
|
||||
window: str = 'hamming',
|
||||
n_mels: int = 80,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
filter_length_min: int = -1,
|
||||
filter_length_max: int = -1,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
dither: float = 1.0,
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.window = window
|
||||
self.n_mels = n_mels
|
||||
self.frame_length = frame_length
|
||||
self.frame_shift = frame_shift
|
||||
self.filter_length_min = filter_length_min
|
||||
self.filter_length_max = filter_length_max
|
||||
self.lfr_m = lfr_m
|
||||
self.lfr_n = lfr_n
|
||||
self.cmvn_file = cmvn_file
|
||||
self.dither = dither
|
||||
self.snip_edges = snip_edges
|
||||
self.upsacle_samples = upsacle_samples
|
||||
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels * self.lfr_m
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform_length = input_lengths[i]
|
||||
waveform = input[i][:waveform_length]
|
||||
if self.upsacle_samples:
|
||||
waveform = waveform * (1 << 15)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=self.n_mels,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift,
|
||||
dither=self.dither,
|
||||
energy_floor=0.0,
|
||||
window_type=self.window,
|
||||
sample_frequency=self.fs,
|
||||
snip_edges=self.snip_edges)
|
||||
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
if self.cmvn is not None:
|
||||
mat = apply_cmvn(mat, self.cmvn)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
|
||||
def forward_fbank(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform_length = input_lengths[i]
|
||||
waveform = input[i][:waveform_length]
|
||||
waveform = waveform * (1 << 15)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=self.n_mels,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift,
|
||||
dither=self.dither,
|
||||
energy_floor=0.0,
|
||||
window_type=self.window,
|
||||
sample_frequency=self.fs)
|
||||
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
|
||||
def forward_lfr_cmvn(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
mat = input[i, :input_lengths[i], :]
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
if self.cmvn is not None:
|
||||
mat = apply_cmvn(mat, self.cmvn)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
|
||||
|
||||
class WavFrontendOnline(AbsFrontend):
|
||||
"""Conventional frontend structure for streaming ASR/VAD.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cmvn_file: str = None,
|
||||
fs: int = 16000,
|
||||
window: str = 'hamming',
|
||||
n_mels: int = 80,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
filter_length_min: int = -1,
|
||||
filter_length_max: int = -1,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
dither: float = 1.0,
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.window = window
|
||||
self.n_mels = n_mels
|
||||
self.frame_length = frame_length
|
||||
self.frame_shift = frame_shift
|
||||
self.frame_sample_length = int(self.frame_length * self.fs / 1000)
|
||||
self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
|
||||
self.filter_length_min = filter_length_min
|
||||
self.filter_length_max = filter_length_max
|
||||
self.lfr_m = lfr_m
|
||||
self.lfr_n = lfr_n
|
||||
self.cmvn_file = cmvn_file
|
||||
self.dither = dither
|
||||
self.snip_edges = snip_edges
|
||||
self.upsacle_samples = upsacle_samples
|
||||
self.waveforms = None
|
||||
self.reserve_waveforms = None
|
||||
self.fbanks = None
|
||||
self.fbanks_lens = None
|
||||
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
|
||||
self.input_cache = None
|
||||
self.lfr_splice_cache = []
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels * self.lfr_m
|
||||
|
||||
@staticmethod
|
||||
def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply CMVN with mvn data
|
||||
"""
|
||||
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
|
||||
means = np.tile(cmvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
|
||||
inputs += torch.from_numpy(means).type(dtype).to(device)
|
||||
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
@staticmethod
|
||||
# inputs tensor has catted the cache tensor
|
||||
# def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
|
||||
# is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, int]:
|
||||
"""
|
||||
Apply lfr with data
|
||||
"""
|
||||
|
||||
LFR_inputs = []
|
||||
# inputs = torch.vstack((inputs_lfr_cache, inputs))
|
||||
T = inputs.shape[0] # include the right context
|
||||
T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
|
||||
splice_idx = T_lfr
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
|
||||
else: # process last LFR frame
|
||||
if is_final:
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n:]).view(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = torch.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
else:
|
||||
# update splice_idx and break the circle
|
||||
splice_idx = i
|
||||
break
|
||||
splice_idx = min(T - 1, splice_idx * lfr_n)
|
||||
lfr_splice_cache = inputs[splice_idx:, :]
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
|
||||
|
||||
@staticmethod
|
||||
def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
|
||||
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
|
||||
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
||||
|
||||
def forward_fbank(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
if self.input_cache is None:
|
||||
self.input_cache = torch.empty(0)
|
||||
input = torch.cat((self.input_cache, input), dim=1)
|
||||
frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
|
||||
# update self.in_cache
|
||||
self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
|
||||
waveforms = torch.empty(0)
|
||||
feats_pad = torch.empty(0)
|
||||
feats_lens = torch.empty(0)
|
||||
if frame_num:
|
||||
waveforms = []
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform = input[i]
|
||||
# we need accurate wave samples that used for fbank extracting
|
||||
waveforms.append(
|
||||
waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
|
||||
waveform = waveform * (1 << 15)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=self.n_mels,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift,
|
||||
dither=self.dither,
|
||||
energy_floor=0.0,
|
||||
window_type=self.window,
|
||||
sample_frequency=self.fs)
|
||||
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
waveforms = torch.stack(waveforms)
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
self.fbanks = feats_pad
|
||||
import copy
|
||||
self.fbanks_lens = copy.deepcopy(feats_lens)
|
||||
return waveforms, feats_pad, feats_lens
|
||||
|
||||
def get_fbank(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.fbanks, self.fbanks_lens
|
||||
|
||||
def forward_lfr_cmvn(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
is_final: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
lfr_splice_frame_idxs = []
|
||||
for i in range(batch_size):
|
||||
mat = input[i, :input_lengths[i], :]
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
# update self.lfr_splice_cache in self.apply_lfr
|
||||
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
|
||||
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
|
||||
is_final)
|
||||
if self.cmvn_file is not None:
|
||||
mat = self.apply_cmvn(mat, self.cmvn)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
|
||||
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.shape[0]
|
||||
assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
|
||||
waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D
|
||||
if feats.shape[0]:
|
||||
# if self.reserve_waveforms is None and self.lfr_m > 1:
|
||||
# self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
|
||||
self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat(
|
||||
(self.reserve_waveforms, waveforms), dim=1)
|
||||
if not self.lfr_splice_cache: # 初始化splice_cache
|
||||
for i in range(batch_size):
|
||||
self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
|
||||
# need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
|
||||
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
|
||||
lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache) # B T D
|
||||
feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
|
||||
feats_lengths += lfr_splice_cache_tensor[0].shape[0]
|
||||
frame_from_waveforms = int(
|
||||
(self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
|
||||
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
|
||||
feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
|
||||
if self.lfr_m == 1:
|
||||
self.reserve_waveforms = None
|
||||
else:
|
||||
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
||||
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
||||
# print('frame_frame: ' + str(frame_from_waveforms))
|
||||
self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
|
||||
sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
|
||||
self.waveforms = self.waveforms[:, :sample_length]
|
||||
else:
|
||||
# update self.reserve_waveforms and self.lfr_splice_cache
|
||||
self.reserve_waveforms = self.waveforms[:,
|
||||
:-(self.frame_sample_length - self.frame_shift_sample_length)]
|
||||
for i in range(batch_size):
|
||||
self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
|
||||
return torch.empty(0), feats_lengths
|
||||
else:
|
||||
if is_final:
|
||||
self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
|
||||
feats = torch.stack(self.lfr_splice_cache)
|
||||
feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
|
||||
feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
|
||||
if is_final:
|
||||
self.cache_reset()
|
||||
return feats, feats_lengths
|
||||
|
||||
def get_waveforms(self):
|
||||
return self.waveforms
|
||||
|
||||
def cache_reset(self):
|
||||
self.reserve_waveforms = None
|
||||
self.input_cache = None
|
||||
self.lfr_splice_cache = []
|
||||
|
||||
|
||||
class WavFrontendMel23(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.frame_length = frame_length
|
||||
self.frame_shift = frame_shift
|
||||
self.lfr_m = lfr_m
|
||||
self.lfr_n = lfr_n
|
||||
self.n_mels = 23
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels * (2 * self.lfr_m + 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform_length = input_lengths[i]
|
||||
waveform = input[i][:waveform_length]
|
||||
waveform = waveform.numpy()
|
||||
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
|
||||
mat = eend_ola_feature.transform(mat)
|
||||
mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
|
||||
mat = mat[::self.lfr_n]
|
||||
mat = torch.from_numpy(mat)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
180
funasr_local/models/frontend/wav_frontend_kaldifeat.py
Normal file
180
funasr_local/models/frontend/wav_frontend_kaldifeat.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from typeguard import check_argument_types
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
# import kaldifeat
|
||||
|
||||
def load_cmvn(cmvn_file):
|
||||
with open(cmvn_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
means_list = []
|
||||
vars_list = []
|
||||
for i in range(len(lines)):
|
||||
line_item = lines[i].split()
|
||||
if line_item[0] == '<AddShift>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
add_shift_line = line_item[3:(len(line_item) - 1)]
|
||||
means_list = list(add_shift_line)
|
||||
continue
|
||||
elif line_item[0] == '<Rescale>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
vars_list = list(rescale_line)
|
||||
continue
|
||||
means = np.array(means_list).astype(np.float)
|
||||
vars = np.array(vars_list).astype(np.float)
|
||||
cmvn = np.array([means, vars])
|
||||
cmvn = torch.as_tensor(cmvn)
|
||||
return cmvn
|
||||
|
||||
|
||||
def apply_cmvn(inputs, cmvn_file): # noqa
|
||||
"""
|
||||
Apply CMVN with mvn data
|
||||
"""
|
||||
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
|
||||
cmvn = load_cmvn(cmvn_file)
|
||||
means = np.tile(cmvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
|
||||
inputs += torch.from_numpy(means).type(dtype).to(device)
|
||||
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
|
||||
def apply_lfr(inputs, lfr_m, lfr_n):
|
||||
LFR_inputs = []
|
||||
T = inputs.shape[0]
|
||||
T_lfr = int(np.ceil(T / lfr_n))
|
||||
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
|
||||
inputs = torch.vstack((left_padding, inputs))
|
||||
T = T + (lfr_m - 1) // 2
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
|
||||
else: # process last LFR frame
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n:]).view(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = torch.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32)
|
||||
|
||||
|
||||
# class WavFrontend_kaldifeat(AbsFrontend):
|
||||
# """Conventional frontend structure for ASR.
|
||||
# """
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# cmvn_file: str = None,
|
||||
# fs: int = 16000,
|
||||
# window: str = 'hamming',
|
||||
# n_mels: int = 80,
|
||||
# frame_length: int = 25,
|
||||
# frame_shift: int = 10,
|
||||
# lfr_m: int = 1,
|
||||
# lfr_n: int = 1,
|
||||
# dither: float = 1.0,
|
||||
# snip_edges: bool = True,
|
||||
# upsacle_samples: bool = True,
|
||||
# device: str = 'cpu',
|
||||
# **kwargs,
|
||||
# ):
|
||||
# super().__init__()
|
||||
#
|
||||
# opts = kaldifeat.FbankOptions()
|
||||
# opts.device = device
|
||||
# opts.frame_opts.samp_freq = fs
|
||||
# opts.frame_opts.dither = dither
|
||||
# opts.frame_opts.window_type = window
|
||||
# opts.frame_opts.frame_shift_ms = float(frame_shift)
|
||||
# opts.frame_opts.frame_length_ms = float(frame_length)
|
||||
# opts.mel_opts.num_bins = n_mels
|
||||
# opts.energy_floor = 0
|
||||
# opts.frame_opts.snip_edges = snip_edges
|
||||
# opts.mel_opts.debug_mel = False
|
||||
# self.opts = opts
|
||||
# self.fbank_fn = None
|
||||
# self.fbank_beg_idx = 0
|
||||
# self.reset_fbank_status()
|
||||
#
|
||||
# self.lfr_m = lfr_m
|
||||
# self.lfr_n = lfr_n
|
||||
# self.cmvn_file = cmvn_file
|
||||
# self.upsacle_samples = upsacle_samples
|
||||
#
|
||||
# def output_size(self) -> int:
|
||||
# return self.n_mels * self.lfr_m
|
||||
#
|
||||
# def forward_fbank(
|
||||
# self,
|
||||
# input: torch.Tensor,
|
||||
# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# batch_size = input.size(0)
|
||||
# feats = []
|
||||
# feats_lens = []
|
||||
# for i in range(batch_size):
|
||||
# waveform_length = input_lengths[i]
|
||||
# waveform = input[i][:waveform_length]
|
||||
# waveform = waveform * (1 << 15)
|
||||
#
|
||||
# self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
||||
# frames = self.fbank_fn.num_frames_ready
|
||||
# frames_cur = frames - self.fbank_beg_idx
|
||||
# mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
|
||||
# device=self.opts.device)
|
||||
# for i in range(self.fbank_beg_idx, frames):
|
||||
# mat[i, :] = self.fbank_fn.get_frame(i)
|
||||
# self.fbank_beg_idx += frames_cur
|
||||
#
|
||||
# feat_length = mat.size(0)
|
||||
# feats.append(mat)
|
||||
# feats_lens.append(feat_length)
|
||||
#
|
||||
# feats_lens = torch.as_tensor(feats_lens)
|
||||
# feats_pad = pad_sequence(feats,
|
||||
# batch_first=True,
|
||||
# padding_value=0.0)
|
||||
# return feats_pad, feats_lens
|
||||
#
|
||||
# def reset_fbank_status(self):
|
||||
# self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
|
||||
# self.fbank_beg_idx = 0
|
||||
#
|
||||
# def forward_lfr_cmvn(
|
||||
# self,
|
||||
# input: torch.Tensor,
|
||||
# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# batch_size = input.size(0)
|
||||
# feats = []
|
||||
# feats_lens = []
|
||||
# for i in range(batch_size):
|
||||
# mat = input[i, :input_lengths[i], :]
|
||||
# if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
# mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
# if self.cmvn_file is not None:
|
||||
# mat = apply_cmvn(mat, self.cmvn_file)
|
||||
# feat_length = mat.size(0)
|
||||
# feats.append(mat)
|
||||
# feats_lens.append(feat_length)
|
||||
#
|
||||
# feats_lens = torch.as_tensor(feats_lens)
|
||||
# feats_pad = pad_sequence(feats,
|
||||
# batch_first=True,
|
||||
# padding_value=0.0)
|
||||
# return feats_pad, feats_lens
|
||||
81
funasr_local/models/frontend/windowing.py
Normal file
81
funasr_local/models/frontend/windowing.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python3
|
||||
# 2020, Technische Universität München; Ludwig Kürzinger
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Sliding Window for raw audio input data."""
|
||||
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class SlidingWindow(AbsFrontend):
|
||||
"""Sliding Window.
|
||||
|
||||
Provides a sliding window over a batched continuous raw audio tensor.
|
||||
Optionally, provides padding (Currently not implemented).
|
||||
Combine this module with a pre-encoder compatible with raw audio data,
|
||||
for example Sinc convolutions.
|
||||
|
||||
Known issues:
|
||||
Output length is calculated incorrectly if audio shorter than win_length.
|
||||
WARNING: trailing values are discarded - padding not implemented yet.
|
||||
There is currently no additional window function applied to input values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
win_length: int = 400,
|
||||
hop_length: int = 160,
|
||||
channels: int = 1,
|
||||
padding: int = None,
|
||||
fs=None,
|
||||
):
|
||||
"""Initialize.
|
||||
|
||||
Args:
|
||||
win_length: Length of frame.
|
||||
hop_length: Relative starting point of next frame.
|
||||
channels: Number of input channels.
|
||||
padding: Padding (placeholder, currently not implemented).
|
||||
fs: Sampling rate (placeholder for compatibility, not used).
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.channels = channels
|
||||
self.padding = padding
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply a sliding window on the input.
|
||||
|
||||
Args:
|
||||
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
|
||||
input_lengths: Input lengths within batch.
|
||||
|
||||
Returns:
|
||||
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
|
||||
Tensor: Output lengths within batch.
|
||||
"""
|
||||
input_size = input.size()
|
||||
B = input_size[0]
|
||||
T = input_size[1]
|
||||
C = self.channels
|
||||
D = self.win_length
|
||||
# (B, T, C) --> (T, B, C)
|
||||
continuous = input.view(B, T, C).permute(1, 0, 2)
|
||||
windowed = continuous.unfold(0, D, self.hop_length)
|
||||
# (T, B, C, D) --> (B, T, C, D)
|
||||
output = windowed.permute(1, 0, 2, 3).contiguous()
|
||||
# After unfold(), windowed lengths change:
|
||||
output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
|
||||
return output, output_lengths
|
||||
|
||||
def output_size(self) -> int:
|
||||
"""Return output length of feature dimension D, i.e. the window length."""
|
||||
return self.win_length
|
||||
1
funasr_local/models/joint_net/__init__.py
Normal file
1
funasr_local/models/joint_net/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
61
funasr_local/models/joint_net/joint_network.py
Normal file
61
funasr_local/models/joint_net/joint_network.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Transducer joint network implementation."""
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.nets_utils import get_activation
|
||||
|
||||
|
||||
class JointNetwork(torch.nn.Module):
|
||||
"""Transducer joint network module.
|
||||
|
||||
Args:
|
||||
output_size: Output size.
|
||||
encoder_size: Encoder output size.
|
||||
decoder_size: Decoder output size..
|
||||
joint_space_size: Joint space size.
|
||||
joint_act_type: Type of activation for joint network.
|
||||
**activation_parameters: Parameters for the activation function.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_size: int,
|
||||
encoder_size: int,
|
||||
decoder_size: int,
|
||||
joint_space_size: int = 256,
|
||||
joint_activation_type: str = "tanh",
|
||||
) -> None:
|
||||
"""Construct a JointNetwork object."""
|
||||
super().__init__()
|
||||
|
||||
self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size)
|
||||
self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False)
|
||||
|
||||
self.lin_out = torch.nn.Linear(joint_space_size, output_size)
|
||||
|
||||
self.joint_activation = get_activation(
|
||||
joint_activation_type
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
dec_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Joint computation of encoder and decoder hidden state sequences.
|
||||
|
||||
Args:
|
||||
enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
|
||||
dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
|
||||
|
||||
Returns:
|
||||
joint_out: Joint output state sequences. (B, T, U, D_out)
|
||||
|
||||
"""
|
||||
if project_input:
|
||||
joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out))
|
||||
else:
|
||||
joint_out = self.joint_activation(enc_out + dec_out)
|
||||
return self.lin_out(joint_out)
|
||||
0
funasr_local/models/pooling/__init__.py
Normal file
0
funasr_local/models/pooling/__init__.py
Normal file
98
funasr_local/models/pooling/statistic_pooling.py
Normal file
98
funasr_local/models/pooling/statistic_pooling.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from funasr_local.modules.nets_utils import make_non_pad_mask
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
|
||||
VAR2STD_EPSILON = 1e-12
|
||||
|
||||
class StatisticPooling(torch.nn.Module):
|
||||
def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12):
|
||||
super(StatisticPooling, self).__init__()
|
||||
if isinstance(pooling_dim, int):
|
||||
pooling_dim = (pooling_dim, )
|
||||
self.pooling_dim = pooling_dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, xs_pad, ilens=None):
|
||||
# xs_pad in (Batch, Channel, Time, Frequency)
|
||||
|
||||
if ilens is None:
|
||||
masks = torch.ones_like(xs_pad).to(xs_pad)
|
||||
else:
|
||||
masks = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
|
||||
mean = (torch.sum(xs_pad, dim=self.pooling_dim, keepdim=True) /
|
||||
torch.sum(masks, dim=self.pooling_dim, keepdim=True))
|
||||
squared_difference = torch.pow(xs_pad - mean, 2.0)
|
||||
variance = (torch.sum(squared_difference, dim=self.pooling_dim, keepdim=True) /
|
||||
torch.sum(masks, dim=self.pooling_dim, keepdim=True))
|
||||
for i in reversed(self.pooling_dim):
|
||||
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
|
||||
|
||||
mask = torch.less_equal(variance, self.eps).float()
|
||||
variance = (1.0 - mask) * variance + mask * self.eps
|
||||
stddev = torch.sqrt(variance)
|
||||
|
||||
stat_pooling = torch.cat([mean, stddev], dim=1)
|
||||
|
||||
return stat_pooling
|
||||
|
||||
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
|
||||
return {}
|
||||
|
||||
|
||||
def statistic_pooling(
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
pooling_dim: Tuple = (2, 3)
|
||||
) -> torch.Tensor:
|
||||
# xs_pad in (Batch, Channel, Time, Frequency)
|
||||
|
||||
if ilens is None:
|
||||
seq_mask = torch.ones_like(xs_pad).to(xs_pad)
|
||||
else:
|
||||
seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
|
||||
mean = (torch.sum(xs_pad, dim=pooling_dim, keepdim=True) /
|
||||
torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
|
||||
squared_difference = torch.pow(xs_pad - mean, 2.0)
|
||||
variance = (torch.sum(squared_difference, dim=pooling_dim, keepdim=True) /
|
||||
torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
|
||||
for i in reversed(pooling_dim):
|
||||
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
|
||||
|
||||
value_mask = torch.less_equal(variance, VAR2STD_EPSILON).float()
|
||||
variance = (1.0 - value_mask) * variance + value_mask * VAR2STD_EPSILON
|
||||
stddev = torch.sqrt(variance)
|
||||
|
||||
stat_pooling = torch.cat([mean, stddev], dim=1)
|
||||
|
||||
return stat_pooling
|
||||
|
||||
|
||||
def windowed_statistic_pooling(
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
pooling_dim: Tuple = (2, 3),
|
||||
pooling_size: int = 20,
|
||||
pooling_stride: int = 1
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
# xs_pad in (Batch, Channel, Time, Frequency)
|
||||
|
||||
tt = xs_pad.shape[2]
|
||||
num_chunk = int(math.ceil(tt / pooling_stride))
|
||||
pad = pooling_size // 2
|
||||
if len(xs_pad.shape) == 4:
|
||||
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
|
||||
else:
|
||||
features = F.pad(xs_pad, (pad, pad), "reflect")
|
||||
stat_list = []
|
||||
|
||||
for i in range(num_chunk):
|
||||
# B x C
|
||||
st, ed = i*pooling_stride, i*pooling_stride+pooling_size
|
||||
stat = statistic_pooling(features[:, :, st: ed], pooling_dim=pooling_dim)
|
||||
stat_list.append(stat.unsqueeze(2))
|
||||
|
||||
# B x C x T
|
||||
return torch.cat(stat_list, dim=2), ilens / pooling_stride
|
||||
0
funasr_local/models/postencoder/__init__.py
Normal file
0
funasr_local/models/postencoder/__init__.py
Normal file
17
funasr_local/models/postencoder/abs_postencoder.py
Normal file
17
funasr_local/models/postencoder/abs_postencoder.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsPostEncoder(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def output_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
# 2021, University of Stuttgart; Pavel Denisov
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Hugging Face Transformers PostEncoder."""
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import torch
|
||||
|
||||
try:
|
||||
from transformers import AutoModel
|
||||
|
||||
is_transformers_available = True
|
||||
except ImportError:
|
||||
is_transformers_available = False
|
||||
|
||||
|
||||
class HuggingFaceTransformersPostEncoder(AbsPostEncoder):
|
||||
"""Hugging Face Transformers PostEncoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
model_name_or_path: str,
|
||||
):
|
||||
"""Initialize the module."""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
if not is_transformers_available:
|
||||
raise ImportError(
|
||||
"`transformers` is not available. Please install it via `pip install"
|
||||
" transformers` or `cd /path/to/espnet/tools && . ./activate_python.sh"
|
||||
" && ./installers/install_transformers.sh`."
|
||||
)
|
||||
|
||||
model = AutoModel.from_pretrained(model_name_or_path)
|
||||
|
||||
if hasattr(model, "encoder"):
|
||||
self.transformer = model.encoder
|
||||
else:
|
||||
self.transformer = model
|
||||
|
||||
if hasattr(self.transformer, "embed_tokens"):
|
||||
del self.transformer.embed_tokens
|
||||
if hasattr(self.transformer, "wte"):
|
||||
del self.transformer.wte
|
||||
if hasattr(self.transformer, "word_embedding"):
|
||||
del self.transformer.word_embedding
|
||||
|
||||
self.pretrained_params = copy.deepcopy(self.transformer.state_dict())
|
||||
|
||||
if (
|
||||
self.transformer.config.is_encoder_decoder
|
||||
or self.transformer.config.model_type in ["xlnet", "t5"]
|
||||
):
|
||||
self.use_inputs_embeds = True
|
||||
self.extend_attention_mask = False
|
||||
elif self.transformer.config.model_type == "gpt2":
|
||||
self.use_inputs_embeds = True
|
||||
self.extend_attention_mask = True
|
||||
else:
|
||||
self.use_inputs_embeds = False
|
||||
self.extend_attention_mask = True
|
||||
|
||||
self.linear_in = torch.nn.Linear(
|
||||
input_size, self.transformer.config.hidden_size
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward."""
|
||||
input = self.linear_in(input)
|
||||
|
||||
args = {"return_dict": True}
|
||||
|
||||
mask = (~make_pad_mask(input_lengths)).to(input.device).float()
|
||||
|
||||
if self.extend_attention_mask:
|
||||
args["attention_mask"] = _extend_attention_mask(mask)
|
||||
else:
|
||||
args["attention_mask"] = mask
|
||||
|
||||
if self.use_inputs_embeds:
|
||||
args["inputs_embeds"] = input
|
||||
else:
|
||||
args["hidden_states"] = input
|
||||
|
||||
if self.transformer.config.model_type == "mpnet":
|
||||
args["head_mask"] = [None for _ in self.transformer.layer]
|
||||
|
||||
output = self.transformer(**args).last_hidden_state
|
||||
|
||||
return output, input_lengths
|
||||
|
||||
def reload_pretrained_parameters(self):
|
||||
self.transformer.load_state_dict(self.pretrained_params)
|
||||
logging.info("Pretrained Transformers model parameters reloaded!")
|
||||
|
||||
def output_size(self) -> int:
|
||||
"""Get the output size."""
|
||||
return self.transformer.config.hidden_size
|
||||
|
||||
|
||||
def _extend_attention_mask(mask: torch.Tensor) -> torch.Tensor:
|
||||
mask = mask[:, None, None, :]
|
||||
mask = (1.0 - mask) * -10000.0
|
||||
return mask
|
||||
0
funasr_local/models/predictor/__init__.py
Normal file
0
funasr_local/models/predictor/__init__.py
Normal file
748
funasr_local/models/predictor/cif.py
Normal file
748
funasr_local/models/predictor/cif.py
Normal file
@@ -0,0 +1,748 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import logging
|
||||
import numpy as np
|
||||
from funasr_local.torch_utils.device_funcs import to_device
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.streaming_utils.utils import sequence_mask
|
||||
|
||||
class CifPredictor(nn.Module):
|
||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
|
||||
super(CifPredictor, self).__init__()
|
||||
|
||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
||||
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
|
||||
self.cif_output = nn.Linear(idim, 1)
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
self.threshold = threshold
|
||||
self.smooth_factor = smooth_factor
|
||||
self.noise_threshold = noise_threshold
|
||||
self.tail_threshold = tail_threshold
|
||||
|
||||
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
|
||||
target_label_length=None):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
memory = self.cif_conv1d(queries)
|
||||
output = memory + context
|
||||
output = self.dropout(output)
|
||||
output = output.transpose(1, 2)
|
||||
output = torch.relu(output)
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
alphas = alphas.squeeze(-1)
|
||||
mask = mask.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
target_length = None
|
||||
token_num = alphas.sum(-1)
|
||||
if target_length is not None:
|
||||
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
||||
elif self.tail_threshold > 0.0:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
||||
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
||||
b, t, d = hidden.size()
|
||||
tail_threshold = self.tail_threshold
|
||||
if mask is not None:
|
||||
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
|
||||
ones_t = torch.ones_like(zeros_t)
|
||||
mask_1 = torch.cat([mask, zeros_t], dim=1)
|
||||
mask_2 = torch.cat([ones_t, mask], dim=1)
|
||||
mask = mask_2 - mask_1
|
||||
tail_threshold = mask * tail_threshold
|
||||
alphas = torch.cat([alphas, zeros_t], dim=1)
|
||||
alphas = torch.add(alphas, tail_threshold)
|
||||
else:
|
||||
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
||||
tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
||||
alphas = torch.cat([alphas, tail_threshold], dim=1)
|
||||
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
|
||||
hidden = torch.cat([hidden, zeros], dim=1)
|
||||
token_num = alphas.sum(dim=-1)
|
||||
token_num_floor = torch.floor(token_num)
|
||||
|
||||
return hidden, alphas, token_num_floor
|
||||
|
||||
|
||||
def gen_frame_alignments(self,
|
||||
alphas: torch.Tensor = None,
|
||||
encoder_sequence_length: torch.Tensor = None):
|
||||
batch_size, maximum_length = alphas.size()
|
||||
int_type = torch.int32
|
||||
|
||||
is_training = self.training
|
||||
if is_training:
|
||||
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
|
||||
else:
|
||||
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
|
||||
|
||||
max_token_num = torch.max(token_num).item()
|
||||
|
||||
alphas_cumsum = torch.cumsum(alphas, dim=1)
|
||||
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
|
||||
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
|
||||
|
||||
index = torch.ones([batch_size, max_token_num], dtype=int_type)
|
||||
index = torch.cumsum(index, dim=1)
|
||||
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
|
||||
|
||||
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
|
||||
index_div_bool_zeros = index_div.eq(0)
|
||||
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
|
||||
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
|
||||
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
|
||||
index_div_bool_zeros_count *= token_num_mask
|
||||
|
||||
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
|
||||
ones = torch.ones_like(index_div_bool_zeros_count_tile)
|
||||
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
|
||||
ones = torch.cumsum(ones, dim=2)
|
||||
cond = index_div_bool_zeros_count_tile == ones
|
||||
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
|
||||
|
||||
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
|
||||
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
|
||||
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
|
||||
predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
|
||||
int_type).to(encoder_sequence_length.device)
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
|
||||
|
||||
predictor_alignments = index_div_bool_zeros_count_tile_out
|
||||
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
||||
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
||||
|
||||
|
||||
class CifPredictorV2(nn.Module):
|
||||
def __init__(self,
|
||||
idim,
|
||||
l_order,
|
||||
r_order,
|
||||
threshold=1.0,
|
||||
dropout=0.1,
|
||||
smooth_factor=1.0,
|
||||
noise_threshold=0,
|
||||
tail_threshold=0.0,
|
||||
tf2torch_tensor_name_prefix_torch="predictor",
|
||||
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
|
||||
tail_mask=True,
|
||||
):
|
||||
super(CifPredictorV2, self).__init__()
|
||||
|
||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
||||
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
|
||||
self.cif_output = nn.Linear(idim, 1)
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
self.threshold = threshold
|
||||
self.smooth_factor = smooth_factor
|
||||
self.noise_threshold = noise_threshold
|
||||
self.tail_threshold = tail_threshold
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
self.tail_mask = tail_mask
|
||||
|
||||
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
|
||||
target_label_length=None):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
output = torch.relu(self.cif_conv1d(queries))
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
alphas = alphas.squeeze(-1)
|
||||
mask = mask.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
target_length = None
|
||||
token_num = alphas.sum(-1)
|
||||
if target_length is not None:
|
||||
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
||||
elif self.tail_threshold > 0.0:
|
||||
if self.tail_mask:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
||||
else:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
|
||||
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
def forward_chunk(self, hidden, cache=None):
|
||||
batch_size, len_time, hidden_size = hidden.shape
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
output = torch.relu(self.cif_conv1d(queries))
|
||||
output = output.transpose(1, 2)
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
|
||||
alphas = alphas.squeeze(-1)
|
||||
|
||||
token_length = []
|
||||
list_fires = []
|
||||
list_frames = []
|
||||
cache_alphas = []
|
||||
cache_hiddens = []
|
||||
|
||||
if cache is not None and "chunk_size" in cache:
|
||||
alphas[:, :cache["chunk_size"][0]] = 0.0
|
||||
alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
||||
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
|
||||
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
|
||||
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
|
||||
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
|
||||
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
|
||||
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
|
||||
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
|
||||
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
|
||||
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
|
||||
hidden = torch.cat((hidden, tail_hidden), dim=1)
|
||||
alphas = torch.cat((alphas, tail_alphas), dim=1)
|
||||
|
||||
len_time = alphas.shape[1]
|
||||
for b in range(batch_size):
|
||||
integrate = 0.0
|
||||
frames = torch.zeros((hidden_size), device=hidden.device)
|
||||
list_frame = []
|
||||
list_fire = []
|
||||
for t in range(len_time):
|
||||
alpha = alphas[b][t]
|
||||
if alpha + integrate < self.threshold:
|
||||
integrate += alpha
|
||||
list_fire.append(integrate)
|
||||
frames += alpha * hidden[b][t]
|
||||
else:
|
||||
frames += (self.threshold - integrate) * hidden[b][t]
|
||||
list_frame.append(frames)
|
||||
integrate += alpha
|
||||
list_fire.append(integrate)
|
||||
integrate -= self.threshold
|
||||
frames = integrate * hidden[b][t]
|
||||
|
||||
cache_alphas.append(integrate)
|
||||
if integrate > 0.0:
|
||||
cache_hiddens.append(frames / integrate)
|
||||
else:
|
||||
cache_hiddens.append(frames)
|
||||
|
||||
token_length.append(torch.tensor(len(list_frame), device=alphas.device))
|
||||
list_fires.append(list_fire)
|
||||
list_frames.append(list_frame)
|
||||
|
||||
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
|
||||
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
|
||||
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
|
||||
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
|
||||
|
||||
max_token_len = max(token_length)
|
||||
if max_token_len == 0:
|
||||
return hidden, torch.stack(token_length, 0)
|
||||
list_ls = []
|
||||
for b in range(batch_size):
|
||||
pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
|
||||
if token_length[b] == 0:
|
||||
list_ls.append(pad_frames)
|
||||
else:
|
||||
list_frames[b] = torch.stack(list_frames[b])
|
||||
list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
|
||||
|
||||
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
|
||||
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
|
||||
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
|
||||
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
|
||||
return torch.stack(list_ls, 0), torch.stack(token_length, 0)
|
||||
|
||||
|
||||
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
||||
b, t, d = hidden.size()
|
||||
tail_threshold = self.tail_threshold
|
||||
if mask is not None:
|
||||
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
|
||||
ones_t = torch.ones_like(zeros_t)
|
||||
mask_1 = torch.cat([mask, zeros_t], dim=1)
|
||||
mask_2 = torch.cat([ones_t, mask], dim=1)
|
||||
mask = mask_2 - mask_1
|
||||
tail_threshold = mask * tail_threshold
|
||||
alphas = torch.cat([alphas, zeros_t], dim=1)
|
||||
alphas = torch.add(alphas, tail_threshold)
|
||||
else:
|
||||
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
||||
tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
||||
if b > 1:
|
||||
alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
|
||||
else:
|
||||
alphas = torch.cat([alphas, tail_threshold], dim=1)
|
||||
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
|
||||
hidden = torch.cat([hidden, zeros], dim=1)
|
||||
token_num = alphas.sum(dim=-1)
|
||||
token_num_floor = torch.floor(token_num)
|
||||
|
||||
return hidden, alphas, token_num_floor
|
||||
|
||||
def gen_frame_alignments(self,
|
||||
alphas: torch.Tensor = None,
|
||||
encoder_sequence_length: torch.Tensor = None):
|
||||
batch_size, maximum_length = alphas.size()
|
||||
int_type = torch.int32
|
||||
|
||||
is_training = self.training
|
||||
if is_training:
|
||||
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
|
||||
else:
|
||||
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
|
||||
|
||||
max_token_num = torch.max(token_num).item()
|
||||
|
||||
alphas_cumsum = torch.cumsum(alphas, dim=1)
|
||||
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
|
||||
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
|
||||
|
||||
index = torch.ones([batch_size, max_token_num], dtype=int_type)
|
||||
index = torch.cumsum(index, dim=1)
|
||||
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
|
||||
|
||||
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
|
||||
index_div_bool_zeros = index_div.eq(0)
|
||||
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
|
||||
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
|
||||
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
|
||||
index_div_bool_zeros_count *= token_num_mask
|
||||
|
||||
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
|
||||
ones = torch.ones_like(index_div_bool_zeros_count_tile)
|
||||
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
|
||||
ones = torch.cumsum(ones, dim=2)
|
||||
cond = index_div_bool_zeros_count_tile == ones
|
||||
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
|
||||
|
||||
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
|
||||
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
|
||||
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
|
||||
predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
|
||||
int_type).to(encoder_sequence_length.device)
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
|
||||
|
||||
predictor_alignments = index_div_bool_zeros_count_tile_out
|
||||
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
||||
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
## predictor
|
||||
"{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
}, # (256,256,3),(3,256,256)
|
||||
"{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.cif_output.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1,256),(1,256,1)
|
||||
"{}.cif_output.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1,),(1,)
|
||||
}
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
names = name.split('.')
|
||||
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class mae_loss(nn.Module):
|
||||
|
||||
def __init__(self, normalize_length=False):
|
||||
super(mae_loss, self).__init__()
|
||||
self.normalize_length = normalize_length
|
||||
self.criterion = torch.nn.L1Loss(reduction='sum')
|
||||
|
||||
def forward(self, token_length, pre_token_length):
|
||||
loss_token_normalizer = token_length.size(0)
|
||||
if self.normalize_length:
|
||||
loss_token_normalizer = token_length.sum().type(torch.float32)
|
||||
loss = self.criterion(token_length, pre_token_length)
|
||||
loss = loss / loss_token_normalizer
|
||||
return loss
|
||||
|
||||
|
||||
def cif(hidden, alphas, threshold):
|
||||
batch_size, len_time, hidden_size = hidden.size()
|
||||
|
||||
# loop varss
|
||||
integrate = torch.zeros([batch_size], device=hidden.device)
|
||||
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
|
||||
# intermediate vars along time
|
||||
list_fires = []
|
||||
list_frames = []
|
||||
|
||||
for t in range(len_time):
|
||||
alpha = alphas[:, t]
|
||||
distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
|
||||
|
||||
integrate += alpha
|
||||
list_fires.append(integrate)
|
||||
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(fire_place,
|
||||
integrate - torch.ones([batch_size], device=hidden.device),
|
||||
integrate)
|
||||
cur = torch.where(fire_place,
|
||||
distribution_completion,
|
||||
alpha)
|
||||
remainds = alpha - cur
|
||||
|
||||
frame += cur[:, None] * hidden[:, t, :]
|
||||
list_frames.append(frame)
|
||||
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
|
||||
remainds[:, None] * hidden[:, t, :],
|
||||
frame)
|
||||
|
||||
fires = torch.stack(list_fires, 1)
|
||||
frames = torch.stack(list_frames, 1)
|
||||
list_ls = []
|
||||
len_labels = torch.round(alphas.sum(-1)).int()
|
||||
max_label_len = len_labels.max()
|
||||
for b in range(batch_size):
|
||||
fire = fires[b, :]
|
||||
l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
|
||||
pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
|
||||
list_ls.append(torch.cat([l, pad_l], 0))
|
||||
return torch.stack(list_ls, 0), fires
|
||||
|
||||
|
||||
def cif_wo_hidden(alphas, threshold):
|
||||
batch_size, len_time = alphas.size()
|
||||
|
||||
# loop varss
|
||||
integrate = torch.zeros([batch_size], device=alphas.device)
|
||||
# intermediate vars along time
|
||||
list_fires = []
|
||||
|
||||
for t in range(len_time):
|
||||
alpha = alphas[:, t]
|
||||
|
||||
integrate += alpha
|
||||
list_fires.append(integrate)
|
||||
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(fire_place,
|
||||
integrate - torch.ones([batch_size], device=alphas.device),
|
||||
integrate)
|
||||
|
||||
fires = torch.stack(list_fires, 1)
|
||||
return fires
|
||||
|
||||
|
||||
class CifPredictorV3(nn.Module):
|
||||
def __init__(self,
|
||||
idim,
|
||||
l_order,
|
||||
r_order,
|
||||
threshold=1.0,
|
||||
dropout=0.1,
|
||||
smooth_factor=1.0,
|
||||
noise_threshold=0,
|
||||
tail_threshold=0.0,
|
||||
tf2torch_tensor_name_prefix_torch="predictor",
|
||||
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
|
||||
smooth_factor2=1.0,
|
||||
noise_threshold2=0,
|
||||
upsample_times=5,
|
||||
upsample_type="cnn",
|
||||
use_cif1_cnn=True,
|
||||
tail_mask=True,
|
||||
):
|
||||
super(CifPredictorV3, self).__init__()
|
||||
|
||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
||||
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
|
||||
self.cif_output = nn.Linear(idim, 1)
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
self.threshold = threshold
|
||||
self.smooth_factor = smooth_factor
|
||||
self.noise_threshold = noise_threshold
|
||||
self.tail_threshold = tail_threshold
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
self.upsample_times = upsample_times
|
||||
self.upsample_type = upsample_type
|
||||
self.use_cif1_cnn = use_cif1_cnn
|
||||
if self.upsample_type == 'cnn':
|
||||
self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
|
||||
self.cif_output2 = nn.Linear(idim, 1)
|
||||
elif self.upsample_type == 'cnn_blstm':
|
||||
self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
|
||||
self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
|
||||
self.cif_output2 = nn.Linear(idim*2, 1)
|
||||
elif self.upsample_type == 'cnn_attn':
|
||||
self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
|
||||
from funasr_local.models.encoder.transformer_encoder import EncoderLayer as TransformerEncoderLayer
|
||||
from funasr_local.modules.attention import MultiHeadedAttention
|
||||
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
idim,
|
||||
idim*2,
|
||||
0.1,
|
||||
)
|
||||
self.self_attn = TransformerEncoderLayer(
|
||||
idim,
|
||||
MultiHeadedAttention(
|
||||
4, idim, 0.1
|
||||
),
|
||||
PositionwiseFeedForward(*positionwise_layer_args),
|
||||
0.1,
|
||||
True, #normalize_before,
|
||||
False, #concat_after,
|
||||
)
|
||||
self.cif_output2 = nn.Linear(idim, 1)
|
||||
self.smooth_factor2 = smooth_factor2
|
||||
self.noise_threshold2 = noise_threshold2
|
||||
|
||||
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
|
||||
target_label_length=None):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
output = torch.relu(self.cif_conv1d(queries))
|
||||
|
||||
# alphas2 is an extra head for timestamp prediction
|
||||
if not self.use_cif1_cnn:
|
||||
_output = context
|
||||
else:
|
||||
_output = output
|
||||
if self.upsample_type == 'cnn':
|
||||
output2 = self.upsample_cnn(_output)
|
||||
output2 = output2.transpose(1,2)
|
||||
elif self.upsample_type == 'cnn_blstm':
|
||||
output2 = self.upsample_cnn(_output)
|
||||
output2 = output2.transpose(1,2)
|
||||
output2, (_, _) = self.blstm(output2)
|
||||
elif self.upsample_type == 'cnn_attn':
|
||||
output2 = self.upsample_cnn(_output)
|
||||
output2 = output2.transpose(1,2)
|
||||
output2, _ = self.self_attn(output2, mask)
|
||||
# import pdb; pdb.set_trace()
|
||||
alphas2 = torch.sigmoid(self.cif_output2(output2))
|
||||
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
|
||||
# repeat the mask in T demension to match the upsampled length
|
||||
if mask is not None:
|
||||
mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
|
||||
mask2 = mask2.unsqueeze(-1)
|
||||
alphas2 = alphas2 * mask2
|
||||
alphas2 = alphas2.squeeze(-1)
|
||||
token_num2 = alphas2.sum(-1)
|
||||
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
alphas = alphas.squeeze(-1)
|
||||
mask = mask.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
target_length = None
|
||||
token_num = alphas.sum(-1)
|
||||
|
||||
if target_length is not None:
|
||||
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
||||
elif self.tail_threshold > 0.0:
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
||||
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||
return acoustic_embeds, token_num, alphas, cif_peak, token_num2
|
||||
|
||||
def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
|
||||
h = hidden
|
||||
b = hidden.shape[0]
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
output = torch.relu(self.cif_conv1d(queries))
|
||||
|
||||
# alphas2 is an extra head for timestamp prediction
|
||||
if not self.use_cif1_cnn:
|
||||
_output = context
|
||||
else:
|
||||
_output = output
|
||||
if self.upsample_type == 'cnn':
|
||||
output2 = self.upsample_cnn(_output)
|
||||
output2 = output2.transpose(1,2)
|
||||
elif self.upsample_type == 'cnn_blstm':
|
||||
output2 = self.upsample_cnn(_output)
|
||||
output2 = output2.transpose(1,2)
|
||||
output2, (_, _) = self.blstm(output2)
|
||||
elif self.upsample_type == 'cnn_attn':
|
||||
output2 = self.upsample_cnn(_output)
|
||||
output2 = output2.transpose(1,2)
|
||||
output2, _ = self.self_attn(output2, mask)
|
||||
alphas2 = torch.sigmoid(self.cif_output2(output2))
|
||||
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
|
||||
# repeat the mask in T demension to match the upsampled length
|
||||
if mask is not None:
|
||||
mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
|
||||
mask2 = mask2.unsqueeze(-1)
|
||||
alphas2 = alphas2 * mask2
|
||||
alphas2 = alphas2.squeeze(-1)
|
||||
_token_num = alphas2.sum(-1)
|
||||
if token_num is not None:
|
||||
alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
|
||||
# re-downsample
|
||||
ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
|
||||
ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
|
||||
# upsampled alphas and cif_peak
|
||||
us_alphas = alphas2
|
||||
us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
|
||||
return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
|
||||
|
||||
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
||||
b, t, d = hidden.size()
|
||||
tail_threshold = self.tail_threshold
|
||||
if mask is not None:
|
||||
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
|
||||
ones_t = torch.ones_like(zeros_t)
|
||||
mask_1 = torch.cat([mask, zeros_t], dim=1)
|
||||
mask_2 = torch.cat([ones_t, mask], dim=1)
|
||||
mask = mask_2 - mask_1
|
||||
tail_threshold = mask * tail_threshold
|
||||
alphas = torch.cat([alphas, zeros_t], dim=1)
|
||||
alphas = torch.add(alphas, tail_threshold)
|
||||
else:
|
||||
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
||||
tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
||||
alphas = torch.cat([alphas, tail_threshold], dim=1)
|
||||
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
|
||||
hidden = torch.cat([hidden, zeros], dim=1)
|
||||
token_num = alphas.sum(dim=-1)
|
||||
token_num_floor = torch.floor(token_num)
|
||||
|
||||
return hidden, alphas, token_num_floor
|
||||
|
||||
def gen_frame_alignments(self,
|
||||
alphas: torch.Tensor = None,
|
||||
encoder_sequence_length: torch.Tensor = None):
|
||||
batch_size, maximum_length = alphas.size()
|
||||
int_type = torch.int32
|
||||
|
||||
is_training = self.training
|
||||
if is_training:
|
||||
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
|
||||
else:
|
||||
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
|
||||
|
||||
max_token_num = torch.max(token_num).item()
|
||||
|
||||
alphas_cumsum = torch.cumsum(alphas, dim=1)
|
||||
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
|
||||
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
|
||||
|
||||
index = torch.ones([batch_size, max_token_num], dtype=int_type)
|
||||
index = torch.cumsum(index, dim=1)
|
||||
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
|
||||
|
||||
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
|
||||
index_div_bool_zeros = index_div.eq(0)
|
||||
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
|
||||
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
|
||||
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
|
||||
index_div_bool_zeros_count *= token_num_mask
|
||||
|
||||
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
|
||||
ones = torch.ones_like(index_div_bool_zeros_count_tile)
|
||||
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
|
||||
ones = torch.cumsum(ones, dim=2)
|
||||
cond = index_div_bool_zeros_count_tile == ones
|
||||
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
|
||||
|
||||
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
|
||||
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
|
||||
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
|
||||
predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
|
||||
int_type).to(encoder_sequence_length.device)
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
|
||||
|
||||
predictor_alignments = index_div_bool_zeros_count_tile_out
|
||||
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
||||
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
||||
0
funasr_local/models/preencoder/__init__.py
Normal file
0
funasr_local/models/preencoder/__init__.py
Normal file
17
funasr_local/models/preencoder/abs_preencoder.py
Normal file
17
funasr_local/models/preencoder/abs_preencoder.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsPreEncoder(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def output_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
38
funasr_local/models/preencoder/linear.py
Normal file
38
funasr_local/models/preencoder/linear.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# 2021, Carnegie Mellon University; Xuankai Chang
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Linear Projection."""
|
||||
|
||||
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LinearProjection(AbsPreEncoder):
|
||||
"""Linear Projection Preencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
):
|
||||
"""Initialize the module."""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
self.output_dim = output_size
|
||||
self.linear_out = torch.nn.Linear(input_size, output_size)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward."""
|
||||
output = self.linear_out(input)
|
||||
return output, input_lengths # no state in this layer
|
||||
|
||||
def output_size(self) -> int:
|
||||
"""Get the output size."""
|
||||
return self.output_dim
|
||||
282
funasr_local/models/preencoder/sinc.py
Normal file
282
funasr_local/models/preencoder/sinc.py
Normal file
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python3
|
||||
# 2020, Technische Universität München; Ludwig Kürzinger
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Sinc convolutions for raw audio input."""
|
||||
|
||||
from collections import OrderedDict
|
||||
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr_local.layers.sinc_conv import LogCompression
|
||||
from funasr_local.layers.sinc_conv import SincConv
|
||||
import humanfriendly
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
|
||||
class LightweightSincConvs(AbsPreEncoder):
|
||||
"""Lightweight Sinc Convolutions.
|
||||
|
||||
Instead of using precomputed features, end-to-end speech recognition
|
||||
can also be done directly from raw audio using sinc convolutions, as
|
||||
described in "Lightweight End-to-End Speech Recognition from Raw Audio
|
||||
Data Using Sinc-Convolutions" by Kürzinger et al.
|
||||
https://arxiv.org/abs/2010.07597
|
||||
|
||||
To use Sinc convolutions in your model instead of the default f-bank
|
||||
frontend, set this module as your pre-encoder with `preencoder: sinc`
|
||||
and use the input of the sliding window frontend with
|
||||
`frontend: sliding_window` in your yaml configuration file.
|
||||
So that the process flow is:
|
||||
|
||||
Frontend (SlidingWindow) -> SpecAug -> Normalization ->
|
||||
Pre-encoder (LightweightSincConvs) -> Encoder -> Decoder
|
||||
|
||||
Note that this method also performs data augmentation in time domain
|
||||
(vs. in spectral domain in the default frontend).
|
||||
Use `plot_sinc_filters.py` to visualize the learned Sinc filters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str, float] = 16000,
|
||||
in_channels: int = 1,
|
||||
out_channels: int = 256,
|
||||
activation_type: str = "leakyrelu",
|
||||
dropout_type: str = "dropout",
|
||||
windowing_type: str = "hamming",
|
||||
scale_type: str = "mel",
|
||||
):
|
||||
"""Initialize the module.
|
||||
|
||||
Args:
|
||||
fs: Sample rate.
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels (for each input channel).
|
||||
activation_type: Choice of activation function.
|
||||
dropout_type: Choice of dropout function.
|
||||
windowing_type: Choice of windowing function.
|
||||
scale_type: Choice of filter-bank initialization scale.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
self.fs = fs
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.activation_type = activation_type
|
||||
self.dropout_type = dropout_type
|
||||
self.windowing_type = windowing_type
|
||||
self.scale_type = scale_type
|
||||
|
||||
self.choices_dropout = {
|
||||
"dropout": torch.nn.Dropout,
|
||||
"spatial": SpatialDropout,
|
||||
"dropout2d": torch.nn.Dropout2d,
|
||||
}
|
||||
if dropout_type not in self.choices_dropout:
|
||||
raise NotImplementedError(
|
||||
f"Dropout type has to be one of "
|
||||
f"{list(self.choices_dropout.keys())}",
|
||||
)
|
||||
|
||||
self.choices_activation = {
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"relu": torch.nn.ReLU,
|
||||
}
|
||||
if activation_type not in self.choices_activation:
|
||||
raise NotImplementedError(
|
||||
f"Activation type has to be one of "
|
||||
f"{list(self.choices_activation.keys())}",
|
||||
)
|
||||
|
||||
# initialization
|
||||
self._create_sinc_convs()
|
||||
# Sinc filters require custom initialization
|
||||
self.espnet_initialization_fn()
|
||||
|
||||
def _create_sinc_convs(self):
|
||||
blocks = OrderedDict()
|
||||
|
||||
# SincConvBlock
|
||||
out_channels = 128
|
||||
self.filters = SincConv(
|
||||
self.in_channels,
|
||||
out_channels,
|
||||
kernel_size=101,
|
||||
stride=1,
|
||||
fs=self.fs,
|
||||
window_func=self.windowing_type,
|
||||
scale_type=self.scale_type,
|
||||
)
|
||||
block = OrderedDict(
|
||||
[
|
||||
("Filters", self.filters),
|
||||
("LogCompression", LogCompression()),
|
||||
("BatchNorm", torch.nn.BatchNorm1d(out_channels, affine=True)),
|
||||
("AvgPool", torch.nn.AvgPool1d(2)),
|
||||
]
|
||||
)
|
||||
blocks["SincConvBlock"] = torch.nn.Sequential(block)
|
||||
in_channels = out_channels
|
||||
|
||||
# First convolutional block, connects the sinc output to the front-end "body"
|
||||
out_channels = 128
|
||||
blocks["DConvBlock1"] = self.gen_lsc_block(
|
||||
in_channels,
|
||||
out_channels,
|
||||
depthwise_kernel_size=25,
|
||||
depthwise_stride=2,
|
||||
pointwise_groups=0,
|
||||
avgpool=True,
|
||||
dropout_probability=0.1,
|
||||
)
|
||||
in_channels = out_channels
|
||||
|
||||
# Second convolutional block, multiple convolutional layers
|
||||
out_channels = self.out_channels
|
||||
for layer in [2, 3, 4]:
|
||||
blocks[f"DConvBlock{layer}"] = self.gen_lsc_block(
|
||||
in_channels, out_channels, depthwise_kernel_size=9, depthwise_stride=1
|
||||
)
|
||||
in_channels = out_channels
|
||||
|
||||
# Third Convolutional block, acts as coupling to encoder
|
||||
out_channels = self.out_channels
|
||||
blocks["DConvBlock5"] = self.gen_lsc_block(
|
||||
in_channels,
|
||||
out_channels,
|
||||
depthwise_kernel_size=7,
|
||||
depthwise_stride=1,
|
||||
pointwise_groups=0,
|
||||
)
|
||||
|
||||
self.blocks = torch.nn.Sequential(blocks)
|
||||
|
||||
def gen_lsc_block(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
depthwise_kernel_size: int = 9,
|
||||
depthwise_stride: int = 1,
|
||||
depthwise_groups=None,
|
||||
pointwise_groups=0,
|
||||
dropout_probability: float = 0.15,
|
||||
avgpool=False,
|
||||
):
|
||||
"""Generate a convolutional block for Lightweight Sinc convolutions.
|
||||
|
||||
Each block consists of either a depthwise or a depthwise-separable
|
||||
convolutions together with dropout, (batch-)normalization layer, and
|
||||
an optional average-pooling layer.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
depthwise_kernel_size: Kernel size of the depthwise convolution.
|
||||
depthwise_stride: Stride of the depthwise convolution.
|
||||
depthwise_groups: Number of groups of the depthwise convolution.
|
||||
pointwise_groups: Number of groups of the pointwise convolution.
|
||||
dropout_probability: Dropout probability in the block.
|
||||
avgpool: If True, an AvgPool layer is inserted.
|
||||
|
||||
Returns:
|
||||
torch.nn.Sequential: Neural network building block.
|
||||
"""
|
||||
block = OrderedDict()
|
||||
if not depthwise_groups:
|
||||
# GCD(in_channels, out_channels) to prevent size mismatches
|
||||
depthwise_groups, r = in_channels, out_channels
|
||||
while r != 0:
|
||||
depthwise_groups, r = depthwise_groups, depthwise_groups % r
|
||||
block["depthwise"] = torch.nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
depthwise_kernel_size,
|
||||
depthwise_stride,
|
||||
groups=depthwise_groups,
|
||||
)
|
||||
if pointwise_groups:
|
||||
block["pointwise"] = torch.nn.Conv1d(
|
||||
out_channels, out_channels, 1, 1, groups=pointwise_groups
|
||||
)
|
||||
block["activation"] = self.choices_activation[self.activation_type]()
|
||||
block["batchnorm"] = torch.nn.BatchNorm1d(out_channels, affine=True)
|
||||
if avgpool:
|
||||
block["avgpool"] = torch.nn.AvgPool1d(2)
|
||||
block["dropout"] = self.choices_dropout[self.dropout_type](dropout_probability)
|
||||
return torch.nn.Sequential(block)
|
||||
|
||||
def espnet_initialization_fn(self):
|
||||
"""Initialize sinc filters with filterbank values."""
|
||||
self.filters.init_filters()
|
||||
for block in self.blocks:
|
||||
for layer in block:
|
||||
if type(layer) == torch.nn.BatchNorm1d and layer.affine:
|
||||
layer.weight.data[:] = 1.0
|
||||
layer.bias.data[:] = 0.0
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply Lightweight Sinc Convolutions.
|
||||
|
||||
The input shall be formatted as (B, T, C_in, D_in)
|
||||
with B as batch size, T as time dimension, C_in as channels,
|
||||
and D_in as feature dimension.
|
||||
|
||||
The output will then be (B, T, C_out*D_out)
|
||||
with C_out and D_out as output dimensions.
|
||||
|
||||
The current module structure only handles D_in=400, so that D_out=1.
|
||||
Remark for the multichannel case: C_out is the number of out_channels
|
||||
given at initialization multiplied with C_in.
|
||||
"""
|
||||
# Transform input data:
|
||||
# (B, T, C_in, D_in) -> (B*T, C_in, D_in)
|
||||
B, T, C_in, D_in = input.size()
|
||||
input_frames = input.view(B * T, C_in, D_in)
|
||||
output_frames = self.blocks.forward(input_frames)
|
||||
|
||||
# ---TRANSFORM: (B*T, C_out, D_out) -> (B, T, C_out*D_out)
|
||||
_, C_out, D_out = output_frames.size()
|
||||
output_frames = output_frames.view(B, T, C_out * D_out)
|
||||
return output_frames, input_lengths # no state in this layer
|
||||
|
||||
def output_size(self) -> int:
|
||||
"""Get the output size."""
|
||||
return self.out_channels * self.in_channels
|
||||
|
||||
|
||||
class SpatialDropout(torch.nn.Module):
|
||||
"""Spatial dropout module.
|
||||
|
||||
Apply dropout to full channels on tensors of input (B, C, D)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dropout_probability: float = 0.15,
|
||||
shape: Optional[Union[tuple, list]] = None,
|
||||
):
|
||||
"""Initialize.
|
||||
|
||||
Args:
|
||||
dropout_probability: Dropout probability.
|
||||
shape (tuple, list): Shape of input tensors.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if shape is None:
|
||||
shape = (0, 2, 1)
|
||||
self.dropout = torch.nn.Dropout2d(dropout_probability)
|
||||
self.shape = (shape,)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward of spatial dropout module."""
|
||||
y = x.permute(*self.shape)
|
||||
y = self.dropout(y)
|
||||
return y.permute(*self.shape)
|
||||
0
funasr_local/models/specaug/__init__.py
Normal file
0
funasr_local/models/specaug/__init__.py
Normal file
18
funasr_local/models/specaug/abs_specaug.py
Normal file
18
funasr_local/models/specaug/abs_specaug.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsSpecAug(torch.nn.Module):
|
||||
"""Abstract class for the augmentation of spectrogram
|
||||
|
||||
The process-flow:
|
||||
|
||||
Frontend -> SpecAug -> Normalization -> Encoder -> Decoder
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lengths: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
184
funasr_local/models/specaug/specaug.py
Normal file
184
funasr_local/models/specaug/specaug.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""SpecAugment module."""
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Union
|
||||
|
||||
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr_local.layers.mask_along_axis import MaskAlongAxis
|
||||
from funasr_local.layers.mask_along_axis import MaskAlongAxisVariableMaxWidth
|
||||
from funasr_local.layers.mask_along_axis import MaskAlongAxisLFR
|
||||
from funasr_local.layers.time_warp import TimeWarp
|
||||
|
||||
|
||||
class SpecAug(AbsSpecAug):
|
||||
"""Implementation of SpecAug.
|
||||
|
||||
Reference:
|
||||
Daniel S. Park et al.
|
||||
"SpecAugment: A Simple Data
|
||||
Augmentation Method for Automatic Speech Recognition"
|
||||
|
||||
.. warning::
|
||||
When using cuda mode, time_warp doesn't have reproducibility
|
||||
due to `torch.nn.functional.interpolate`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
apply_time_warp: bool = True,
|
||||
time_warp_window: int = 5,
|
||||
time_warp_mode: str = "bicubic",
|
||||
apply_freq_mask: bool = True,
|
||||
freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
|
||||
num_freq_mask: int = 2,
|
||||
apply_time_mask: bool = True,
|
||||
time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
|
||||
time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
|
||||
num_time_mask: int = 2,
|
||||
):
|
||||
if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
|
||||
raise ValueError(
|
||||
"Either one of time_warp, time_mask, or freq_mask should be applied"
|
||||
)
|
||||
if (
|
||||
apply_time_mask
|
||||
and (time_mask_width_range is not None)
|
||||
and (time_mask_width_ratio_range is not None)
|
||||
):
|
||||
raise ValueError(
|
||||
'Either one of "time_mask_width_range" or '
|
||||
'"time_mask_width_ratio_range" can be used'
|
||||
)
|
||||
super().__init__()
|
||||
self.apply_time_warp = apply_time_warp
|
||||
self.apply_freq_mask = apply_freq_mask
|
||||
self.apply_time_mask = apply_time_mask
|
||||
|
||||
if apply_time_warp:
|
||||
self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
|
||||
else:
|
||||
self.time_warp = None
|
||||
|
||||
if apply_freq_mask:
|
||||
self.freq_mask = MaskAlongAxis(
|
||||
dim="freq",
|
||||
mask_width_range=freq_mask_width_range,
|
||||
num_mask=num_freq_mask,
|
||||
)
|
||||
else:
|
||||
self.freq_mask = None
|
||||
|
||||
if apply_time_mask:
|
||||
if time_mask_width_range is not None:
|
||||
self.time_mask = MaskAlongAxis(
|
||||
dim="time",
|
||||
mask_width_range=time_mask_width_range,
|
||||
num_mask=num_time_mask,
|
||||
)
|
||||
elif time_mask_width_ratio_range is not None:
|
||||
self.time_mask = MaskAlongAxisVariableMaxWidth(
|
||||
dim="time",
|
||||
mask_width_ratio_range=time_mask_width_ratio_range,
|
||||
num_mask=num_time_mask,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Either one of "time_mask_width_range" or '
|
||||
'"time_mask_width_ratio_range" should be used.'
|
||||
)
|
||||
else:
|
||||
self.time_mask = None
|
||||
|
||||
def forward(self, x, x_lengths=None):
|
||||
if self.time_warp is not None:
|
||||
x, x_lengths = self.time_warp(x, x_lengths)
|
||||
if self.freq_mask is not None:
|
||||
x, x_lengths = self.freq_mask(x, x_lengths)
|
||||
if self.time_mask is not None:
|
||||
x, x_lengths = self.time_mask(x, x_lengths)
|
||||
return x, x_lengths
|
||||
|
||||
class SpecAugLFR(AbsSpecAug):
|
||||
"""Implementation of SpecAug.
|
||||
lfr_rate:low frame rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
apply_time_warp: bool = True,
|
||||
time_warp_window: int = 5,
|
||||
time_warp_mode: str = "bicubic",
|
||||
apply_freq_mask: bool = True,
|
||||
freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
|
||||
num_freq_mask: int = 2,
|
||||
lfr_rate: int = 0,
|
||||
apply_time_mask: bool = True,
|
||||
time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
|
||||
time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
|
||||
num_time_mask: int = 2,
|
||||
):
|
||||
if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
|
||||
raise ValueError(
|
||||
"Either one of time_warp, time_mask, or freq_mask should be applied"
|
||||
)
|
||||
if (
|
||||
apply_time_mask
|
||||
and (time_mask_width_range is not None)
|
||||
and (time_mask_width_ratio_range is not None)
|
||||
):
|
||||
raise ValueError(
|
||||
'Either one of "time_mask_width_range" or '
|
||||
'"time_mask_width_ratio_range" can be used'
|
||||
)
|
||||
super().__init__()
|
||||
self.apply_time_warp = apply_time_warp
|
||||
self.apply_freq_mask = apply_freq_mask
|
||||
self.apply_time_mask = apply_time_mask
|
||||
|
||||
if apply_time_warp:
|
||||
self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
|
||||
else:
|
||||
self.time_warp = None
|
||||
|
||||
if apply_freq_mask:
|
||||
self.freq_mask = MaskAlongAxisLFR(
|
||||
dim="freq",
|
||||
mask_width_range=freq_mask_width_range,
|
||||
num_mask=num_freq_mask,
|
||||
lfr_rate=lfr_rate+1,
|
||||
)
|
||||
|
||||
else:
|
||||
self.freq_mask = None
|
||||
|
||||
if apply_time_mask:
|
||||
if time_mask_width_range is not None:
|
||||
self.time_mask = MaskAlongAxisLFR(
|
||||
dim="time",
|
||||
mask_width_range=time_mask_width_range,
|
||||
num_mask=num_time_mask,
|
||||
lfr_rate=lfr_rate + 1,
|
||||
)
|
||||
elif time_mask_width_ratio_range is not None:
|
||||
self.time_mask = MaskAlongAxisVariableMaxWidth(
|
||||
dim="time",
|
||||
mask_width_ratio_range=time_mask_width_ratio_range,
|
||||
num_mask=num_time_mask,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Either one of "time_mask_width_range" or '
|
||||
'"time_mask_width_ratio_range" should be used.'
|
||||
)
|
||||
else:
|
||||
self.time_mask = None
|
||||
|
||||
def forward(self, x, x_lengths=None):
|
||||
if self.time_warp is not None:
|
||||
x, x_lengths = self.time_warp(x, x_lengths)
|
||||
if self.freq_mask is not None:
|
||||
x, x_lengths = self.freq_mask(x, x_lengths)
|
||||
if self.time_mask is not None:
|
||||
x, x_lengths = self.time_mask(x, x_lengths)
|
||||
return x, x_lengths
|
||||
133
funasr_local/models/target_delay_transformer.py
Normal file
133
funasr_local/models/target_delay_transformer.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr_local.modules.embedding import SinusoidalPositionEncoder
|
||||
#from funasr_local.models.encoder.transformer_encoder import TransformerEncoder as Encoder
|
||||
from funasr_local.models.encoder.sanm_encoder import SANMEncoder as Encoder
|
||||
#from funasr_local.modules.mask import subsequent_n_mask
|
||||
from funasr_local.train.abs_model import AbsPunctuation
|
||||
|
||||
|
||||
class TargetDelayTransformer(AbsPunctuation):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
punc_size: int,
|
||||
pos_enc: str = None,
|
||||
embed_unit: int = 128,
|
||||
att_unit: int = 256,
|
||||
head: int = 2,
|
||||
unit: int = 1024,
|
||||
layer: int = 4,
|
||||
dropout_rate: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
if pos_enc == "sinusoidal":
|
||||
# pos_enc_class = PositionalEncoding
|
||||
pos_enc_class = SinusoidalPositionEncoder
|
||||
elif pos_enc is None:
|
||||
|
||||
def pos_enc_class(*args, **kwargs):
|
||||
return nn.Sequential() # indentity
|
||||
|
||||
else:
|
||||
raise ValueError(f"unknown pos-enc option: {pos_enc}")
|
||||
|
||||
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||
self.encoder = Encoder(
|
||||
input_size=embed_unit,
|
||||
output_size=att_unit,
|
||||
attention_heads=head,
|
||||
linear_units=unit,
|
||||
num_blocks=layer,
|
||||
dropout_rate=dropout_rate,
|
||||
input_layer="pe",
|
||||
# pos_enc_class=pos_enc_class,
|
||||
padding_idx=0,
|
||||
)
|
||||
self.decoder = nn.Linear(att_unit, punc_size)
|
||||
|
||||
|
||||
# def _target_mask(self, ys_in_pad):
|
||||
# ys_mask = ys_in_pad != 0
|
||||
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
|
||||
# return ys_mask.unsqueeze(-2) & m
|
||||
|
||||
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(input)
|
||||
# mask = self._target_mask(input)
|
||||
h, _, _ = self.encoder(x, text_lengths)
|
||||
y = self.decoder(h)
|
||||
return y, None
|
||||
|
||||
def with_vad(self):
|
||||
return False
|
||||
|
||||
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (torch.Tensor): encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]: Tuple of
|
||||
torch.float32 scores for next token (vocab_size)
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
y = y.unsqueeze(0)
|
||||
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1).squeeze(0)
|
||||
return logp, cache
|
||||
|
||||
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, vocab_size)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.encoder.encoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
|
||||
|
||||
# batch decoding
|
||||
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
136
funasr_local/models/vad_realtime_transformer.py
Normal file
136
funasr_local/models/vad_realtime_transformer.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr_local.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr_local.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
|
||||
from funasr_local.train.abs_model import AbsPunctuation
|
||||
|
||||
|
||||
class VadRealtimeTransformer(AbsPunctuation):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
punc_size: int,
|
||||
pos_enc: str = None,
|
||||
embed_unit: int = 128,
|
||||
att_unit: int = 256,
|
||||
head: int = 2,
|
||||
unit: int = 1024,
|
||||
layer: int = 4,
|
||||
dropout_rate: float = 0.5,
|
||||
kernel_size: int = 11,
|
||||
sanm_shfit: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
if pos_enc == "sinusoidal":
|
||||
# pos_enc_class = PositionalEncoding
|
||||
pos_enc_class = SinusoidalPositionEncoder
|
||||
elif pos_enc is None:
|
||||
|
||||
def pos_enc_class(*args, **kwargs):
|
||||
return nn.Sequential() # indentity
|
||||
|
||||
else:
|
||||
raise ValueError(f"unknown pos-enc option: {pos_enc}")
|
||||
|
||||
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||
self.encoder = Encoder(
|
||||
input_size=embed_unit,
|
||||
output_size=att_unit,
|
||||
attention_heads=head,
|
||||
linear_units=unit,
|
||||
num_blocks=layer,
|
||||
dropout_rate=dropout_rate,
|
||||
input_layer="pe",
|
||||
# pos_enc_class=pos_enc_class,
|
||||
padding_idx=0,
|
||||
kernel_size=kernel_size,
|
||||
sanm_shfit=sanm_shfit,
|
||||
)
|
||||
self.decoder = nn.Linear(att_unit, punc_size)
|
||||
|
||||
|
||||
# def _target_mask(self, ys_in_pad):
|
||||
# ys_mask = ys_in_pad != 0
|
||||
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
|
||||
# return ys_mask.unsqueeze(-2) & m
|
||||
|
||||
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
|
||||
vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(input)
|
||||
# mask = self._target_mask(input)
|
||||
h, _, _ = self.encoder(x, text_lengths, vad_indexes)
|
||||
y = self.decoder(h)
|
||||
return y, None
|
||||
|
||||
def with_vad(self):
|
||||
return True
|
||||
|
||||
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (torch.Tensor): encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]: Tuple of
|
||||
torch.float32 scores for next token (vocab_size)
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
y = y.unsqueeze(0)
|
||||
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1).squeeze(0)
|
||||
return logp, cache
|
||||
|
||||
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, vocab_size)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.encoder.encoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
|
||||
|
||||
# batch decoding
|
||||
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
Reference in New Issue
Block a user