add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

View File

187
funasr_local/models/ctc.py Normal file
View File

@@ -0,0 +1,187 @@
import logging
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
class CTC(torch.nn.Module):
"""CTC module.
Args:
odim: dimension of outputs
encoder_output_size: number of encoder projection units
dropout_rate: dropout rate (0.0 ~ 1.0)
ctc_type: builtin or warpctc
reduce: reduce the CTC loss into a scalar
"""
def __init__(
self,
odim: int,
encoder_output_size: int,
dropout_rate: float = 0.0,
ctc_type: str = "builtin",
reduce: bool = True,
ignore_nan_grad: bool = True,
):
assert check_argument_types()
super().__init__()
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
self.ctc_lo = torch.nn.Linear(eprojs, odim)
self.ctc_type = ctc_type
self.ignore_nan_grad = ignore_nan_grad
if self.ctc_type == "builtin":
self.ctc_loss = torch.nn.CTCLoss(reduction="none")
elif self.ctc_type == "warpctc":
import warpctc_pytorch as warp_ctc
if ignore_nan_grad:
logging.warning("ignore_nan_grad option is not supported for warp_ctc")
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
elif self.ctc_type == "gtnctc":
from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
self.ctc_loss = GTNCTCLossFunction.apply
else:
raise ValueError(
f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'
)
self.reduce = reduce
def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor:
if self.ctc_type == "builtin":
th_pred = th_pred.log_softmax(2)
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
if loss.requires_grad and self.ignore_nan_grad:
# ctc_grad: (L, B, O)
ctc_grad = loss.grad_fn(torch.ones_like(loss))
ctc_grad = ctc_grad.sum([0, 2])
indices = torch.isfinite(ctc_grad)
size = indices.long().sum()
if size == 0:
# Return as is
logging.warning(
"All samples in this mini-batch got nan grad."
" Returning nan value instead of CTC loss"
)
elif size != th_pred.size(1):
logging.warning(
f"{th_pred.size(1) - size}/{th_pred.size(1)}"
" samples got nan grad."
" These were ignored for CTC loss."
)
# Create mask for target
target_mask = torch.full(
[th_target.size(0)],
1,
dtype=torch.bool,
device=th_target.device,
)
s = 0
for ind, le in enumerate(th_olen):
if not indices[ind]:
target_mask[s : s + le] = 0
s += le
# Calc loss again using maksed data
loss = self.ctc_loss(
th_pred[:, indices, :],
th_target[target_mask],
th_ilen[indices],
th_olen[indices],
)
else:
size = th_pred.size(1)
if self.reduce:
# Batch-size average
loss = loss.sum() / size
else:
loss = loss / size
return loss
elif self.ctc_type == "warpctc":
# warpctc only supports float32
th_pred = th_pred.to(dtype=torch.float32)
th_target = th_target.cpu().int()
th_ilen = th_ilen.cpu().int()
th_olen = th_olen.cpu().int()
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
if self.reduce:
# NOTE: sum() is needed to keep consistency since warpctc
# return as tensor w/ shape (1,)
# but builtin return as tensor w/o shape (scalar).
loss = loss.sum()
return loss
elif self.ctc_type == "gtnctc":
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
return self.ctc_loss(log_probs, th_target, th_ilen, 0, "none")
else:
raise NotImplementedError
def forward(self, hs_pad, hlens, ys_pad, ys_lens):
"""Calculate CTC loss.
Args:
hs_pad: batch of padded hidden state sequences (B, Tmax, D)
hlens: batch of lengths of hidden state sequences (B)
ys_pad: batch of padded character id sequence tensor (B, Lmax)
ys_lens: batch of lengths of character sequence (B)
"""
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
if self.ctc_type == "gtnctc":
# gtn expects list form for ys
ys_true = [y[y != -1] for y in ys_pad] # parse padded ys
else:
# ys_hat: (B, L, D) -> (L, B, D)
ys_hat = ys_hat.transpose(0, 1)
# (B, L) -> (BxL,)
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
device=hs_pad.device, dtype=hs_pad.dtype
)
return loss
def softmax(self, hs_pad):
"""softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
"""
return F.softmax(self.ctc_lo(hs_pad), dim=2)
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
Args:
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: argmax applied 2d tensor (B, Tmax)
"""
return torch.argmax(self.ctc_lo(hs_pad), dim=2)

View File

@@ -0,0 +1,160 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
from typeguard import check_argument_types
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.torch_utils.device_funcs import force_gatherable
from funasr_local.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class Data2VecPretrainModel(AbsESPnetModel):
"""Data2Vec Pretrain model"""
def __init__(
self,
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
):
assert check_argument_types()
super().__init__()
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.encoder = encoder
self.num_updates = 0
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
), (speech.shape, speech_lengths.shape)
self.encoder.set_num_updates(self.num_updates)
# 1. Encoder
encoder_out = self.encode(speech, speech_lengths)
losses = encoder_out["losses"]
loss = sum(losses.values())
sample_size = encoder_out["sample_size"]
loss = loss.sum() / sample_size
target_var = float(encoder_out["target_var"])
pred_var = float(encoder_out["pred_var"])
ema_decay = float(encoder_out["ema_decay"])
stats = dict(
loss=torch.clone(loss.detach()),
target_var=target_var,
pred_var=pred_var,
ema_decay=ema_decay,
)
loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
"""Frontend + Encoder.
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
speech_lengths = None
encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
return encoder_out
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def set_num_updates(self, num_updates):
self.num_updates = num_updates
def get_num_updates(self):
return self.num_updates

View File

@@ -0,0 +1,334 @@
import random
import numpy as np
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.nets_utils import to_device
from funasr_local.modules.rnn.attentions import initial_att
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.utils.get_default_kwargs import get_default_kwargs
def build_attention_list(
eprojs: int,
dunits: int,
atype: str = "location",
num_att: int = 1,
num_encs: int = 1,
aheads: int = 4,
adim: int = 320,
awin: int = 5,
aconv_chans: int = 10,
aconv_filts: int = 100,
han_mode: bool = False,
han_type=None,
han_heads: int = 4,
han_dim: int = 320,
han_conv_chans: int = -1,
han_conv_filts: int = 100,
han_win: int = 5,
):
att_list = torch.nn.ModuleList()
if num_encs == 1:
for i in range(num_att):
att = initial_att(
atype,
eprojs,
dunits,
aheads,
adim,
awin,
aconv_chans,
aconv_filts,
)
att_list.append(att)
elif num_encs > 1: # no multi-speaker mode
if han_mode:
att = initial_att(
han_type,
eprojs,
dunits,
han_heads,
han_dim,
han_win,
han_conv_chans,
han_conv_filts,
han_mode=True,
)
return att
else:
att_list = torch.nn.ModuleList()
for idx in range(num_encs):
att = initial_att(
atype[idx],
eprojs,
dunits,
aheads[idx],
adim[idx],
awin[idx],
aconv_chans[idx],
aconv_filts[idx],
)
att_list.append(att)
else:
raise ValueError(
"Number of encoders needs to be more than one. {}".format(num_encs)
)
return att_list
class RNNDecoder(AbsDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
rnn_type: str = "lstm",
num_layers: int = 1,
hidden_size: int = 320,
sampling_probability: float = 0.0,
dropout: float = 0.0,
context_residual: bool = False,
replace_sos: bool = False,
num_encs: int = 1,
att_conf: dict = get_default_kwargs(build_attention_list),
):
# FIXME(kamo): The parts of num_spk should be refactored more more more
assert check_argument_types()
if rnn_type not in {"lstm", "gru"}:
raise ValueError(f"Not supported: rnn_type={rnn_type}")
super().__init__()
eprojs = encoder_output_size
self.dtype = rnn_type
self.dunits = hidden_size
self.dlayers = num_layers
self.context_residual = context_residual
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.odim = vocab_size
self.sampling_probability = sampling_probability
self.dropout = dropout
self.num_encs = num_encs
# for multilingual translation
self.replace_sos = replace_sos
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
self.dropout_emb = torch.nn.Dropout(p=dropout)
self.decoder = torch.nn.ModuleList()
self.dropout_dec = torch.nn.ModuleList()
self.decoder += [
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
if self.dtype == "lstm"
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
for _ in range(1, self.dlayers):
self.decoder += [
torch.nn.LSTMCell(hidden_size, hidden_size)
if self.dtype == "lstm"
else torch.nn.GRUCell(hidden_size, hidden_size)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
# NOTE: dropout is applied only for the vertical connections
# see https://arxiv.org/pdf/1409.2329.pdf
if context_residual:
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
else:
self.output = torch.nn.Linear(hidden_size, vocab_size)
self.att_list = build_attention_list(
eprojs=eprojs, dunits=hidden_size, **att_conf
)
def zero_state(self, hs_pad):
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
if self.dtype == "lstm":
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for i in range(1, self.dlayers):
z_list[i], c_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]),
(z_prev[i], c_prev[i]),
)
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for i in range(1, self.dlayers):
z_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
)
return z_list, c_list
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlens = [hlens]
# attention index for the attention module
# in SPA (speaker parallel attention),
# att_idx is used to select attention module. In other cases, it is 0.
att_idx = min(strm_idx, len(self.att_list) - 1)
# hlens should be list of list of integer
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
# get dim, length info
olength = ys_in_pad.size(1)
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
z_all = []
if self.num_encs == 1:
att_w = None
self.att_list[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * self.num_encs # atts
for idx in range(self.num_encs + 1):
# reset pre-computation of h in atts and han
self.att_list[idx].reset()
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in range(olength):
if self.num_encs == 1:
att_c, att_w = self.att_list[att_idx](
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att_list[idx](
hs_pad[idx],
hlens[idx],
self.dropout_dec[0](z_list[0]),
att_w_list[idx],
)
hs_pad_han = torch.stack(att_c_list, dim=1)
hlens_han = [self.num_encs] * len(ys_in_pad)
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
hs_pad_han,
hlens_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs],
)
if i > 0 and random.random() < self.sampling_probability:
z_out = self.output(z_all[-1])
z_out = np.argmax(z_out.detach().cpu(), axis=1)
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
else:
# utt x (zdim + hdim)
ey = torch.cat((eys[:, i, :], att_c), dim=1)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.context_residual:
z_all.append(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
) # utt x (zdim + hdim)
else:
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
z_all = torch.stack(z_all, dim=1)
z_all = self.output(z_all)
z_all.masked_fill_(
make_pad_mask(ys_in_lens, z_all, 1),
0,
)
return z_all, ys_in_lens
def init_state(self, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
c_list = [self.zero_state(x[0].unsqueeze(0))]
z_list = [self.zero_state(x[0].unsqueeze(0))]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(x[0].unsqueeze(0)))
z_list.append(self.zero_state(x[0].unsqueeze(0)))
# TODO(karita): support strm_index for `asr_mix`
strm_index = 0
att_idx = min(strm_index, len(self.att_list) - 1)
if self.num_encs == 1:
a = None
self.att_list[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
for idx in range(self.num_encs + 1):
# reset pre-computation of h in atts and han
self.att_list[idx].reset()
return dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=a,
workspace=(att_idx, z_list, c_list),
)
def score(self, yseq, state, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
att_idx, z_list, c_list = state["workspace"]
vy = yseq[-1].unsqueeze(0)
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
if self.num_encs == 1:
att_c, att_w = self.att_list[att_idx](
x[0].unsqueeze(0),
[x[0].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"],
)
else:
att_w = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * self.num_encs # atts
for idx in range(self.num_encs):
att_c_list[idx], att_w[idx] = self.att_list[idx](
x[idx].unsqueeze(0),
[x[idx].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][idx],
)
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
h_han,
[self.num_encs],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][self.num_encs],
)
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(
ey, z_list, c_list, state["z_prev"], state["c_prev"]
)
if self.context_residual:
logits = self.output(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
)
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
logp = F.log_softmax(logits, dim=1).squeeze(0)
return (
logp,
dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=att_w,
workspace=(att_idx, z_list, c_list),
),
)

View File

