add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

View File

View File

@@ -0,0 +1,14 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
class AbsNormalize(torch.nn.Module, ABC):
@abstractmethod
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# return output, output_lengths
raise NotImplementedError

View File

@@ -0,0 +1,191 @@
"""Beamformer module."""
from distutils.version import LooseVersion
from typing import Sequence
from typing import Tuple
from typing import Union
import torch
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
EPS = torch.finfo(torch.double).eps
is_torch_1_8_plus = LooseVersion(torch.__version__) >= LooseVersion("1.8.0")
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
def new_complex_like(
ref: Union[torch.Tensor, ComplexTensor],
real_imag: Tuple[torch.Tensor, torch.Tensor],
):
if isinstance(ref, ComplexTensor):
return ComplexTensor(*real_imag)
elif is_torch_complex_tensor(ref):
return torch.complex(*real_imag)
else:
raise ValueError(
"Please update your PyTorch version to 1.9+ for complex support."
)
def is_torch_complex_tensor(c):
return (
not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
)
def is_complex(c):
return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
def to_double(c):
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
return c.to(dtype=torch.complex128)
else:
return c.double()
def to_float(c):
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
return c.to(dtype=torch.complex64)
else:
return c.float()
def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
if not isinstance(seq, (list, tuple)):
raise TypeError(
"cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
"not Tensor"
)
if isinstance(seq[0], ComplexTensor):
return FC.cat(seq, *args, **kwargs)
else:
return torch.cat(seq, *args, **kwargs)
def complex_norm(
c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
) -> torch.Tensor:
if not is_complex(c):
raise TypeError("Input is not a complex tensor.")
if is_torch_complex_tensor(c):
return torch.norm(c, dim=dim, keepdim=keepdim)
else:
return torch.sqrt(
(c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
)
def einsum(equation, *operands):
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
# NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
# mixed input with complex and real tensors.
if len(operands) == 1:
if isinstance(operands[0], (tuple, list)):
operands = operands[0]
complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
return complex_module.einsum(equation, *operands)
elif len(operands) != 2:
op0 = operands[0]
same_type = all(op.dtype == op0.dtype for op in operands[1:])
if same_type:
_einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
return _einsum(equation, *operands)
else:
raise ValueError("0 or More than 2 operands are not supported.")
a, b = operands
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
return FC.einsum(equation, a, b)
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
if not torch.is_complex(a):
o_real = torch.einsum(equation, a, b.real)
o_imag = torch.einsum(equation, a, b.imag)
return torch.complex(o_real, o_imag)
elif not torch.is_complex(b):
o_real = torch.einsum(equation, a.real, b)
o_imag = torch.einsum(equation, a.imag, b)
return torch.complex(o_real, o_imag)
else:
return torch.einsum(equation, a, b)
else:
return torch.einsum(equation, a, b)
def inverse(
c: Union[torch.Tensor, ComplexTensor]
) -> Union[torch.Tensor, ComplexTensor]:
if isinstance(c, ComplexTensor):
return c.inverse2()
else:
return c.inverse()
def matmul(
a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
) -> Union[torch.Tensor, ComplexTensor]:
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
# NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
# multiplication between complex and real tensors.
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
return FC.matmul(a, b)
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
if not torch.is_complex(a):
o_real = torch.matmul(a, b.real)
o_imag = torch.matmul(a, b.imag)
return torch.complex(o_real, o_imag)
elif not torch.is_complex(b):
o_real = torch.matmul(a.real, b)
o_imag = torch.matmul(a.imag, b)
return torch.complex(o_real, o_imag)
else:
return torch.matmul(a, b)
else:
return torch.matmul(a, b)
def trace(a: Union[torch.Tensor, ComplexTensor]):
# NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
# support bacth processing. Use FC.trace() as fallback.
return FC.trace(a)
def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
if isinstance(a, ComplexTensor):
return FC.reverse(a, dim=dim)
else:
return torch.flip(a, dims=(dim,))
def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
"""Solve the linear equation ax = b."""
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
# NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
# mixed input with complex and real tensors.
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
return FC.solve(b, a, return_LU=False)
else:
return matmul(inverse(a), b)
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
if torch.is_complex(a) and torch.is_complex(b):
return torch.linalg.solve(a, b)
else:
return matmul(inverse(a), b)
else:
if is_torch_1_8_plus:
return torch.linalg.solve(a, b)
else:
return torch.solve(b, a)[0]
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
if not isinstance(seq, (list, tuple)):
raise TypeError(
"stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
"not Tensor"
)
if isinstance(seq[0], ComplexTensor):
return FC.stack(seq, *args, **kwargs)
else:
return torch.stack(seq, *args, **kwargs)

View File

@@ -0,0 +1,121 @@
from pathlib import Path
from typing import Tuple
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.layers.inversible_interface import InversibleInterface
class GlobalMVN(AbsNormalize, InversibleInterface):
"""Apply global mean and variance normalization
TODO(kamo): Make this class portable somehow
Args:
stats_file: npy file
norm_means: Apply mean normalization
norm_vars: Apply var normalization
eps:
"""
def __init__(
self,
stats_file: Union[Path, str],
norm_means: bool = True,
norm_vars: bool = True,
eps: float = 1.0e-20,
):
assert check_argument_types()
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.eps = eps
stats_file = Path(stats_file)
self.stats_file = stats_file
stats = np.load(stats_file)
if isinstance(stats, np.ndarray):
# Kaldi like stats
count = stats[0].flatten()[-1]
mean = stats[0, :-1] / count
var = stats[1, :-1] / count - mean * mean
else:
# New style: Npz file
count = stats["count"]
sum_v = stats["sum"]
sum_square_v = stats["sum_square"]
mean = sum_v / count
var = sum_square_v / count - mean * mean
std = np.sqrt(np.maximum(var, eps))
self.register_buffer("mean", torch.from_numpy(mean))
self.register_buffer("std", torch.from_numpy(std))
def extra_repr(self):
return (
f"stats_file={self.stats_file}, "
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
)
def forward(
self, x: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function
Args:
x: (B, L, ...)
ilens: (B,)
"""
if ilens is None:
ilens = x.new_full([x.size(0)], x.size(1))
norm_means = self.norm_means
norm_vars = self.norm_vars
self.mean = self.mean.to(x.device, x.dtype)
self.std = self.std.to(x.device, x.dtype)
mask = make_pad_mask(ilens, x, 1)
# feat: (B, T, D)
if norm_means:
if x.requires_grad:
x = x - self.mean
else:
x -= self.mean
if x.requires_grad:
x = x.masked_fill(mask, 0.0)
else:
x.masked_fill_(mask, 0.0)
if norm_vars:
x /= self.std
return x, ilens
def inverse(
self, x: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if ilens is None:
ilens = x.new_full([x.size(0)], x.size(1))
norm_means = self.norm_means
norm_vars = self.norm_vars
self.mean = self.mean.to(x.device, x.dtype)
self.std = self.std.to(x.device, x.dtype)
mask = make_pad_mask(ilens, x, 1)
if x.requires_grad:
x = x.masked_fill(mask, 0.0)
else:
x.masked_fill_(mask, 0.0)
if norm_vars:
x *= self.std
# feat: (B, T, D)
if norm_means:
x += self.mean
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
return x, ilens

View File

@@ -0,0 +1,14 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
class InversibleInterface(ABC):
@abstractmethod
def inverse(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# return output, output_lengths
raise NotImplementedError

View File

@@ -0,0 +1,82 @@
import torch
from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
from funasr_local.modules.nets_utils import make_pad_mask
class LabelAggregate(torch.nn.Module):
def __init__(
self,
win_length: int = 512,
hop_length: int = 128,
center: bool = True,
):
assert check_argument_types()
super().__init__()
self.win_length = win_length
self.hop_length = hop_length
self.center = center
def extra_repr(self):
return (
f"win_length={self.win_length}, "
f"hop_length={self.hop_length}, "
f"center={self.center}, "
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""LabelAggregate forward function.
Args:
input: (Batch, Nsamples, Label_dim)
ilens: (Batch)
Returns:
output: (Batch, Frames, Label_dim)
"""
bs = input.size(0)
max_length = input.size(1)
label_dim = input.size(2)
# NOTE(jiatong):
# The default behaviour of label aggregation is compatible with
# torch.stft about framing and padding.
# Step1: center padding
if self.center:
pad = self.win_length // 2
max_length = max_length + 2 * pad
input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0)
input[:, :pad, :] = input[:, pad : (2 * pad), :]
input[:, (max_length - pad) : max_length, :] = input[
:, (max_length - 2 * pad) : (max_length - pad), :
]
nframe = (max_length - self.win_length) // self.hop_length + 1
# Step2: framing
output = input.as_strided(
(bs, nframe, self.win_length, label_dim),
(max_length * label_dim, self.hop_length * label_dim, label_dim, 1),
)
# Step3: aggregate label
output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2)
output = output.float()
# Step4: process lengths
if ilens is not None:
if self.center:
pad = self.win_length // 2
ilens = ilens + 2 * pad
olens = (ilens - self.win_length) // self.hop_length + 1
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
else:
olens = None
return output.to(input.dtype), olens

