add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

View File

View File

@@ -0,0 +1,17 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
class AbsFrontend(torch.nn.Module, ABC):
@abstractmethod
def output_size(self) -> int:
raise NotImplementedError
@abstractmethod
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@@ -0,0 +1,258 @@
import copy
from typing import Optional
from typing import Tuple
from typing import Union
import humanfriendly
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
from typeguard import check_argument_types
from funasr_local.layers.log_mel import LogMel
from funasr_local.layers.stft import Stft
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.modules.frontends.frontend import Frontend
from funasr_local.utils.get_default_kwargs import get_default_kwargs
class DefaultFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=center,
window=window,
normalized=normalized,
onesided=onesided,
)
else:
self.stft = None
self.apply_stft = apply_stft
if frontend_conf is not None:
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
else:
self.frontend = None
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.n_mels = n_mels
self.frontend_type = "default"
def output_size(self) -> int:
return self.n_mels
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 3. [Multi channel case]: Select a channel
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
# Select 1ch randomly
ch = np.random.randint(input_stft.size(2))
input_stft = input_stft[:, :, ch, :]
else:
# Use the first channel
input_stft = input_stft[:, :, 0, :]
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
return input_feats, feats_lens
def _compute_stft(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> torch.Tensor:
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
return input_stft, feats_lens
class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
frame_length: int = None,
frame_shift: int = None,
lfr_m: int = None,
lfr_n: int = None,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=center,
window=window,
normalized=normalized,
onesided=onesided,
)
else:
self.stft = None
self.apply_stft = apply_stft
if frontend_conf is not None:
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
else:
self.frontend = None
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.n_mels = n_mels
self.frontend_type = "multichannelfrontend"
def output_size(self) -> int:
return self.n_mels
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
#import pdb;pdb.set_trace()
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
if isinstance(input, ComplexTensor):
input_stft = input
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
bt = input_feats.size(0)
if input_feats.dim() ==4:
channel_size = input_feats.size(2)
# batch * channel * T * D
#pdb.set_trace()
input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
# input_feats = input_feats.transpose(1,2)
# batch * channel
feats_lens = feats_lens.repeat(1,channel_size).squeeze()
else:
channel_size = 1
return input_feats, feats_lens, channel_size
def _compute_stft(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> torch.Tensor:
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
return input_stft, feats_lens

View File

@@ -0,0 +1,51 @@
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This module is for computing audio features
import librosa
import numpy as np
def transform(Y, dtype=np.float32):
Y = np.abs(Y)
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
Y = np.dot(Y ** 2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
mean = np.mean(Y, axis=0)
Y = Y - mean
return Y.astype(dtype)
def subsample(Y, T, subsampling=1):
Y_ss = Y[::subsampling]
T_ss = T[::subsampling]
return Y_ss, T_ss
def splice(Y, context_size=0):
Y_pad = np.pad(
Y,
[(context_size, context_size), (0, 0)],
'constant')
Y_spliced = np.lib.stride_tricks.as_strided(
np.ascontiguousarray(Y_pad),
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
(Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
return Y_spliced
def stft(
data,
frame_size=1024,
frame_shift=256):
fft_size = 1 << (frame_size - 1).bit_length()
if len(data) % frame_shift == 0:
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
hop_length=frame_shift).T[:-1]
else:
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
hop_length=frame_shift).T

View File

@@ -0,0 +1,146 @@
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.frontend.default import DefaultFrontend
from funasr_local.models.frontend.s3prl import S3prlFrontend
import numpy as np
import torch
from typeguard import check_argument_types
from typing import Tuple
class FusedFrontends(AbsFrontend):
def __init__(
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
):
assert check_argument_types()
super().__init__()
self.align_method = (
align_method # fusing method : linear_projection only for now
)
self.proj_dim = proj_dim # dim of the projection done on each frontend
self.frontends = [] # list of the frontends to combine
for i, frontend in enumerate(frontends):
frontend_type = frontend["frontend_type"]
if frontend_type == "default":
n_mels, fs, n_fft, win_length, hop_length = (
frontend.get("n_mels", 80),
fs,
frontend.get("n_fft", 512),
frontend.get("win_length"),
frontend.get("hop_length", 128),
)
window, center, normalized, onesided = (
frontend.get("window", "hann"),
frontend.get("center", True),
frontend.get("normalized", False),
frontend.get("onesided", True),
)
fmin, fmax, htk, apply_stft = (
frontend.get("fmin", None),
frontend.get("fmax", None),
frontend.get("htk", False),
frontend.get("apply_stft", True),
)
self.frontends.append(
DefaultFrontend(
n_mels=n_mels,
n_fft=n_fft,
fs=fs,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
fmin=fmin,
fmax=fmax,
htk=htk,
apply_stft=apply_stft,
)
)
elif frontend_type == "s3prl":
frontend_conf, download_dir, multilayer_feature = (
frontend.get("frontend_conf"),
frontend.get("download_dir"),
frontend.get("multilayer_feature"),
)
self.frontends.append(
S3prlFrontend(
fs=fs,
frontend_conf=frontend_conf,
download_dir=download_dir,
multilayer_feature=multilayer_feature,
)
)
else:
raise NotImplementedError # frontends are only default or s3prl
self.frontends = torch.nn.ModuleList(self.frontends)
self.gcd = np.gcd.reduce([frontend.hop_length for frontend in self.frontends])
self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
if torch.cuda.is_available():
dev = "cuda"
else:
dev = "cpu"
if self.align_method == "linear_projection":
self.projection_layers = [
torch.nn.Linear(
in_features=frontend.output_size(),
out_features=self.factors[i] * self.proj_dim,
)
for i, frontend in enumerate(self.frontends)
]
self.projection_layers = torch.nn.ModuleList(self.projection_layers)
self.projection_layers = self.projection_layers.to(torch.device(dev))
def output_size(self) -> int:
return len(self.frontends) * self.proj_dim
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# step 0 : get all frontends features
self.feats = []
for frontend in self.frontends:
with torch.no_grad():
input_feats, feats_lens = frontend.forward(input, input_lengths)
self.feats.append([input_feats, feats_lens])
if (
self.align_method == "linear_projection"
): # TODO(Dan): to add other align methods
# first step : projections
self.feats_proj = []
for i, frontend in enumerate(self.frontends):
input_feats = self.feats[i][0]
self.feats_proj.append(self.projection_layers[i](input_feats))
# 2nd step : reshape
self.feats_reshaped = []
for i, frontend in enumerate(self.frontends):
input_feats_proj = self.feats_proj[i]
bs, nf, dim = input_feats_proj.shape
input_feats_reshaped = torch.reshape(
input_feats_proj, (bs, nf * self.factors[i], dim // self.factors[i])
)
self.feats_reshaped.append(input_feats_reshaped)
# 3rd step : drop the few last frames
m = min([x.shape[1] for x in self.feats_reshaped])
self.feats_final = [x[:, :m, :] for x in self.feats_reshaped]
input_feats = torch.cat(
self.feats_final, dim=-1
) # change the input size of the preencoder : proj_dim * n_frontends
feats_lens = torch.ones_like(self.feats[0][1]) * (m)
else:
raise NotImplementedError
return input_feats, feats_lens

View File

@@ -0,0 +1,143 @@
import copy
import logging
import os
from argparse import Namespace
from typing import Optional
from typing import Tuple
from typing import Union
import humanfriendly
import torch
from typeguard import check_argument_types
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.modules.frontends.frontend import Frontend
from funasr_local.modules.nets_utils import pad_list
from funasr_local.utils.get_default_kwargs import get_default_kwargs
def base_s3prl_setup(args):
args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
args.upstream_model_config = getattr(args, "upstream_model_config", None)
args.upstream_refresh = getattr(args, "upstream_refresh", False)
args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
args.init_ckpt = getattr(args, "init_ckpt", None)
args.verbose = getattr(args, "verbose", False)
args.tile_factor = getattr(args, "tile_factor", 1)
return args
class S3prlFrontend(AbsFrontend):
"""Speech Pretrained Representation frontend structure for ASR."""
def __init__(
self,
fs: Union[int, str] = 16000,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
download_dir: str = None,
multilayer_feature: bool = False,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
if download_dir is not None:
torch.hub.set_dir(download_dir)
self.multilayer_feature = multilayer_feature
self.upstream, self.featurizer = self._get_upstream(frontend_conf)
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
self.output_dim = self.featurizer.output_dim
self.frontend_type = "s3prl"
self.hop_length = self.upstream.get_downsample_rates("key")
def _get_upstream(self, frontend_conf):
"""Get S3PRL upstream model."""
s3prl_args = base_s3prl_setup(
Namespace(**frontend_conf, device="cpu"),
)
self.args = s3prl_args
s3prl_path = None
python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
for p in python_path_list:
if p.endswith("s3prl"):
s3prl_path = p
break
assert s3prl_path is not None
s3prl_upstream = torch.hub.load(
s3prl_path,
s3prl_args.upstream,
ckpt=s3prl_args.upstream_ckpt,
model_config=s3prl_args.upstream_model_config,
refresh=s3prl_args.upstream_refresh,
source="local",
).to("cpu")
if getattr(
s3prl_upstream, "model", None
) is not None and s3prl_upstream.model.__class__.__name__ in [
"Wav2Vec2Model",
"HubertModel",
]:
s3prl_upstream.model.encoder.layerdrop = 0.0
from s3prl.upstream.interfaces import Featurizer
if self.multilayer_feature is None:
feature_selection = "last_hidden_state"
else:
feature_selection = "hidden_states"
s3prl_featurizer = Featurizer(
upstream=s3prl_upstream,
feature_selection=feature_selection,
upstream_device="cpu",
)
return s3prl_upstream, s3prl_featurizer
def _tile_representations(self, feature):
"""Tile up the representations by `tile_factor`.
Input - sequence of representations
shape: (batch_size, seq_len, feature_dim)
Output - sequence of tiled representations
shape: (batch_size, seq_len * factor, feature_dim)
"""
assert (
len(feature.shape) == 3
), "Input argument `feature` has invalid shape: {}".format(feature.shape)
tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
tiled_feature = tiled_feature.reshape(
feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
)
return tiled_feature
def output_size(self) -> int:
return self.output_dim
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
self.upstream.eval()
with torch.no_grad():
feats = self.upstream(wavs)
feats = self.featurizer(wavs, feats)
if self.args.tile_factor != 1:
feats = self._tile_representations(feats)
input_feats = pad_list(feats, 0.0)
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
# Saving CUDA Memory
del feats
return input_feats, feats_lens
def reload_pretrained_parameters(self):
self.upstream.load_state_dict(self.pretrained_params)
logging.info("Pretrained S3PRL frontend model parameters reloaded!")

View File

@@ -0,0 +1,503 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from typing import Tuple
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
from typeguard import check_argument_types
import funasr_local.models.frontend.eend_ola_feature as eend_ola_feature
from funasr_local.models.frontend.abs_frontend import AbsFrontend
def load_cmvn(cmvn_file):
with open(cmvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float32)
vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
return cmvn
def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
means = cmvn[0:1, :dim]
vars = cmvn[1:2, :dim]
inputs += means.to(device)
inputs *= vars.to(device)
return inputs.type(torch.float32)
def apply_lfr(inputs, lfr_m, lfr_n):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n:]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
class WavFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
self.n_mels = n_mels
self.frame_length = frame_length
self.frame_shift = frame_shift
self.filter_length_min = filter_length_min
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
def output_size(self) -> int:
return self.n_mels * self.lfr_m
def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
if self.upsacle_samples:
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs,
snip_edges=self.snip_edges)
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
def forward_fbank(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
def forward_lfr_cmvn(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
mat = input[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
class WavFrontendOnline(AbsFrontend):
"""Conventional frontend structure for streaming ASR/VAD.
"""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
self.n_mels = n_mels
self.frame_length = frame_length
self.frame_shift = frame_shift
self.frame_sample_length = int(self.frame_length * self.fs / 1000)
self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
self.filter_length_min = filter_length_min
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
self.waveforms = None
self.reserve_waveforms = None
self.fbanks = None
self.fbanks_lens = None
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
self.input_cache = None
self.lfr_splice_cache = []
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@staticmethod
def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
return inputs.type(torch.float32)
@staticmethod
# inputs tensor has catted the cache tensor
# def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
# is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
torch.Tensor, torch.Tensor, int]:
"""
Apply lfr with data
"""
LFR_inputs = []
# inputs = torch.vstack((inputs_lfr_cache, inputs))
T = inputs.shape[0] # include the right context
T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
splice_idx = T_lfr
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
if is_final:
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n:]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
else:
# update splice_idx and break the circle
splice_idx = i
break
splice_idx = min(T - 1, splice_idx * lfr_n)
lfr_splice_cache = inputs[splice_idx:, :]
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
@staticmethod
def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
def forward_fbank(
self,
input: torch.Tensor,
input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
if self.input_cache is None:
self.input_cache = torch.empty(0)
input = torch.cat((self.input_cache, input), dim=1)
frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
# update self.in_cache
self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
waveforms = torch.empty(0)
feats_pad = torch.empty(0)
feats_lens = torch.empty(0)
if frame_num:
waveforms = []
feats = []
feats_lens = []
for i in range(batch_size):
waveform = input[i]
# we need accurate wave samples that used for fbank extracting
waveforms.append(
waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
waveforms = torch.stack(waveforms)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
self.fbanks = feats_pad
import copy
self.fbanks_lens = copy.deepcopy(feats_lens)
return waveforms, feats_pad, feats_lens
def get_fbank(self) -> Tuple[torch.Tensor, torch.Tensor]:
return self.fbanks, self.fbanks_lens
def forward_lfr_cmvn(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
is_final: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
lfr_splice_frame_idxs = []
for i in range(batch_size):
mat = input[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
is_final)
if self.cmvn_file is not None:
mat = self.apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
return feats_pad, feats_lens, lfr_splice_frame_idxs
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.shape[0]
assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D
if feats.shape[0]:
# if self.reserve_waveforms is None and self.lfr_m > 1:
# self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat(
(self.reserve_waveforms, waveforms), dim=1)
if not self.lfr_splice_cache: # 初始化splice_cache
for i in range(batch_size):
self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
# need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache) # B T D
feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
feats_lengths += lfr_splice_cache_tensor[0].shape[0]
frame_from_waveforms = int(
(self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
if self.lfr_m == 1:
self.reserve_waveforms = None
else:
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
# print('frame_frame: ' + str(frame_from_waveforms))
self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
self.waveforms = self.waveforms[:, :sample_length]
else:
# update self.reserve_waveforms and self.lfr_splice_cache
self.reserve_waveforms = self.waveforms[:,
:-(self.frame_sample_length - self.frame_shift_sample_length)]
for i in range(batch_size):
self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
return torch.empty(0), feats_lengths
else:
if is_final:
self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
feats = torch.stack(self.lfr_splice_cache)
feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
if is_final:
self.cache_reset()
return feats, feats_lengths
def get_waveforms(self):
return self.waveforms
def cache_reset(self):
self.reserve_waveforms = None
self.input_cache = None
self.lfr_splice_cache = []
class WavFrontendMel23(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
def __init__(
self,
fs: int = 16000,
frame_length: int = 25,
frame_shift: int = 10,
lfr_m: int = 1,
lfr_n: int = 1,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.frame_length = frame_length
self.frame_shift = frame_shift
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.n_mels = 23
def output_size(self) -> int:
return self.n_mels * (2 * self.lfr_m + 1)
def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform.numpy()
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
mat = eend_ola_feature.transform(mat)
mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
mat = mat[::self.lfr_n]
mat = torch.from_numpy(mat)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens

View File

@@ -0,0 +1,180 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from typing import Tuple
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
# import kaldifeat
def load_cmvn(cmvn_file):
with open(cmvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn)
return cmvn
def apply_cmvn(inputs, cmvn_file): # noqa
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
cmvn = load_cmvn(cmvn_file)
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
return inputs.type(torch.float32)
def apply_lfr(inputs, lfr_m, lfr_n):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n:]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
# class WavFrontend_kaldifeat(AbsFrontend):
# """Conventional frontend structure for ASR.
# """
#
# def __init__(
# self,
# cmvn_file: str = None,
# fs: int = 16000,
# window: str = 'hamming',
# n_mels: int = 80,
# frame_length: int = 25,
# frame_shift: int = 10,
# lfr_m: int = 1,
# lfr_n: int = 1,
# dither: float = 1.0,
# snip_edges: bool = True,
# upsacle_samples: bool = True,
# device: str = 'cpu',
# **kwargs,
# ):
# super().__init__()
#
# opts = kaldifeat.FbankOptions()
# opts.device = device
# opts.frame_opts.samp_freq = fs
# opts.frame_opts.dither = dither
# opts.frame_opts.window_type = window
# opts.frame_opts.frame_shift_ms = float(frame_shift)
# opts.frame_opts.frame_length_ms = float(frame_length)
# opts.mel_opts.num_bins = n_mels
# opts.energy_floor = 0
# opts.frame_opts.snip_edges = snip_edges
# opts.mel_opts.debug_mel = False
# self.opts = opts
# self.fbank_fn = None
# self.fbank_beg_idx = 0
# self.reset_fbank_status()
#
# self.lfr_m = lfr_m
# self.lfr_n = lfr_n
# self.cmvn_file = cmvn_file
# self.upsacle_samples = upsacle_samples
#
# def output_size(self) -> int:
# return self.n_mels * self.lfr_m
#
# def forward_fbank(
# self,
# input: torch.Tensor,
# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# batch_size = input.size(0)
# feats = []
# feats_lens = []
# for i in range(batch_size):
# waveform_length = input_lengths[i]
# waveform = input[i][:waveform_length]
# waveform = waveform * (1 << 15)
#
# self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
# frames = self.fbank_fn.num_frames_ready
# frames_cur = frames - self.fbank_beg_idx
# mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
# device=self.opts.device)
# for i in range(self.fbank_beg_idx, frames):
# mat[i, :] = self.fbank_fn.get_frame(i)
# self.fbank_beg_idx += frames_cur
#
# feat_length = mat.size(0)
# feats.append(mat)
# feats_lens.append(feat_length)
#
# feats_lens = torch.as_tensor(feats_lens)
# feats_pad = pad_sequence(feats,
# batch_first=True,
# padding_value=0.0)
# return feats_pad, feats_lens
#
# def reset_fbank_status(self):
# self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
# self.fbank_beg_idx = 0
#
# def forward_lfr_cmvn(
# self,
# input: torch.Tensor,
# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# batch_size = input.size(0)
# feats = []
# feats_lens = []
# for i in range(batch_size):
# mat = input[i, :input_lengths[i], :]
# if self.lfr_m != 1 or self.lfr_n != 1:
# mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
# if self.cmvn_file is not None:
# mat = apply_cmvn(mat, self.cmvn_file)
# feat_length = mat.size(0)
# feats.append(mat)
# feats_lens.append(feat_length)
#
# feats_lens = torch.as_tensor(feats_lens)
# feats_pad = pad_sequence(feats,
# batch_first=True,
# padding_value=0.0)
# return feats_pad, feats_lens

View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python3
# 2020, Technische Universität München; Ludwig Kürzinger
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Sliding Window for raw audio input data."""
from funasr_local.models.frontend.abs_frontend import AbsFrontend
import torch
from typeguard import check_argument_types
from typing import Tuple
class SlidingWindow(AbsFrontend):
"""Sliding Window.
Provides a sliding window over a batched continuous raw audio tensor.
Optionally, provides padding (Currently not implemented).
Combine this module with a pre-encoder compatible with raw audio data,
for example Sinc convolutions.
Known issues:
Output length is calculated incorrectly if audio shorter than win_length.
WARNING: trailing values are discarded - padding not implemented yet.
There is currently no additional window function applied to input values.
"""
def __init__(
self,
win_length: int = 400,
hop_length: int = 160,
channels: int = 1,
padding: int = None,
fs=None,
):
"""Initialize.
Args:
win_length: Length of frame.
hop_length: Relative starting point of next frame.
channels: Number of input channels.
padding: Padding (placeholder, currently not implemented).
fs: Sampling rate (placeholder for compatibility, not used).
"""
assert check_argument_types()
super().__init__()
self.fs = fs
self.win_length = win_length
self.hop_length = hop_length
self.channels = channels
self.padding = padding
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply a sliding window on the input.
Args:
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
input_lengths: Input lengths within batch.
Returns:
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
Tensor: Output lengths within batch.
"""
input_size = input.size()
B = input_size[0]
T = input_size[1]
C = self.channels
D = self.win_length
# (B, T, C) --> (T, B, C)
continuous = input.view(B, T, C).permute(1, 0, 2)
windowed = continuous.unfold(0, D, self.hop_length)
# (T, B, C, D) --> (B, T, C, D)
output = windowed.permute(1, 0, 2, 3).contiguous()
# After unfold(), windowed lengths change:
output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
return output, output_lengths
def output_size(self) -> int:
"""Return output length of feature dimension D, i.e. the window length."""
return self.win_length