@@ -0,0 +1,258 @@
"""RNN decoder definition for Transducer models."""
from typing import List, Optional, Tuple
import torch
from typeguard import check_argument_types
from funasr_local.modules.beam_search.beam_search_transducer import Hypothesis
from funasr_local.models.specaug.specaug import SpecAug
class RNNTDecoder(torch.nn.Module):
"""RNN decoder module.
Args:
vocab_size: Vocabulary size.
embed_size: Embedding size.
hidden_size: Hidden size..
rnn_type: Decoder layers type.
num_layers: Number of decoder layers.
dropout_rate: Dropout rate for decoder layers.
embed_dropout_rate: Dropout rate for embedding layer.
embed_pad: Embedding padding symbol ID.
"""
def __init__(
self,
vocab_size: int,
embed_size: int = 256,
hidden_size: int = 256,
rnn_type: str = "lstm",
num_layers: int = 1,
dropout_rate: float = 0.0,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
) -> None:
"""Construct a RNNDecoder object."""
super().__init__()
assert check_argument_types()
if rnn_type not in ("lstm", "gru"):
raise ValueError(f"Not supported: rnn_type={rnn_type}")
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
self.rnn = torch.nn.ModuleList(
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
)
for _ in range(1, num_layers):
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
self.dropout_rnn = torch.nn.ModuleList(
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
)
self.dlayers = num_layers
self.dtype = rnn_type
self.output_size = hidden_size
self.vocab_size = vocab_size
self.device = next(self.parameters()).device
self.score_cache = {}
def forward(
self,
labels: torch.Tensor,
label_lens: torch.Tensor,
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
) -> torch.Tensor:
"""Encode source label sequences.
Args:
labels: Label ID sequences. (B, L)
states: Decoder hidden states.
((N, B, D_dec), (N, B, D_dec) or None) or None
Returns:
dec_out: Decoder output sequences. (B, U, D_dec)
"""
if states is None:
states = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out
def rnn_forward(
self,
x: torch.Tensor,
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Encode source label sequences.
Args:
x: RNN input sequences. (B, D_emb)
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
Returns:
x: RNN output sequences. (B, D_dec)
(h_next, c_next): Decoder hidden states.
(N, B, D_dec), (N, B, D_dec) or None)
"""
h_prev, c_prev = state
h_next, c_next = self.init_state(x.size(0))
for layer in range(self.dlayers):
if self.dtype == "lstm":
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
layer
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
else:
x, h_next[layer : layer + 1] = self.rnn[layer](
x, hx=h_prev[layer : layer + 1]
)
x = self.dropout_rnn[layer](x)
return x, (h_next, c_next)
def score(
self,
label: torch.Tensor,
label_sequence: List[int],
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypothesis.
Args:
label: Previous label. (1, 1)
label_sequence: Current label sequence.
dec_state: Previous decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
dec_state: Decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
"""
str_labels = "_".join(map(str, label_sequence))
if str_labels in self.score_cache:
dec_out, dec_state = self.score_cache[str_labels]
else:
dec_embed = self.embed(label)
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
self.score_cache[str_labels] = (dec_out, dec_state)
return dec_out[0], dec_state
def batch_score(
self,
hyps: List[Hypothesis],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
dec_embed = self.embed(labels)
states = self.create_batch_states([h.dec_state for h in hyps])
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out.squeeze(1), states
def set_device(self, device: torch.device) -> None:
"""Set GPU device to use.
Args:
device: Device ID.
"""
self.device = device
def init_state(
self, batch_size: int
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
h_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
if self.dtype == "lstm":
c_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
return (h_n, c_n)
return (h_n, None)
def select_state(
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
idx: State ID to extract.
Returns:
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
"""
return (
states[0][:, idx : idx + 1, :],
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
)
def create_batch_states(
self,
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Create decoder hidden states.
Args:
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
Returns:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
return (
torch.cat([s[0] for s in new_states], dim=1),
torch.cat([s[1] for s in new_states], dim=1)
if self.dtype == "lstm"
else None,
)

View File

View File

@@ -0,0 +1,19 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
from funasr_local.modules.scorers.scorer_interface import ScorerInterface
class AbsDecoder(torch.nn.Module, ScorerInterface, ABC):
@abstractmethod
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@@ -0,0 +1,776 @@
from typing import List
from typing import Tuple
import logging
import torch
import torch.nn as nn
import numpy as np
from funasr_local.modules.streaming_utils import utils as myutils
from funasr_local.models.decoder.transformer_decoder import BaseTransformerDecoder
from typeguard import check_argument_types
from funasr_local.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr_local.modules.embedding import PositionalEncoding
from funasr_local.modules.layer_norm import LayerNorm
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr_local.modules.repeat import repeat
from funasr_local.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
class ContextualDecoderLayer(nn.Module):
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(ContextualDecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
if self_attn is not None:
self.norm2 = LayerNorm(size)
if src_attn is not None:
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
# tgt = self.dropout(tgt)
if isinstance(tgt, Tuple):
tgt, _ = tgt
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
if self.training:
cache = None
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
x = residual + self.dropout(x)
x_self_attn = x
residual = x
if self.normalize_before:
x = self.norm3(x)
x = self.src_attn(x, memory, memory_mask)
x_src_attn = x
x = residual + self.dropout(x)
return x, tgt_mask, x_self_attn, x_src_attn
class ContextualBiasDecoder(nn.Module):
def __init__(
self,
size,
src_attn,
dropout_rate,
normalize_before=True,
):
"""Construct an DecoderLayer object."""
super(ContextualBiasDecoder, self).__init__()
self.size = size
self.src_attn = src_attn
if src_attn is not None:
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
x = tgt
if self.src_attn is not None:
if self.normalize_before:
x = self.norm3(x)
x = self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
class ContextualParaformerDecoder(ParaformerSANMDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
att_layer_num: int = 6,
kernel_size: int = 21,
sanm_shfit: int = 0,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
if input_layer == 'none':
self.embed = None
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
# pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
self.att_layer_num = att_layer_num
self.num_blocks = num_blocks
if sanm_shfit is None:
sanm_shfit = (kernel_size - 1) // 2
self.decoders = repeat(
att_layer_num - 1,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.dropout = nn.Dropout(dropout_rate)
self.bias_decoder = ContextualBiasDecoder(
size=attention_dim,
src_attn=MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
),
dropout_rate=dropout_rate,
normalize_before=True,
)
self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
self.last_decoder = ContextualDecoderLayer(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
)
if num_blocks - att_layer_num <= 0:
self.decoders2 = None
else:
self.decoders2 = repeat(
num_blocks - att_layer_num,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
),
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.decoders3 = repeat(
1,
lambda lnum: DecoderLayerSANM(
attention_dim,
None,
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
contextual_info: torch.Tensor,
return_hidden: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
x, tgt_mask, memory, memory_mask, _ = self.decoders(
x, tgt_mask, memory, memory_mask
)
_, _, x_self_attn, x_src_attn = self.last_decoder(
x, tgt_mask, memory, memory_mask
)
# contextual paraformer related
contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
if self.bias_output is not None:
x = torch.cat([x_src_attn, cx], dim=2)
x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
x = x_self_attn + self.dropout(x)
if self.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.decoders2(
x, tgt_mask, memory, memory_mask
)
x, tgt_mask, memory, memory_mask, _ = self.decoders3(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
olens = tgt_mask.sum(1)
if self.output_layer is not None and return_hidden is False:
x = self.output_layer(x)
return x, olens
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
## decoder
# ffn
"{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
# fsmn
"{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 2, 0),
}, # (256,1,31),(1,31,256,1)
# src att
"{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# dnn
"{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
# embed_concat_ffn
"{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
# out norm
"{}.after_norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.after_norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# in embed
"{}.embed.0.weight".format(tensor_name_prefix_torch):
{"name": "{}/w_embs".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (4235,256),(4235,256)
# out layer
"{}.output_layer.weight".format(tensor_name_prefix_torch):
{"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
"squeeze": [None, None],
"transpose": [(1, 0), None],
}, # (4235,256),(256,4235)
"{}.output_layer.bias".format(tensor_name_prefix_torch):
{"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
"seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
"squeeze": [None, None],
"transpose": [None, None],
}, # (4235,),(4235,)
## clas decoder
# src att
"{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# dnn
"{}.bias_output.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
}, # (1024,256),(1,256,1024)
}
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
decoder_layeridx_sets = set()
for name in sorted(var_dict_torch.keys(), reverse=False):
names = name.split('.')
if names[0] == self.tf2torch_tensor_name_prefix_torch:
if names[1] == "decoders":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 0
layeridx += layeridx_bias
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "last_decoder":
layeridx = 15
name_q = name.replace("last_decoder", "decoders.layeridx")
layeridx_bias = 0
layeridx += layeridx_bias
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "decoders2":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
name_q = name_q.replace("decoders2", "decoders")
layeridx_bias = len(decoder_layeridx_sets)
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "decoders3":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "bias_decoder":
name_q = name
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
name_tf = map_dict[name]["name"]
if isinstance(name_tf, list):
idx_list = 0
if name_tf[idx_list] in var_dict_tf.keys():
pass
else:
idx_list = 1
data_tf = var_dict_tf[name_tf[idx_list]]
if map_dict[name]["squeeze"][idx_list] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
if map_dict[name]["transpose"][idx_list] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
name_tf[idx_list],
var_dict_tf[name_tf[
idx_list]].shape))
else:
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
elif names[1] == "after_norm":
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
elif names[1] == "embed_concat_ffn":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
return var_dict_torch_update

View File

@@ -0,0 +1,334 @@
import random
import numpy as np
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.nets_utils import to_device
from funasr_local.modules.rnn.attentions import initial_att
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.utils.get_default_kwargs import get_default_kwargs
def build_attention_list(
eprojs: int,
dunits: int,
atype: str = "location",
num_att: int = 1,
num_encs: int = 1,
aheads: int = 4,
adim: int = 320,
awin: int = 5,
aconv_chans: int = 10,
aconv_filts: int = 100,
han_mode: bool = False,
han_type=None,
han_heads: int = 4,
han_dim: int = 320,
han_conv_chans: int = -1,
han_conv_filts: int = 100,
han_win: int = 5,
):
att_list = torch.nn.ModuleList()
if num_encs == 1:
for i in range(num_att):
att = initial_att(
atype,
eprojs,
dunits,
aheads,
adim,
awin,
aconv_chans,
aconv_filts,
)
att_list.append(att)
elif num_encs > 1: # no multi-speaker mode
if han_mode:
att = initial_att(
han_type,
eprojs,
dunits,
han_heads,
han_dim,
han_win,
han_conv_chans,
han_conv_filts,
han_mode=True,
)
return att
else:
att_list = torch.nn.ModuleList()
for idx in range(num_encs):
att = initial_att(
atype[idx],
eprojs,
dunits,
aheads[idx],
adim[idx],
awin[idx],
aconv_chans[idx],
aconv_filts[idx],
)
att_list.append(att)
else:
raise ValueError(
"Number of encoders needs to be more than one. {}".format(num_encs)
)
return att_list
class RNNDecoder(AbsDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
rnn_type: str = "lstm",
num_layers: int = 1,
hidden_size: int = 320,
sampling_probability: float = 0.0,
dropout: float = 0.0,
context_residual: bool = False,
replace_sos: bool = False,
num_encs: int = 1,
att_conf: dict = get_default_kwargs(build_attention_list),
):
# FIXME(kamo): The parts of num_spk should be refactored more more more
assert check_argument_types()
if rnn_type not in {"lstm", "gru"}:
raise ValueError(f"Not supported: rnn_type={rnn_type}")
super().__init__()
eprojs = encoder_output_size
self.dtype = rnn_type
self.dunits = hidden_size
self.dlayers = num_layers
self.context_residual = context_residual
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.odim = vocab_size
self.sampling_probability = sampling_probability
self.dropout = dropout
self.num_encs = num_encs
# for multilingual translation
self.replace_sos = replace_sos
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
self.dropout_emb = torch.nn.Dropout(p=dropout)
self.decoder = torch.nn.ModuleList()
self.dropout_dec = torch.nn.ModuleList()
self.decoder += [
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
if self.dtype == "lstm"
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
for _ in range(1, self.dlayers):
self.decoder += [
torch.nn.LSTMCell(hidden_size, hidden_size)
if self.dtype == "lstm"
else torch.nn.GRUCell(hidden_size, hidden_size)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
# NOTE: dropout is applied only for the vertical connections
# see https://arxiv.org/pdf/1409.2329.pdf
if context_residual:
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
else:
self.output = torch.nn.Linear(hidden_size, vocab_size)
self.att_list = build_attention_list(
eprojs=eprojs, dunits=hidden_size, **att_conf
)
def zero_state(self, hs_pad):
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
if self.dtype == "lstm":
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for i in range(1, self.dlayers):
z_list[i], c_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]),
(z_prev[i], c_prev[i]),
)
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for i in range(1, self.dlayers):
z_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
)
return z_list, c_list
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlens = [hlens]
# attention index for the attention module
# in SPA (speaker parallel attention),
# att_idx is used to select attention module. In other cases, it is 0.
att_idx = min(strm_idx, len(self.att_list) - 1)
# hlens should be list of list of integer
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
# get dim, length info
olength = ys_in_pad.size(1)
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
z_all = []
if self.num_encs == 1:
att_w = None
self.att_list[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * self.num_encs # atts
for idx in range(self.num_encs + 1):
# reset pre-computation of h in atts and han
self.att_list[idx].reset()
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in range(olength):
if self.num_encs == 1:
att_c, att_w = self.att_list[att_idx](
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att_list[idx](
hs_pad[idx],
hlens[idx],
self.dropout_dec[0](z_list[0]),
att_w_list[idx],
)
hs_pad_han = torch.stack(att_c_list, dim=1)
hlens_han = [self.num_encs] * len(ys_in_pad)
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
hs_pad_han,
hlens_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs],
)
if i > 0 and random.random() < self.sampling_probability:
z_out = self.output(z_all[-1])
z_out = np.argmax(z_out.detach().cpu(), axis=1)
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
else:
# utt x (zdim + hdim)
ey = torch.cat((eys[:, i, :], att_c), dim=1)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.context_residual:
z_all.append(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
) # utt x (zdim + hdim)
else:
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
z_all = torch.stack(z_all, dim=1)
z_all = self.output(z_all)
z_all.masked_fill_(
make_pad_mask(ys_in_lens, z_all, 1),
0,
)
return z_all, ys_in_lens
def init_state(self, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
c_list = [self.zero_state(x[0].unsqueeze(0))]
z_list = [self.zero_state(x[0].unsqueeze(0))]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(x[0].unsqueeze(0)))
z_list.append(self.zero_state(x[0].unsqueeze(0)))
# TODO(karita): support strm_index for `asr_mix`
strm_index = 0
att_idx = min(strm_index, len(self.att_list) - 1)
if self.num_encs == 1:
a = None
self.att_list[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
for idx in range(self.num_encs + 1):
# reset pre-computation of h in atts and han
self.att_list[idx].reset()
return dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=a,
workspace=(att_idx, z_list, c_list),
)
def score(self, yseq, state, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
att_idx, z_list, c_list = state["workspace"]
vy = yseq[-1].unsqueeze(0)
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
if self.num_encs == 1:
att_c, att_w = self.att_list[att_idx](
x[0].unsqueeze(0),
[x[0].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"],
)
else:
att_w = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * self.num_encs # atts
for idx in range(self.num_encs):
att_c_list[idx], att_w[idx] = self.att_list[idx](
x[idx].unsqueeze(0),
[x[idx].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][idx],
)
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
h_han,
[self.num_encs],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][self.num_encs],
)
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(
ey, z_list, c_list, state["z_prev"], state["c_prev"]
)
if self.context_residual:
logits = self.output(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
)
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
logp = F.log_softmax(logits, dim=1).squeeze(0)
return (
logp,
dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=att_w,
workspace=(att_idx, z_list, c_list),
),
)

View File

@@ -0,0 +1,258 @@
"""RNN decoder definition for Transducer models."""
from typing import List, Optional, Tuple
import torch
from typeguard import check_argument_types
from funasr_local.modules.beam_search.beam_search_transducer import Hypothesis
from funasr_local.models.specaug.specaug import SpecAug
class RNNTDecoder(torch.nn.Module):
"""RNN decoder module.
Args:
vocab_size: Vocabulary size.
embed_size: Embedding size.
hidden_size: Hidden size..
rnn_type: Decoder layers type.
num_layers: Number of decoder layers.
dropout_rate: Dropout rate for decoder layers.
embed_dropout_rate: Dropout rate for embedding layer.
embed_pad: Embedding padding symbol ID.
"""
def __init__(
self,
vocab_size: int,
embed_size: int = 256,
hidden_size: int = 256,
rnn_type: str = "lstm",
num_layers: int = 1,
dropout_rate: float = 0.0,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
) -> None:
"""Construct a RNNDecoder object."""
super().__init__()
assert check_argument_types()
if rnn_type not in ("lstm", "gru"):
raise ValueError(f"Not supported: rnn_type={rnn_type}")
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
self.rnn = torch.nn.ModuleList(
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
)
for _ in range(1, num_layers):
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
self.dropout_rnn = torch.nn.ModuleList(
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
)
self.dlayers = num_layers
self.dtype = rnn_type
self.output_size = hidden_size
self.vocab_size = vocab_size
self.device = next(self.parameters()).device
self.score_cache = {}
def forward(
self,
labels: torch.Tensor,
label_lens: torch.Tensor,
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
) -> torch.Tensor:
"""Encode source label sequences.
Args:
labels: Label ID sequences. (B, L)
states: Decoder hidden states.
((N, B, D_dec), (N, B, D_dec) or None) or None
Returns:
dec_out: Decoder output sequences. (B, U, D_dec)
"""
if states is None:
states = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out
def rnn_forward(
self,
x: torch.Tensor,
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Encode source label sequences.
Args:
x: RNN input sequences. (B, D_emb)
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
Returns:
x: RNN output sequences. (B, D_dec)
(h_next, c_next): Decoder hidden states.
(N, B, D_dec), (N, B, D_dec) or None)
"""
h_prev, c_prev = state
h_next, c_next = self.init_state(x.size(0))
for layer in range(self.dlayers):
if self.dtype == "lstm":
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
layer
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
else:
x, h_next[layer : layer + 1] = self.rnn[layer](
x, hx=h_prev[layer : layer + 1]
)
x = self.dropout_rnn[layer](x)
return x, (h_next, c_next)
def score(
self,
label: torch.Tensor,
label_sequence: List[int],
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypothesis.
Args:
label: Previous label. (1, 1)
label_sequence: Current label sequence.
dec_state: Previous decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
dec_state: Decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
"""
str_labels = "_".join(map(str, label_sequence))
if str_labels in self.score_cache:
dec_out, dec_state = self.score_cache[str_labels]
else:
dec_embed = self.embed(label)
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
self.score_cache[str_labels] = (dec_out, dec_state)
return dec_out[0], dec_state
def batch_score(
self,
hyps: List[Hypothesis],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
dec_embed = self.embed(labels)
states = self.create_batch_states([h.dec_state for h in hyps])
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out.squeeze(1), states
def set_device(self, device: torch.device) -> None:
"""Set GPU device to use.
Args:
device: Device ID.
"""
self.device = device
def init_state(
self, batch_size: int
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
h_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
if self.dtype == "lstm":
c_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
return (h_n, c_n)
return (h_n, None)
def select_state(
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
idx: State ID to extract.
Returns:
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
"""
return (
states[0][:, idx : idx + 1, :],
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
)
def create_batch_states(
self,
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Create decoder hidden states.
Args:
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
Returns:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
return (
torch.cat([s[0] for s in new_states], dim=1),
torch.cat([s[1] for s in new_states], dim=1)
if self.dtype == "lstm"
else None,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
import torch
from torch.nn import functional as F
from funasr_local.models.decoder.abs_decoder import AbsDecoder
class DenseDecoder(AbsDecoder):
def __init__(
self,
vocab_size,
encoder_output_size,
num_nodes_resnet1: int = 256,
num_nodes_last_layer: int = 256,
batchnorm_momentum: float = 0.5,
):
super(DenseDecoder, self).__init__()
self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1)
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
self.output_dense = torch.nn.Linear(num_nodes_last_layer, vocab_size, bias=False)
def forward(self, features):
embeddings = {}
features = self.resnet1_dense(features)
embeddings["resnet1_dense"] = features
features = F.relu(features)
features = self.resnet1_bn(features)
features = self.resnet2_dense(features)
embeddings["resnet2_dense"] = features
features = F.relu(features)
features = self.resnet2_bn(features)
features = self.output_dense(features)
return features, embeddings

View File

@@ -0,0 +1,766 @@
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Decoder definition."""
from typing import Any
from typing import List
from typing import Sequence
from typing import Tuple
import torch
from torch import nn
from typeguard import check_argument_types
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.modules.attention import MultiHeadedAttention
from funasr_local.modules.dynamic_conv import DynamicConvolution
from funasr_local.modules.dynamic_conv2d import DynamicConvolution2D
from funasr_local.modules.embedding import PositionalEncoding
from funasr_local.modules.layer_norm import LayerNorm
from funasr_local.modules.lightconv import LightweightConvolution
from funasr_local.modules.lightconv2d import LightweightConvolution2D
from funasr_local.modules.mask import subsequent_mask
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr_local.modules.repeat import repeat
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
# Must set by the inheritance
self.decoders = None
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(
memory_mask, (0, padlen), "constant", False
)
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, olens
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
Args:
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
memory: encoded memory, float32 (batch, maxlen_in, feat)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x = self.embed(tgt)
if cache is None:
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
def score(self, ys, state, x):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
logp, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
)
return logp.squeeze(0), state
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
class TransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class ParaformerDecoderSAN(BaseTransformerDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
embeds_id: int = -1,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.embeds_id = embeds_id
self.attention_dim = attention_dim
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(
memory_mask, (0, padlen), "constant", False
)
# x = self.embed(tgt)
x = tgt
embeds_outputs = None
for layer_id, decoder in enumerate(self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, memory_mask
)
if layer_id == self.embeds_id:
embeds_outputs = x
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
if embeds_outputs is not None:
return x, olens, embeds_outputs
else:
return x, olens
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)

View File

@@ -0,0 +1,458 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from typeguard import check_argument_types
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr_local.models.ctc import CTC
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.modules.add_sos_eos import add_sos_eos
from funasr_local.modules.e2e_asr_common import ErrorCalculator
from funasr_local.modules.nets_utils import th_accuracy
from funasr_local.torch_utils.device_funcs import force_gatherable
from funasr_local.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class ESPnetASRModel(AbsESPnetModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.blank_id = 0
self.sos = 1
self.eos = 2
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.interctc_weight = interctc_weight
self.token_list = token_list.copy()
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.encoder = encoder
if not hasattr(self.encoder, "interctc_use_conditioning"):
self.encoder.interctc_use_conditioning = False
if self.encoder.interctc_use_conditioning:
self.encoder.conditioning_layer = torch.nn.Linear(
vocab_size, self.encoder.output_size()
)
self.error_calculator = None
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
# and threw an Exception in the multi-GPU experiment.
# thanks Jeff Farris for pointing out the issue.
if ctc_weight == 1.0:
self.decoder = None
else:
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
feats, feats_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
) # [batch, seqlen, dim]
batch_size = decoder_out.size(0)
decoder_num_class = decoder_out.size(2)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
decoder_out.view(-1, decoder_num_class),
ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction="none",
)
nll = nll.view(batch_size, -1)
nll = nll.sum(dim=1)
assert nll.size(0) == batch_size
return nll
def batchify_nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
GPU memory usage
"""
total_num = encoder_out.size(0)
if total_num <= batch_size:
nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
else:
nll = []
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
batch_ys_pad = ys_pad[start_idx:end_idx, :]
batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
batch_nll = self.nll(
batch_encoder_out,
batch_encoder_out_lens,
batch_ys_pad,
batch_ys_pad_lens,
)
nll.append(batch_nll)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nll)
assert nll.size(0) == total_num
return nll
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc

View File

@@ -0,0 +1,249 @@
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Common functions for ASR."""
import json
import logging
import sys
from itertools import groupby
import numpy as np
import six
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
described in Eq. (50) of S. Watanabe et al
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps:
:param i:
:param M:
:param D_end:
:return:
"""
if len(ended_hyps) == 0:
return False
count = 0
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
for m in six.moves.range(M):
# get ended_hyps with their length is i - m
hyp_length = i - m
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
if len(hyps_same_length) > 0:
best_hyp_same_length = sorted(
hyps_same_length, key=lambda x: x["score"], reverse=True
)[0]
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
count += 1
if count == M:
return True
else:
return False
# TODO(takaaki-hori): add different smoothing methods
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
"""Obtain label distribution for loss smoothing.
:param odim:
:param lsm_type:
:param blank:
:param transcript:
:return:
"""
if transcript is not None:
with open(transcript, "rb") as f:
trans_json = json.load(f)["utts"]
if lsm_type == "unigram":
assert transcript is not None, (
"transcript is required for %s label smoothing" % lsm_type
)
labelcount = np.zeros(odim)
for k, v in trans_json.items():
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
# to avoid an error when there is no text in an uttrance
if len(ids) > 0:
labelcount[ids] += 1
labelcount[odim - 1] = len(transcript) # count <eos>
labelcount[labelcount == 0] = 1 # flooring
labelcount[blank] = 0 # remove counts for blank
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
else:
logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
sys.exit()
return labeldist
def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
"""Return the output size of the VGG frontend.
:param in_channel: input channel size
:param out_channel: output channel size
:return: output size
:rtype int
"""
idim = idim / in_channel
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
return int(idim) * out_channel # numer of channels
class ErrorCalculator(object):
"""Calculate CER and WER for E2E_ASR and CTC models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list:
:param sym_space:
:param sym_blank:
:return:
"""
def __init__(
self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
):
"""Construct an ErrorCalculator object."""
super(ErrorCalculator, self).__init__()
self.report_cer = report_cer
self.report_wer = report_wer
self.char_list = char_list
self.space = sym_space
self.blank = sym_blank
self.idx_blank = self.char_list.index(self.blank)
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad, is_ctc=False):
"""Calculate sentence-level WER/CER score.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:param bool is_ctc: calculate CER score for CTC
:return: sentence-level WER score
:rtype float
:return: sentence-level CER score
:rtype float
"""
cer, wer = None, None
if is_ctc:
return self.calculate_cer_ctc(ys_hat, ys_pad)
elif not self.report_cer and not self.report_wer:
return cer, wer
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
if self.report_cer:
cer = self.calculate_cer(seqs_hat, seqs_true)
if self.report_wer:
wer = self.calculate_wer(seqs_hat, seqs_true)
return cer, wer
def calculate_cer_ctc(self, ys_hat, ys_pad):
"""Calculate sentence-level CER score for CTC.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:return: average sentence-level CER score
:rtype float
"""
import editdistance
cers, char_ref_lens = [], []
for i, y in enumerate(ys_hat):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat, seq_true = [], []
for idx in y_hat:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_hat.append(self.char_list[int(idx)])
for idx in y_true:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_true.append(self.char_list[int(idx)])
hyp_chars = "".join(seq_hat)
ref_chars = "".join(seq_true)
if len(ref_chars) > 0:
cers.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
return cer_ctc
def convert_to_char(self, ys_hat, ys_pad):
"""Convert index to character.
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
:param torch.Tensor seqs_true: reference (batch, seqlen)
:return: token list of prediction
:rtype list
:return: token list of reference
:rtype list
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
return seqs_hat, seqs_true
def calculate_cer(self, seqs_hat, seqs_true):
"""Calculate sentence-level CER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level CER score
:rtype float
"""
import editdistance
char_eds, char_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
return float(sum(char_eds)) / sum(char_ref_lens)
def calculate_wer(self, seqs_hat, seqs_true):
"""Calculate sentence-level WER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level WER score
:rtype float
"""
import editdistance
word_eds, word_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)

View File

@@ -0,0 +1,326 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import logging
import torch
from typeguard import check_argument_types
from funasr_local.modules.e2e_asr_common import ErrorCalculator
from funasr_local.modules.nets_utils import th_accuracy
from funasr_local.modules.add_sos_eos import add_sos_eos
from funasr_local.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr_local.models.ctc import CTC
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.torch_utils.device_funcs import force_gatherable
from funasr_local.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
import pdb
import random
import math
class MFCCA(AbsESPnetModel):
"""
Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
https://arxiv.org/abs/2210.05265
"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
decoder: AbsDecoder,
ctc: CTC,
rnnt_decoder: None,
ctc_weight: float = 0.5,
ignore_id: int = -1,
lsm_weight: float = 0.0,
mask_ratio: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert rnnt_decoder is None, "Not implemented"
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.token_list = token_list.copy()
self.mask_ratio = mask_ratio
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.encoder = encoder
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
# and threw an Exception in the multi-GPU experiment.
# thanks Jeff Farris for pointing out the issue.
if ctc_weight == 1.0:
self.decoder = None
else:
self.decoder = decoder
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
self.rnnt_decoder = rnnt_decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
else:
self.error_calculator = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
#pdb.set_trace()
if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
rate_num = random.random()
#rate_num = 0.1
if(rate_num<=self.mask_ratio):
retain_channel = math.ceil(random.random() *8)
if(retain_channel>1):
speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
else:
speech = speech[:,:,torch.randperm(8)[0]]
#pdb.set_trace()
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# 2a. Attention-decoder branch
if self.ctc_weight == 1.0:
loss_att, acc_att, cer_att, wer_att = None, None, None, None
else:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 2b. CTC branch
if self.ctc_weight == 0.0:
loss_ctc, cer_ctc = None, None
else:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 2c. RNN-T branch
if self.rnnt_decoder is not None:
_ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
stats = dict(
loss=loss.detach(),
loss_att=loss_att.detach() if loss_att is not None else None,
loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
acc=acc_att,
cer=cer_att,
wer=wer_att,
cer_ctc=cer_ctc,
)
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
#pdb.set_trace()
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
if(encoder_out.dim()==4):
assert encoder_out.size(2) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
else:
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
channel_size = 1
return feats, feats_lengths, channel_size
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
if(encoder_out.dim()==4):
encoder_out = encoder_out.mean(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def _calc_rnnt_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
raise NotImplementedError

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,253 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from typeguard import check_argument_types
from funasr_local.models.frontend.wav_frontend import WavFrontendMel23
from funasr_local.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr_local.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr_local.modules.eend_ola.utils.power import generate_mapping_dict
from funasr_local.torch_utils.device_funcs import force_gatherable
from funasr_local.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
pass
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
def pad_attractor(att, max_n_speakers):
C, D = att.shape
if C < max_n_speakers:
att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
return att
class DiarEENDOLAModel(AbsESPnetModel):
"""EEND-OLA diarization model"""
def __init__(
self,
frontend: WavFrontendMel23,
encoder: EENDOLATransformerEncoder,
encoder_decoder_attractor: EncoderDecoderAttractor,
n_units: int = 256,
max_n_speaker: int = 8,
attractor_loss_weight: float = 1.0,
mapping_dict=None,
**kwargs,
):
assert check_argument_types()
super().__init__()
self.frontend = frontend
self.enc = encoder
self.eda = encoder_decoder_attractor
self.attractor_loss_weight = attractor_loss_weight
self.max_n_speaker = max_n_speaker
if mapping_dict is None:
mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
self.mapping_dict = mapping_dict
# PostNet
self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
def forward_encoder(self, xs, ilens):
xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
pad_shape = xs.shape
xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
emb = self.enc(xs, xs_mask)
emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
return emb
def forward_post_net(self, logits, ilens):
maxlen = torch.max(ilens).to(torch.int).item()
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
outputs, (_, _) = self.postnet(logits)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
outputs = [self.output_layer(output) for output in outputs]
return outputs
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def estimate_sequential(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
n_speakers: int = None,
shuffle: bool = True,
threshold: float = 0.5,
**kwargs):
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
emb = self.forward_encoder(speech, speech_lengths)
if shuffle:
orders = [np.arange(e.shape[0]) for e in emb]
for order in orders:
np.random.shuffle(order)
attractors, probs = self.eda.estimate(
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
else:
attractors, probs = self.eda.estimate(emb)
attractors_active = []
for p, att, e in zip(probs, attractors, emb):
if n_speakers and n_speakers >= 0:
att = att[:n_speakers, ]
attractors_active.append(att)
elif threshold is not None:
silence = torch.nonzero(p < threshold)[0]
n_spk = silence[0] if silence.size else None
att = att[:n_spk, ]
attractors_active.append(att)
else:
NotImplementedError('n_speakers or threshold has to be given.')
raw_n_speakers = [att.shape[0] for att in attractors_active]
attractors = [
pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
for att in attractors_active]
ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
logits = self.forward_post_net(ys, speech_lengths)
ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
zip(logits, raw_n_speakers)]
return ys, emb, attractors, raw_n_speakers
def recover_y_from_powerlabel(self, logit, n_speaker):
pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
for i in oov_index:
if i > 0:
pred[i] = pred[i - 1]
else:
pred[i] = 0
pred = [self.inv_mapping_func(i) for i in pred]
decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
decisions = torch.from_numpy(
np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
torch.float32)
decisions = decisions[:, :n_speaker]
return decisions
def inv_mapping_func(self, label):
if not isinstance(label, int):
label = int(label)
if label in self.mapping_dict['label2dec'].keys():
num = self.mapping_dict['label2dec'][label]
else:
num = -1
return num
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
pass

View File

@@ -0,0 +1,494 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from contextlib import contextmanager
from distutils.version import LooseVersion
from itertools import permutations
from typing import Dict
from typing import Optional
from typing import Tuple, List
import numpy as np
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import to_device
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.torch_utils.device_funcs import force_gatherable
from funasr_local.train.abs_espnet_model import AbsESPnetModel
from funasr_local.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr_local.utils.misc import int2vec
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class DiarSondModel(AbsESPnetModel):
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
def __init__(
self,
vocab_size: int,
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
encoder: torch.nn.Module,
speaker_encoder: Optional[torch.nn.Module],
ci_scorer: torch.nn.Module,
cd_scorer: Optional[torch.nn.Module],
decoder: torch.nn.Module,
token_list: list,
lsm_weight: float = 0.1,
length_normalized_loss: bool = False,
max_spk_num: int = 16,
label_aggregator: Optional[torch.nn.Module] = None,
normalize_speech_speaker: bool = False,
ignore_id: int = -1,
speaker_discrimination_loss_weight: float = 1.0,
inter_score_loss_weight: float = 0.0,
inputs_type: str = "raw",
):
assert check_argument_types()
super().__init__()
self.encoder = encoder
self.speaker_encoder = speaker_encoder
self.ci_scorer = ci_scorer
self.cd_scorer = cd_scorer
self.normalize = normalize
self.frontend = frontend
self.specaug = specaug
self.label_aggregator = label_aggregator
self.decoder = decoder
self.token_list = token_list
self.max_spk_num = max_spk_num
self.normalize_speech_speaker = normalize_speech_speaker
self.ignore_id = ignore_id
self.criterion_diar = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
self.pse_embedding = self.generate_pse_embedding()
self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
self.inter_score_loss_weight = inter_score_loss_weight
self.forward_steps = 0
self.inputs_type = inputs_type
def generate_pse_embedding(self):
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
for idx, pse_label in enumerate(self.token_list):
emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float)
embedding[idx] = emb
return torch.from_numpy(embedding)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor = None,
profile: torch.Tensor = None,
profile_lengths: torch.Tensor = None,
binary_labels: torch.Tensor = None,
binary_labels_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
Args:
speech: (Batch, samples) or (Batch, frames, input_size)
speech_lengths: (Batch,) default None for chunk interator,
because the chunk-iterator does not
have the speech_lengths returned.
see in
espnet2/iterators/chunk_iter_factory.py
profile: (Batch, N_spk, dim)
profile_lengths: (Batch,)
binary_labels: (Batch, frames, max_spk_num)
binary_labels_lengths: (Batch,)
"""
assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
batch_size = speech.shape[0]
self.forward_steps = self.forward_steps + 1
if self.pse_embedding.device != speech.device:
self.pse_embedding = self.pse_embedding.to(speech.device)
self.power_weight = self.power_weight.to(speech.device)
self.int_token_arr = self.int_token_arr.to(speech.device)
# 1. Network forward
pred, inter_outputs = self.prediction_forward(
speech, speech_lengths,
profile, profile_lengths,
return_inter_outputs=True
)
(speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
# 2. Aggregate time-domain labels to match forward outputs
if self.label_aggregator is not None:
binary_labels, binary_labels_lengths = self.label_aggregator(
binary_labels, binary_labels_lengths
)
# 2. Calculate power-set encoding (PSE) labels
raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
# If encoder uses conv* as input_layer (i.e., subsampling),
# the sequence length of 'pred' might be slightly less than the
# length of 'spk_labels'. Here we force them to be equal.
length_diff_tolerance = 2
length_diff = abs(pse_labels.shape[1] - pred.shape[1])
if length_diff <= length_diff_tolerance:
min_len = min(pred.shape[1], pse_labels.shape[1])
pse_labels = pse_labels[:, :min_len]
pred = pred[:, :min_len]
cd_score = cd_score[:, :min_len]
ci_score = ci_score[:, :min_len]
loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device)
loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
+ self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
(
correct,
num_frames,
speech_scored,
speech_miss,
speech_falarm,
speaker_scored,
speaker_miss,
speaker_falarm,
speaker_error,
) = self.calc_diarization_error(
pred=F.embedding(pred.argmax(dim=2) * (~label_mask), self.pse_embedding),
label=F.embedding(pse_labels * (~label_mask), self.pse_embedding),
length=binary_labels_lengths
)
if speech_scored > 0 and num_frames > 0:
sad_mr, sad_fr, mi, fa, cf, acc, der = (
speech_miss / speech_scored,
speech_falarm / speech_scored,
speaker_miss / speaker_scored,
speaker_falarm / speaker_scored,
speaker_error / speaker_scored,
correct / num_frames,
(speaker_miss + speaker_falarm + speaker_error) / speaker_scored,
)
else:
sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0
stats = dict(
loss=loss.detach(),
loss_diar=loss_diar.detach() if loss_diar is not None else None,
loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
sad_mr=sad_mr,
sad_fr=sad_fr,
mi=mi,
fa=fa,
cf=cf,
acc=acc,
der=der,
forward_steps=self.forward_steps,
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def classification_loss(
self,
predictions: torch.Tensor,
labels: torch.Tensor,
prediction_lengths: torch.Tensor
) -> torch.Tensor:
mask = make_pad_mask(prediction_lengths, maxlen=labels.shape[1])
pad_labels = labels.masked_fill(
mask.to(predictions.device),
value=self.ignore_id
)
loss = self.criterion_diar(predictions.contiguous(), pad_labels)
return loss
def speaker_discrimination_loss(
self,
profile: torch.Tensor,
profile_lengths: torch.Tensor
) -> torch.Tensor:
profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float() # (B, N, 1)
mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2)) # (B, N, N)
mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0).to(mask))
eps = 1e-12
coding_norm = torch.linalg.norm(
profile * profile_mask + (1 - profile_mask) * eps,
dim=2, keepdim=True
) * profile_mask
# profile: Batch, N, dim
cos_theta = F.cosine_similarity(profile.unsqueeze(2), profile.unsqueeze(1), dim=-1, eps=eps) * mask
cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps)
loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum()
return loss
def calculate_multi_labels(self, pse_labels, pse_labels_lengths):
mask = make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1])
padding_labels = pse_labels.masked_fill(
mask.to(pse_labels.device),
value=0
).to(pse_labels)
multi_labels = F.embedding(padding_labels, self.pse_embedding)
return multi_labels
def internal_score_loss(
self,
cd_score: torch.Tensor,
ci_score: torch.Tensor,
pse_labels: torch.Tensor,
pse_labels_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths)
ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths)
cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths)
return ci_loss, cd_loss
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
profile: torch.Tensor = None,
profile_lengths: torch.Tensor = None,
binary_labels: torch.Tensor = None,
binary_labels_lengths: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode_speaker(
self,
profile: torch.Tensor,
profile_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
with autocast(False):
if profile.shape[1] < self.max_spk_num:
profile = F.pad(profile, [0, 0, 0, self.max_spk_num-profile.shape[1], 0, 0], "constant", 0.0)
profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float()
profile = F.normalize(profile, dim=2)
if self.speaker_encoder is not None:
profile = self.speaker_encoder(profile, profile_lengths)[0]
return profile * profile_mask, profile_lengths
else:
return profile, profile_lengths
def encode_speech(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.encoder is not None and self.inputs_type == "raw":
speech, speech_lengths = self.encode(speech, speech_lengths)
speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
return speech * speech_mask, speech_lengths
else:
return speech, speech_lengths
@staticmethod
def concate_speech_ivc(
speech: torch.Tensor,
ivc: torch.Tensor
) -> torch.Tensor:
nn, tt = ivc.shape[1], speech.shape[1]
speech = speech.unsqueeze(dim=1) # B x 1 x T x D
speech = speech.expand(-1, nn, -1, -1) # B x N x T x D
ivc = ivc.unsqueeze(dim=2) # B x N x 1 x D
ivc = ivc.expand(-1, -1, tt, -1) # B x N x T x D
sd_in = torch.cat([speech, ivc], dim=3) # B x N x T x 2D
return sd_in
def calc_similarity(
self,
speech_encoder_outputs: torch.Tensor,
speaker_encoder_outputs: torch.Tensor,
seq_len: torch.Tensor = None,
spk_len: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1]
d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2]
if self.normalize_speech_speaker:
speech_encoder_outputs = F.normalize(speech_encoder_outputs, dim=2)
speaker_encoder_outputs = F.normalize(speaker_encoder_outputs, dim=2)
ge_in = self.concate_speech_ivc(speech_encoder_outputs, speaker_encoder_outputs)
ge_in = torch.reshape(ge_in, [bb * self.max_spk_num, tt, d_sph + d_spk])
ge_len = seq_len.unsqueeze(1).expand(-1, self.max_spk_num)
ge_len = torch.reshape(ge_len, [bb * self.max_spk_num])
cd_simi = self.cd_scorer(ge_in, ge_len)[0]
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
if isinstance(self.ci_scorer, AbsEncoder):
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
else:
ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
return ci_simi, cd_simi
def post_net_forward(self, simi, seq_len):
logits = self.decoder(simi, seq_len)[0]
return logits
def prediction_forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
profile: torch.Tensor,
profile_lengths: torch.Tensor,
return_inter_outputs: bool = False,
) -> [torch.Tensor, Optional[list]]:
# speech encoding
speech, speech_lengths = self.encode_speech(speech, speech_lengths)
# speaker encoding
profile, profile_lengths = self.encode_speaker(profile, profile_lengths)
# calculating similarity
ci_simi, cd_simi = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
similarity = torch.cat([cd_simi, ci_simi], dim=2)
# post net forward
logits = self.post_net_forward(similarity, speech_lengths)
if return_inter_outputs:
return logits, [(speech, speech_lengths), (profile, profile_lengths), (ci_simi, cd_simi)]
return logits
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim)
encoder_outputs = self.encoder(feats, feats_lengths)
encoder_out, encoder_out_lens = encoder_outputs[:2]
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = speech.shape[0]
speech_lengths = (
speech_lengths
if speech_lengths is not None
else torch.ones(batch_size).int() * speech.shape[1]
)
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
@staticmethod
def calc_diarization_error(pred, label, length):
# Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
(batch_size, max_len, num_output) = label.size()
# mask the padding part
mask = ~make_pad_mask(length, maxlen=label.shape[1]).unsqueeze(-1).numpy()
# pred and label have the shape (batch_size, max_len, num_output)
label_np = label.data.cpu().numpy().astype(int)
pred_np = (pred.data.cpu().numpy() > 0).astype(int)
label_np = label_np * mask
pred_np = pred_np * mask
length = length.data.cpu().numpy()
# compute speech activity detection error
n_ref = np.sum(label_np, axis=2)
n_sys = np.sum(pred_np, axis=2)
speech_scored = float(np.sum(n_ref > 0))
speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0)))
speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0)))
# compute speaker diarization error
speaker_scored = float(np.sum(n_ref))
speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0)))
speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0)))
n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2)
speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map))
correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output)
num_frames = np.sum(length)
return (
correct,
num_frames,
speech_scored,
speech_miss,
speech_falarm,
speaker_scored,
speaker_miss,
speaker_falarm,
speaker_error,
)

View File

@@ -0,0 +1,274 @@
"""
Author: Speech Lab, Alibaba Group, China
"""
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from typeguard import check_argument_types
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr_local.models.ctc import CTC
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.modules.add_sos_eos import add_sos_eos
from funasr_local.modules.e2e_asr_common import ErrorCalculator
from funasr_local.modules.nets_utils import th_accuracy
from funasr_local.torch_utils.device_funcs import force_gatherable
from funasr_local.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class ESPnetSVModel(AbsESPnetModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
pooling_layer: torch.nn.Module,
decoder: AbsDecoder,
):
assert check_argument_types()
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.vocab_size = vocab_size
self.token_list = token_list.copy()
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.encoder = encoder
self.pooling_layer = pooling_layer
self.decoder = decoder
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
loss_transducer, cer_transducer, wer_transducer = None, None, None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
if self.use_transducer_decoder:
# 2a. Transducer decoder branch
(
loss_transducer,
cer_transducer,
wer_transducer,
) = self._calc_transducer_loss(
encoder_out,
encoder_out_lens,
text,
)
if loss_ctc is not None:
loss = loss_transducer + (self.ctc_weight * loss_ctc)
else:
loss = loss_transducer
# Collect Transducer branch stats
stats["loss_transducer"] = (
loss_transducer.detach() if loss_transducer is not None else None
)
stats["cer_transducer"] = cer_transducer
stats["wer_transducer"] = wer_transducer
else:
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim) -> (Batch, Channel, Length2, Dim2)
encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths

View File

@@ -0,0 +1,175 @@
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import numpy as np
from typeguard import check_argument_types
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.predictor.cif import mae_loss
from funasr_local.modules.add_sos_eos import add_sos_eos
from funasr_local.modules.nets_utils import make_pad_mask, pad_list
from funasr_local.torch_utils.device_funcs import force_gatherable
from funasr_local.train.abs_espnet_model import AbsESPnetModel
from funasr_local.models.predictor.cif import CifPredictorV3
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class TimestampPredictor(AbsESPnetModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
frontend: Optional[AbsFrontend],
encoder: AbsEncoder,
predictor: CifPredictorV3,
predictor_bias: int = 0,
token_list=None,
):
assert check_argument_types()
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.frontend = frontend
self.encoder = encoder
self.encoder.interctc_use_conditioning = False
self.predictor = predictor
self.predictor_bias = predictor_bias
self.criterion_pre = mae_loss()
self.token_list = token_list
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
speech = speech[:, :speech_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
if self.predictor_bias == 1:
_, text = add_sos_eos(text, 1, 2, -1)
text_lengths = text_lengths + self.predictor_bias
_, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
loss = loss_pre
stats = dict()
# Collect Attn branch stats
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
encoder_out_mask,
token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,665 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import torch
from torch import nn
import math
from funasr_local.models.encoder.fsmn_encoder import FSMN
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
kVadInStateInSpeechSegment = 2
kVadInStateEndPointDetected = 3
class FrameState(Enum):
kFrameStateInvalid = -1
kFrameStateSpeech = 1
kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
kChangeStateSpeech2Speech = 0
kChangeStateSpeech2Sil = 1
kChangeStateSil2Sil = 2
kChangeStateSil2Speech = 3
kChangeStateNoBegin = 4
kChangeStateInvalid = 5
class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
sample_rate: int = 16000,
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
snr_mode: int = 0,
max_end_silence_time: int = 800,
max_start_silence_time: int = 3000,
do_start_point_detection: bool = True,
do_end_point_detection: bool = True,
window_size_ms: int = 200,
sil_to_speech_time_thres: int = 150,
speech_to_sil_time_thres: int = 150,
speech_2_noise_ratio: float = 1.0,
do_extend: int = 1,
lookback_time_start_point: int = 200,
lookahead_time_end_point: int = 100,
max_single_segment_time: int = 60000,
nn_eval_block_size: int = 8,
dcd_block_size: int = 4,
snr_thres: int = -100.0,
noise_frame_num_used_for_snr: int = 100,
decibel_thres: int = -100.0,
speech_noise_thres: float = 0.6,
fe_prior_thres: float = 1e-4,
silence_pdf_num: int = 1,
sil_pdf_ids: List[int] = [0],
speech_noise_thresh_low: float = -0.1,
speech_noise_thresh_high: float = 0.3,
output_frame_probs: bool = False,
frame_in_ms: int = 10,
frame_length_ms: int = 25,
):
self.sample_rate = sample_rate
self.detect_mode = detect_mode
self.snr_mode = snr_mode
self.max_end_silence_time = max_end_silence_time
self.max_start_silence_time = max_start_silence_time
self.do_start_point_detection = do_start_point_detection
self.do_end_point_detection = do_end_point_detection
self.window_size_ms = window_size_ms
self.sil_to_speech_time_thres = sil_to_speech_time_thres
self.speech_to_sil_time_thres = speech_to_sil_time_thres
self.speech_2_noise_ratio = speech_2_noise_ratio
self.do_extend = do_extend
self.lookback_time_start_point = lookback_time_start_point
self.lookahead_time_end_point = lookahead_time_end_point
self.max_single_segment_time = max_single_segment_time
self.nn_eval_block_size = nn_eval_block_size
self.dcd_block_size = dcd_block_size
self.snr_thres = snr_thres
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
self.decibel_thres = decibel_thres
self.speech_noise_thres = speech_noise_thres
self.fe_prior_thres = fe_prior_thres
self.silence_pdf_num = silence_pdf_num
self.sil_pdf_ids = sil_pdf_ids
self.speech_noise_thresh_low = speech_noise_thresh_low
self.speech_noise_thresh_high = speech_noise_thresh_high
self.output_frame_probs = output_frame_probs
self.frame_in_ms = frame_in_ms
self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
def Reset(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
class E2EVadFrameProb(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
self.score = 0.0
self.frame_id = 0
self.frm_state = 0
class WindowDetector(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
self.sil_to_speech_time = sil_to_speech_time
self.speech_to_sil_time = speech_to_sil_time
self.frame_size_ms = frame_size_ms
self.win_size_frame = int(window_size_ms / frame_size_ms)
self.win_sum = 0
self.win_state = [0] * self.win_size_frame # 初始化窗
self.cur_win_pos = 0
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def Reset(self) -> None:
self.cur_win_pos = 0
self.win_sum = 0
self.win_state = [0] * self.win_size_frame
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def GetWinSize(self) -> int:
return int(self.win_size_frame)
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
cur_frame_state = FrameState.kFrameStateSil
if frameState == FrameState.kFrameStateSpeech:
cur_frame_state = 1
elif frameState == FrameState.kFrameStateSil:
cur_frame_state = 0
else:
return AudioChangeState.kChangeStateInvalid
self.win_sum -= self.win_state[self.cur_win_pos]
self.win_sum += cur_frame_state
self.win_state[self.cur_win_pos] = cur_frame_state
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSpeech
return AudioChangeState.kChangeStateSil2Speech
if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSil
return AudioChangeState.kChangeStateSpeech2Sil
if self.pre_frame_state == FrameState.kFrameStateSil:
return AudioChangeState.kChangeStateSil2Sil
if self.pre_frame_state == FrameState.kFrameStateSpeech:
return AudioChangeState.kChangeStateSpeech2Speech
return AudioChangeState.kChangeStateInvalid
def FrameSizeMs(self) -> int:
return int(self.frame_size_ms)
class E2EVadModel(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms)
self.encoder = encoder
# init variables
self.is_final = False
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.max_time_out = False
self.decibel = []
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.ResetDetection()
self.frontend = frontend
def AllResetDetection(self):
self.is_final = False
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.max_time_out = False
self.decibel = []
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.ResetDetection()
def ResetDetection(self):
self.continous_silence_frame_count = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.windows_detector.Reset()
self.sil_frame = 0
self.frame_probs = []
def ComputeDecibel(self) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if self.data_buf_all is None:
self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0]
self.data_buf = self.data_buf_all
else:
self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0]))
for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
self.decibel.append(
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
0.000001))
def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
scores = self.encoder(feats, in_cache).to('cpu') # return B * T * D
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
if self.scores is None:
self.scores = scores # the first calculation
else:
self.scores = torch.cat((self.scores, scores), dim=1)
def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
while self.data_buf_start_frame < frame_idx:
if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
self.data_buf_start_frame += 1
self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
self.PopDataBufTillFrame(start_frm)
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
if last_frm_is_end_point:
extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
expected_sample_number += int(extra_sample)
if end_point_is_sent_end:
expected_sample_number = max(expected_sample_number, len(self.data_buf))
if len(self.data_buf) < expected_sample_number:
print('error in calling pop data_buf\n')
if len(self.output_data_buf) == 0 or first_frm_is_start_point:
self.output_data_buf.append(E2EVadSpeechBufWithDoa())
self.output_data_buf[-1].Reset()
self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
self.output_data_buf[-1].doa = 0
cur_seg = self.output_data_buf[-1]
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print('warning\n')
out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
data_to_pop = 0
if end_point_is_sent_end:
data_to_pop = expected_sample_number
else:
data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if data_to_pop > len(self.data_buf):
print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
data_to_pop = len(self.data_buf)
expected_sample_number = len(self.data_buf)
cur_seg.doa = 0
for sample_cpy_out in range(0, data_to_pop):
# cur_seg.buffer[out_pos ++] = data_buf_.back();
out_pos += 1
for sample_cpy_out in range(data_to_pop, expected_sample_number):
# cur_seg.buffer[out_pos++] = data_buf_.back()
out_pos += 1
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print('Something wrong with the VAD algorithm\n')
self.data_buf_start_frame += frm_cnt
cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
if first_frm_is_start_point:
cur_seg.contain_seg_start_point = True
if last_frm_is_end_point:
cur_seg.contain_seg_end_point = True
def OnSilenceDetected(self, valid_frame: int):
self.lastest_confirmed_silence_frame = valid_frame
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame)
# silence_detected_callback_
# pass
def OnVoiceDetected(self, valid_frame: int) -> None:
self.latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
if self.vad_opts.do_start_point_detection:
pass
if self.confirmed_start_frame != -1:
print('not reset vad properly\n')
else:
self.confirmed_start_frame = start_frame
if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t)
if self.vad_opts.do_end_point_detection:
pass
if self.confirmed_end_frame != -1:
print('not reset vad properly\n')
else:
self.confirmed_end_frame = end_frame
if not fake_result:
self.sil_frame = 0
self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
self.number_end_time_detected += 1
def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
if is_final_frame:
self.OnVoiceEnd(cur_frm_idx, False, True)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
def GetLatency(self) -> int:
return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
def LatencyFrmNumAtStartPoint(self) -> int:
vad_latency = self.windows_detector.GetWinSize()
if self.vad_opts.do_extend:
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
def GetFrameState(self, t: int) -> FrameState:
frame_state = FrameState.kFrameStateInvalid
cur_decibel = self.decibel[t]
cur_snr = cur_decibel - self.noise_average_decibel
# for each frame, calc log posterior probability of each state
if cur_decibel < self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSil
self.DetectOneFrame(frame_state, t, False)
return frame_state
sum_score = 0.0
noise_prob = 0.0
assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(self.sil_pdf_ids) > 0:
assert len(self.scores) == 1 # 只支持batch_size = 1的测试
sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
sum_score = sum(sil_pdf_scores)
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
total_score = 1.0
sum_score = total_score - sum_score
speech_prob = math.log(sum_score)
if self.vad_opts.output_frame_probs:
frame_prob = E2EVadFrameProb()
frame_prob.noise_prob = noise_prob
frame_prob.speech_prob = speech_prob
frame_prob.score = sum_score
frame_prob.frame_id = t
self.frame_probs.append(frame_prob)
if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSpeech
else:
frame_state = FrameState.kFrameStateSil
else:
frame_state = FrameState.kFrameStateSil
if self.noise_average_decibel < -99.9:
self.noise_average_decibel = cur_decibel
else:
self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
self.vad_opts.noise_frame_num_used_for_snr
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
self.ComputeScores(feats, in_cache)
if not is_final:
self.DetectCommonFrames()
else:
self.DetectLastFrames()
segments = []
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
i].contain_seg_end_point):
continue
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment_batch.append(segment)
self.output_data_buf_offset += 1 # need update this parameter
if segment_batch:
segments.append(segment_batch)
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
return segments, in_cache
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False, max_end_sil: int = 800
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
self.ComputeScores(feats, in_cache)
self.ComputeDecibel()
if not is_final:
self.DetectCommonFrames()
else:
self.DetectLastFrames()
segments = []
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if not self.output_data_buf[i].contain_seg_start_point:
continue
if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
continue
start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
if self.output_data_buf[i].contain_seg_end_point:
end_ms = self.output_data_buf[i].end_ms
self.next_seg = True
self.output_data_buf_offset += 1
else:
end_ms = -1
self.next_seg = False
segment = [start_ms, end_ms]
segment_batch.append(segment)
if segment_batch:
segments.append(segment_batch)
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
return segments, in_cache
def DetectCommonFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
return 0
def DetectLastFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
if i != 0:
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
else:
self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
return 0
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
tmp_cur_frm_state = FrameState.kFrameStateSpeech
else:
tmp_cur_frm_state = FrameState.kFrameStateSil
elif cur_frm_state == FrameState.kFrameStateSil:
tmp_cur_frm_state = FrameState.kFrameStateSil
state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
frm_shift_in_ms = self.vad_opts.frame_in_ms
if AudioChangeState.kChangeStateSil2Speech == state_change:
silence_frame_count = self.continous_silence_frame_count
self.continous_silence_frame_count = 0
self.pre_end_silence_detected = False
start_frame = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
self.OnVoiceStart(start_frame)
self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
for t in range(start_frame + 1, cur_frm_idx + 1):
self.OnVoiceDetected(t)
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
self.OnVoiceDetected(t)
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
pass
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.max_time_out = True
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSil2Sil == state_change:
self.continous_silence_frame_count += 1
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
# silence timeout, return zero length decision
if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
or (is_final_frame and self.number_end_time_detected == 0):
for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
self.OnSilenceDetected(t)
self.OnVoiceStart(0, True)
self.OnVoiceEnd(0, True, False);
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
else:
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
if self.vad_opts.do_extend:
lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
lookback_frame -= 1
lookback_frame = max(0, lookback_frame)
self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif self.vad_opts.do_extend and not is_final_frame:
if self.continous_silence_frame_count <= int(
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
self.ResetDetection()

View File

View File

@@ -0,0 +1,21 @@
from abc import ABC
from abc import abstractmethod
from typing import Optional
from typing import Tuple
import torch
class AbsEncoder(torch.nn.Module, ABC):
@abstractmethod
def output_size(self) -> int:
raise NotImplementedError
@abstractmethod
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,577 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.modules.data2vec.data_utils import compute_mask_indices
from funasr_local.modules.data2vec.ema_module import EMAModule
from funasr_local.modules.data2vec.grad_multiply import GradMultiply
from funasr_local.modules.data2vec.wav2vec2 import (
ConvFeatureExtractionModel,
TransformerEncoder,
)
from funasr_local.modules.nets_utils import make_pad_mask
def get_annealed_rate(start, end, curr_step, total_steps):
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
class Data2VecEncoder(AbsEncoder):
def __init__(
self,
# for ConvFeatureExtractionModel
input_size: int = None,
extractor_mode: str = None,
conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
# for Transformer Encoder
## model architecture
layer_type: str = "transformer",
layer_norm_first: bool = False,
encoder_layers: int = 12,
encoder_embed_dim: int = 768,
encoder_ffn_embed_dim: int = 3072,
encoder_attention_heads: int = 12,
activation_fn: str = "gelu",
## dropouts
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.0,
encoder_layerdrop: float = 0.0,
dropout_input: float = 0.0,
dropout_features: float = 0.0,
## grad settings
feature_grad_mult: float = 1.0,
## masking
mask_prob: float = 0.65,
mask_length: int = 10,
mask_selection: str = "static",
mask_other: int = 0,
no_mask_overlap: bool = False,
mask_min_space: int = 1,
require_same_masks: bool = True, # if set as True, collate_fn should be clipping
mask_dropout: float = 0.0,
## channel masking
mask_channel_length: int = 10,
mask_channel_prob: float = 0.0,
mask_channel_before: bool = False,
mask_channel_selection: str = "static",
mask_channel_other: int = 0,
no_mask_channel_overlap: bool = False,
mask_channel_min_space: int = 1,
## positional embeddings
conv_pos: int = 128,
conv_pos_groups: int = 16,
pos_conv_depth: int = 1,
max_positions: int = 100000,
# EMA module
average_top_k_layers: int = 8,
layer_norm_target_layer: bool = False,
instance_norm_target_layer: bool = False,
instance_norm_targets: bool = False,
layer_norm_targets: bool = False,
batch_norm_target_layer: bool = False,
group_norm_target_layer: bool = False,
ema_decay: float = 0.999,
ema_end_decay: float = 0.9999,
ema_anneal_end_step: int = 100000,
ema_transformer_only: bool = True,
ema_layers_only: bool = True,
min_target_var: float = 0.1,
min_pred_var: float = 0.01,
# Loss
loss_beta: float = 0.0,
loss_scale: float = None,
# FP16 optimization
required_seq_len_multiple: int = 2,
):
assert check_argument_types()
super().__init__()
# ConvFeatureExtractionModel
self.conv_feature_layers = conv_feature_layers
feature_enc_layers = eval(conv_feature_layers)
self.extractor_embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=extractor_mode,
in_d=input_size,
)
# Transformer Encoder
## model architecture
self.layer_type = layer_type
self.layer_norm_first = layer_norm_first
self.encoder_layers = encoder_layers
self.encoder_embed_dim = encoder_embed_dim
self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
self.encoder_attention_heads = encoder_attention_heads
self.activation_fn = activation_fn
## dropout
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.encoder_layerdrop = encoder_layerdrop
self.dropout_input = dropout_input
self.dropout_features = dropout_features
## grad settings
self.feature_grad_mult = feature_grad_mult
## masking
self.mask_prob = mask_prob
self.mask_length = mask_length
self.mask_selection = mask_selection
self.mask_other = mask_other
self.no_mask_overlap = no_mask_overlap
self.mask_min_space = mask_min_space
self.require_same_masks = require_same_masks # if set as True, collate_fn should be clipping
self.mask_dropout = mask_dropout
## channel masking
self.mask_channel_length = mask_channel_length
self.mask_channel_prob = mask_channel_prob
self.mask_channel_before = mask_channel_before
self.mask_channel_selection = mask_channel_selection
self.mask_channel_other = mask_channel_other
self.no_mask_channel_overlap = no_mask_channel_overlap
self.mask_channel_min_space = mask_channel_min_space
## positional embeddings
self.conv_pos = conv_pos
self.conv_pos_groups = conv_pos_groups
self.pos_conv_depth = pos_conv_depth
self.max_positions = max_positions
self.mask_emb = nn.Parameter(torch.FloatTensor(self.encoder_embed_dim).uniform_())
self.encoder = TransformerEncoder(
dropout=self.dropout,
encoder_embed_dim=self.encoder_embed_dim,
required_seq_len_multiple=required_seq_len_multiple,
pos_conv_depth=self.pos_conv_depth,
conv_pos=self.conv_pos,
conv_pos_groups=self.conv_pos_groups,
# transformer layers
layer_type=self.layer_type,
encoder_layers=self.encoder_layers,
encoder_ffn_embed_dim=self.encoder_ffn_embed_dim,
encoder_attention_heads=self.encoder_attention_heads,
attention_dropout=self.attention_dropout,
activation_dropout=self.activation_dropout,
activation_fn=self.activation_fn,
layer_norm_first=self.layer_norm_first,
encoder_layerdrop=self.encoder_layerdrop,
max_positions=self.max_positions,
)
## projections and dropouts
self.post_extract_proj = nn.Linear(self.extractor_embed, self.encoder_embed_dim)
self.dropout_input = nn.Dropout(self.dropout_input)
self.dropout_features = nn.Dropout(self.dropout_features)
self.layer_norm = torch.nn.LayerNorm(self.extractor_embed)
self.final_proj = nn.Linear(self.encoder_embed_dim, self.encoder_embed_dim)
# EMA module
self.average_top_k_layers = average_top_k_layers
self.layer_norm_target_layer = layer_norm_target_layer
self.instance_norm_target_layer = instance_norm_target_layer
self.instance_norm_targets = instance_norm_targets
self.layer_norm_targets = layer_norm_targets
self.batch_norm_target_layer = batch_norm_target_layer
self.group_norm_target_layer = group_norm_target_layer
self.ema_decay = ema_decay
self.ema_end_decay = ema_end_decay
self.ema_anneal_end_step = ema_anneal_end_step
self.ema_transformer_only = ema_transformer_only
self.ema_layers_only = ema_layers_only
self.min_target_var = min_target_var
self.min_pred_var = min_pred_var
self.ema = None
# Loss
self.loss_beta = loss_beta
self.loss_scale = loss_scale
# FP16 optimization
self.required_seq_len_multiple = required_seq_len_multiple
self.num_updates = 0
logging.info("Data2VecEncoder settings: {}".format(self.__dict__))
def make_ema_teacher(self):
skip_keys = set()
if self.ema_layers_only:
self.ema_transformer_only = True
for k, _ in self.encoder.pos_conv.named_parameters():
skip_keys.add(f"pos_conv.{k}")
self.ema = EMAModule(
self.encoder if self.ema_transformer_only else self,
ema_decay=self.ema_decay,
ema_fp32=True,
skip_keys=skip_keys,
)
def set_num_updates(self, num_updates):
if self.ema is None and self.final_proj is not None:
logging.info("Making EMA Teacher")
self.make_ema_teacher()
elif self.training and self.ema is not None:
if self.ema_decay != self.ema_end_decay:
if num_updates >= self.ema_anneal_end_step:
decay = self.ema_end_decay
else:
decay = get_annealed_rate(
self.ema_decay,
self.ema_end_decay,
num_updates,
self.ema_anneal_end_step,
)
self.ema.set_decay(decay)
if self.ema.get_decay() < 1:
self.ema.step(self.encoder if self.ema_transformer_only else self)
self.num_updates = num_updates
def apply_mask(
self,
x,
padding_mask,
mask_indices=None,
mask_channel_indices=None,
):
B, T, C = x.shape
if self.mask_channel_prob > 0 and self.mask_channel_before:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
if self.mask_prob > 0:
if mask_indices is None:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=1,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
require_same_masks=self.require_same_masks,
mask_dropout=self.mask_dropout,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.mask_channel_prob > 0 and not self.mask_channel_before:
if mask_channel_indices is None:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
return torch.floor((input_length - kernel_size).to(torch.float32) / stride + 1)
conv_cfg_list = eval(self.conv_feature_layers)
for i in range(len(conv_cfg_list)):
input_lengths = _conv_out_length(
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
)
return input_lengths.to(torch.long)
def forward(
self,
xs_pad,
ilens=None,
mask=False,
features_only=True,
layer=None,
mask_indices=None,
mask_channel_indices=None,
padding_count=None,
):
# create padding_mask by ilens
if ilens is not None:
padding_mask = make_pad_mask(lengths=ilens).to(xs_pad.device)
else:
padding_mask = None
features = xs_pad
if self.feature_grad_mult > 0:
features = self.feature_extractor(features)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(features)
features = features.transpose(1, 2)
features = self.layer_norm(features)
orig_padding_mask = padding_mask
if padding_mask is not None:
input_lengths = (1 - padding_mask.long()).sum(-1)
# apply conv formula to get real output_lengths
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
padding_mask = torch.zeros(
features.shape[:2], dtype=features.dtype, device=features.device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
padding_mask[
(
torch.arange(padding_mask.shape[0], device=padding_mask.device),
output_lengths - 1,
)
] = 1
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
else:
padding_mask = None
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
pre_encoder_features = None
if self.ema_transformer_only:
pre_encoder_features = features.clone()
features = self.dropout_input(features)
if mask:
x, mask_indices = self.apply_mask(
features,
padding_mask,
mask_indices=mask_indices,
mask_channel_indices=mask_channel_indices,
)
else:
x = features
mask_indices = None
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
layer=layer,
)
if features_only:
encoder_out_lens = (1 - padding_mask.long()).sum(1)
return x, encoder_out_lens, None
result = {
"losses": {},
"padding_mask": padding_mask,
"x": x,
}
with torch.no_grad():
self.ema.model.eval()
if self.ema_transformer_only:
y, layer_results = self.ema.model.extract_features(
pre_encoder_features,
padding_mask=padding_mask,
min_layer=self.encoder_layers - self.average_top_k_layers,
)
y = {
"x": y,
"padding_mask": padding_mask,
"layer_results": layer_results,
}
else:
y = self.ema.model.extract_features(
source=xs_pad,
padding_mask=orig_padding_mask,
mask=False,
)
target_layer_results = [l[2] for l in y["layer_results"]]
permuted = False
if self.instance_norm_target_layer or self.batch_norm_target_layer:
target_layer_results = [
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
]
permuted = True
if self.batch_norm_target_layer:
target_layer_results = [
F.batch_norm(
tl.float(), running_mean=None, running_var=None, training=True
)
for tl in target_layer_results
]
if self.instance_norm_target_layer:
target_layer_results = [
F.instance_norm(tl.float()) for tl in target_layer_results
]
if permuted:
target_layer_results = [
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
]
if self.group_norm_target_layer:
target_layer_results = [
F.layer_norm(tl.float(), tl.shape[-2:])
for tl in target_layer_results
]
if self.layer_norm_target_layer:
target_layer_results = [
F.layer_norm(tl.float(), tl.shape[-1:])
for tl in target_layer_results
]
y = sum(target_layer_results) / len(target_layer_results)
if self.layer_norm_targets:
y = F.layer_norm(y.float(), y.shape[-1:])
if self.instance_norm_targets:
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
if not permuted:
y = y.transpose(0, 1)
y = y[mask_indices]
x = x[mask_indices]
x = self.final_proj(x)
sz = x.size(-1)
if self.loss_beta == 0:
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
else:
loss = F.smooth_l1_loss(
x.float(), y.float(), reduction="none", beta=self.loss_beta
).sum(dim=-1)
if self.loss_scale is not None:
scale = self.loss_scale
else:
scale = 1 / math.sqrt(sz)
result["losses"]["regression"] = loss.sum() * scale
if "sample_size" not in result:
result["sample_size"] = loss.numel()
with torch.no_grad():
result["target_var"] = self.compute_var(y)
result["pred_var"] = self.compute_var(x.float())
if self.num_updates > 5000 and result["target_var"] < self.min_target_var:
logging.error(
f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
)
raise Exception(
f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
)
if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
logging.error(
f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
)
raise Exception(
f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
)
if self.ema is not None:
result["ema_decay"] = self.ema.get_decay() * 1000
return result
@staticmethod
def compute_var(y):
y = y.view(-1, y.size(-1))
if dist.is_initialized():
zc = torch.tensor(y.size(0)).cuda()
zs = y.sum(dim=0)
zss = (y ** 2).sum(dim=0)
dist.all_reduce(zc)
dist.all_reduce(zs)
dist.all_reduce(zss)
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
return torch.sqrt(var + 1e-6).mean()
else:
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
def extract_features(
self, xs_pad, ilens, mask=False, layer=None
):
res = self.forward(
xs_pad,
ilens,
mask=mask,
features_only=True,
layer=layer,
)
return res
def remove_pretraining_modules(self, last_layer=None):
self.final_proj = None
self.ema = None
if last_layer is not None:
self.encoder.layers = nn.ModuleList(
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
)
def output_size(self) -> int:
return self.encoder_embed_dim

View File

@@ -0,0 +1,686 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class _BatchNorm1d(nn.Module):
def __init__(
self,
input_shape=None,
input_size=None,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
combine_batch_time=False,
skip_transpose=False,
):
super().__init__()
self.combine_batch_time = combine_batch_time
self.skip_transpose = skip_transpose
if input_size is None and skip_transpose:
input_size = input_shape[1]
elif input_size is None:
input_size = input_shape[-1]
self.norm = nn.BatchNorm1d(
input_size,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
)
def forward(self, x):
shape_or = x.shape
if self.combine_batch_time:
if x.ndim == 3:
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
else:
x = x.reshape(
shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
)
elif not self.skip_transpose:
x = x.transpose(-1, 1)
x_n = self.norm(x)
if self.combine_batch_time:
x_n = x_n.reshape(shape_or)
elif not self.skip_transpose:
x_n = x_n.transpose(1, -1)
return x_n
class _Conv1d(nn.Module):
def __init__(
self,
out_channels,
kernel_size,
input_shape=None,
in_channels=None,
stride=1,
dilation=1,
padding="same",
groups=1,
bias=True,
padding_mode="reflect",
skip_transpose=False,
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.unsqueeze = False
self.skip_transpose = skip_transpose
if input_shape is None and in_channels is None:
raise ValueError("Must provide one of input_shape or in_channels")
if in_channels is None:
in_channels = self._check_input_shape(input_shape)
self.conv = nn.Conv1d(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=0,
groups=groups,
bias=bias,
)
def forward(self, x):
if not self.skip_transpose:
x = x.transpose(1, -1)
if self.unsqueeze:
x = x.unsqueeze(1)
if self.padding == "same":
x = self._manage_padding(
x, self.kernel_size, self.dilation, self.stride
)
elif self.padding == "causal":
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == "valid":
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got "
+ self.padding
)
wx = self.conv(x)
if self.unsqueeze:
wx = wx.squeeze(1)
if not self.skip_transpose:
wx = wx.transpose(1, -1)
return wx
def _manage_padding(
self, x, kernel_size: int, dilation: int, stride: int,
):
# Detecting input shape
L_in = x.shape[-1]
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, padding, mode=self.padding_mode)
return x
def _check_input_shape(self, shape):
"""Checks the input shape and returns the number of input channels.
"""
if len(shape) == 2:
self.unsqueeze = True
in_channels = 1
elif self.skip_transpose:
in_channels = shape[1]
elif len(shape) == 3:
in_channels = shape[2]
else:
raise ValueError(
"conv1d expects 2d, 3d inputs. Got " + str(len(shape))
)
# Kernel size must be odd
if self.kernel_size % 2 == 0:
raise ValueError(
"The field kernel size must be an odd number. Got %s."
% (self.kernel_size)
)
return in_channels
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation
padding = [kernel_size // 2, kernel_size // 2]
else:
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
return padding
# Skip transpose as much as possible for efficiency
class Conv1d(_Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(skip_transpose=True, *args, **kwargs)
class BatchNorm1d(_BatchNorm1d):
def __init__(self, *args, **kwargs):
super().__init__(skip_transpose=True, *args, **kwargs)
def length_to_mask(length, max_len=None, dtype=None, device=None):
assert len(length.shape) == 1
if max_len is None:
max_len = length.max().long().item() # using arange to generate mask
mask = torch.arange(
max_len, device=length.device, dtype=length.dtype
).expand(len(length), max_len) < length.unsqueeze(1)
if dtype is None:
dtype = length.dtype
if device is None:
device = length.device
mask = torch.as_tensor(mask, dtype=dtype, device=device)
return mask
class TDNNBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation,
activation=nn.ReLU,
groups=1,
):
super(TDNNBlock, self).__init__()
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
groups=groups,
)
self.activation = activation()
self.norm = BatchNorm1d(input_size=out_channels)
def forward(self, x):
return self.norm(self.activation(self.conv(x)))
class Res2NetBlock(torch.nn.Module):
"""An implementation of Res2NetBlock w/ dilation.
Arguments
---------
in_channels : int
The number of channels expected in the input.
out_channels : int
The number of output channels.
scale : int
The scale of the Res2Net block.
kernel_size: int
The kernel size of the Res2Net block.
dilation : int
The dilation of the Res2Net block.
Example
-------
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 120, 64])
"""
def __init__(
self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
):
super(Res2NetBlock, self).__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
in_channel = in_channels // scale
hidden_channel = out_channels // scale
self.blocks = nn.ModuleList(
[
TDNNBlock(
in_channel,
hidden_channel,
kernel_size=kernel_size,
dilation=dilation,
)
for i in range(scale - 1)
]
)
self.scale = scale
def forward(self, x):
y = []
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
if i == 0:
y_i = x_i
elif i == 1:
y_i = self.blocks[i - 1](x_i)
else:
y_i = self.blocks[i - 1](x_i + y_i)
y.append(y_i)
y = torch.cat(y, dim=1)
return y
class SEBlock(nn.Module):
"""An implementation of squeeze-and-excitation block.
Arguments
---------
in_channels : int
The number of input channels.
se_channels : int
The number of output channels after squeeze.
out_channels : int
The number of output channels.
Example
-------
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> se_layer = SEBlock(64, 16, 64)
>>> lengths = torch.rand((8,))
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 120, 64])
"""
def __init__(self, in_channels, se_channels, out_channels):
super(SEBlock, self).__init__()
self.conv1 = Conv1d(
in_channels=in_channels, out_channels=se_channels, kernel_size=1
)
self.relu = torch.nn.ReLU(inplace=True)
self.conv2 = Conv1d(
in_channels=se_channels, out_channels=out_channels, kernel_size=1
)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, lengths=None):
L = x.shape[-1]
if lengths is not None:
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
total = mask.sum(dim=2, keepdim=True)
s = (x * mask).sum(dim=2, keepdim=True) / total
else:
s = x.mean(dim=2, keepdim=True)
s = self.relu(self.conv1(s))
s = self.sigmoid(self.conv2(s))
return s * x
class AttentiveStatisticsPooling(nn.Module):
"""This class implements an attentive statistic pooling layer for each channel.
It returns the concatenated mean and std of the input tensor.
Arguments
---------
channels: int
The number of input channels.
attention_channels: int
The number of attention channels.
Example
-------
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> asp_layer = AttentiveStatisticsPooling(64)
>>> lengths = torch.rand((8,))
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 1, 128])
"""
def __init__(self, channels, attention_channels=128, global_context=True):
super().__init__()
self.eps = 1e-12
self.global_context = global_context
if global_context:
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
else:
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
self.tanh = nn.Tanh()
self.conv = Conv1d(
in_channels=attention_channels, out_channels=channels, kernel_size=1
)
def forward(self, x, lengths=None):
"""Calculates mean and std for a batch (input tensor).
Arguments
---------
x : torch.Tensor
Tensor of shape [N, C, L].
"""
L = x.shape[-1]
def _compute_statistics(x, m, dim=2, eps=self.eps):
mean = (m * x).sum(dim)
std = torch.sqrt(
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
)
return mean, std
if lengths is None:
lengths = torch.ones(x.shape[0], device=x.device)
# Make binary mask of shape [N, 1, L]
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if self.global_context:
# torch.std is unstable for backward computation
# https://github.com/pytorch/pytorch/issues/4320
total = mask.sum(dim=2, keepdim=True).float()
mean, std = _compute_statistics(x, mask / total)
mean = mean.unsqueeze(2).repeat(1, 1, L)
std = std.unsqueeze(2).repeat(1, 1, L)
attn = torch.cat([x, mean, std], dim=1)
else:
attn = x
# Apply layers
attn = self.conv(self.tanh(self.tdnn(attn)))
# Filter out zero-paddings
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=2)
mean, std = _compute_statistics(x, attn)
# Append mean and std of the batch
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(2)
return pooled_stats
class SERes2NetBlock(nn.Module):
"""An implementation of building block in ECAPA-TDNN, i.e.,
TDNN-Res2Net-TDNN-SEBlock.
Arguments
----------
out_channels: int
The number of output channels.
res2net_scale: int
The scale of the Res2Net block.
kernel_size: int
The kernel size of the TDNN blocks.
dilation: int
The dilation of the Res2Net block.
activation : torch class
A class for constructing the activation layers.
groups: int
Number of blocked connections from input channels to output channels.
Example
-------
>>> x = torch.rand(8, 120, 64).transpose(1, 2)
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
>>> out = conv(x).transpose(1, 2)
>>> out.shape
torch.Size([8, 120, 64])
"""
def __init__(
self,
in_channels,
out_channels,
res2net_scale=8,
se_channels=128,
kernel_size=1,
dilation=1,
activation=torch.nn.ReLU,
groups=1,
):
super().__init__()
self.out_channels = out_channels
self.tdnn1 = TDNNBlock(
in_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.res2net_block = Res2NetBlock(
out_channels, out_channels, res2net_scale, kernel_size, dilation
)
self.tdnn2 = TDNNBlock(
out_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.se_block = SEBlock(out_channels, se_channels, out_channels)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x, lengths=None):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.tdnn1(x)
x = self.res2net_block(x)
x = self.tdnn2(x)
x = self.se_block(x, lengths)
return x + residual
class ECAPA_TDNN(torch.nn.Module):
"""An implementation of the speaker embedding model in a paper.
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
Arguments
---------
activation : torch class
A class for constructing the activation layers.
channels : list of ints
Output channels for TDNN/SERes2Net layer.
kernel_sizes : list of ints
List of kernel sizes for each layer.
dilations : list of ints
List of dilations for kernels in each layer.
lin_neurons : int
Number of neurons in linear layers.
groups : list of ints
List of groups for kernels in each layer.
Example
-------
>>> input_feats = torch.rand([5, 120, 80])
>>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
>>> outputs = compute_embedding(input_feats)
>>> outputs.shape
torch.Size([5, 1, 192])
"""
def __init__(
self,
input_size,
lin_neurons=192,
activation=torch.nn.ReLU,
channels=[512, 512, 512, 512, 1536],
kernel_sizes=[5, 3, 3, 3, 1],
dilations=[1, 2, 3, 4, 1],
attention_channels=128,
res2net_scale=8,
se_channels=128,
global_context=True,
groups=[1, 1, 1, 1, 1],
window_size=20,
window_shift=1,
):
super().__init__()
assert len(channels) == len(kernel_sizes)
assert len(channels) == len(dilations)
self.channels = channels
self.blocks = nn.ModuleList()
self.window_size = window_size
self.window_shift = window_shift
# The initial TDNN layer
self.blocks.append(
TDNNBlock(
input_size,
channels[0],
kernel_sizes[0],
dilations[0],
activation,
groups[0],
)
)
# SE-Res2Net layers
for i in range(1, len(channels) - 1):
self.blocks.append(
SERes2NetBlock(
channels[i - 1],
channels[i],
res2net_scale=res2net_scale,
se_channels=se_channels,
kernel_size=kernel_sizes[i],
dilation=dilations[i],
activation=activation,
groups=groups[i],
)
)
# Multi-layer feature aggregation
self.mfa = TDNNBlock(
channels[-1],
channels[-1],
kernel_sizes[-1],
dilations[-1],
activation,
groups=groups[-1],
)
# Attentive Statistical Pooling
self.asp = AttentiveStatisticsPooling(
channels[-1],
attention_channels=attention_channels,
global_context=global_context,
)
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
# Final linear transformation
self.fc = Conv1d(
in_channels=channels[-1] * 2,
out_channels=lin_neurons,
kernel_size=1,
)
def windowed_pooling(self, x, lengths=None):
# x: Batch, Channel, Time
tt = x.shape[2]
num_chunk = int(math.ceil(tt / self.window_shift))
pad = self.window_size // 2
x = F.pad(x, (pad, pad, 0, 0), "reflect")
stat_list = []
for i in range(num_chunk):
# B x C
st, ed = i * self.window_shift, i * self.window_shift + self.window_size
x = self.asp(x[:, :, st: ed],
lengths=torch.clamp(lengths - i, 0, self.window_size)
if lengths is not None else None)
x = self.asp_bn(x)
x = self.fc(x)
stat_list.append(x)
return torch.cat(stat_list, dim=2)
def forward(self, x, lengths=None):
"""Returns the embedding vector.
Arguments
---------
x : torch.Tensor
Tensor of shape (batch, time, channel).
lengths: torch.Tensor
Tensor of shape (batch, )
"""
# Minimize transpose for efficiency
x = x.transpose(1, 2)
xl = []
for layer in self.blocks:
try:
x = layer(x, lengths=lengths)
except TypeError:
x = layer(x)
xl.append(x)
# Multi-layer feature aggregation
x = torch.cat(xl[1:], dim=1)
x = self.mfa(x)
if self.window_size is None:
# Attentive Statistical Pooling
x = self.asp(x, lengths=lengths)
x = self.asp_bn(x)
# Final linear transformation
x = self.fc(x)
# x = x.transpose(1, 2)
x = x.squeeze(2) # -> B, C
else:
x = self.windowed_pooling(x, lengths)
x = x.transpose(1, 2) # -> B, T, C
return x

View File

@@ -0,0 +1,270 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder self-attention layer definition."""
import torch
from torch import nn
from funasr_local.modules.layer_norm import LayerNorm
from torch.autograd import Variable
class Encoder_Conformer_Layer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
feed_forward,
feed_forward_macaron,
conv_module,
dropout_rate,
normalize_before=True,
concat_after=False,
cca_pos=0,
):
"""Construct an Encoder_Conformer_Layer object."""
super(Encoder_Conformer_Layer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = LayerNorm(size) # for the FNN module
self.norm_mha = LayerNorm(size) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = LayerNorm(size)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = LayerNorm(size) # for the CNN module
self.norm_final = LayerNorm(size) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
self.cca_pos = cca_pos
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
def forward(self, x_input, mask, cache=None):
"""Compute encoded features.
Args:
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- w/o pos emb: Tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
if isinstance(x_input, tuple):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.cca_pos<2:
if pos_emb is not None:
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x = residual + self.dropout(self.conv_module(x))
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn_cros_channel,
self_attn_conformer,
feed_forward_csa,
feed_forward_macaron_csa,
conv_module_csa,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.encoder_cros_channel_atten = self_attn_cros_channel
self.encoder_csa = Encoder_Conformer_Layer(
size,
self_attn_conformer,
feed_forward_csa,
feed_forward_macaron_csa,
conv_module_csa,
dropout_rate,
normalize_before,
concat_after,
cca_pos=0)
self.norm_mha = LayerNorm(size) # for the MHA module
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x_input, mask, channel_size, cache=None):
"""Compute encoded features.
Args:
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- w/o pos emb: Tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
if isinstance(x_input, tuple):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None
residual = x
x = self.norm_mha(x)
t_leng = x.size(1)
d_dim = x.size(2)
x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
x_pad = torch.cat([pad_before,x_new, pad_after], 1)
x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
x_new = x_new.reshape(-1,channel_size,d_dim)
x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
x = residual + self.dropout(x_att)
if pos_emb is not None:
x_input = (x, pos_emb)
else:
x_input = x
x_input, mask = self.encoder_csa(x_input, mask)
return x_input, mask , channel_size

View File

@@ -0,0 +1,301 @@
from typing import Tuple, Dict
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
def forward(self, input):
output = self.linear(input)
return output
class AffineTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, input):
output = self.linear(input)
return output
class RectifiedLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, input):
out = self.relu(input)
return out
class FSMNBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
lorder=None,
rorder=None,
lstride=1,
rstride=1,
):
super(FSMNBlock, self).__init__()
self.dim = input_dim
if lorder is None:
return
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.conv_left = nn.Conv2d(
self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
if self.rorder > 0:
self.conv_right = nn.Conv2d(
self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
else:
self.conv_right = None
def forward(self, input: torch.Tensor, cache: torch.Tensor):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) # B D T C
cache = cache.to(x_per.device)
y_left = torch.cat((cache, x_per), dim=2)
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = self.conv_left(y_left)
out = x_per + y_left
if self.conv_right is not None:
# maybe need to check
y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
y_right = self.conv_right(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return output, cache
class BasicBlock(nn.Sequential):
def __init__(self,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride: int,
rstride: int,
stack_layer: int
):
super(BasicBlock, self).__init__()
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.stack_layer = stack_layer
self.linear = LinearTransform(linear_dim, proj_dim)
self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
self.affine = AffineTransform(proj_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
x1 = self.linear(input) # B T D
cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
if cache_layer_name not in in_cache:
in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
x3 = self.affine(x2)
x4 = self.relu(x3)
return x4
class FsmnStack(nn.Sequential):
def __init__(self, *args):
super(FsmnStack, self).__init__(*args)
def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
x = input
for module in self._modules.values():
x = module(x, in_cache)
return x
'''
FSMN net for keyword spotting
input_dim: input dimension
linear_dim: fsmn input dimensionll
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
num_syn: output dimension
fsmn_layers: no. of sequential fsmn layers
'''
class FSMN(nn.Module):
def __init__(
self,
input_dim: int,
input_affine_dim: int,
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride: int,
rstride: int,
output_affine_dim: int,
output_dim: int
):
super(FSMN, self).__init__()
self.input_dim = input_dim
self.input_affine_dim = input_affine_dim
self.fsmn_layers = fsmn_layers
self.linear_dim = linear_dim
self.proj_dim = proj_dim
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
range(fsmn_layers)])
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
self.softmax = nn.Softmax(dim=-1)
def fuse_modules(self):
pass
def forward(
self,
input: torch.Tensor,
in_cache: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
"""
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
x7 = self.softmax(x6)
return x7
'''
one deep fsmn layer
dimproj: projection dimension, input and output dimension of memory blocks
dimlinear: dimension of mapping layer
lorder: left order
rorder: right order
lstride: left stride
rstride: right stride
'''
class DFSMN(nn.Module):
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
super(DFSMN, self).__init__()
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.expand = AffineTransform(dimproj, dimlinear)
self.shrink = LinearTransform(dimlinear, dimproj)
self.conv_left = nn.Conv2d(
dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
if rorder > 0:
self.conv_right = nn.Conv2d(
dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
else:
self.conv_right = None
def forward(self, input):
f1 = F.relu(self.expand(input))
p1 = self.shrink(f1)
x = torch.unsqueeze(p1, 1)
x_per = x.permute(0, 3, 2, 1)
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
if self.conv_right is not None:
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
else:
out = x_per + self.conv_left(y_left)
out1 = out.permute(0, 3, 2, 1)
output = input + out1.squeeze(1)
return output
'''
build stacked dfsmn layers
'''
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
repeats = [
nn.Sequential(
DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
for i in range(fsmn_layers)
]
return nn.Sequential(*repeats)
if __name__ == '__main__':
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
print(fsmn)
num_params = sum(p.numel() for p in fsmn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 400) # batch-size * time * dim
y, _ = fsmn(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())

View File

@@ -0,0 +1,450 @@
from typing import Optional
from typing import Tuple
import logging
import torch
from torch import nn
from typeguard import check_argument_types
from funasr_local.models.encoder.encoder_layer_mfcca import EncoderLayer
from funasr_local.modules.nets_utils import get_activation
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr_local.modules.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
)
from funasr_local.modules.layer_norm import LayerNorm
from funasr_local.modules.multi_layer_conv import Conv1dLinear
from funasr_local.modules.multi_layer_conv import MultiLayeredConv1d
from funasr_local.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr_local.modules.repeat import repeat
from funasr_local.modules.subsampling import Conv2dSubsampling
from funasr_local.modules.subsampling import Conv2dSubsampling2
from funasr_local.modules.subsampling import Conv2dSubsampling6
from funasr_local.modules.subsampling import Conv2dSubsampling8
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.modules.subsampling import check_short_utt
from funasr_local.models.encoder.abs_encoder import AbsEncoder
import pdb
import math
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
"""
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(self, x):
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x)
return x.transpose(1, 2)
class MFCCAEncoder(AbsEncoder):
"""Conformer encoder module.
Args:
input_size (int): Input dimension.
output_size (int): Dimention of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
attention_dropout_rate (float): Dropout rate in attention.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
input_layer (Union[str, torch.nn.Module]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
If True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
If False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
rel_pos_type (str): Whether to use the latest relative positional encoding or
the legacy one. The legacy relative positional encoding will be deprecated
in the future. More Details can be found in
https://github.com/espnet/espnet/pull/2816.
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
encoder_attn_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
macaron_style (bool): Whether to use macaron style for positionwise layer.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = "legacy",
pos_enc_layer_type: str = "rel_pos",
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
zero_triu: bool = False,
cnn_module_kernel: int = 31,
padding_idx: int = -1,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if rel_pos_type == "legacy":
if pos_enc_layer_type == "rel_pos":
pos_enc_layer_type = "legacy_rel_pos"
if selfattention_layer_type == "rel_selfattn":
selfattention_layer_type = "legacy_rel_selfattn"
elif rel_pos_type == "latest":
assert selfattention_layer_type != "legacy_rel_selfattn"
assert pos_enc_layer_type != "legacy_rel_pos"
else:
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "legacy_rel_pos":
assert selfattention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(output_size, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
logging.warning(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
zero_triu,
)
else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation)
encoder_selfattn_layer_raw = MultiHeadedAttention
encoder_selfattn_layer_args_raw = (
attention_heads,
output_size,
attention_dropout_rate,
)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
channel_size: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
t_leng = xs_pad.size(1)
d_dim = xs_pad.size(2)
xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
#pdb.set_trace()
if(channel_size<8):
repeat_num = math.ceil(8/channel_size)
xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
xs_pad = self.conv1(xs_pad)
xs_pad = self.conv2(xs_pad)
xs_pad = self.conv3(xs_pad)
xs_pad = self.conv4(xs_pad)
xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
mask_tmp = masks.size(1)
masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None
def forward_hidden(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
num_layer = len(self.encoders)
for idx, encoder in enumerate(self.encoders):
xs_pad, masks = encoder(xs_pad, masks)
if idx == num_layer // 2 - 1:
hidden_feature = xs_pad
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
hidden_feature = hidden_feature[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
self.hidden_feature = self.after_norm(hidden_feature)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None

View File

@@ -0,0 +1,38 @@
import torch
from torch.nn import functional as F
class DotScorer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
xs_pad: torch.Tensor,
spk_emb: torch.Tensor,
):
# xs_pad: B, T, D
# spk_emb: B, N, D
scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
return scores
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
return {}
class CosScorer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
xs_pad: torch.Tensor,
spk_emb: torch.Tensor,
):
# xs_pad: B, T, D
# spk_emb: B, N, D
scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
return scores
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
return {}

View File

@@ -0,0 +1,277 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
from typeguard import check_argument_types
import numpy as np
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.layer_norm import LayerNorm
from funasr_local.models.encoder.abs_encoder import AbsEncoder
import math
from funasr_local.modules.repeat import repeat
class EncoderLayer(nn.Module):
def __init__(
self,
input_units,
num_units,
kernel_size=3,
activation="tanh",
stride=1,
include_batch_norm=False,
residual=False
):
super().__init__()
left_padding = math.ceil((kernel_size - stride) / 2)
right_padding = kernel_size - stride - left_padding
self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
self.conv1d = nn.Conv1d(
input_units,
num_units,
kernel_size,
stride,
)
self.activation = self.get_activation(activation)
if include_batch_norm:
self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
self.residual = residual
self.include_batch_norm = include_batch_norm
self.input_units = input_units
self.num_units = num_units
self.stride = stride
@staticmethod
def get_activation(activation):
if activation == "tanh":
return nn.Tanh()
else:
return nn.ReLU()
def forward(self, xs_pad, ilens=None):
outputs = self.conv1d(self.conv_padding(xs_pad))
if self.residual and self.stride == 1 and self.input_units == self.num_units:
outputs = outputs + xs_pad
if self.include_batch_norm:
outputs = self.bn(outputs)
# add parenthesis for repeat module
return self.activation(outputs), ilens
class ConvEncoder(AbsEncoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Convolution encoder in OpenNMT framework
"""
def __init__(
self,
num_layers,
input_units,
num_units,
kernel_size=3,
dropout_rate=0.3,
position_encoder=None,
activation='tanh',
auxiliary_states=True,
out_units=None,
out_norm=False,
out_residual=False,
include_batchnorm=False,
regularization_weight=0.0,
stride=1,
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
):
assert check_argument_types()
super().__init__()
self._output_size = num_units
self.num_layers = num_layers
self.input_units = input_units
self.num_units = num_units
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.position_encoder = position_encoder
self.out_units = out_units
self.auxiliary_states = auxiliary_states
self.out_norm = out_norm
self.activation = activation
self.out_residual = out_residual
self.include_batch_norm = include_batchnorm
self.regularization_weight = regularization_weight
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
if isinstance(stride, int):
self.stride = [stride] * self.num_layers
else:
self.stride = stride
self.downsample_rate = 1
for s in self.stride:
self.downsample_rate *= s
self.dropout = nn.Dropout(dropout_rate)
self.cnn_a = repeat(
self.num_layers,
lambda lnum: EncoderLayer(
input_units if lnum == 0 else num_units,
num_units,
kernel_size,
activation,
self.stride[lnum],
include_batchnorm,
residual=True if lnum > 0 else False
)
)
if self.out_units is not None:
left_padding = math.ceil((kernel_size - stride) / 2)
right_padding = kernel_size - stride - left_padding
self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
self.conv_out = nn.Conv1d(
num_units,
out_units,
kernel_size,
)
if self.out_norm:
self.after_norm = LayerNorm(out_units)
def output_size(self) -> int:
return self.num_units
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
inputs = xs_pad
if self.position_encoder is not None:
inputs = self.position_encoder(inputs)
if self.dropout_rate > 0:
inputs = self.dropout(inputs)
outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
if self.out_units is not None:
outputs = self.conv_out(self.out_padding(outputs))
outputs = outputs.transpose(1, 2)
if self.out_norm:
outputs = self.after_norm(outputs)
if self.out_residual:
outputs = outputs + inputs
return outputs, ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
}
if self.out_units is not None:
# add output layer
map_dict_local.update({
"{}.conv_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers),
"squeeze": None,
"transpose": (2, 1, 0),
}, # tf: (1, 256, 256) -> torch: (256, 256, 1)
"{}.conv_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers),
"squeeze": None,
"transpose": None,
}, # tf: (256,) -> torch: (256,)
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
# process special (first and last) layers
if name in map_dict:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
# process general layers
else:
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update

