mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/models/frontend/__init__.py
Normal file
0
funasr_local/models/frontend/__init__.py
Normal file
17
funasr_local/models/frontend/abs_frontend.py
Normal file
17
funasr_local/models/frontend/abs_frontend.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsFrontend(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def output_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
258
funasr_local/models/frontend/default.py
Normal file
258
funasr_local/models/frontend/default.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import copy
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.layers.log_mel import LogMel
|
||||
from funasr_local.layers.stft import Stft
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.modules.frontends.frontend import Frontend
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
class DefaultFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
htk: bool = False,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
apply_stft: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
self.hop_length = hop_length
|
||||
|
||||
if apply_stft:
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
center=center,
|
||||
window=window,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
)
|
||||
else:
|
||||
self.stft = None
|
||||
self.apply_stft = apply_stft
|
||||
|
||||
if frontend_conf is not None:
|
||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||
else:
|
||||
self.frontend = None
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
self.frontend_type = "default"
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
if self.stft is not None:
|
||||
input_stft, feats_lens = self._compute_stft(input, input_lengths)
|
||||
else:
|
||||
input_stft = ComplexTensor(input[..., 0], input[..., 1])
|
||||
feats_lens = input_lengths
|
||||
# 2. [Option] Speech enhancement
|
||||
if self.frontend is not None:
|
||||
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||
# input_stft: (Batch, Length, [Channel], Freq)
|
||||
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||
|
||||
# 3. [Multi channel case]: Select a channel
|
||||
if input_stft.dim() == 4:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
if self.training:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(input_stft.size(2))
|
||||
input_stft = input_stft[:, :, ch, :]
|
||||
else:
|
||||
# Use the first channel
|
||||
input_stft = input_stft[:, :, 0, :]
|
||||
|
||||
# 4. STFT -> Power spectrum
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||
|
||||
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
||||
# input_power: (Batch, [Channel,] Length, Freq)
|
||||
# -> input_feats: (Batch, Length, Dim)
|
||||
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||
|
||||
return input_feats, feats_lens
|
||||
|
||||
def _compute_stft(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# Change torch.Tensor to ComplexTensor
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
return input_stft, feats_lens
|
||||
|
||||
|
||||
|
||||
|
||||
class MultiChannelFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
htk: bool = False,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
apply_stft: bool = True,
|
||||
frame_length: int = None,
|
||||
frame_shift: int = None,
|
||||
lfr_m: int = None,
|
||||
lfr_n: int = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
self.hop_length = hop_length
|
||||
|
||||
if apply_stft:
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
center=center,
|
||||
window=window,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
)
|
||||
else:
|
||||
self.stft = None
|
||||
self.apply_stft = apply_stft
|
||||
|
||||
if frontend_conf is not None:
|
||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||
else:
|
||||
self.frontend = None
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
self.frontend_type = "multichannelfrontend"
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
#import pdb;pdb.set_trace()
|
||||
if self.stft is not None:
|
||||
input_stft, feats_lens = self._compute_stft(input, input_lengths)
|
||||
else:
|
||||
if isinstance(input, ComplexTensor):
|
||||
input_stft = input
|
||||
else:
|
||||
input_stft = ComplexTensor(input[..., 0], input[..., 1])
|
||||
feats_lens = input_lengths
|
||||
# 2. [Option] Speech enhancement
|
||||
if self.frontend is not None:
|
||||
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||
# input_stft: (Batch, Length, [Channel], Freq)
|
||||
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||
# 4. STFT -> Power spectrum
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||
|
||||
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
||||
# input_power: (Batch, [Channel,] Length, Freq)
|
||||
# -> input_feats: (Batch, Length, Dim)
|
||||
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||
bt = input_feats.size(0)
|
||||
if input_feats.dim() ==4:
|
||||
channel_size = input_feats.size(2)
|
||||
# batch * channel * T * D
|
||||
#pdb.set_trace()
|
||||
input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
|
||||
# input_feats = input_feats.transpose(1,2)
|
||||
# batch * channel
|
||||
feats_lens = feats_lens.repeat(1,channel_size).squeeze()
|
||||
else:
|
||||
channel_size = 1
|
||||
return input_feats, feats_lens, channel_size
|
||||
|
||||
def _compute_stft(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# Change torch.Tensor to ComplexTensor
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
return input_stft, feats_lens
|
||||
51
funasr_local/models/frontend/eend_ola_feature.py
Normal file
51
funasr_local/models/frontend/eend_ola_feature.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# This module is for computing audio features
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
|
||||
def transform(Y, dtype=np.float32):
|
||||
Y = np.abs(Y)
|
||||
n_fft = 2 * (Y.shape[1] - 1)
|
||||
sr = 8000
|
||||
n_mels = 23
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||
Y = np.dot(Y ** 2, mel_basis.T)
|
||||
Y = np.log10(np.maximum(Y, 1e-10))
|
||||
mean = np.mean(Y, axis=0)
|
||||
Y = Y - mean
|
||||
return Y.astype(dtype)
|
||||
|
||||
|
||||
def subsample(Y, T, subsampling=1):
|
||||
Y_ss = Y[::subsampling]
|
||||
T_ss = T[::subsampling]
|
||||
return Y_ss, T_ss
|
||||
|
||||
|
||||
def splice(Y, context_size=0):
|
||||
Y_pad = np.pad(
|
||||
Y,
|
||||
[(context_size, context_size), (0, 0)],
|
||||
'constant')
|
||||
Y_spliced = np.lib.stride_tricks.as_strided(
|
||||
np.ascontiguousarray(Y_pad),
|
||||
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
|
||||
(Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
|
||||
return Y_spliced
|
||||
|
||||
|
||||
def stft(
|
||||
data,
|
||||
frame_size=1024,
|
||||
frame_shift=256):
|
||||
fft_size = 1 << (frame_size - 1).bit_length()
|
||||
if len(data) % frame_shift == 0:
|
||||
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||
hop_length=frame_shift).T[:-1]
|
||||
else:
|
||||
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||
hop_length=frame_shift).T
|
||||
146
funasr_local/models/frontend/fused.py
Normal file
146
funasr_local/models/frontend/fused.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.models.frontend.default import DefaultFrontend
|
||||
from funasr_local.models.frontend.s3prl import S3prlFrontend
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class FusedFrontends(AbsFrontend):
|
||||
def __init__(
|
||||
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
|
||||
):
|
||||
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.align_method = (
|
||||
align_method # fusing method : linear_projection only for now
|
||||
)
|
||||
self.proj_dim = proj_dim # dim of the projection done on each frontend
|
||||
self.frontends = [] # list of the frontends to combine
|
||||
|
||||
for i, frontend in enumerate(frontends):
|
||||
frontend_type = frontend["frontend_type"]
|
||||
if frontend_type == "default":
|
||||
n_mels, fs, n_fft, win_length, hop_length = (
|
||||
frontend.get("n_mels", 80),
|
||||
fs,
|
||||
frontend.get("n_fft", 512),
|
||||
frontend.get("win_length"),
|
||||
frontend.get("hop_length", 128),
|
||||
)
|
||||
window, center, normalized, onesided = (
|
||||
frontend.get("window", "hann"),
|
||||
frontend.get("center", True),
|
||||
frontend.get("normalized", False),
|
||||
frontend.get("onesided", True),
|
||||
)
|
||||
fmin, fmax, htk, apply_stft = (
|
||||
frontend.get("fmin", None),
|
||||
frontend.get("fmax", None),
|
||||
frontend.get("htk", False),
|
||||
frontend.get("apply_stft", True),
|
||||
)
|
||||
|
||||
self.frontends.append(
|
||||
DefaultFrontend(
|
||||
n_mels=n_mels,
|
||||
n_fft=n_fft,
|
||||
fs=fs,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
window=window,
|
||||
center=center,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
apply_stft=apply_stft,
|
||||
)
|
||||
)
|
||||
elif frontend_type == "s3prl":
|
||||
frontend_conf, download_dir, multilayer_feature = (
|
||||
frontend.get("frontend_conf"),
|
||||
frontend.get("download_dir"),
|
||||
frontend.get("multilayer_feature"),
|
||||
)
|
||||
self.frontends.append(
|
||||
S3prlFrontend(
|
||||
fs=fs,
|
||||
frontend_conf=frontend_conf,
|
||||
download_dir=download_dir,
|
||||
multilayer_feature=multilayer_feature,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError # frontends are only default or s3prl
|
||||
|
||||
self.frontends = torch.nn.ModuleList(self.frontends)
|
||||
|
||||
self.gcd = np.gcd.reduce([frontend.hop_length for frontend in self.frontends])
|
||||
self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
|
||||
if torch.cuda.is_available():
|
||||
dev = "cuda"
|
||||
else:
|
||||
dev = "cpu"
|
||||
if self.align_method == "linear_projection":
|
||||
self.projection_layers = [
|
||||
torch.nn.Linear(
|
||||
in_features=frontend.output_size(),
|
||||
out_features=self.factors[i] * self.proj_dim,
|
||||
)
|
||||
for i, frontend in enumerate(self.frontends)
|
||||
]
|
||||
self.projection_layers = torch.nn.ModuleList(self.projection_layers)
|
||||
self.projection_layers = self.projection_layers.to(torch.device(dev))
|
||||
|
||||
def output_size(self) -> int:
|
||||
return len(self.frontends) * self.proj_dim
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# step 0 : get all frontends features
|
||||
self.feats = []
|
||||
for frontend in self.frontends:
|
||||
with torch.no_grad():
|
||||
input_feats, feats_lens = frontend.forward(input, input_lengths)
|
||||
self.feats.append([input_feats, feats_lens])
|
||||
|
||||
if (
|
||||
self.align_method == "linear_projection"
|
||||
): # TODO(Dan): to add other align methods
|
||||
|
||||
# first step : projections
|
||||
self.feats_proj = []
|
||||
for i, frontend in enumerate(self.frontends):
|
||||
input_feats = self.feats[i][0]
|
||||
self.feats_proj.append(self.projection_layers[i](input_feats))
|
||||
|
||||
# 2nd step : reshape
|
||||
self.feats_reshaped = []
|
||||
for i, frontend in enumerate(self.frontends):
|
||||
input_feats_proj = self.feats_proj[i]
|
||||
bs, nf, dim = input_feats_proj.shape
|
||||
input_feats_reshaped = torch.reshape(
|
||||
input_feats_proj, (bs, nf * self.factors[i], dim // self.factors[i])
|
||||
)
|
||||
self.feats_reshaped.append(input_feats_reshaped)
|
||||
|
||||
# 3rd step : drop the few last frames
|
||||
m = min([x.shape[1] for x in self.feats_reshaped])
|
||||
self.feats_final = [x[:, :m, :] for x in self.feats_reshaped]
|
||||
|
||||
input_feats = torch.cat(
|
||||
self.feats_final, dim=-1
|
||||
) # change the input size of the preencoder : proj_dim * n_frontends
|
||||
feats_lens = torch.ones_like(self.feats[0][1]) * (m)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return input_feats, feats_lens
|
||||
143
funasr_local/models/frontend/s3prl.py
Normal file
143
funasr_local/models/frontend/s3prl.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.modules.frontends.frontend import Frontend
|
||||
from funasr_local.modules.nets_utils import pad_list
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
def base_s3prl_setup(args):
|
||||
args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
|
||||
args.upstream_model_config = getattr(args, "upstream_model_config", None)
|
||||
args.upstream_refresh = getattr(args, "upstream_refresh", False)
|
||||
args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
|
||||
args.init_ckpt = getattr(args, "init_ckpt", None)
|
||||
args.verbose = getattr(args, "verbose", False)
|
||||
args.tile_factor = getattr(args, "tile_factor", 1)
|
||||
return args
|
||||
|
||||
|
||||
class S3prlFrontend(AbsFrontend):
|
||||
"""Speech Pretrained Representation frontend structure for ASR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
download_dir: str = None,
|
||||
multilayer_feature: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
if download_dir is not None:
|
||||
torch.hub.set_dir(download_dir)
|
||||
|
||||
self.multilayer_feature = multilayer_feature
|
||||
self.upstream, self.featurizer = self._get_upstream(frontend_conf)
|
||||
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
|
||||
self.output_dim = self.featurizer.output_dim
|
||||
self.frontend_type = "s3prl"
|
||||
self.hop_length = self.upstream.get_downsample_rates("key")
|
||||
|
||||
def _get_upstream(self, frontend_conf):
|
||||
"""Get S3PRL upstream model."""
|
||||
s3prl_args = base_s3prl_setup(
|
||||
Namespace(**frontend_conf, device="cpu"),
|
||||
)
|
||||
self.args = s3prl_args
|
||||
|
||||
s3prl_path = None
|
||||
python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
|
||||
for p in python_path_list:
|
||||
if p.endswith("s3prl"):
|
||||
s3prl_path = p
|
||||
break
|
||||
assert s3prl_path is not None
|
||||
|
||||
s3prl_upstream = torch.hub.load(
|
||||
s3prl_path,
|
||||
s3prl_args.upstream,
|
||||
ckpt=s3prl_args.upstream_ckpt,
|
||||
model_config=s3prl_args.upstream_model_config,
|
||||
refresh=s3prl_args.upstream_refresh,
|
||||
source="local",
|
||||
).to("cpu")
|
||||
|
||||
if getattr(
|
||||
s3prl_upstream, "model", None
|
||||
) is not None and s3prl_upstream.model.__class__.__name__ in [
|
||||
"Wav2Vec2Model",
|
||||
"HubertModel",
|
||||
]:
|
||||
s3prl_upstream.model.encoder.layerdrop = 0.0
|
||||
|
||||
from s3prl.upstream.interfaces import Featurizer
|
||||
|
||||
if self.multilayer_feature is None:
|
||||
feature_selection = "last_hidden_state"
|
||||
else:
|
||||
feature_selection = "hidden_states"
|
||||
s3prl_featurizer = Featurizer(
|
||||
upstream=s3prl_upstream,
|
||||
feature_selection=feature_selection,
|
||||
upstream_device="cpu",
|
||||
)
|
||||
|
||||
return s3prl_upstream, s3prl_featurizer
|
||||
|
||||
def _tile_representations(self, feature):
|
||||
"""Tile up the representations by `tile_factor`.
|
||||
|
||||
Input - sequence of representations
|
||||
shape: (batch_size, seq_len, feature_dim)
|
||||
Output - sequence of tiled representations
|
||||
shape: (batch_size, seq_len * factor, feature_dim)
|
||||
"""
|
||||
assert (
|
||||
len(feature.shape) == 3
|
||||
), "Input argument `feature` has invalid shape: {}".format(feature.shape)
|
||||
tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
|
||||
tiled_feature = tiled_feature.reshape(
|
||||
feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
|
||||
)
|
||||
return tiled_feature
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.output_dim
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
|
||||
self.upstream.eval()
|
||||
with torch.no_grad():
|
||||
feats = self.upstream(wavs)
|
||||
feats = self.featurizer(wavs, feats)
|
||||
|
||||
if self.args.tile_factor != 1:
|
||||
feats = self._tile_representations(feats)
|
||||
|
||||
input_feats = pad_list(feats, 0.0)
|
||||
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
|
||||
|
||||
# Saving CUDA Memory
|
||||
del feats
|
||||
|
||||
return input_feats, feats_lens
|
||||
|
||||
def reload_pretrained_parameters(self):
|
||||
self.upstream.load_state_dict(self.pretrained_params)
|
||||
logging.info("Pretrained S3PRL frontend model parameters reloaded!")
|
||||
503
funasr_local/models/frontend/wav_frontend.py
Normal file
503
funasr_local/models/frontend/wav_frontend.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from typeguard import check_argument_types
|
||||
|
||||
import funasr_local.models.frontend.eend_ola_feature as eend_ola_feature
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
|
||||
|
||||
def load_cmvn(cmvn_file):
|
||||
with open(cmvn_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
means_list = []
|
||||
vars_list = []
|
||||
for i in range(len(lines)):
|
||||
line_item = lines[i].split()
|
||||
if line_item[0] == '<AddShift>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
add_shift_line = line_item[3:(len(line_item) - 1)]
|
||||
means_list = list(add_shift_line)
|
||||
continue
|
||||
elif line_item[0] == '<Rescale>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
vars_list = list(rescale_line)
|
||||
continue
|
||||
means = np.array(means_list).astype(np.float32)
|
||||
vars = np.array(vars_list).astype(np.float32)
|
||||
cmvn = np.array([means, vars])
|
||||
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
|
||||
return cmvn
|
||||
|
||||
|
||||
def apply_cmvn(inputs, cmvn): # noqa
|
||||
"""
|
||||
Apply CMVN with mvn data
|
||||
"""
|
||||
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
|
||||
means = cmvn[0:1, :dim]
|
||||
vars = cmvn[1:2, :dim]
|
||||
inputs += means.to(device)
|
||||
inputs *= vars.to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
|
||||
def apply_lfr(inputs, lfr_m, lfr_n):
|
||||
LFR_inputs = []
|
||||
T = inputs.shape[0]
|
||||
T_lfr = int(np.ceil(T / lfr_n))
|
||||
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
|
||||
inputs = torch.vstack((left_padding, inputs))
|
||||
T = T + (lfr_m - 1) // 2
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
|
||||
else: # process last LFR frame
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n:]).view(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = torch.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32)
|
||||
|
||||
|
||||
class WavFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cmvn_file: str = None,
|
||||
fs: int = 16000,
|
||||
window: str = 'hamming',
|
||||
n_mels: int = 80,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
filter_length_min: int = -1,
|
||||
filter_length_max: int = -1,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
dither: float = 1.0,
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.window = window
|
||||
self.n_mels = n_mels
|
||||
self.frame_length = frame_length
|
||||
self.frame_shift = frame_shift
|
||||
self.filter_length_min = filter_length_min
|
||||
self.filter_length_max = filter_length_max
|
||||
self.lfr_m = lfr_m
|
||||
self.lfr_n = lfr_n
|
||||
self.cmvn_file = cmvn_file
|
||||
self.dither = dither
|
||||
self.snip_edges = snip_edges
|
||||
self.upsacle_samples = upsacle_samples
|
||||
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels * self.lfr_m
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform_length = input_lengths[i]
|
||||
waveform = input[i][:waveform_length]
|
||||
if self.upsacle_samples:
|
||||
waveform = waveform * (1 << 15)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=self.n_mels,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift,
|
||||
dither=self.dither,
|
||||
energy_floor=0.0,
|
||||
window_type=self.window,
|
||||
sample_frequency=self.fs,
|
||||
snip_edges=self.snip_edges)
|
||||
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
if self.cmvn is not None:
|
||||
mat = apply_cmvn(mat, self.cmvn)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
|
||||
def forward_fbank(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform_length = input_lengths[i]
|
||||
waveform = input[i][:waveform_length]
|
||||
waveform = waveform * (1 << 15)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=self.n_mels,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift,
|
||||
dither=self.dither,
|
||||
energy_floor=0.0,
|
||||
window_type=self.window,
|
||||
sample_frequency=self.fs)
|
||||
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
|
||||
def forward_lfr_cmvn(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
mat = input[i, :input_lengths[i], :]
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
if self.cmvn is not None:
|
||||
mat = apply_cmvn(mat, self.cmvn)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
|
||||
|
||||
class WavFrontendOnline(AbsFrontend):
|
||||
"""Conventional frontend structure for streaming ASR/VAD.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cmvn_file: str = None,
|
||||
fs: int = 16000,
|
||||
window: str = 'hamming',
|
||||
n_mels: int = 80,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
filter_length_min: int = -1,
|
||||
filter_length_max: int = -1,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
dither: float = 1.0,
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.window = window
|
||||
self.n_mels = n_mels
|
||||
self.frame_length = frame_length
|
||||
self.frame_shift = frame_shift
|
||||
self.frame_sample_length = int(self.frame_length * self.fs / 1000)
|
||||
self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
|
||||
self.filter_length_min = filter_length_min
|
||||
self.filter_length_max = filter_length_max
|
||||
self.lfr_m = lfr_m
|
||||
self.lfr_n = lfr_n
|
||||
self.cmvn_file = cmvn_file
|
||||
self.dither = dither
|
||||
self.snip_edges = snip_edges
|
||||
self.upsacle_samples = upsacle_samples
|
||||
self.waveforms = None
|
||||
self.reserve_waveforms = None
|
||||
self.fbanks = None
|
||||
self.fbanks_lens = None
|
||||
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
|
||||
self.input_cache = None
|
||||
self.lfr_splice_cache = []
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels * self.lfr_m
|
||||
|
||||
@staticmethod
|
||||
def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply CMVN with mvn data
|
||||
"""
|
||||
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
|
||||
means = np.tile(cmvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
|
||||
inputs += torch.from_numpy(means).type(dtype).to(device)
|
||||
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
@staticmethod
|
||||
# inputs tensor has catted the cache tensor
|
||||
# def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
|
||||
# is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, int]:
|
||||
"""
|
||||
Apply lfr with data
|
||||
"""
|
||||
|
||||
LFR_inputs = []
|
||||
# inputs = torch.vstack((inputs_lfr_cache, inputs))
|
||||
T = inputs.shape[0] # include the right context
|
||||
T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
|
||||
splice_idx = T_lfr
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
|
||||
else: # process last LFR frame
|
||||
if is_final:
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n:]).view(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = torch.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
else:
|
||||
# update splice_idx and break the circle
|
||||
splice_idx = i
|
||||
break
|
||||
splice_idx = min(T - 1, splice_idx * lfr_n)
|
||||
lfr_splice_cache = inputs[splice_idx:, :]
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
|
||||
|
||||
@staticmethod
|
||||
def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
|
||||
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
|
||||
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
||||
|
||||
def forward_fbank(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
if self.input_cache is None:
|
||||
self.input_cache = torch.empty(0)
|
||||
input = torch.cat((self.input_cache, input), dim=1)
|
||||
frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
|
||||
# update self.in_cache
|
||||
self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
|
||||
waveforms = torch.empty(0)
|
||||
feats_pad = torch.empty(0)
|
||||
feats_lens = torch.empty(0)
|
||||
if frame_num:
|
||||
waveforms = []
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform = input[i]
|
||||
# we need accurate wave samples that used for fbank extracting
|
||||
waveforms.append(
|
||||
waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
|
||||
waveform = waveform * (1 << 15)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=self.n_mels,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift,
|
||||
dither=self.dither,
|
||||
energy_floor=0.0,
|
||||
window_type=self.window,
|
||||
sample_frequency=self.fs)
|
||||
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
waveforms = torch.stack(waveforms)
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
self.fbanks = feats_pad
|
||||
import copy
|
||||
self.fbanks_lens = copy.deepcopy(feats_lens)
|
||||
return waveforms, feats_pad, feats_lens
|
||||
|
||||
def get_fbank(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.fbanks, self.fbanks_lens
|
||||
|
||||
def forward_lfr_cmvn(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
is_final: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
lfr_splice_frame_idxs = []
|
||||
for i in range(batch_size):
|
||||
mat = input[i, :input_lengths[i], :]
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
# update self.lfr_splice_cache in self.apply_lfr
|
||||
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
|
||||
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
|
||||
is_final)
|
||||
if self.cmvn_file is not None:
|
||||
mat = self.apply_cmvn(mat, self.cmvn)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
|
||||
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.shape[0]
|
||||
assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
|
||||
waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D
|
||||
if feats.shape[0]:
|
||||
# if self.reserve_waveforms is None and self.lfr_m > 1:
|
||||
# self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
|
||||
self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat(
|
||||
(self.reserve_waveforms, waveforms), dim=1)
|
||||
if not self.lfr_splice_cache: # 初始化splice_cache
|
||||
for i in range(batch_size):
|
||||
self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
|
||||
# need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
|
||||
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
|
||||
lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache) # B T D
|
||||
feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
|
||||
feats_lengths += lfr_splice_cache_tensor[0].shape[0]
|
||||
frame_from_waveforms = int(
|
||||
(self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
|
||||
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
|
||||
feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
|
||||
if self.lfr_m == 1:
|
||||
self.reserve_waveforms = None
|
||||
else:
|
||||
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
||||
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
||||
# print('frame_frame: ' + str(frame_from_waveforms))
|
||||
self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
|
||||
sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
|
||||
self.waveforms = self.waveforms[:, :sample_length]
|
||||
else:
|
||||
# update self.reserve_waveforms and self.lfr_splice_cache
|
||||
self.reserve_waveforms = self.waveforms[:,
|
||||
:-(self.frame_sample_length - self.frame_shift_sample_length)]
|
||||
for i in range(batch_size):
|
||||
self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
|
||||
return torch.empty(0), feats_lengths
|
||||
else:
|
||||
if is_final:
|
||||
self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
|
||||
feats = torch.stack(self.lfr_splice_cache)
|
||||
feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
|
||||
feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
|
||||
if is_final:
|
||||
self.cache_reset()
|
||||
return feats, feats_lengths
|
||||
|
||||
def get_waveforms(self):
|
||||
return self.waveforms
|
||||
|
||||
def cache_reset(self):
|
||||
self.reserve_waveforms = None
|
||||
self.input_cache = None
|
||||
self.lfr_splice_cache = []
|
||||
|
||||
|
||||
class WavFrontendMel23(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.frame_length = frame_length
|
||||
self.frame_shift = frame_shift
|
||||
self.lfr_m = lfr_m
|
||||
self.lfr_n = lfr_n
|
||||
self.n_mels = 23
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels * (2 * self.lfr_m + 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform_length = input_lengths[i]
|
||||
waveform = input[i][:waveform_length]
|
||||
waveform = waveform.numpy()
|
||||
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
|
||||
mat = eend_ola_feature.transform(mat)
|
||||
mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
|
||||
mat = mat[::self.lfr_n]
|
||||
mat = torch.from_numpy(mat)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
180
funasr_local/models/frontend/wav_frontend_kaldifeat.py
Normal file
180
funasr_local/models/frontend/wav_frontend_kaldifeat.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from typeguard import check_argument_types
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
# import kaldifeat
|
||||
|
||||
def load_cmvn(cmvn_file):
|
||||
with open(cmvn_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
means_list = []
|
||||
vars_list = []
|
||||
for i in range(len(lines)):
|
||||
line_item = lines[i].split()
|
||||
if line_item[0] == '<AddShift>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
add_shift_line = line_item[3:(len(line_item) - 1)]
|
||||
means_list = list(add_shift_line)
|
||||
continue
|
||||
elif line_item[0] == '<Rescale>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
vars_list = list(rescale_line)
|
||||
continue
|
||||
means = np.array(means_list).astype(np.float)
|
||||
vars = np.array(vars_list).astype(np.float)
|
||||
cmvn = np.array([means, vars])
|
||||
cmvn = torch.as_tensor(cmvn)
|
||||
return cmvn
|
||||
|
||||
|
||||
def apply_cmvn(inputs, cmvn_file): # noqa
|
||||
"""
|
||||
Apply CMVN with mvn data
|
||||
"""
|
||||
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
|
||||
cmvn = load_cmvn(cmvn_file)
|
||||
means = np.tile(cmvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
|
||||
inputs += torch.from_numpy(means).type(dtype).to(device)
|
||||
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
|
||||
def apply_lfr(inputs, lfr_m, lfr_n):
|
||||
LFR_inputs = []
|
||||
T = inputs.shape[0]
|
||||
T_lfr = int(np.ceil(T / lfr_n))
|
||||
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
|
||||
inputs = torch.vstack((left_padding, inputs))
|
||||
T = T + (lfr_m - 1) // 2
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
|
||||
else: # process last LFR frame
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n:]).view(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = torch.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32)
|
||||
|
||||
|
||||
# class WavFrontend_kaldifeat(AbsFrontend):
|
||||
# """Conventional frontend structure for ASR.
|
||||
# """
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# cmvn_file: str = None,
|
||||
# fs: int = 16000,
|
||||
# window: str = 'hamming',
|
||||
# n_mels: int = 80,
|
||||
# frame_length: int = 25,
|
||||
# frame_shift: int = 10,
|
||||
# lfr_m: int = 1,
|
||||
# lfr_n: int = 1,
|
||||
# dither: float = 1.0,
|
||||
# snip_edges: bool = True,
|
||||
# upsacle_samples: bool = True,
|
||||
# device: str = 'cpu',
|
||||
# **kwargs,
|
||||
# ):
|
||||
# super().__init__()
|
||||
#
|
||||
# opts = kaldifeat.FbankOptions()
|
||||
# opts.device = device
|
||||
# opts.frame_opts.samp_freq = fs
|
||||
# opts.frame_opts.dither = dither
|
||||
# opts.frame_opts.window_type = window
|
||||
# opts.frame_opts.frame_shift_ms = float(frame_shift)
|
||||
# opts.frame_opts.frame_length_ms = float(frame_length)
|
||||
# opts.mel_opts.num_bins = n_mels
|
||||
# opts.energy_floor = 0
|
||||
# opts.frame_opts.snip_edges = snip_edges
|
||||
# opts.mel_opts.debug_mel = False
|
||||
# self.opts = opts
|
||||
# self.fbank_fn = None
|
||||
# self.fbank_beg_idx = 0
|
||||
# self.reset_fbank_status()
|
||||
#
|
||||
# self.lfr_m = lfr_m
|
||||
# self.lfr_n = lfr_n
|
||||
# self.cmvn_file = cmvn_file
|
||||
# self.upsacle_samples = upsacle_samples
|
||||
#
|
||||
# def output_size(self) -> int:
|
||||
# return self.n_mels * self.lfr_m
|
||||
#
|
||||
# def forward_fbank(
|
||||
# self,
|
||||
# input: torch.Tensor,
|
||||
# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# batch_size = input.size(0)
|
||||
# feats = []
|
||||
# feats_lens = []
|
||||
# for i in range(batch_size):
|
||||
# waveform_length = input_lengths[i]
|
||||
# waveform = input[i][:waveform_length]
|
||||
# waveform = waveform * (1 << 15)
|
||||
#
|
||||
# self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
||||
# frames = self.fbank_fn.num_frames_ready
|
||||
# frames_cur = frames - self.fbank_beg_idx
|
||||
# mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
|
||||
# device=self.opts.device)
|
||||
# for i in range(self.fbank_beg_idx, frames):
|
||||
# mat[i, :] = self.fbank_fn.get_frame(i)
|
||||
# self.fbank_beg_idx += frames_cur
|
||||
#
|
||||
# feat_length = mat.size(0)
|
||||
# feats.append(mat)
|
||||
# feats_lens.append(feat_length)
|
||||
#
|
||||
# feats_lens = torch.as_tensor(feats_lens)
|
||||
# feats_pad = pad_sequence(feats,
|
||||
# batch_first=True,
|
||||
# padding_value=0.0)
|
||||
# return feats_pad, feats_lens
|
||||
#
|
||||
# def reset_fbank_status(self):
|
||||
# self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
|
||||
# self.fbank_beg_idx = 0
|
||||
#
|
||||
# def forward_lfr_cmvn(
|
||||
# self,
|
||||
# input: torch.Tensor,
|
||||
# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# batch_size = input.size(0)
|
||||
# feats = []
|
||||
# feats_lens = []
|
||||
# for i in range(batch_size):
|
||||
# mat = input[i, :input_lengths[i], :]
|
||||
# if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
# mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
# if self.cmvn_file is not None:
|
||||
# mat = apply_cmvn(mat, self.cmvn_file)
|
||||
# feat_length = mat.size(0)
|
||||
# feats.append(mat)
|
||||
# feats_lens.append(feat_length)
|
||||
#
|
||||
# feats_lens = torch.as_tensor(feats_lens)
|
||||
# feats_pad = pad_sequence(feats,
|
||||
# batch_first=True,
|
||||
# padding_value=0.0)
|
||||
# return feats_pad, feats_lens
|
||||
81
funasr_local/models/frontend/windowing.py
Normal file
81
funasr_local/models/frontend/windowing.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python3
|
||||
# 2020, Technische Universität München; Ludwig Kürzinger
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Sliding Window for raw audio input data."""
|
||||
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class SlidingWindow(AbsFrontend):
|
||||
"""Sliding Window.
|
||||
|
||||
Provides a sliding window over a batched continuous raw audio tensor.
|
||||
Optionally, provides padding (Currently not implemented).
|
||||
Combine this module with a pre-encoder compatible with raw audio data,
|
||||
for example Sinc convolutions.
|
||||
|
||||
Known issues:
|
||||
Output length is calculated incorrectly if audio shorter than win_length.
|
||||
WARNING: trailing values are discarded - padding not implemented yet.
|
||||
There is currently no additional window function applied to input values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
win_length: int = 400,
|
||||
hop_length: int = 160,
|
||||
channels: int = 1,
|
||||
padding: int = None,
|
||||
fs=None,
|
||||
):
|
||||
"""Initialize.
|
||||
|
||||
Args:
|
||||
win_length: Length of frame.
|
||||
hop_length: Relative starting point of next frame.
|
||||
channels: Number of input channels.
|
||||
padding: Padding (placeholder, currently not implemented).
|
||||
fs: Sampling rate (placeholder for compatibility, not used).
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.channels = channels
|
||||
self.padding = padding
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply a sliding window on the input.
|
||||
|
||||
Args:
|
||||
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
|
||||
input_lengths: Input lengths within batch.
|
||||
|
||||
Returns:
|
||||
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
|
||||
Tensor: Output lengths within batch.
|
||||
"""
|
||||
input_size = input.size()
|
||||
B = input_size[0]
|
||||
T = input_size[1]
|
||||
C = self.channels
|
||||
D = self.win_length
|
||||
# (B, T, C) --> (T, B, C)
|
||||
continuous = input.view(B, T, C).permute(1, 0, 2)
|
||||
windowed = continuous.unfold(0, D, self.hop_length)
|
||||
# (T, B, C, D) --> (B, T, C, D)
|
||||
output = windowed.permute(1, 0, 2, 3).contiguous()
|
||||
# After unfold(), windowed lengths change:
|
||||
output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
|
||||
return output, output_lengths
|
||||
|
||||
def output_size(self) -> int:
|
||||
"""Return output length of feature dimension D, i.e. the window length."""
|
||||
return self.win_length
|
||||
Reference in New Issue
Block a user