View File

@@ -0,0 +1,83 @@
import librosa
import torch
from typing import Tuple
from funasr_local.modules.nets_utils import make_pad_mask
class LogMel(torch.nn.Module):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = None,
fmax: float = None,
htk: bool = False,
log_base: float = None,
):
super().__init__()
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
_mel_options = dict(
sr=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.mel_options = _mel_options
self.log_base = log_base
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat = librosa.filters.mel(**_mel_options)
# melmat: (D2, D1) -> (D1, D2)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
def extra_repr(self):
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
def forward(
self,
feat: torch.Tensor,
ilens: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat = torch.matmul(feat, self.melmat)
mel_feat = torch.clamp(mel_feat, min=1e-10)
if self.log_base is None:
logmel_feat = mel_feat.log()
elif self.log_base == 2.0:
logmel_feat = mel_feat.log2()
elif self.log_base == 10.0:
logmel_feat = mel_feat.log10()
else:
logmel_feat = mel_feat.log() / torch.log(self.log_base)
# Zero padding
if ilens is not None:
logmel_feat = logmel_feat.masked_fill(
make_pad_mask(ilens, logmel_feat, 1), 0.0
)
else:
ilens = feat.new_full(
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
)
return logmel_feat, ilens

View File

@@ -0,0 +1,340 @@
import math
import torch
from typeguard import check_argument_types
from typing import Sequence
from typing import Union
def mask_along_axis(
spec: torch.Tensor,
spec_lengths: torch.Tensor,
mask_width_range: Sequence[int] = (0, 30),
dim: int = 1,
num_mask: int = 2,
replace_with_zero: bool = True,
):
"""Apply mask along the specified direction.
Args:
spec: (Batch, Length, Freq)
spec_lengths: (Length): Not using lengths in this implementation
mask_width_range: Select the width randomly between this range
"""
org_size = spec.size()
if spec.dim() == 4:
# spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
spec = spec.view(-1, spec.size(2), spec.size(3))
B = spec.shape[0]
# D = Length or Freq
D = spec.shape[dim]
# mask_length: (B, num_mask, 1)
mask_length = torch.randint(
mask_width_range[0],
mask_width_range[1],
(B, num_mask),
device=spec.device,
).unsqueeze(2)
# mask_pos: (B, num_mask, 1)
mask_pos = torch.randint(
0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
).unsqueeze(2)
# aran: (1, 1, D)
aran = torch.arange(D, device=spec.device)[None, None, :]
# mask: (Batch, num_mask, D)
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
# Multiply masks: (Batch, num_mask, D) -> (Batch, D)
mask = mask.any(dim=1)
if dim == 1:
# mask: (Batch, Length, 1)
mask = mask.unsqueeze(2)
elif dim == 2:
# mask: (Batch, 1, Freq)
mask = mask.unsqueeze(1)
if replace_with_zero:
value = 0.0
else:
value = spec.mean()
if spec.requires_grad:
spec = spec.masked_fill(mask, value)
else:
spec = spec.masked_fill_(mask, value)
spec = spec.view(*org_size)
return spec, spec_lengths
def mask_along_axis_lfr(
spec: torch.Tensor,
spec_lengths: torch.Tensor,
mask_width_range: Sequence[int] = (0, 30),
dim: int = 1,
num_mask: int = 2,
replace_with_zero: bool = True,
lfr_rate: int = 1,
):
"""Apply mask along the specified direction.
Args:
spec: (Batch, Length, Freq)
spec_lengths: (Length): Not using lengths in this implementation
mask_width_range: Select the width randomly between this range
lfr_ratelow frame rate
"""
org_size = spec.size()
if spec.dim() == 4:
# spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
spec = spec.view(-1, spec.size(2), spec.size(3))
B = spec.shape[0]
# D = Length or Freq
D = spec.shape[dim] // lfr_rate
# mask_length: (B, num_mask, 1)
mask_length = torch.randint(
mask_width_range[0],
mask_width_range[1],
(B, num_mask),
device=spec.device,
).unsqueeze(2)
if lfr_rate > 1:
mask_length = mask_length.repeat(1, lfr_rate, 1)
# mask_pos: (B, num_mask, 1)
mask_pos = torch.randint(
0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
).unsqueeze(2)
if lfr_rate > 1:
mask_pos_raw = mask_pos.clone()
mask_pos = torch.zeros((B, 0, 1), device=spec.device, dtype=torch.int32)
for i in range(lfr_rate):
mask_pos_i = mask_pos_raw + D * i
mask_pos = torch.cat((mask_pos, mask_pos_i), dim=1)
# aran: (1, 1, D)
D = spec.shape[dim]
aran = torch.arange(D, device=spec.device)[None, None, :]
# mask: (Batch, num_mask, D)
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
# Multiply masks: (Batch, num_mask, D) -> (Batch, D)
mask = mask.any(dim=1)
if dim == 1:
# mask: (Batch, Length, 1)
mask = mask.unsqueeze(2)
elif dim == 2:
# mask: (Batch, 1, Freq)
mask = mask.unsqueeze(1)
if replace_with_zero:
value = 0.0
else:
value = spec.mean()
if spec.requires_grad:
spec = spec.masked_fill(mask, value)
else:
spec = spec.masked_fill_(mask, value)
spec = spec.view(*org_size)
return spec, spec_lengths
class MaskAlongAxis(torch.nn.Module):
def __init__(
self,
mask_width_range: Union[int, Sequence[int]] = (0, 30),
num_mask: int = 2,
dim: Union[int, str] = "time",
replace_with_zero: bool = True,
):
assert check_argument_types()
if isinstance(mask_width_range, int):
mask_width_range = (0, mask_width_range)
if len(mask_width_range) != 2:
raise TypeError(
f"mask_width_range must be a tuple of int and int values: "
f"{mask_width_range}",
)
assert mask_width_range[1] > mask_width_range[0]
if isinstance(dim, str):
if dim == "time":
dim = 1
elif dim == "freq":
dim = 2
else:
raise ValueError("dim must be int, 'time' or 'freq'")
if dim == 1:
self.mask_axis = "time"
elif dim == 2:
self.mask_axis = "freq"
else:
self.mask_axis = "unknown"
super().__init__()
self.mask_width_range = mask_width_range
self.num_mask = num_mask
self.dim = dim
self.replace_with_zero = replace_with_zero
def extra_repr(self):
return (
f"mask_width_range={self.mask_width_range}, "
f"num_mask={self.num_mask}, axis={self.mask_axis}"
)
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
"""Forward function.
Args:
spec: (Batch, Length, Freq)
"""
return mask_along_axis(
spec,
spec_lengths,
mask_width_range=self.mask_width_range,
dim=self.dim,
num_mask=self.num_mask,
replace_with_zero=self.replace_with_zero,
)
class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
"""Mask input spec along a specified axis with variable maximum width.
Formula:
max_width = max_width_ratio * seq_len
"""
def __init__(
self,
mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
num_mask: int = 2,
dim: Union[int, str] = "time",
replace_with_zero: bool = True,
):
assert check_argument_types()
if isinstance(mask_width_ratio_range, float):
mask_width_ratio_range = (0.0, mask_width_ratio_range)
if len(mask_width_ratio_range) != 2:
raise TypeError(
f"mask_width_ratio_range must be a tuple of float and float values: "
f"{mask_width_ratio_range}",
)
assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
if isinstance(dim, str):
if dim == "time":
dim = 1
elif dim == "freq":
dim = 2
else:
raise ValueError("dim must be int, 'time' or 'freq'")
if dim == 1:
self.mask_axis = "time"
elif dim == 2:
self.mask_axis = "freq"
else:
self.mask_axis = "unknown"
super().__init__()
self.mask_width_ratio_range = mask_width_ratio_range
self.num_mask = num_mask
self.dim = dim
self.replace_with_zero = replace_with_zero
def extra_repr(self):
return (
f"mask_width_ratio_range={self.mask_width_ratio_range}, "
f"num_mask={self.num_mask}, axis={self.mask_axis}"
)
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
"""Forward function.
Args:
spec: (Batch, Length, Freq)
"""
max_seq_len = spec.shape[self.dim]
min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
min_mask_width = max([0, min_mask_width])
max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
max_mask_width = min([max_seq_len, max_mask_width])
if max_mask_width > min_mask_width:
return mask_along_axis(
spec,
spec_lengths,
mask_width_range=(min_mask_width, max_mask_width),
dim=self.dim,
num_mask=self.num_mask,
replace_with_zero=self.replace_with_zero,
)
return spec, spec_lengths
class MaskAlongAxisLFR(torch.nn.Module):
def __init__(
self,
mask_width_range: Union[int, Sequence[int]] = (0, 30),
num_mask: int = 2,
dim: Union[int, str] = "time",
replace_with_zero: bool = True,
lfr_rate: int = 1,
):
assert check_argument_types()
if isinstance(mask_width_range, int):
mask_width_range = (0, mask_width_range)
if len(mask_width_range) != 2:
raise TypeError(
f"mask_width_range must be a tuple of int and int values: "
f"{mask_width_range}",
)
assert mask_width_range[1] > mask_width_range[0]
if isinstance(dim, str):
if dim == "time":
dim = 1
lfr_rate = 1
elif dim == "freq":
dim = 2
else:
raise ValueError("dim must be int, 'time' or 'freq'")
if dim == 1:
self.mask_axis = "time"
lfr_rate = 1
elif dim == 2:
self.mask_axis = "freq"
else:
self.mask_axis = "unknown"
super().__init__()
self.mask_width_range = mask_width_range
self.num_mask = num_mask
self.dim = dim
self.replace_with_zero = replace_with_zero
self.lfr_rate = lfr_rate
def extra_repr(self):
return (
f"mask_width_range={self.mask_width_range}, "
f"num_mask={self.num_mask}, axis={self.mask_axis}"
)
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
"""Forward function.
Args:
spec: (Batch, Length, Freq)
"""
return mask_along_axis_lfr(
spec,
spec_lengths,
mask_width_range=self.mask_width_range,
dim=self.dim,
num_mask=self.num_mask,
replace_with_zero=self.replace_with_zero,
lfr_rate=self.lfr_rate,
)

View File

@@ -0,0 +1,273 @@
#!/usr/bin/env python3
# 2020, Technische Universität München; Ludwig Kürzinger
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Sinc convolutions."""
import math
import torch
from typeguard import check_argument_types
from typing import Union
class LogCompression(torch.nn.Module):
"""Log Compression Activation.
Activation function `log(abs(x) + 1)`.
"""
def __init__(self):
"""Initialize."""
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward.
Applies the Log Compression function elementwise on tensor x.
"""
return torch.log(torch.abs(x) + 1)
class SincConv(torch.nn.Module):
"""Sinc Convolution.
This module performs a convolution using Sinc filters in time domain as kernel.
Sinc filters function as band passes in spectral domain.
The filtering is done as a convolution in time domain, and no transformation
to spectral domain is necessary.
This implementation of the Sinc convolution is heavily inspired
by Ravanelli et al. https://github.com/mravanelli/SincNet,
and adapted for the ESpnet toolkit.
Combine Sinc convolutions with a log compression activation function, as in:
https://arxiv.org/abs/2010.07597
Notes:
Currently, the same filters are applied to all input channels.
The windowing function is applied on the kernel to obtained a smoother filter,
and not on the input values, which is different to traditional ASR.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
window_func: str = "hamming",
scale_type: str = "mel",
fs: Union[int, float] = 16000,
):
"""Initialize Sinc convolutions.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: Sinc filter kernel size (needs to be an odd number).
stride: See torch.nn.functional.conv1d.
padding: See torch.nn.functional.conv1d.
dilation: See torch.nn.functional.conv1d.
window_func: Window function on the filter, one of ["hamming", "none"].
fs (str, int, float): Sample rate of the input data
"""
assert check_argument_types()
super().__init__()
window_funcs = {
"none": self.none_window,
"hamming": self.hamming_window,
}
if window_func not in window_funcs:
raise NotImplementedError(
f"Window function has to be one of {list(window_funcs.keys())}",
)
self.window_func = window_funcs[window_func]
scale_choices = {
"mel": MelScale,
"bark": BarkScale,
}
if scale_type not in scale_choices:
raise NotImplementedError(
f"Scale has to be one of {list(scale_choices.keys())}",
)
self.scale = scale_choices[scale_type]
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.dilation = dilation
self.stride = stride
self.fs = float(fs)
if self.kernel_size % 2 == 0:
raise ValueError("SincConv: Kernel size must be odd.")
self.f = None
N = self.kernel_size // 2
self._x = 2 * math.pi * torch.linspace(1, N, N)
self._window = self.window_func(torch.linspace(1, N, N))
# init may get overwritten by E2E network,
# but is still required to calculate output dim
self.init_filters()
@staticmethod
def sinc(x: torch.Tensor) -> torch.Tensor:
"""Sinc function."""
x2 = x + 1e-6
return torch.sin(x2) / x2
@staticmethod
def none_window(x: torch.Tensor) -> torch.Tensor:
"""Identity-like windowing function."""
return torch.ones_like(x)
@staticmethod
def hamming_window(x: torch.Tensor) -> torch.Tensor:
"""Hamming Windowing function."""
L = 2 * x.size(0) + 1
x = x.flip(0)
return 0.54 - 0.46 * torch.cos(2.0 * math.pi * x / L)
def init_filters(self):
"""Initialize filters with filterbank values."""
f = self.scale.bank(self.out_channels, self.fs)
f = torch.div(f, self.fs)
self.f = torch.nn.Parameter(f, requires_grad=True)
def _create_filters(self, device: str):
"""Calculate coefficients.
This function (re-)calculates the filter convolutions coefficients.
"""
f_mins = torch.abs(self.f[:, 0])
f_maxs = torch.abs(self.f[:, 0]) + torch.abs(self.f[:, 1] - self.f[:, 0])
self._x = self._x.to(device)
self._window = self._window.to(device)
f_mins_x = torch.matmul(f_mins.view(-1, 1), self._x.view(1, -1))
f_maxs_x = torch.matmul(f_maxs.view(-1, 1), self._x.view(1, -1))
kernel = (torch.sin(f_maxs_x) - torch.sin(f_mins_x)) / (0.5 * self._x)
kernel = kernel * self._window
kernel_left = kernel.flip(1)
kernel_center = (2 * f_maxs - 2 * f_mins).unsqueeze(1)
filters = torch.cat([kernel_left, kernel_center, kernel], dim=1)
filters = filters.view(filters.size(0), 1, filters.size(1))
self.sinc_filters = filters
def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Sinc convolution forward function.
Args:
xs: Batch in form of torch.Tensor (B, C_in, D_in).
Returns:
xs: Batch in form of torch.Tensor (B, C_out, D_out).
"""
self._create_filters(xs.device)
xs = torch.nn.functional.conv1d(
xs,
self.sinc_filters,
padding=self.padding,
stride=self.stride,
dilation=self.dilation,
groups=self.in_channels,
)
return xs
def get_odim(self, idim: int) -> int:
"""Obtain the output dimension of the filter."""
D_out = idim + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1
D_out = (D_out // self.stride) + 1
return D_out
class MelScale:
"""Mel frequency scale."""
@staticmethod
def convert(f):
"""Convert Hz to mel."""
return 1125.0 * torch.log(torch.div(f, 700.0) + 1.0)
@staticmethod
def invert(x):
"""Convert mel to Hz."""
return 700.0 * (torch.exp(torch.div(x, 1125.0)) - 1.0)
@classmethod
def bank(cls, channels: int, fs: float) -> torch.Tensor:
"""Obtain initialization values for the mel scale.
Args:
channels: Number of channels.
fs: Sample rate.
Returns:
torch.Tensor: Filter start frequencíes.
torch.Tensor: Filter stop frequencies.
"""
assert check_argument_types()
# min and max bandpass edge frequencies
min_frequency = torch.tensor(30.0)
max_frequency = torch.tensor(fs * 0.5)
frequencies = torch.linspace(
cls.convert(min_frequency), cls.convert(max_frequency), channels + 2
)
frequencies = cls.invert(frequencies)
f1, f2 = frequencies[:-2], frequencies[2:]
return torch.stack([f1, f2], dim=1)
class BarkScale:
"""Bark frequency scale.
Has wider bandwidths at lower frequencies, see:
Critical bandwidth: BARK
Zwicker and Terhardt, 1980
"""
@staticmethod
def convert(f):
"""Convert Hz to Bark."""
b = torch.div(f, 1000.0)
b = torch.pow(b, 2.0) * 1.4
b = torch.pow(b + 1.0, 0.69)
return b * 75.0 + 25.0
@staticmethod
def invert(x):
"""Convert Bark to Hz."""
f = torch.div(x - 25.0, 75.0)
f = torch.pow(f, (1.0 / 0.69))
f = torch.div(f - 1.0, 1.4)
f = torch.pow(f, 0.5)
return f * 1000.0
@classmethod
def bank(cls, channels: int, fs: float) -> torch.Tensor:
"""Obtain initialization values for the Bark scale.
Args:
channels: Number of channels.
fs: Sample rate.
Returns:
torch.Tensor: Filter start frequencíes.
torch.Tensor: Filter stop frequencíes.
"""
assert check_argument_types()
# min and max BARK center frequencies by approximation
min_center_frequency = torch.tensor(70.0)
max_center_frequency = torch.tensor(fs * 0.45)
center_frequencies = torch.linspace(
cls.convert(min_center_frequency),
cls.convert(max_center_frequency),
channels,
)
center_frequencies = cls.invert(center_frequencies)
f1 = center_frequencies - torch.div(cls.convert(center_frequencies), 2)
f2 = center_frequencies + torch.div(cls.convert(center_frequencies), 2)
return torch.stack([f1, f2], dim=1)

234
funasr_local/layers/stft.py Normal file
View File

@@ -0,0 +1,234 @@
from distutils.version import LooseVersion
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch_complex.tensor import ComplexTensor
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.layers.complex_utils import is_complex
from funasr_local.layers.inversible_interface import InversibleInterface
import librosa
import numpy as np
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
is_torch_1_7_plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
class Stft(torch.nn.Module, InversibleInterface):
def __init__(
self,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
assert check_argument_types()
super().__init__()
self.n_fft = n_fft
if win_length is None:
self.win_length = n_fft
else:
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
if window.lower() != "povey":
raise ValueError(f"{window} window is not implemented")
self.window = window
def extra_repr(self):
return (
f"n_fft={self.n_fft}, "
f"win_length={self.win_length}, "
f"hop_length={self.hop_length}, "
f"center={self.center}, "
f"normalized={self.normalized}, "
f"onesided={self.onesided}"
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""STFT forward function.
Args:
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
ilens: (Batch)
Returns:
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
"""
bs = input.size(0)
if input.dim() == 3:
multi_channel = True
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
input = input.transpose(1, 2).reshape(-1, input.size(1))
else:
multi_channel = False
# NOTE(kamo):
# The default behaviour of torch.stft is compatible with librosa.stft
# about padding and scaling.
# Note that it's different from scipy.signal.stft
# output: (Batch, Freq, Frames, 2=real_imag)
# or (Batch, Channel, Freq, Frames, 2=real_imag)
if self.window is not None:
if self.window.lower() == "povey":
window = torch.hann_window(self.win_length, periodic=False,
device=input.device, dtype=input.dtype).pow(0.85)
else:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(
self.win_length, dtype=input.dtype, device=input.device
)
else:
window = None
# For the compatibility of ARM devices, which do not support
# torch.stft() due to the lake of MKL.
if input.is_cuda or torch.backends.mkl.is_available():
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
window=window,
normalized=self.normalized,
onesided=self.onesided,
)
if is_torch_1_7_plus:
stft_kwargs["return_complex"] = False
output = torch.stft(input, **stft_kwargs)
else:
if self.training:
raise NotImplementedError(
"stft is implemented with librosa on this device, which does not "
"support the training mode."
)
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
window=window,
)
if window is not None:
# pad the given window to n_fft
n_pad_left = (self.n_fft - window.shape[0]) // 2
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
stft_kwargs["window"] = torch.cat(
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
).numpy()
else:
win_length = (
self.win_length if self.win_length is not None else self.n_fft
)
stft_kwargs["window"] = torch.ones(win_length)
output = []
# iterate over istances in a batch
for i, instance in enumerate(input):
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
output = torch.stack(output, 0)
if not self.onesided:
len_conj = self.n_fft - output.shape[1]
conj = output[:, 1 : 1 + len_conj].flip(1)
conj[:, :, :, -1].data *= -1
output = torch.cat([output, conj], 1)
if self.normalized:
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
# output: (Batch, Freq, Frames, 2=real_imag)
# -> (Batch, Frames, Freq, 2=real_imag)
output = output.transpose(1, 2)
if multi_channel:
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
1, 2
)
if ilens is not None:
if self.center:
pad = self.n_fft // 2
ilens = ilens + 2 * pad
olens = (ilens - self.n_fft) // self.hop_length + 1
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
else:
olens = None
return output, olens
def inverse(
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Inverse STFT.
Args:
input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
ilens: (batch,)
Returns:
wavs: (batch, samples)
ilens: (batch,)
"""
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
istft = torch.functional.istft
else:
try:
import torchaudio
except ImportError:
raise ImportError(
"Please install torchaudio>=0.3.0 or use torch>=1.6.0"
)
if not hasattr(torchaudio.functional, "istft"):
raise ImportError(
"Please install torchaudio>=0.3.0 or use torch>=1.6.0"
)
istft = torchaudio.functional.istft
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
if is_complex(input):
datatype = input.real.dtype
else:
datatype = input.dtype
window = window_func(self.win_length, dtype=datatype, device=input.device)
else:
window = None
if is_complex(input):
input = torch.stack([input.real, input.imag], dim=-1)
elif input.shape[-1] != 2:
raise TypeError("Invalid input type")
input = input.transpose(1, 2)
wavs = istft(
input,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=window,
center=self.center,
normalized=self.normalized,
onesided=self.onesided,
length=ilens.max() if ilens is not None else ilens,
)
return wavs, ilens

View File

@@ -0,0 +1,88 @@
"""Time warp module."""
import torch
from funasr_local.modules.nets_utils import pad_list
DEFAULT_TIME_WARP_MODE = "bicubic"
def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
"""Time warping using torch.interpolate.
Args:
x: (Batch, Time, Freq)
window: time warp parameter
mode: Interpolate mode
"""
# bicubic supports 4D or more dimension tensor
org_size = x.size()
if x.dim() == 3:
# x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
x = x[:, None]
t = x.shape[2]
if t - window <= window:
return x.view(*org_size)
center = torch.randint(window, t - window, (1,))[0]
warped = torch.randint(center - window, center + window, (1,))[0] + 1
# left: (Batch, Channel, warped, Freq)
# right: (Batch, Channel, time - warped, Freq)
left = torch.nn.functional.interpolate(
x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
)
right = torch.nn.functional.interpolate(
x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
)
if x.requires_grad:
x = torch.cat([left, right], dim=-2)
else:
x[:, :, :warped] = left
x[:, :, warped:] = right
return x.view(*org_size)
class TimeWarp(torch.nn.Module):
"""Time warping using torch.interpolate.
Args:
window: time warp parameter
mode: Interpolate mode
"""
def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
super().__init__()
self.window = window
self.mode = mode
def extra_repr(self):
return f"window={self.window}, mode={self.mode}"
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
"""Forward function.
Args:
x: (Batch, Time, Freq)
x_lengths: (Batch,)
"""
if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
# Note that applying same warping for each sample
y = time_warp(x, window=self.window, mode=self.mode)
else:
# FIXME(kamo): I have no idea to batchify Timewarp
ys = []
for i in range(x.size(0)):
_y = time_warp(
x[i][None, : x_lengths[i]],
window=self.window,
mode=self.mode,
)[0]
ys.append(_y)
y = pad_list(ys, 0.0)
return y, x_lengths

View File

@@ -0,0 +1,88 @@
from typing import Tuple
import torch
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.layers.abs_normalize import AbsNormalize
class UtteranceMVN(AbsNormalize):
def __init__(
self,
norm_means: bool = True,
norm_vars: bool = False,
eps: float = 1.0e-20,
):
assert check_argument_types()
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.eps = eps
def extra_repr(self):
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
def forward(
self, x: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function
Args:
x: (B, L, ...)
ilens: (B,)
"""
return utterance_mvn(
x,
ilens,
norm_means=self.norm_means,
norm_vars=self.norm_vars,
eps=self.eps,
)
def utterance_mvn(
x: torch.Tensor,
ilens: torch.Tensor = None,
norm_means: bool = True,
norm_vars: bool = False,
eps: float = 1.0e-20,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply utterance mean and variance normalization
Args:
x: (B, T, D), assumed zero padded
ilens: (B,)
norm_means:
norm_vars:
eps:
"""
if ilens is None:
ilens = x.new_full([x.size(0)], x.size(1))
ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
# Zero padding
if x.requires_grad:
x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
else:
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
# mean: (B, 1, D)
mean = x.sum(dim=1, keepdim=True) / ilens_
if norm_means:
x -= mean
if norm_vars:
var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
std = torch.clamp(var.sqrt(), min=eps)
x = x / std.sqrt()
return x, ilens
else:
if norm_vars:
y = x - mean
y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
std = torch.clamp(var.sqrt(), min=eps)
x /= std
return x, ilens