View File

@@ -0,0 +1,335 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
from typeguard import check_argument_types
import numpy as np
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.layer_norm import LayerNorm
from funasr_local.models.encoder.abs_encoder import AbsEncoder
import math
from funasr_local.modules.repeat import repeat
from funasr_local.modules.multi_layer_conv import FsmnFeedForward
class FsmnBlock(torch.nn.Module):
def __init__(
self,
n_feat,
dropout_rate,
kernel_size,
fsmn_shift=0,
):
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1,
padding=0, groups=n_feat, bias=False)
# padding
left_padding = (kernel_size - 1) // 2
if fsmn_shift > 0:
left_padding = left_padding + fsmn_shift
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
def forward(self, inputs, mask, mask_shfit_chunk=None):
b, t, d = inputs.size()
if mask is not None:
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x = x + inputs
x = self.dropout(x)
return x * mask
class EncoderLayer(torch.nn.Module):
def __init__(
self,
in_size,
size,
feed_forward,
fsmn_block,
dropout_rate=0.0
):
super().__init__()
self.in_size = in_size
self.size = size
self.ffn = feed_forward
self.memory = fsmn_block
self.dropout = nn.Dropout(dropout_rate)
def forward(
self,
xs_pad: torch.Tensor,
mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# xs_pad in Batch, Time, Dim
context = self.ffn(xs_pad)[0]
memory = self.memory(context, mask)
memory = self.dropout(memory)
if self.in_size == self.size:
return memory + xs_pad, mask
return memory, mask
class FsmnEncoder(AbsEncoder):
"""Encoder using Fsmn
"""
def __init__(self,
in_units,
filter_size,
fsmn_num_layers,
dnn_num_layers,
num_memory_units=512,
ffn_inner_dim=2048,
dropout_rate=0.0,
shift=0,
position_encoder=None,
sample_rate=1,
out_units=None,
tf2torch_tensor_name_prefix_torch="post_net",
tf2torch_tensor_name_prefix_tf="EAND/post_net"
):
"""Initializes the parameters of the encoder.
Args:
filter_size: the total order of memory block
fsmn_num_layers: The number of fsmn layers.
dnn_num_layers: The number of dnn layers
num_units: The number of memory units.
ffn_inner_dim: The number of units of the inner linear transformation
in the feed forward layer.
dropout_rate: The probability to drop units from the outputs.
shift: left padding, to control delay
position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
apply on inputs or ``None``.
"""
super(FsmnEncoder, self).__init__()
self.in_units = in_units
self.filter_size = filter_size
self.fsmn_num_layers = fsmn_num_layers
self.dnn_num_layers = dnn_num_layers
self.num_memory_units = num_memory_units
self.ffn_inner_dim = ffn_inner_dim
self.dropout_rate = dropout_rate
self.shift = shift
if not isinstance(shift, list):
self.shift = [shift for _ in range(self.fsmn_num_layers)]
self.sample_rate = sample_rate
if not isinstance(sample_rate, list):
self.sample_rate = [sample_rate for _ in range(self.fsmn_num_layers)]
self.position_encoder = position_encoder
self.dropout = nn.Dropout(dropout_rate)
self.out_units = out_units
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.fsmn_layers = repeat(
self.fsmn_num_layers,
lambda lnum: EncoderLayer(
in_units if lnum == 0 else num_memory_units,
num_memory_units,
FsmnFeedForward(
in_units if lnum == 0 else num_memory_units,
ffn_inner_dim,
num_memory_units,
1,
dropout_rate
),
FsmnBlock(
num_memory_units,
dropout_rate,
filter_size,
self.shift[lnum]
)
),
)
self.dnn_layers = repeat(
dnn_num_layers,
lambda lnum: FsmnFeedForward(
num_memory_units,
ffn_inner_dim,
num_memory_units,
1,
dropout_rate,
)
)
if out_units is not None:
self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
def output_size(self) -> int:
return self.num_memory_units
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
inputs = xs_pad
if self.position_encoder is not None:
inputs = self.position_encoder(inputs)
inputs = self.dropout(inputs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
inputs = self.fsmn_layers(inputs, masks)[0]
inputs = self.dnn_layers(inputs)[0]
if self.out_units is not None:
inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
return inputs, ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
# for fsmn_layers
"{}.fsmn_layers.layeridx.ffn.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.fsmn_layers.layeridx.ffn.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.fsmn_layers.layeridx.ffn.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.fsmn_layers.layeridx.ffn.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.fsmn_layers.layeridx.ffn.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.fsmn_layers.layeridx.memory.fsmn_block.weight".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/memory/depth_conv_w".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 2, 0),
}, # (1, 31, 512, 1) -> (31, 512, 1) -> (512, 1, 31)
# for dnn_layers
"{}.dnn_layers.layeridx.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.dnn_layers.layeridx.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.dnn_layers.layeridx.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.dnn_layers.layeridx.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.dnn_layers.layeridx.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
}
if self.out_units is not None:
# add output layer
map_dict_local.update({
"{}.conv1d.weight".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.conv1d.bias".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
# process special (first and last) layers
if name in map_dict:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
# process general layers
else:
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update

View File

@@ -0,0 +1,480 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from funasr_local.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
from funasr_local.modules.embedding import SinusoidalPositionEncoder
from funasr_local.modules.layer_norm import LayerNorm
from funasr_local.modules.multi_layer_conv import Conv1dLinear
from funasr_local.modules.multi_layer_conv import MultiLayeredConv1d
from funasr_local.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr_local.modules.repeat import repeat
from funasr_local.modules.subsampling import Conv2dSubsampling
from funasr_local.modules.subsampling import Conv2dSubsampling2
from funasr_local.modules.subsampling import Conv2dSubsampling6
from funasr_local.modules.subsampling import Conv2dSubsampling8
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.modules.subsampling import check_short_utt
from funasr_local.models.ctc import CTC
from funasr_local.models.encoder.abs_encoder import AbsEncoder
class EncoderLayer(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_att_chunk_encoder
class SelfAttentionEncoder(AbsEncoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Self attention encoder in OpenNMT framework
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
out_units=None,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
elif input_layer == "null":
self.embed = None
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
output_size,
MultiHeadSelfAttention(
attention_heads,
output_size,
output_size,
attention_dropout_rate,
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
) if lnum > 0 else EncoderLayer(
input_size,
output_size,
MultiHeadSelfAttention(
attention_heads,
input_size if input_layer == "pe" or input_layer == "null" else output_size,
output_size,
attention_dropout_rate,
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.out_units = out_units
if out_units is not None:
self.output_linear = nn.Linear(output_size, out_units)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad = xs_pad * self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
xs_pad = self.dropout(xs_pad)
# encoder_outs = self.encoders0(xs_pad, masks)
# xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
if self.out_units is not None:
xs_pad = self.output_linear(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
# cicd
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (768,256),(1,256,768)
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (768,),(768,)
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# ffn
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# out norm
"{}.after_norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.after_norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
}
if self.out_units is not None:
map_dict_local.update({
"{}.output_linear.weight".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
},
"{}.output_linear.bias".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
# process special (first and last) layers
if name in map_dict:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
# process general layers
else:
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update

View File

@@ -0,0 +1,853 @@
import torch
from torch.nn import functional as F
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from typing import Tuple, Optional
from funasr_local.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
from collections import OrderedDict
import logging
import numpy as np
class BasicLayer(torch.nn.Module):
def __init__(self, in_filters: int, filters: int, stride: int, bn_momentum: float = 0.5):
super().__init__()
self.stride = stride
self.in_filters = in_filters
self.filters = filters
self.bn1 = torch.nn.BatchNorm2d(in_filters, eps=1e-3, momentum=bn_momentum, affine=True)
self.relu1 = torch.nn.ReLU()
self.conv1 = torch.nn.Conv2d(in_filters, filters, 3, stride, bias=False)
self.bn2 = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
self.relu2 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(filters, filters, 3, 1, bias=False)
if in_filters != filters or stride > 1:
self.conv_sc = torch.nn.Conv2d(in_filters, filters, 1, stride, bias=False)
self.bn_sc = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
def proper_padding(self, x, stride):
# align padding mode to tf.layers.conv2d with padding_mod="same"
if stride == 1:
return F.pad(x, (1, 1, 1, 1), "constant", 0)
elif stride == 2:
h, w = x.size(2), x.size(3)
# (left, right, top, bottom)
return F.pad(x, (w % 2, 1, h % 2, 1), "constant", 0)
def forward(self, xs_pad, ilens):
identity = xs_pad
if self.in_filters != self.filters or self.stride > 1:
identity = self.conv_sc(identity)
identity = self.bn_sc(identity)
xs_pad = self.relu1(self.bn1(xs_pad))
xs_pad = self.proper_padding(xs_pad, self.stride)
xs_pad = self.conv1(xs_pad)
xs_pad = self.relu2(self.bn2(xs_pad))
xs_pad = self.proper_padding(xs_pad, 1)
xs_pad = self.conv2(xs_pad)
if self.stride == 2:
ilens = (ilens + 1) // self.stride
return xs_pad + identity, ilens
class BasicBlock(torch.nn.Module):
def __init__(self, in_filters, filters, num_layer, stride, bn_momentum=0.5):
super().__init__()
self.num_layer = num_layer
for i in range(num_layer):
layer = BasicLayer(in_filters if i == 0 else filters, filters,
stride if i == 0 else 1, bn_momentum)
self.add_module("layer_{}".format(i), layer)
def forward(self, xs_pad, ilens):
for i in range(self.num_layer):
xs_pad, ilens = self._modules["layer_{}".format(i)](xs_pad, ilens)
return xs_pad, ilens
class ResNet34(AbsEncoder):
def __init__(
self,
input_size,
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
):
super(ResNet34, self).__init__()
self.use_head_conv = use_head_conv
self.use_head_maxpool = use_head_maxpool
self.num_nodes_pooling_layer = num_nodes_pooling_layer
self.layers_in_block = layers_in_block
self.filters_in_block = filters_in_block
self.input_size = input_size
pre_filters = filters_in_block[0]
if use_head_conv:
self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
if use_head_maxpool:
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
for i in range(len(layers_in_block)):
if i == 0:
in_filters = pre_filters if self.use_head_conv else 1
else:
in_filters = filters_in_block[i-1]
block = BasicBlock(in_filters,
filters=filters_in_block[i],
num_layer=layers_in_block[i],
stride=1 if i == 0 else 2,
bn_momentum=batchnorm_momentum)
self.add_module("block_{}".format(i), block)
self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1)
self.resnet0_bn = torch.nn.BatchNorm2d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
self.time_ds_ratio = 8
def output_size(self) -> int:
return self.num_nodes_pooling_layer
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
features = xs_pad
assert features.size(-1) == self.input_size, \
"Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
features = torch.unsqueeze(features, dim=1)
if self.use_head_conv:
features = self.pre_conv(features)
features = self.pre_conv_bn(features)
features = F.relu(features)
if self.use_head_maxpool:
features = self.head_maxpool(features)
resnet_outs, resnet_out_lens = features, ilens
for i in range(len(self.layers_in_block)):
block = self._modules["block_{}".format(i)]
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
features = self.resnet0_dense(resnet_outs)
features = F.relu(features)
features = self.resnet0_bn(features)
return features, resnet_out_lens
# Note: For training, this implement is not equivalent to tf because of the kernel_regularizer in tf.layers.
# TODO: implement kernel_regularizer in torch with munal loss addition or weigth_decay in the optimizer
class ResNet34_SP_L2Reg(AbsEncoder):
def __init__(
self,
input_size,
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="EAND/speech_encoder",
tf_train_steps=720000,
):
super(ResNet34_SP_L2Reg, self).__init__()
self.use_head_conv = use_head_conv
self.use_head_maxpool = use_head_maxpool
self.num_nodes_pooling_layer = num_nodes_pooling_layer
self.layers_in_block = layers_in_block
self.filters_in_block = filters_in_block
self.input_size = input_size
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.tf_train_steps = tf_train_steps
pre_filters = filters_in_block[0]
if use_head_conv:
self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
if use_head_maxpool:
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
for i in range(len(layers_in_block)):
if i == 0:
in_filters = pre_filters if self.use_head_conv else 1
else:
in_filters = filters_in_block[i-1]
block = BasicBlock(in_filters,
filters=filters_in_block[i],
num_layer=layers_in_block[i],
stride=1 if i == 0 else 2,
bn_momentum=batchnorm_momentum)
self.add_module("block_{}".format(i), block)
self.resnet0_dense = torch.nn.Conv1d(filters_in_block[-1] * input_size // 8, num_nodes_pooling_layer, 1)
self.resnet0_bn = torch.nn.BatchNorm1d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
self.time_ds_ratio = 8
def output_size(self) -> int:
return self.num_nodes_pooling_layer
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
features = xs_pad
assert features.size(-1) == self.input_size, \
"Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
features = torch.unsqueeze(features, dim=1)
if self.use_head_conv:
features = self.pre_conv(features)
features = self.pre_conv_bn(features)
features = F.relu(features)
if self.use_head_maxpool:
features = self.head_maxpool(features)
resnet_outs, resnet_out_lens = features, ilens
for i in range(len(self.layers_in_block)):
block = self._modules["block_{}".format(i)]
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
# B, C, T, F
bb, cc, tt, ff = resnet_outs.shape
resnet_outs = torch.reshape(resnet_outs.permute(0, 3, 1, 2), [bb, ff*cc, tt])
features = self.resnet0_dense(resnet_outs)
features = F.relu(features)
features = self.resnet0_bn(features)
return features, resnet_out_lens
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
train_steps = self.tf_train_steps
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
}
for layer_idx in range(3):
map_dict_local.update({
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
},
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
})
for block_idx in range(len(self.layers_in_block)):
for layer_idx in range(self.layers_in_block[block_idx]):
for i in ["1", "2", "_sc"]:
map_dict_local.update({
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
if name in map_dict:
if "num_batches_tracked" not in name:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
logging.info("torch tensor: {}, manually assigning to: {}".format(
name, map_dict[name]
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update
class ResNet34Diar(ResNet34):
def __init__(
self,
input_size,
embedding_node="resnet1_dense",
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
num_nodes_resnet1=256,
num_nodes_last_layer=256,
pooling_type="window_shift",
pool_size=20,
stride=1,
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
"""
super(ResNet34Diar, self).__init__(
input_size,
use_head_conv=use_head_conv,
batchnorm_momentum=batchnorm_momentum,
use_head_maxpool=use_head_maxpool,
num_nodes_pooling_layer=num_nodes_pooling_layer,
layers_in_block=layers_in_block,
filters_in_block=filters_in_block,
)
self.embedding_node = embedding_node
self.num_nodes_resnet1 = num_nodes_resnet1
self.num_nodes_last_layer = num_nodes_last_layer
self.pooling_type = pooling_type
self.pool_size = pool_size
self.stride = stride
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
def output_size(self) -> int:
if self.embedding_node.startswith("resnet1"):
return self.num_nodes_resnet1
elif self.embedding_node.startswith("resnet2"):
return self.num_nodes_last_layer
return self.num_nodes_pooling_layer
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
endpoints = OrderedDict()
res_out, ilens = super().forward(xs_pad, ilens)
endpoints["resnet0_bn"] = res_out
if self.pooling_type == "frame_gsp":
features = statistic_pooling(res_out, ilens, (3, ))
else:
features, ilens = windowed_statistic_pooling(res_out, ilens, (2, 3), self.pool_size, self.stride)
features = features.transpose(1, 2)
endpoints["pooling"] = features
features = self.resnet1_dense(features)
endpoints["resnet1_dense"] = features
features = F.relu(features)
endpoints["resnet1_relu"] = features
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet1_bn"] = features
features = self.resnet2_dense(features)
endpoints["resnet2_dense"] = features
features = F.relu(features)
endpoints["resnet2_relu"] = features
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet2_bn"] = features
return endpoints[self.embedding_node], ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
train_steps = 300000
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
}
for layer_idx in range(3):
map_dict_local.update({
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": (3, 2, 0, 1) if layer_idx == 0 else (1, 0),
},
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
})
for block_idx in range(len(self.layers_in_block)):
for layer_idx in range(self.layers_in_block[block_idx]):
for i in ["1", "2", "_sc"]:
map_dict_local.update({
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
if name in map_dict:
if "num_batches_tracked" not in name:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
logging.info("torch tensor: {}, manually assigning to: {}".format(
name, map_dict[name]
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update
class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
def __init__(
self,
input_size,
embedding_node="resnet1_dense",
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
num_nodes_resnet1=256,
num_nodes_last_layer=256,
pooling_type="window_shift",
pool_size=20,
stride=1,
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
"""
Author: Speech Lab, Alibaba Group, China
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
super(ResNet34SpL2RegDiar, self).__init__(
input_size,
use_head_conv=use_head_conv,
batchnorm_momentum=batchnorm_momentum,
use_head_maxpool=use_head_maxpool,
num_nodes_pooling_layer=num_nodes_pooling_layer,
layers_in_block=layers_in_block,
filters_in_block=filters_in_block,
)
self.embedding_node = embedding_node
self.num_nodes_resnet1 = num_nodes_resnet1
self.num_nodes_last_layer = num_nodes_last_layer
self.pooling_type = pooling_type
self.pool_size = pool_size
self.stride = stride
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
def output_size(self) -> int:
if self.embedding_node.startswith("resnet1"):
return self.num_nodes_resnet1
elif self.embedding_node.startswith("resnet2"):
return self.num_nodes_last_layer
return self.num_nodes_pooling_layer
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
endpoints = OrderedDict()
res_out, ilens = super().forward(xs_pad, ilens)
endpoints["resnet0_bn"] = res_out
if self.pooling_type == "frame_gsp":
features = statistic_pooling(res_out, ilens, (2, ))
else:
features, ilens = windowed_statistic_pooling(res_out, ilens, (2, ), self.pool_size, self.stride)
features = features.transpose(1, 2)
endpoints["pooling"] = features
features = self.resnet1_dense(features)
endpoints["resnet1_dense"] = features
features = F.relu(features)
endpoints["resnet1_relu"] = features
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet1_bn"] = features
features = self.resnet2_dense(features)
endpoints["resnet2_dense"] = features
features = F.relu(features)
endpoints["resnet2_relu"] = features
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet2_bn"] = features
return endpoints[self.embedding_node], ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
train_steps = 720000
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
}
for layer_idx in range(3):
map_dict_local.update({
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
},
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
})
for block_idx in range(len(self.layers_in_block)):
for layer_idx in range(self.layers_in_block[block_idx]):
for i in ["1", "2", "_sc"]:
map_dict_local.update({
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
if name in map_dict:
if "num_batches_tracked" not in name:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
var_dict_torch_update[name] = torch.from_numpy(np.array(map_dict[name])).type(torch.int64).to("cpu")
logging.info("torch tensor: {}, manually assigning to: {}".format(
name, map_dict[name]
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update

View File

@@ -0,0 +1,115 @@
from typing import Optional
from typing import Sequence
from typing import Tuple
import numpy as np
import torch
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.rnn.encoders import RNN
from funasr_local.modules.rnn.encoders import RNNP
from funasr_local.models.encoder.abs_encoder import AbsEncoder
class RNNEncoder(AbsEncoder):
"""RNNEncoder class.
Args:
input_size: The number of expected features in the input
output_size: The number of output features
hidden_size: The number of hidden features
bidirectional: If ``True`` becomes a bidirectional LSTM
use_projection: Use projection layer or not
num_layers: Number of recurrent layers
dropout: dropout probability
"""
def __init__(
self,
input_size: int,
rnn_type: str = "lstm",
bidirectional: bool = True,
use_projection: bool = True,
num_layers: int = 4,
hidden_size: int = 320,
output_size: int = 320,
dropout: float = 0.0,
subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
self.rnn_type = rnn_type
self.bidirectional = bidirectional
self.use_projection = use_projection
if rnn_type not in {"lstm", "gru"}:
raise ValueError(f"Not supported rnn_type={rnn_type}")
if subsample is None:
subsample = np.ones(num_layers + 1, dtype=np.int)
else:
subsample = subsample[:num_layers]
# Append 1 at the beginning because the second or later is used
subsample = np.pad(
np.array(subsample, dtype=np.int),
[1, num_layers - len(subsample)],
mode="constant",
constant_values=1,
)
rnn_type = ("b" if bidirectional else "") + rnn_type
if use_projection:
self.enc = torch.nn.ModuleList(
[
RNNP(
input_size,
num_layers,
hidden_size,
output_size,
subsample,
dropout,
typ=rnn_type,
)
]
)
else:
self.enc = torch.nn.ModuleList(
[
RNN(
input_size,
num_layers,
hidden_size,
output_size,
dropout,
typ=rnn_type,
)
]
)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if prev_states is None:
prev_states = [None] * len(self.enc)
assert len(prev_states) == len(self.enc)
current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
current_states.append(states)
if self.use_projection:
xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0)
else:
xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0)
return xs_pad, ilens, current_states

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,684 @@
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Transformer encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple
import torch
from torch import nn
from typeguard import check_argument_types
import logging
from funasr_local.models.ctc import CTC
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.modules.attention import MultiHeadedAttention
from funasr_local.modules.embedding import PositionalEncoding
from funasr_local.modules.layer_norm import LayerNorm
from funasr_local.modules.multi_layer_conv import Conv1dLinear
from funasr_local.modules.multi_layer_conv import MultiLayeredConv1d
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr_local.modules.repeat import repeat
from funasr_local.modules.nets_utils import rename_state_dict
from funasr_local.modules.dynamic_conv import DynamicConvolution
from funasr_local.modules.dynamic_conv2d import DynamicConvolution2D
from funasr_local.modules.lightconv import LightweightConvolution
from funasr_local.modules.lightconv2d import LightweightConvolution2D
from funasr_local.modules.subsampling import Conv2dSubsampling
from funasr_local.modules.subsampling import Conv2dSubsampling2
from funasr_local.modules.subsampling import Conv2dSubsampling6
from funasr_local.modules.subsampling import Conv2dSubsampling8
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.modules.subsampling import check_short_utt
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def __init__(
self,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
def forward(self, x, mask, cache=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x_q, x, x, mask)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
class TransformerEncoder(AbsEncoder):
"""Transformer encoder module.
Args:
input_size: input dim
output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
attention_dropout_rate: dropout rate in attention
positional_dropout_rate: dropout rate after adding positional encoding
input_layer: input layer type
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
positionwise_layer_type: linear of conv1d
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
padding_idx: padding_idx for input_layer=embed
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
MultiHeadedAttention(
attention_heads, output_size, attention_dropout_rate
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
# https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
# https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
class TransformerEncoder_s0(torch.nn.Module):
"""Transformer encoder module.
Args:
idim (int): Input dimension.
attention_dim (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
conv_wshare (int): The number of kernel of convolution. Only used in
selfattention_layer_type == "lightconv*" or "dynamiconv*".
conv_kernel_length (Union[int, str]): Kernel size str of convolution
(e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type
== "lightconv*" or "dynamiconv*".
conv_usebias (bool): Whether to use bias in convolution. Only used in
selfattention_layer_type == "lightconv*" or "dynamiconv*".
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
attention_dropout_rate (float): Dropout rate in attention.
input_layer (Union[str, torch.nn.Module]): Input layer type.
pos_enc_class (torch.nn.Module): Positional encoding module class.
`PositionalEncoding `or `ScaledPositionalEncoding`
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
selfattention_layer_type (str): Encoder attention layer type.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
indices start from 1.
if not None, intermediate outputs are returned (which changes return type
signature.)
"""
def __init__(
self,
idim,
attention_dim=256,
attention_heads=4,
conv_wshare=4,
conv_kernel_length="11",
conv_usebias=False,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
selfattention_layer_type="selfattn",
padding_idx=-1,
stochastic_depth_rate=0.0,
intermediate_layers=None,
ctc_softmax=None,
conditioning_layer_dim=None,
):
"""Construct an Encoder object."""
super(TransformerEncoder_s0, self).__init__()
self._register_load_state_dict_pre_hook(_pre_hook)
self.conv_subsampling_factor = 1
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(idim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 4
elif input_layer == "conv2d-scaled-pos-enc":
self.embed = Conv2dSubsampling(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
)
self.conv_subsampling_factor = 4
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 6
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 8
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
positionwise_layer, positionwise_layer_args = self.get_positionwise_layer(
positionwise_layer_type,
attention_dim,
linear_units,
dropout_rate,
positionwise_conv_kernel_size,
)
if selfattention_layer_type in [
"selfattn",
"rel_selfattn",
"legacy_rel_selfattn",
]:
logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = [
(
attention_heads,
attention_dim,
attention_dropout_rate,
)
] * num_blocks
elif selfattention_layer_type == "lightconv":
logging.info("encoder self-attention layer type = lightweight convolution")
encoder_selfattn_layer = LightweightConvolution
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "lightconv2d":
logging.info(
"encoder self-attention layer "
"type = lightweight convolution 2-dimensional"
)
encoder_selfattn_layer = LightweightConvolution2D
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "dynamicconv":
logging.info("encoder self-attention layer type = dynamic convolution")
encoder_selfattn_layer = DynamicConvolution
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "dynamicconv2d":
logging.info(
"encoder self-attention layer type = dynamic convolution 2-dimensional"
)
encoder_selfattn_layer = DynamicConvolution2D
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
else:
raise NotImplementedError(selfattention_layer_type)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
attention_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate * float(1 + lnum) / num_blocks,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
self.intermediate_layers = intermediate_layers
self.use_conditioning = True if ctc_softmax is not None else False
if self.use_conditioning:
self.ctc_softmax = ctc_softmax
self.conditioning_layer = torch.nn.Linear(
conditioning_layer_dim, attention_dim
)
def get_positionwise_layer(
self,
positionwise_layer_type="linear",
attention_dim=256,
linear_units=2048,
dropout_rate=0.1,
positionwise_conv_kernel_size=1,
):
"""Define positionwise layer."""
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
return positionwise_layer, positionwise_layer_args
def forward(self, xs, masks):
"""Encode input sequence.
Args:
xs (torch.Tensor): Input tensor (#batch, time, idim).
masks (torch.Tensor): Mask tensor (#batch, time).
Returns:
torch.Tensor: Output tensor (#batch, time, attention_dim).
torch.Tensor: Mask tensor (#batch, time).
"""
if isinstance(
self.embed,
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8),
):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
if self.intermediate_layers is None:
xs, masks = self.encoders(xs, masks)
else:
intermediate_outputs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
xs, masks = encoder_layer(xs, masks)
if (
self.intermediate_layers is not None
and layer_idx + 1 in self.intermediate_layers
):
encoder_output = xs
# intermediate branches also require normalization.
if self.normalize_before:
encoder_output = self.after_norm(encoder_output)
intermediate_outputs.append(encoder_output)
if self.use_conditioning:
intermediate_result = self.ctc_softmax(encoder_output)
xs = xs + self.conditioning_layer(intermediate_result)
if self.normalize_before:
xs = self.after_norm(xs)
if self.intermediate_layers is not None:
return xs, masks, intermediate_outputs
return xs, masks
def forward_one_step(self, xs, masks, cache=None):
"""Encode input frame.
Args:
xs (torch.Tensor): Input tensor.
masks (torch.Tensor): Mask tensor.
cache (List[torch.Tensor]): List of cache tensors.
Returns:
torch.Tensor: Output tensor.
torch.Tensor: Mask tensor.
List[torch.Tensor]: List of new cache tensors.
"""
if isinstance(self.embed, Conv2dSubsampling):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
if cache is None:
cache = [None for _ in range(len(self.encoders))]
new_cache = []
for c, e in zip(cache, self.encoders):
xs, masks = e(xs, masks, cache=c)
new_cache.append(xs)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks, new_cache

View File

View File

@@ -0,0 +1,17 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
class AbsFrontend(torch.nn.Module, ABC):
@abstractmethod
def output_size(self) -> int:
raise NotImplementedError
@abstractmethod
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@@ -0,0 +1,258 @@
import copy
from typing import Optional
from typing import Tuple
from typing import Union
import humanfriendly
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
from typeguard import check_argument_types
from funasr_local.layers.log_mel import LogMel
from funasr_local.layers.stft import Stft
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.modules.frontends.frontend import Frontend
from funasr_local.utils.get_default_kwargs import get_default_kwargs
class DefaultFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=center,
window=window,
normalized=normalized,
onesided=onesided,
)
else:
self.stft = None
self.apply_stft = apply_stft
if frontend_conf is not None:
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
else:
self.frontend = None
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.n_mels = n_mels
self.frontend_type = "default"
def output_size(self) -> int:
return self.n_mels
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 3. [Multi channel case]: Select a channel
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
# Select 1ch randomly
ch = np.random.randint(input_stft.size(2))
input_stft = input_stft[:, :, ch, :]
else:
# Use the first channel
input_stft = input_stft[:, :, 0, :]
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
return input_feats, feats_lens
def _compute_stft(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> torch.Tensor:
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
return input_stft, feats_lens
class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
frame_length: int = None,
frame_shift: int = None,
lfr_m: int = None,
lfr_n: int = None,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=center,
window=window,
normalized=normalized,
onesided=onesided,
)
else:
self.stft = None
self.apply_stft = apply_stft
if frontend_conf is not None:
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
else:
self.frontend = None
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.n_mels = n_mels
self.frontend_type = "multichannelfrontend"
def output_size(self) -> int:
return self.n_mels
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
#import pdb;pdb.set_trace()
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
if isinstance(input, ComplexTensor):
input_stft = input
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
bt = input_feats.size(0)
if input_feats.dim() ==4:
channel_size = input_feats.size(2)
# batch * channel * T * D
#pdb.set_trace()
input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
# input_feats = input_feats.transpose(1,2)
# batch * channel
feats_lens = feats_lens.repeat(1,channel_size).squeeze()
else:
channel_size = 1
return input_feats, feats_lens, channel_size
def _compute_stft(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> torch.Tensor:
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
return input_stft, feats_lens

View File

@@ -0,0 +1,51 @@
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This module is for computing audio features
import librosa
import numpy as np
def transform(Y, dtype=np.float32):
Y = np.abs(Y)
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
Y = np.dot(Y ** 2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
mean = np.mean(Y, axis=0)
Y = Y - mean
return Y.astype(dtype)
def subsample(Y, T, subsampling=1):
Y_ss = Y[::subsampling]
T_ss = T[::subsampling]
return Y_ss, T_ss
def splice(Y, context_size=0):
Y_pad = np.pad(
Y,
[(context_size, context_size), (0, 0)],
'constant')
Y_spliced = np.lib.stride_tricks.as_strided(
np.ascontiguousarray(Y_pad),
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
(Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
return Y_spliced
def stft(
data,
frame_size=1024,
frame_shift=256):
fft_size = 1 << (frame_size - 1).bit_length()
if len(data) % frame_shift == 0:
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
hop_length=frame_shift).T[:-1]
else:
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
hop_length=frame_shift).T

View File

@@ -0,0 +1,146 @@
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.frontend.default import DefaultFrontend
from funasr_local.models.frontend.s3prl import S3prlFrontend
import numpy as np
import torch
from typeguard import check_argument_types
from typing import Tuple
class FusedFrontends(AbsFrontend):
def __init__(
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
):
assert check_argument_types()
super().__init__()
self.align_method = (
align_method # fusing method : linear_projection only for now
)
self.proj_dim = proj_dim # dim of the projection done on each frontend
self.frontends = [] # list of the frontends to combine
for i, frontend in enumerate(frontends):
frontend_type = frontend["frontend_type"]
if frontend_type == "default":
n_mels, fs, n_fft, win_length, hop_length = (
frontend.get("n_mels", 80),
fs,
frontend.get("n_fft", 512),
frontend.get("win_length"),
frontend.get("hop_length", 128),
)
window, center, normalized, onesided = (
frontend.get("window", "hann"),
frontend.get("center", True),
frontend.get("normalized", False),
frontend.get("onesided", True),
)
fmin, fmax, htk, apply_stft = (
frontend.get("fmin", None),
frontend.get("fmax", None),
frontend.get("htk", False),
frontend.get("apply_stft", True),
)
self.frontends.append(
DefaultFrontend(
n_mels=n_mels,
n_fft=n_fft,
fs=fs,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
fmin=fmin,
fmax=fmax,
htk=htk,
apply_stft=apply_stft,
)
)
elif frontend_type == "s3prl":
frontend_conf, download_dir, multilayer_feature = (
frontend.get("frontend_conf"),
frontend.get("download_dir"),
frontend.get("multilayer_feature"),
)
self.frontends.append(
S3prlFrontend(
fs=fs,
frontend_conf=frontend_conf,
download_dir=download_dir,
multilayer_feature=multilayer_feature,
)
)
else:
raise NotImplementedError # frontends are only default or s3prl
self.frontends = torch.nn.ModuleList(self.frontends)
self.gcd = np.gcd.reduce([frontend.hop_length for frontend in self.frontends])
self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
if torch.cuda.is_available():
dev = "cuda"
else:
dev = "cpu"
if self.align_method == "linear_projection":
self.projection_layers = [
torch.nn.Linear(
in_features=frontend.output_size(),
out_features=self.factors[i] * self.proj_dim,
)
for i, frontend in enumerate(self.frontends)
]
self.projection_layers = torch.nn.ModuleList(self.projection_layers)
self.projection_layers = self.projection_layers.to(torch.device(dev))
def output_size(self) -> int:
return len(self.frontends) * self.proj_dim
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# step 0 : get all frontends features
self.feats = []
for frontend in self.frontends:
with torch.no_grad():
input_feats, feats_lens = frontend.forward(input, input_lengths)
self.feats.append([input_feats, feats_lens])
if (
self.align_method == "linear_projection"
): # TODO(Dan): to add other align methods
# first step : projections
self.feats_proj = []
for i, frontend in enumerate(self.frontends):
input_feats = self.feats[i][0]
self.feats_proj.append(self.projection_layers[i](input_feats))
# 2nd step : reshape
self.feats_reshaped = []
for i, frontend in enumerate(self.frontends):
input_feats_proj = self.feats_proj[i]
bs, nf, dim = input_feats_proj.shape
input_feats_reshaped = torch.reshape(
input_feats_proj, (bs, nf * self.factors[i], dim // self.factors[i])
)
self.feats_reshaped.append(input_feats_reshaped)
# 3rd step : drop the few last frames
m = min([x.shape[1] for x in self.feats_reshaped])
self.feats_final = [x[:, :m, :] for x in self.feats_reshaped]
input_feats = torch.cat(
self.feats_final, dim=-1
) # change the input size of the preencoder : proj_dim * n_frontends
feats_lens = torch.ones_like(self.feats[0][1]) * (m)
else:
raise NotImplementedError
return input_feats, feats_lens

View File

@@ -0,0 +1,143 @@
import copy
import logging
import os
from argparse import Namespace
from typing import Optional
from typing import Tuple
from typing import Union
import humanfriendly
import torch
from typeguard import check_argument_types
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.modules.frontends.frontend import Frontend
from funasr_local.modules.nets_utils import pad_list
from funasr_local.utils.get_default_kwargs import get_default_kwargs
def base_s3prl_setup(args):
args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
args.upstream_model_config = getattr(args, "upstream_model_config", None)
args.upstream_refresh = getattr(args, "upstream_refresh", False)
args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
args.init_ckpt = getattr(args, "init_ckpt", None)
args.verbose = getattr(args, "verbose", False)
args.tile_factor = getattr(args, "tile_factor", 1)
return args
class S3prlFrontend(AbsFrontend):
"""Speech Pretrained Representation frontend structure for ASR."""
def __init__(
self,
fs: Union[int, str] = 16000,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
download_dir: str = None,
multilayer_feature: bool = False,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
if download_dir is not None:
torch.hub.set_dir(download_dir)
self.multilayer_feature = multilayer_feature
self.upstream, self.featurizer = self._get_upstream(frontend_conf)
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
self.output_dim = self.featurizer.output_dim
self.frontend_type = "s3prl"
self.hop_length = self.upstream.get_downsample_rates("key")
def _get_upstream(self, frontend_conf):
"""Get S3PRL upstream model."""
s3prl_args = base_s3prl_setup(
Namespace(**frontend_conf, device="cpu"),
)
self.args = s3prl_args
s3prl_path = None
python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
for p in python_path_list:
if p.endswith("s3prl"):
s3prl_path = p
break
assert s3prl_path is not None
s3prl_upstream = torch.hub.load(
s3prl_path,
s3prl_args.upstream,
ckpt=s3prl_args.upstream_ckpt,
model_config=s3prl_args.upstream_model_config,
refresh=s3prl_args.upstream_refresh,
source="local",
).to("cpu")
if getattr(
s3prl_upstream, "model", None
) is not None and s3prl_upstream.model.__class__.__name__ in [
"Wav2Vec2Model",
"HubertModel",
]:
s3prl_upstream.model.encoder.layerdrop = 0.0
from s3prl.upstream.interfaces import Featurizer
if self.multilayer_feature is None:
feature_selection = "last_hidden_state"
else:
feature_selection = "hidden_states"
s3prl_featurizer = Featurizer(
upstream=s3prl_upstream,
feature_selection=feature_selection,
upstream_device="cpu",
)
return s3prl_upstream, s3prl_featurizer
def _tile_representations(self, feature):
"""Tile up the representations by `tile_factor`.
Input - sequence of representations
shape: (batch_size, seq_len, feature_dim)
Output - sequence of tiled representations
shape: (batch_size, seq_len * factor, feature_dim)
"""
assert (
len(feature.shape) == 3
), "Input argument `feature` has invalid shape: {}".format(feature.shape)
tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
tiled_feature = tiled_feature.reshape(
feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
)
return tiled_feature
def output_size(self) -> int:
return self.output_dim
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
self.upstream.eval()
with torch.no_grad():
feats = self.upstream(wavs)
feats = self.featurizer(wavs, feats)
if self.args.tile_factor != 1:
feats = self._tile_representations(feats)
input_feats = pad_list(feats, 0.0)
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
# Saving CUDA Memory
del feats
return input_feats, feats_lens
def reload_pretrained_parameters(self):
self.upstream.load_state_dict(self.pretrained_params)
logging.info("Pretrained S3PRL frontend model parameters reloaded!")

View File

@@ -0,0 +1,503 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from typing import Tuple
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
from typeguard import check_argument_types
import funasr_local.models.frontend.eend_ola_feature as eend_ola_feature
from funasr_local.models.frontend.abs_frontend import AbsFrontend
def load_cmvn(cmvn_file):
with open(cmvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float32)
vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
return cmvn
def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
means = cmvn[0:1, :dim]
vars = cmvn[1:2, :dim]
inputs += means.to(device)
inputs *= vars.to(device)
return inputs.type(torch.float32)
def apply_lfr(inputs, lfr_m, lfr_n):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n:]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
class WavFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
self.n_mels = n_mels
self.frame_length = frame_length
self.frame_shift = frame_shift
self.filter_length_min = filter_length_min
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
def output_size(self) -> int:
return self.n_mels * self.lfr_m
def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
if self.upsacle_samples:
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs,
snip_edges=self.snip_edges)
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
def forward_fbank(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
def forward_lfr_cmvn(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
mat = input[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
class WavFrontendOnline(AbsFrontend):
"""Conventional frontend structure for streaming ASR/VAD.
"""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
self.n_mels = n_mels
self.frame_length = frame_length
self.frame_shift = frame_shift
self.frame_sample_length = int(self.frame_length * self.fs / 1000)
self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
self.filter_length_min = filter_length_min
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
self.waveforms = None
self.reserve_waveforms = None
self.fbanks = None
self.fbanks_lens = None
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
self.input_cache = None
self.lfr_splice_cache = []
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@staticmethod
def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
return inputs.type(torch.float32)
@staticmethod
# inputs tensor has catted the cache tensor
# def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
# is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
torch.Tensor, torch.Tensor, int]:
"""
Apply lfr with data
"""
LFR_inputs = []
# inputs = torch.vstack((inputs_lfr_cache, inputs))
T = inputs.shape[0] # include the right context
T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
splice_idx = T_lfr
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
if is_final:
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n:]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
else:
# update splice_idx and break the circle
splice_idx = i
break
splice_idx = min(T - 1, splice_idx * lfr_n)
lfr_splice_cache = inputs[splice_idx:, :]
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
@staticmethod
def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
def forward_fbank(
self,
input: torch.Tensor,
input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
if self.input_cache is None:
self.input_cache = torch.empty(0)
input = torch.cat((self.input_cache, input), dim=1)
frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
# update self.in_cache
self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
waveforms = torch.empty(0)
feats_pad = torch.empty(0)
feats_lens = torch.empty(0)
if frame_num:
waveforms = []
feats = []
feats_lens = []
for i in range(batch_size):
waveform = input[i]
# we need accurate wave samples that used for fbank extracting
waveforms.append(
waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
waveforms = torch.stack(waveforms)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
self.fbanks = feats_pad
import copy
self.fbanks_lens = copy.deepcopy(feats_lens)
return waveforms, feats_pad, feats_lens
def get_fbank(self) -> Tuple[torch.Tensor, torch.Tensor]:
return self.fbanks, self.fbanks_lens
def forward_lfr_cmvn(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
is_final: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
lfr_splice_frame_idxs = []
for i in range(batch_size):
mat = input[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
is_final)
if self.cmvn_file is not None:
mat = self.apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
return feats_pad, feats_lens, lfr_splice_frame_idxs
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.shape[0]
assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D
if feats.shape[0]:
# if self.reserve_waveforms is None and self.lfr_m > 1:
# self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat(
(self.reserve_waveforms, waveforms), dim=1)
if not self.lfr_splice_cache: # 初始化splice_cache
for i in range(batch_size):
self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
# need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache) # B T D
feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
feats_lengths += lfr_splice_cache_tensor[0].shape[0]
frame_from_waveforms = int(
(self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
if self.lfr_m == 1:
self.reserve_waveforms = None
else:
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
# print('frame_frame: ' + str(frame_from_waveforms))
self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
self.waveforms = self.waveforms[:, :sample_length]
else:
# update self.reserve_waveforms and self.lfr_splice_cache
self.reserve_waveforms = self.waveforms[:,
:-(self.frame_sample_length - self.frame_shift_sample_length)]
for i in range(batch_size):
self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
return torch.empty(0), feats_lengths
else:
if is_final:
self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
feats = torch.stack(self.lfr_splice_cache)
feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
if is_final:
self.cache_reset()
return feats, feats_lengths
def get_waveforms(self):
return self.waveforms
def cache_reset(self):
self.reserve_waveforms = None
self.input_cache = None
self.lfr_splice_cache = []
class WavFrontendMel23(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
def __init__(
self,
fs: int = 16000,
frame_length: int = 25,
frame_shift: int = 10,
lfr_m: int = 1,
lfr_n: int = 1,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.frame_length = frame_length
self.frame_shift = frame_shift
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.n_mels = 23
def output_size(self) -> int:
return self.n_mels * (2 * self.lfr_m + 1)
def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform.numpy()
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
mat = eend_ola_feature.transform(mat)
mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
mat = mat[::self.lfr_n]
mat = torch.from_numpy(mat)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens

View File

@@ -0,0 +1,180 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from typing import Tuple
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
# import kaldifeat
def load_cmvn(cmvn_file):
with open(cmvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn)
return cmvn
def apply_cmvn(inputs, cmvn_file): # noqa
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
cmvn = load_cmvn(cmvn_file)
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
return inputs.type(torch.float32)
def apply_lfr(inputs, lfr_m, lfr_n):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n:]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
# class WavFrontend_kaldifeat(AbsFrontend):
# """Conventional frontend structure for ASR.
# """
#
# def __init__(
# self,
# cmvn_file: str = None,
# fs: int = 16000,
# window: str = 'hamming',
# n_mels: int = 80,
# frame_length: int = 25,
# frame_shift: int = 10,
# lfr_m: int = 1,
# lfr_n: int = 1,
# dither: float = 1.0,
# snip_edges: bool = True,
# upsacle_samples: bool = True,
# device: str = 'cpu',
# **kwargs,
# ):
# super().__init__()
#
# opts = kaldifeat.FbankOptions()
# opts.device = device
# opts.frame_opts.samp_freq = fs
# opts.frame_opts.dither = dither
# opts.frame_opts.window_type = window
# opts.frame_opts.frame_shift_ms = float(frame_shift)
# opts.frame_opts.frame_length_ms = float(frame_length)
# opts.mel_opts.num_bins = n_mels
# opts.energy_floor = 0
# opts.frame_opts.snip_edges = snip_edges
# opts.mel_opts.debug_mel = False
# self.opts = opts
# self.fbank_fn = None
# self.fbank_beg_idx = 0
# self.reset_fbank_status()
#
# self.lfr_m = lfr_m
# self.lfr_n = lfr_n
# self.cmvn_file = cmvn_file
# self.upsacle_samples = upsacle_samples
#
# def output_size(self) -> int:
# return self.n_mels * self.lfr_m
#
# def forward_fbank(
# self,
# input: torch.Tensor,
# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# batch_size = input.size(0)
# feats = []
# feats_lens = []
# for i in range(batch_size):
# waveform_length = input_lengths[i]
# waveform = input[i][:waveform_length]
# waveform = waveform * (1 << 15)
#
# self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
# frames = self.fbank_fn.num_frames_ready
# frames_cur = frames - self.fbank_beg_idx
# mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
# device=self.opts.device)
# for i in range(self.fbank_beg_idx, frames):
# mat[i, :] = self.fbank_fn.get_frame(i)
# self.fbank_beg_idx += frames_cur
#
# feat_length = mat.size(0)
# feats.append(mat)
# feats_lens.append(feat_length)
#
# feats_lens = torch.as_tensor(feats_lens)
# feats_pad = pad_sequence(feats,
# batch_first=True,
# padding_value=0.0)
# return feats_pad, feats_lens
#
# def reset_fbank_status(self):
# self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
# self.fbank_beg_idx = 0
#
# def forward_lfr_cmvn(
# self,
# input: torch.Tensor,
# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# batch_size = input.size(0)
# feats = []
# feats_lens = []
# for i in range(batch_size):
# mat = input[i, :input_lengths[i], :]
# if self.lfr_m != 1 or self.lfr_n != 1:
# mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
# if self.cmvn_file is not None:
# mat = apply_cmvn(mat, self.cmvn_file)
# feat_length = mat.size(0)
# feats.append(mat)
# feats_lens.append(feat_length)
#
# feats_lens = torch.as_tensor(feats_lens)
# feats_pad = pad_sequence(feats,
# batch_first=True,
# padding_value=0.0)
# return feats_pad, feats_lens

View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python3
# 2020, Technische Universität München; Ludwig Kürzinger
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Sliding Window for raw audio input data."""
from funasr_local.models.frontend.abs_frontend import AbsFrontend
import torch
from typeguard import check_argument_types
from typing import Tuple
class SlidingWindow(AbsFrontend):
"""Sliding Window.
Provides a sliding window over a batched continuous raw audio tensor.
Optionally, provides padding (Currently not implemented).
Combine this module with a pre-encoder compatible with raw audio data,
for example Sinc convolutions.
Known issues:
Output length is calculated incorrectly if audio shorter than win_length.
WARNING: trailing values are discarded - padding not implemented yet.
There is currently no additional window function applied to input values.
"""
def __init__(
self,
win_length: int = 400,
hop_length: int = 160,
channels: int = 1,
padding: int = None,
fs=None,
):
"""Initialize.
Args:
win_length: Length of frame.
hop_length: Relative starting point of next frame.
channels: Number of input channels.
padding: Padding (placeholder, currently not implemented).
fs: Sampling rate (placeholder for compatibility, not used).
"""
assert check_argument_types()
super().__init__()
self.fs = fs
self.win_length = win_length
self.hop_length = hop_length
self.channels = channels
self.padding = padding
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply a sliding window on the input.
Args:
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
input_lengths: Input lengths within batch.
Returns:
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
Tensor: Output lengths within batch.
"""
input_size = input.size()
B = input_size[0]
T = input_size[1]
C = self.channels
D = self.win_length
# (B, T, C) --> (T, B, C)
continuous = input.view(B, T, C).permute(1, 0, 2)
windowed = continuous.unfold(0, D, self.hop_length)
# (T, B, C, D) --> (B, T, C, D)
output = windowed.permute(1, 0, 2, 3).contiguous()
# After unfold(), windowed lengths change:
output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
return output, output_lengths
def output_size(self) -> int:
"""Return output length of feature dimension D, i.e. the window length."""
return self.win_length

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,61 @@
"""Transducer joint network implementation."""
import torch
from funasr_local.modules.nets_utils import get_activation
class JointNetwork(torch.nn.Module):
"""Transducer joint network module.
Args:
output_size: Output size.
encoder_size: Encoder output size.
decoder_size: Decoder output size..
joint_space_size: Joint space size.
joint_act_type: Type of activation for joint network.
**activation_parameters: Parameters for the activation function.
"""
def __init__(
self,
output_size: int,
encoder_size: int,
decoder_size: int,
joint_space_size: int = 256,
joint_activation_type: str = "tanh",
) -> None:
"""Construct a JointNetwork object."""
super().__init__()
self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size)
self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False)
self.lin_out = torch.nn.Linear(joint_space_size, output_size)
self.joint_activation = get_activation(
joint_activation_type
)
def forward(
self,
enc_out: torch.Tensor,
dec_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""Joint computation of encoder and decoder hidden state sequences.
Args:
enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
Returns:
joint_out: Joint output state sequences. (B, T, U, D_out)
"""
if project_input:
joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out))
else:
joint_out = self.joint_activation(enc_out + dec_out)
return self.lin_out(joint_out)

View File

View File

@@ -0,0 +1,98 @@
import torch
from typing import Tuple
from typing import Union
from funasr_local.modules.nets_utils import make_non_pad_mask
from torch.nn import functional as F
import math
VAR2STD_EPSILON = 1e-12
class StatisticPooling(torch.nn.Module):
def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12):
super(StatisticPooling, self).__init__()
if isinstance(pooling_dim, int):
pooling_dim = (pooling_dim, )
self.pooling_dim = pooling_dim
self.eps = eps
def forward(self, xs_pad, ilens=None):
# xs_pad in (Batch, Channel, Time, Frequency)
if ilens is None:
masks = torch.ones_like(xs_pad).to(xs_pad)
else:
masks = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
mean = (torch.sum(xs_pad, dim=self.pooling_dim, keepdim=True) /
torch.sum(masks, dim=self.pooling_dim, keepdim=True))
squared_difference = torch.pow(xs_pad - mean, 2.0)
variance = (torch.sum(squared_difference, dim=self.pooling_dim, keepdim=True) /
torch.sum(masks, dim=self.pooling_dim, keepdim=True))
for i in reversed(self.pooling_dim):
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
mask = torch.less_equal(variance, self.eps).float()
variance = (1.0 - mask) * variance + mask * self.eps
stddev = torch.sqrt(variance)
stat_pooling = torch.cat([mean, stddev], dim=1)
return stat_pooling
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
return {}
def statistic_pooling(
xs_pad: torch.Tensor,
ilens: torch.Tensor = None,
pooling_dim: Tuple = (2, 3)
) -> torch.Tensor:
# xs_pad in (Batch, Channel, Time, Frequency)
if ilens is None:
seq_mask = torch.ones_like(xs_pad).to(xs_pad)
else:
seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
mean = (torch.sum(xs_pad, dim=pooling_dim, keepdim=True) /
torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
squared_difference = torch.pow(xs_pad - mean, 2.0)
variance = (torch.sum(squared_difference, dim=pooling_dim, keepdim=True) /
torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
for i in reversed(pooling_dim):
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
value_mask = torch.less_equal(variance, VAR2STD_EPSILON).float()
variance = (1.0 - value_mask) * variance + value_mask * VAR2STD_EPSILON
stddev = torch.sqrt(variance)
stat_pooling = torch.cat([mean, stddev], dim=1)
return stat_pooling
def windowed_statistic_pooling(
xs_pad: torch.Tensor,
ilens: torch.Tensor = None,
pooling_dim: Tuple = (2, 3),
pooling_size: int = 20,
pooling_stride: int = 1
) -> Tuple[torch.Tensor, int]:
# xs_pad in (Batch, Channel, Time, Frequency)
tt = xs_pad.shape[2]
num_chunk = int(math.ceil(tt / pooling_stride))
pad = pooling_size // 2
if len(xs_pad.shape) == 4:
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
else:
features = F.pad(xs_pad, (pad, pad), "reflect")
stat_list = []
for i in range(num_chunk):
# B x C
st, ed = i*pooling_stride, i*pooling_stride+pooling_size
stat = statistic_pooling(features[:, :, st: ed], pooling_dim=pooling_dim)
stat_list.append(stat.unsqueeze(2))
# B x C x T
return torch.cat(stat_list, dim=2), ilens / pooling_stride

View File

@@ -0,0 +1,17 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
class AbsPostEncoder(torch.nn.Module, ABC):
@abstractmethod
def output_size(self) -> int:
raise NotImplementedError
@abstractmethod
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@@ -0,0 +1,115 @@
#!/usr/bin/env python3
# 2021, University of Stuttgart; Pavel Denisov
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Hugging Face Transformers PostEncoder."""
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.models.postencoder.abs_postencoder import AbsPostEncoder
from typeguard import check_argument_types
from typing import Tuple
import copy
import logging
import torch
try:
from transformers import AutoModel
is_transformers_available = True
except ImportError:
is_transformers_available = False
class HuggingFaceTransformersPostEncoder(AbsPostEncoder):
"""Hugging Face Transformers PostEncoder."""
def __init__(
self,
input_size: int,
model_name_or_path: str,
):
"""Initialize the module."""
assert check_argument_types()
super().__init__()
if not is_transformers_available:
raise ImportError(
"`transformers` is not available. Please install it via `pip install"
" transformers` or `cd /path/to/espnet/tools && . ./activate_python.sh"
" && ./installers/install_transformers.sh`."
)
model = AutoModel.from_pretrained(model_name_or_path)
if hasattr(model, "encoder"):
self.transformer = model.encoder
else:
self.transformer = model
if hasattr(self.transformer, "embed_tokens"):
del self.transformer.embed_tokens
if hasattr(self.transformer, "wte"):
del self.transformer.wte
if hasattr(self.transformer, "word_embedding"):
del self.transformer.word_embedding
self.pretrained_params = copy.deepcopy(self.transformer.state_dict())
if (
self.transformer.config.is_encoder_decoder
or self.transformer.config.model_type in ["xlnet", "t5"]
):
self.use_inputs_embeds = True
self.extend_attention_mask = False
elif self.transformer.config.model_type == "gpt2":
self.use_inputs_embeds = True
self.extend_attention_mask = True
else:
self.use_inputs_embeds = False
self.extend_attention_mask = True
self.linear_in = torch.nn.Linear(
input_size, self.transformer.config.hidden_size
)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward."""
input = self.linear_in(input)
args = {"return_dict": True}
mask = (~make_pad_mask(input_lengths)).to(input.device).float()
if self.extend_attention_mask:
args["attention_mask"] = _extend_attention_mask(mask)
else:
args["attention_mask"] = mask
if self.use_inputs_embeds:
args["inputs_embeds"] = input
else:
args["hidden_states"] = input
if self.transformer.config.model_type == "mpnet":
args["head_mask"] = [None for _ in self.transformer.layer]
output = self.transformer(**args).last_hidden_state
return output, input_lengths
def reload_pretrained_parameters(self):
self.transformer.load_state_dict(self.pretrained_params)
logging.info("Pretrained Transformers model parameters reloaded!")
def output_size(self) -> int:
"""Get the output size."""
return self.transformer.config.hidden_size
def _extend_attention_mask(mask: torch.Tensor) -> torch.Tensor:
mask = mask[:, None, None, :]
mask = (1.0 - mask) * -10000.0
return mask

View File

@@ -0,0 +1,748 @@
import torch
from torch import nn
import logging
import numpy as np
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.streaming_utils.utils import sequence_mask
class CifPredictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
super(CifPredictor, self).__init__()
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
self.cif_output = nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
target_label_length=None):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(self,
alphas: torch.Tensor = None,
encoder_sequence_length: torch.Tensor = None):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
int_type).to(encoder_sequence_length.device)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
class CifPredictorV2(nn.Module):
def __init__(self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.0,
tf2torch_tensor_name_prefix_torch="predictor",
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
tail_mask=True,
):
super(CifPredictorV2, self).__init__()
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
self.cif_output = nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.tail_mask = tail_mask
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
target_label_length=None):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
if self.tail_mask:
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
else:
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
def forward_chunk(self, hidden, cache=None):
batch_size, len_time, hidden_size = hidden.shape
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
alphas = alphas.squeeze(-1)
token_length = []
list_fires = []
list_frames = []
cache_alphas = []
cache_hiddens = []
if cache is not None and "chunk_size" in cache:
alphas[:, :cache["chunk_size"][0]] = 0.0
alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
hidden = torch.cat((hidden, tail_hidden), dim=1)
alphas = torch.cat((alphas, tail_alphas), dim=1)
len_time = alphas.shape[1]
for b in range(batch_size):
integrate = 0.0
frames = torch.zeros((hidden_size), device=hidden.device)
list_frame = []
list_fire = []
for t in range(len_time):
alpha = alphas[b][t]
if alpha + integrate < self.threshold:
integrate += alpha
list_fire.append(integrate)
frames += alpha * hidden[b][t]
else:
frames += (self.threshold - integrate) * hidden[b][t]
list_frame.append(frames)
integrate += alpha
list_fire.append(integrate)
integrate -= self.threshold
frames = integrate * hidden[b][t]
cache_alphas.append(integrate)
if integrate > 0.0:
cache_hiddens.append(frames / integrate)
else:
cache_hiddens.append(frames)
token_length.append(torch.tensor(len(list_frame), device=alphas.device))
list_fires.append(list_fire)
list_frames.append(list_frame)
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
max_token_len = max(token_length)
if max_token_len == 0:
return hidden, torch.stack(token_length, 0)
list_ls = []
for b in range(batch_size):
pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
if token_length[b] == 0:
list_ls.append(pad_frames)
else:
list_frames[b] = torch.stack(list_frames[b])
list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
return torch.stack(list_ls, 0), torch.stack(token_length, 0)
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
if b > 1:
alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
else:
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(self,
alphas: torch.Tensor = None,
encoder_sequence_length: torch.Tensor = None):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
int_type).to(encoder_sequence_length.device)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
## predictor
"{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
}, # (256,256,3),(3,256,256)
"{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.cif_output.weight".format(tensor_name_prefix_torch):
{"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1,256),(1,256,1)
"{}.cif_output.bias".format(tensor_name_prefix_torch):
{"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1,),(1,)
}
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
names = name.split('.')
if names[0] == self.tf2torch_tensor_name_prefix_torch:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
return var_dict_torch_update
class mae_loss(nn.Module):
def __init__(self, normalize_length=False):
super(mae_loss, self).__init__()
self.normalize_length = normalize_length
self.criterion = torch.nn.L1Loss(reduction='sum')
def forward(self, token_length, pre_token_length):
loss_token_normalizer = token_length.size(0)
if self.normalize_length:
loss_token_normalizer = token_length.sum().type(torch.float32)
loss = self.criterion(token_length, pre_token_length)
loss = loss / loss_token_normalizer
return loss
def cif(hidden, alphas, threshold):
batch_size, len_time, hidden_size = hidden.size()
# loop varss
integrate = torch.zeros([batch_size], device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(fire_place,
integrate - torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place,
distribution_completion,
alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
remainds[:, None] * hidden[:, t, :],
frame)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
list_ls = []
len_labels = torch.round(alphas.sum(-1)).int()
max_label_len = len_labels.max()
for b in range(batch_size):
fire = fires[b, :]
l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
list_ls.append(torch.cat([l, pad_l], 0))
return torch.stack(list_ls, 0), fires
def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(fire_place,
integrate - torch.ones([batch_size], device=alphas.device),
integrate)
fires = torch.stack(list_fires, 1)
return fires
class CifPredictorV3(nn.Module):
def __init__(self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.0,
tf2torch_tensor_name_prefix_torch="predictor",
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
smooth_factor2=1.0,
noise_threshold2=0,
upsample_times=5,
upsample_type="cnn",
use_cif1_cnn=True,
tail_mask=True,
):
super(CifPredictorV3, self).__init__()
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
self.cif_output = nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.upsample_times = upsample_times
self.upsample_type = upsample_type
self.use_cif1_cnn = use_cif1_cnn
if self.upsample_type == 'cnn':
self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
self.cif_output2 = nn.Linear(idim, 1)
elif self.upsample_type == 'cnn_blstm':
self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
self.cif_output2 = nn.Linear(idim*2, 1)
elif self.upsample_type == 'cnn_attn':
self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
from funasr_local.models.encoder.transformer_encoder import EncoderLayer as TransformerEncoderLayer
from funasr_local.modules.attention import MultiHeadedAttention
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForward
positionwise_layer_args = (
idim,
idim*2,
0.1,
)
self.self_attn = TransformerEncoderLayer(
idim,
MultiHeadedAttention(
4, idim, 0.1
),
PositionwiseFeedForward(*positionwise_layer_args),
0.1,
True, #normalize_before,
False, #concat_after,
)
self.cif_output2 = nn.Linear(idim, 1)
self.smooth_factor2 = smooth_factor2
self.noise_threshold2 = noise_threshold2
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
target_label_length=None):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
# alphas2 is an extra head for timestamp prediction
if not self.use_cif1_cnn:
_output = context
else:
_output = output
if self.upsample_type == 'cnn':
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1,2)
elif self.upsample_type == 'cnn_blstm':
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1,2)
output2, (_, _) = self.blstm(output2)
elif self.upsample_type == 'cnn_attn':
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1,2)
output2, _ = self.self_attn(output2, mask)
# import pdb; pdb.set_trace()
alphas2 = torch.sigmoid(self.cif_output2(output2))
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
# repeat the mask in T demension to match the upsampled length
if mask is not None:
mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
mask2 = mask2.unsqueeze(-1)
alphas2 = alphas2 * mask2
alphas2 = alphas2.squeeze(-1)
token_num2 = alphas2.sum(-1)
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak, token_num2
def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
h = hidden
b = hidden.shape[0]
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
# alphas2 is an extra head for timestamp prediction
if not self.use_cif1_cnn:
_output = context
else:
_output = output
if self.upsample_type == 'cnn':
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1,2)
elif self.upsample_type == 'cnn_blstm':
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1,2)
output2, (_, _) = self.blstm(output2)
elif self.upsample_type == 'cnn_attn':
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1,2)
output2, _ = self.self_attn(output2, mask)
alphas2 = torch.sigmoid(self.cif_output2(output2))
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
# repeat the mask in T demension to match the upsampled length
if mask is not None:
mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
mask2 = mask2.unsqueeze(-1)
alphas2 = alphas2 * mask2
alphas2 = alphas2.squeeze(-1)
_token_num = alphas2.sum(-1)
if token_num is not None:
alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
# re-downsample
ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
# upsampled alphas and cif_peak
us_alphas = alphas2
us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
def gen_frame_alignments(self,
alphas: torch.Tensor = None,
encoder_sequence_length: torch.Tensor = None):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
is_training = self.training
if is_training:
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
else:
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
int_type).to(encoder_sequence_length.device)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()

View File

@@ -0,0 +1,17 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
class AbsPreEncoder(torch.nn.Module, ABC):
@abstractmethod
def output_size(self) -> int:
raise NotImplementedError
@abstractmethod
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
# 2021, Carnegie Mellon University; Xuankai Chang
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Linear Projection."""
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
from typeguard import check_argument_types
from typing import Tuple
import torch
class LinearProjection(AbsPreEncoder):
"""Linear Projection Preencoder."""
def __init__(
self,
input_size: int,
output_size: int,
):
"""Initialize the module."""
assert check_argument_types()
super().__init__()
self.output_dim = output_size
self.linear_out = torch.nn.Linear(input_size, output_size)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward."""
output = self.linear_out(input)
return output, input_lengths # no state in this layer
def output_size(self) -> int:
"""Get the output size."""
return self.output_dim

View File

@@ -0,0 +1,282 @@
#!/usr/bin/env python3
# 2020, Technische Universität München; Ludwig Kürzinger
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Sinc convolutions for raw audio input."""
from collections import OrderedDict
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr_local.layers.sinc_conv import LogCompression
from funasr_local.layers.sinc_conv import SincConv
import humanfriendly
import torch
from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
from typing import Union
class LightweightSincConvs(AbsPreEncoder):
"""Lightweight Sinc Convolutions.
Instead of using precomputed features, end-to-end speech recognition
can also be done directly from raw audio using sinc convolutions, as
described in "Lightweight End-to-End Speech Recognition from Raw Audio
Data Using Sinc-Convolutions" by Kürzinger et al.
https://arxiv.org/abs/2010.07597
To use Sinc convolutions in your model instead of the default f-bank
frontend, set this module as your pre-encoder with `preencoder: sinc`
and use the input of the sliding window frontend with
`frontend: sliding_window` in your yaml configuration file.
So that the process flow is:
Frontend (SlidingWindow) -> SpecAug -> Normalization ->
Pre-encoder (LightweightSincConvs) -> Encoder -> Decoder
Note that this method also performs data augmentation in time domain
(vs. in spectral domain in the default frontend).
Use `plot_sinc_filters.py` to visualize the learned Sinc filters.
"""
def __init__(
self,
fs: Union[int, str, float] = 16000,
in_channels: int = 1,
out_channels: int = 256,
activation_type: str = "leakyrelu",
dropout_type: str = "dropout",
windowing_type: str = "hamming",
scale_type: str = "mel",
):
"""Initialize the module.
Args:
fs: Sample rate.
in_channels: Number of input channels.
out_channels: Number of output channels (for each input channel).
activation_type: Choice of activation function.
dropout_type: Choice of dropout function.
windowing_type: Choice of windowing function.
scale_type: Choice of filter-bank initialization scale.
"""
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
self.fs = fs
self.in_channels = in_channels
self.out_channels = out_channels
self.activation_type = activation_type
self.dropout_type = dropout_type
self.windowing_type = windowing_type
self.scale_type = scale_type
self.choices_dropout = {
"dropout": torch.nn.Dropout,
"spatial": SpatialDropout,
"dropout2d": torch.nn.Dropout2d,
}
if dropout_type not in self.choices_dropout:
raise NotImplementedError(
f"Dropout type has to be one of "
f"{list(self.choices_dropout.keys())}",
)
self.choices_activation = {
"leakyrelu": torch.nn.LeakyReLU,
"relu": torch.nn.ReLU,
}
if activation_type not in self.choices_activation:
raise NotImplementedError(
f"Activation type has to be one of "
f"{list(self.choices_activation.keys())}",
)
# initialization
self._create_sinc_convs()
# Sinc filters require custom initialization
self.espnet_initialization_fn()
def _create_sinc_convs(self):
blocks = OrderedDict()
# SincConvBlock
out_channels = 128
self.filters = SincConv(
self.in_channels,
out_channels,
kernel_size=101,
stride=1,
fs=self.fs,
window_func=self.windowing_type,
scale_type=self.scale_type,
)
block = OrderedDict(
[
("Filters", self.filters),
("LogCompression", LogCompression()),
("BatchNorm", torch.nn.BatchNorm1d(out_channels, affine=True)),
("AvgPool", torch.nn.AvgPool1d(2)),
]
)
blocks["SincConvBlock"] = torch.nn.Sequential(block)
in_channels = out_channels
# First convolutional block, connects the sinc output to the front-end "body"
out_channels = 128
blocks["DConvBlock1"] = self.gen_lsc_block(
in_channels,
out_channels,
depthwise_kernel_size=25,
depthwise_stride=2,
pointwise_groups=0,
avgpool=True,
dropout_probability=0.1,
)
in_channels = out_channels
# Second convolutional block, multiple convolutional layers
out_channels = self.out_channels
for layer in [2, 3, 4]:
blocks[f"DConvBlock{layer}"] = self.gen_lsc_block(
in_channels, out_channels, depthwise_kernel_size=9, depthwise_stride=1
)
in_channels = out_channels
# Third Convolutional block, acts as coupling to encoder
out_channels = self.out_channels
blocks["DConvBlock5"] = self.gen_lsc_block(
in_channels,
out_channels,
depthwise_kernel_size=7,
depthwise_stride=1,
pointwise_groups=0,
)
self.blocks = torch.nn.Sequential(blocks)
def gen_lsc_block(
self,
in_channels: int,
out_channels: int,
depthwise_kernel_size: int = 9,
depthwise_stride: int = 1,
depthwise_groups=None,
pointwise_groups=0,
dropout_probability: float = 0.15,
avgpool=False,
):
"""Generate a convolutional block for Lightweight Sinc convolutions.
Each block consists of either a depthwise or a depthwise-separable
convolutions together with dropout, (batch-)normalization layer, and
an optional average-pooling layer.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
depthwise_kernel_size: Kernel size of the depthwise convolution.
depthwise_stride: Stride of the depthwise convolution.
depthwise_groups: Number of groups of the depthwise convolution.
pointwise_groups: Number of groups of the pointwise convolution.
dropout_probability: Dropout probability in the block.
avgpool: If True, an AvgPool layer is inserted.
Returns:
torch.nn.Sequential: Neural network building block.
"""
block = OrderedDict()
if not depthwise_groups:
# GCD(in_channels, out_channels) to prevent size mismatches
depthwise_groups, r = in_channels, out_channels
while r != 0:
depthwise_groups, r = depthwise_groups, depthwise_groups % r
block["depthwise"] = torch.nn.Conv1d(
in_channels,
out_channels,
depthwise_kernel_size,
depthwise_stride,
groups=depthwise_groups,
)
if pointwise_groups:
block["pointwise"] = torch.nn.Conv1d(
out_channels, out_channels, 1, 1, groups=pointwise_groups
)
block["activation"] = self.choices_activation[self.activation_type]()
block["batchnorm"] = torch.nn.BatchNorm1d(out_channels, affine=True)
if avgpool:
block["avgpool"] = torch.nn.AvgPool1d(2)
block["dropout"] = self.choices_dropout[self.dropout_type](dropout_probability)
return torch.nn.Sequential(block)
def espnet_initialization_fn(self):
"""Initialize sinc filters with filterbank values."""
self.filters.init_filters()
for block in self.blocks:
for layer in block:
if type(layer) == torch.nn.BatchNorm1d and layer.affine:
layer.weight.data[:] = 1.0
layer.bias.data[:] = 0.0
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply Lightweight Sinc Convolutions.
The input shall be formatted as (B, T, C_in, D_in)
with B as batch size, T as time dimension, C_in as channels,
and D_in as feature dimension.
The output will then be (B, T, C_out*D_out)
with C_out and D_out as output dimensions.
The current module structure only handles D_in=400, so that D_out=1.
Remark for the multichannel case: C_out is the number of out_channels
given at initialization multiplied with C_in.
"""
# Transform input data:
# (B, T, C_in, D_in) -> (B*T, C_in, D_in)
B, T, C_in, D_in = input.size()
input_frames = input.view(B * T, C_in, D_in)
output_frames = self.blocks.forward(input_frames)
# ---TRANSFORM: (B*T, C_out, D_out) -> (B, T, C_out*D_out)
_, C_out, D_out = output_frames.size()
output_frames = output_frames.view(B, T, C_out * D_out)
return output_frames, input_lengths # no state in this layer
def output_size(self) -> int:
"""Get the output size."""
return self.out_channels * self.in_channels
class SpatialDropout(torch.nn.Module):
"""Spatial dropout module.
Apply dropout to full channels on tensors of input (B, C, D)
"""
def __init__(
self,
dropout_probability: float = 0.15,
shape: Optional[Union[tuple, list]] = None,
):
"""Initialize.
Args:
dropout_probability: Dropout probability.
shape (tuple, list): Shape of input tensors.
"""
assert check_argument_types()
super().__init__()
if shape is None:
shape = (0, 2, 1)
self.dropout = torch.nn.Dropout2d(dropout_probability)
self.shape = (shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward of spatial dropout module."""
y = x.permute(*self.shape)
y = self.dropout(y)
return y.permute(*self.shape)

View File

View File

@@ -0,0 +1,18 @@
from typing import Optional
from typing import Tuple
import torch
class AbsSpecAug(torch.nn.Module):
"""Abstract class for the augmentation of spectrogram
The process-flow:
Frontend -> SpecAug -> Normalization -> Encoder -> Decoder
"""
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError

View File

@@ -0,0 +1,184 @@
"""SpecAugment module."""
from typing import Optional
from typing import Sequence
from typing import Union
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.layers.mask_along_axis import MaskAlongAxis
from funasr_local.layers.mask_along_axis import MaskAlongAxisVariableMaxWidth
from funasr_local.layers.mask_along_axis import MaskAlongAxisLFR
from funasr_local.layers.time_warp import TimeWarp
class SpecAug(AbsSpecAug):
"""Implementation of SpecAug.
Reference:
Daniel S. Park et al.
"SpecAugment: A Simple Data
Augmentation Method for Automatic Speech Recognition"
.. warning::
When using cuda mode, time_warp doesn't have reproducibility
due to `torch.nn.functional.interpolate`.
"""
def __init__(
self,
apply_time_warp: bool = True,
time_warp_window: int = 5,
time_warp_mode: str = "bicubic",
apply_freq_mask: bool = True,
freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
num_freq_mask: int = 2,
apply_time_mask: bool = True,
time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
num_time_mask: int = 2,
):
if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
raise ValueError(
"Either one of time_warp, time_mask, or freq_mask should be applied"
)
if (
apply_time_mask
and (time_mask_width_range is not None)
and (time_mask_width_ratio_range is not None)
):
raise ValueError(
'Either one of "time_mask_width_range" or '
'"time_mask_width_ratio_range" can be used'
)
super().__init__()
self.apply_time_warp = apply_time_warp
self.apply_freq_mask = apply_freq_mask
self.apply_time_mask = apply_time_mask
if apply_time_warp:
self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
else:
self.time_warp = None
if apply_freq_mask:
self.freq_mask = MaskAlongAxis(
dim="freq",
mask_width_range=freq_mask_width_range,
num_mask=num_freq_mask,
)
else:
self.freq_mask = None
if apply_time_mask:
if time_mask_width_range is not None:
self.time_mask = MaskAlongAxis(
dim="time",
mask_width_range=time_mask_width_range,
num_mask=num_time_mask,
)
elif time_mask_width_ratio_range is not None:
self.time_mask = MaskAlongAxisVariableMaxWidth(
dim="time",
mask_width_ratio_range=time_mask_width_ratio_range,
num_mask=num_time_mask,
)
else:
raise ValueError(
'Either one of "time_mask_width_range" or '
'"time_mask_width_ratio_range" should be used.'
)
else:
self.time_mask = None
def forward(self, x, x_lengths=None):
if self.time_warp is not None:
x, x_lengths = self.time_warp(x, x_lengths)
if self.freq_mask is not None:
x, x_lengths = self.freq_mask(x, x_lengths)
if self.time_mask is not None:
x, x_lengths = self.time_mask(x, x_lengths)
return x, x_lengths
class SpecAugLFR(AbsSpecAug):
"""Implementation of SpecAug.
lfr_ratelow frame rate
"""
def __init__(
self,
apply_time_warp: bool = True,
time_warp_window: int = 5,
time_warp_mode: str = "bicubic",
apply_freq_mask: bool = True,
freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
num_freq_mask: int = 2,
lfr_rate: int = 0,
apply_time_mask: bool = True,
time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
num_time_mask: int = 2,
):
if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
raise ValueError(
"Either one of time_warp, time_mask, or freq_mask should be applied"
)
if (
apply_time_mask
and (time_mask_width_range is not None)
and (time_mask_width_ratio_range is not None)
):
raise ValueError(
'Either one of "time_mask_width_range" or '
'"time_mask_width_ratio_range" can be used'
)
super().__init__()
self.apply_time_warp = apply_time_warp
self.apply_freq_mask = apply_freq_mask
self.apply_time_mask = apply_time_mask
if apply_time_warp:
self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
else:
self.time_warp = None
if apply_freq_mask:
self.freq_mask = MaskAlongAxisLFR(
dim="freq",
mask_width_range=freq_mask_width_range,
num_mask=num_freq_mask,
lfr_rate=lfr_rate+1,
)
else:
self.freq_mask = None
if apply_time_mask:
if time_mask_width_range is not None:
self.time_mask = MaskAlongAxisLFR(
dim="time",
mask_width_range=time_mask_width_range,
num_mask=num_time_mask,
lfr_rate=lfr_rate + 1,
)
elif time_mask_width_ratio_range is not None:
self.time_mask = MaskAlongAxisVariableMaxWidth(
dim="time",
mask_width_ratio_range=time_mask_width_ratio_range,
num_mask=num_time_mask,
)
else:
raise ValueError(
'Either one of "time_mask_width_range" or '
'"time_mask_width_ratio_range" should be used.'
)
else:
self.time_mask = None
def forward(self, x, x_lengths=None):
if self.time_warp is not None:
x, x_lengths = self.time_warp(x, x_lengths)
if self.freq_mask is not None:
x, x_lengths = self.freq_mask(x, x_lengths)
if self.time_mask is not None:
x, x_lengths = self.time_mask(x, x_lengths)
return x, x_lengths

View File

@@ -0,0 +1,133 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr_local.modules.embedding import SinusoidalPositionEncoder
#from funasr_local.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr_local.models.encoder.sanm_encoder import SANMEncoder as Encoder
#from funasr_local.modules.mask import subsequent_n_mask
from funasr_local.train.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
vocab_size: int,
punc_size: int,
pos_enc: str = None,
embed_unit: int = 128,
att_unit: int = 256,
head: int = 2,
unit: int = 1024,
layer: int = 4,
dropout_rate: float = 0.5,
):
super().__init__()
if pos_enc == "sinusoidal":
# pos_enc_class = PositionalEncoding
pos_enc_class = SinusoidalPositionEncoder
elif pos_enc is None:
def pos_enc_class(*args, **kwargs):
return nn.Sequential() # indentity
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(vocab_size, embed_unit)
self.encoder = Encoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
input_layer="pe",
# pos_enc_class=pos_enc_class,
padding_idx=0,
)
self.decoder = nn.Linear(att_unit, punc_size)
# def _target_mask(self, ys_in_pad):
# ys_mask = ys_in_pad != 0
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
# return ys_mask.unsqueeze(-2) & m
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y, None
def with_vad(self):
return False
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list

View File

@@ -0,0 +1,136 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr_local.modules.embedding import SinusoidalPositionEncoder
from funasr_local.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
from funasr_local.train.abs_model import AbsPunctuation
class VadRealtimeTransformer(AbsPunctuation):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
vocab_size: int,
punc_size: int,
pos_enc: str = None,
embed_unit: int = 128,
att_unit: int = 256,
head: int = 2,
unit: int = 1024,
layer: int = 4,
dropout_rate: float = 0.5,
kernel_size: int = 11,
sanm_shfit: int = 0,
):
super().__init__()
if pos_enc == "sinusoidal":
# pos_enc_class = PositionalEncoding
pos_enc_class = SinusoidalPositionEncoder
elif pos_enc is None:
def pos_enc_class(*args, **kwargs):
return nn.Sequential() # indentity
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(vocab_size, embed_unit)
self.encoder = Encoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
input_layer="pe",
# pos_enc_class=pos_enc_class,
padding_idx=0,
kernel_size=kernel_size,
sanm_shfit=sanm_shfit,
)
self.decoder = nn.Linear(att_unit, punc_size)
# def _target_mask(self, ys_in_pad):
# ys_mask = ys_in_pad != 0
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
# return ys_mask.unsqueeze(-2) & m
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths, vad_indexes)
y = self.decoder(h)
return y, None
def with_vad(self):
return True
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list