mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/layers/__init__.py
Normal file
0
funasr_local/layers/__init__.py
Normal file
14
funasr_local/layers/abs_normalize.py
Normal file
14
funasr_local/layers/abs_normalize.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsNormalize(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return output, output_lengths
|
||||
raise NotImplementedError
|
||||
191
funasr_local/layers/complex_utils.py
Normal file
191
funasr_local/layers/complex_utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Beamformer module."""
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch_complex import functional as FC
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
|
||||
EPS = torch.finfo(torch.double).eps
|
||||
is_torch_1_8_plus = LooseVersion(torch.__version__) >= LooseVersion("1.8.0")
|
||||
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
|
||||
|
||||
|
||||
def new_complex_like(
|
||||
ref: Union[torch.Tensor, ComplexTensor],
|
||||
real_imag: Tuple[torch.Tensor, torch.Tensor],
|
||||
):
|
||||
if isinstance(ref, ComplexTensor):
|
||||
return ComplexTensor(*real_imag)
|
||||
elif is_torch_complex_tensor(ref):
|
||||
return torch.complex(*real_imag)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Please update your PyTorch version to 1.9+ for complex support."
|
||||
)
|
||||
|
||||
|
||||
def is_torch_complex_tensor(c):
|
||||
return (
|
||||
not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
|
||||
)
|
||||
|
||||
|
||||
def is_complex(c):
|
||||
return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
|
||||
|
||||
|
||||
def to_double(c):
|
||||
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
|
||||
return c.to(dtype=torch.complex128)
|
||||
else:
|
||||
return c.double()
|
||||
|
||||
|
||||
def to_float(c):
|
||||
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
|
||||
return c.to(dtype=torch.complex64)
|
||||
else:
|
||||
return c.float()
|
||||
|
||||
|
||||
def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
|
||||
if not isinstance(seq, (list, tuple)):
|
||||
raise TypeError(
|
||||
"cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
|
||||
"not Tensor"
|
||||
)
|
||||
if isinstance(seq[0], ComplexTensor):
|
||||
return FC.cat(seq, *args, **kwargs)
|
||||
else:
|
||||
return torch.cat(seq, *args, **kwargs)
|
||||
|
||||
|
||||
def complex_norm(
|
||||
c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
|
||||
) -> torch.Tensor:
|
||||
if not is_complex(c):
|
||||
raise TypeError("Input is not a complex tensor.")
|
||||
if is_torch_complex_tensor(c):
|
||||
return torch.norm(c, dim=dim, keepdim=keepdim)
|
||||
else:
|
||||
return torch.sqrt(
|
||||
(c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
|
||||
)
|
||||
|
||||
|
||||
def einsum(equation, *operands):
|
||||
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
||||
# NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
|
||||
# mixed input with complex and real tensors.
|
||||
if len(operands) == 1:
|
||||
if isinstance(operands[0], (tuple, list)):
|
||||
operands = operands[0]
|
||||
complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
|
||||
return complex_module.einsum(equation, *operands)
|
||||
elif len(operands) != 2:
|
||||
op0 = operands[0]
|
||||
same_type = all(op.dtype == op0.dtype for op in operands[1:])
|
||||
if same_type:
|
||||
_einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
|
||||
return _einsum(equation, *operands)
|
||||
else:
|
||||
raise ValueError("0 or More than 2 operands are not supported.")
|
||||
a, b = operands
|
||||
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
||||
return FC.einsum(equation, a, b)
|
||||
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
||||
if not torch.is_complex(a):
|
||||
o_real = torch.einsum(equation, a, b.real)
|
||||
o_imag = torch.einsum(equation, a, b.imag)
|
||||
return torch.complex(o_real, o_imag)
|
||||
elif not torch.is_complex(b):
|
||||
o_real = torch.einsum(equation, a.real, b)
|
||||
o_imag = torch.einsum(equation, a.imag, b)
|
||||
return torch.complex(o_real, o_imag)
|
||||
else:
|
||||
return torch.einsum(equation, a, b)
|
||||
else:
|
||||
return torch.einsum(equation, a, b)
|
||||
|
||||
|
||||
def inverse(
|
||||
c: Union[torch.Tensor, ComplexTensor]
|
||||
) -> Union[torch.Tensor, ComplexTensor]:
|
||||
if isinstance(c, ComplexTensor):
|
||||
return c.inverse2()
|
||||
else:
|
||||
return c.inverse()
|
||||
|
||||
|
||||
def matmul(
|
||||
a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
|
||||
) -> Union[torch.Tensor, ComplexTensor]:
|
||||
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
||||
# NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
|
||||
# multiplication between complex and real tensors.
|
||||
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
||||
return FC.matmul(a, b)
|
||||
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
||||
if not torch.is_complex(a):
|
||||
o_real = torch.matmul(a, b.real)
|
||||
o_imag = torch.matmul(a, b.imag)
|
||||
return torch.complex(o_real, o_imag)
|
||||
elif not torch.is_complex(b):
|
||||
o_real = torch.matmul(a.real, b)
|
||||
o_imag = torch.matmul(a.imag, b)
|
||||
return torch.complex(o_real, o_imag)
|
||||
else:
|
||||
return torch.matmul(a, b)
|
||||
else:
|
||||
return torch.matmul(a, b)
|
||||
|
||||
|
||||
def trace(a: Union[torch.Tensor, ComplexTensor]):
|
||||
# NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
|
||||
# support bacth processing. Use FC.trace() as fallback.
|
||||
return FC.trace(a)
|
||||
|
||||
|
||||
def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
|
||||
if isinstance(a, ComplexTensor):
|
||||
return FC.reverse(a, dim=dim)
|
||||
else:
|
||||
return torch.flip(a, dims=(dim,))
|
||||
|
||||
|
||||
def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
|
||||
"""Solve the linear equation ax = b."""
|
||||
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
||||
# NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
|
||||
# mixed input with complex and real tensors.
|
||||
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
||||
if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
|
||||
return FC.solve(b, a, return_LU=False)
|
||||
else:
|
||||
return matmul(inverse(a), b)
|
||||
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
||||
if torch.is_complex(a) and torch.is_complex(b):
|
||||
return torch.linalg.solve(a, b)
|
||||
else:
|
||||
return matmul(inverse(a), b)
|
||||
else:
|
||||
if is_torch_1_8_plus:
|
||||
return torch.linalg.solve(a, b)
|
||||
else:
|
||||
return torch.solve(b, a)[0]
|
||||
|
||||
|
||||
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
|
||||
if not isinstance(seq, (list, tuple)):
|
||||
raise TypeError(
|
||||
"stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
|
||||
"not Tensor"
|
||||
)
|
||||
if isinstance(seq[0], ComplexTensor):
|
||||
return FC.stack(seq, *args, **kwargs)
|
||||
else:
|
||||
return torch.stack(seq, *args, **kwargs)
|
||||
121
funasr_local/layers/global_mvn.py
Normal file
121
funasr_local/layers/global_mvn.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.layers.abs_normalize import AbsNormalize
|
||||
from funasr_local.layers.inversible_interface import InversibleInterface
|
||||
|
||||
|
||||
class GlobalMVN(AbsNormalize, InversibleInterface):
|
||||
"""Apply global mean and variance normalization
|
||||
|
||||
TODO(kamo): Make this class portable somehow
|
||||
|
||||
Args:
|
||||
stats_file: npy file
|
||||
norm_means: Apply mean normalization
|
||||
norm_vars: Apply var normalization
|
||||
eps:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats_file: Union[Path, str],
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = True,
|
||||
eps: float = 1.0e-20,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.eps = eps
|
||||
stats_file = Path(stats_file)
|
||||
|
||||
self.stats_file = stats_file
|
||||
stats = np.load(stats_file)
|
||||
if isinstance(stats, np.ndarray):
|
||||
# Kaldi like stats
|
||||
count = stats[0].flatten()[-1]
|
||||
mean = stats[0, :-1] / count
|
||||
var = stats[1, :-1] / count - mean * mean
|
||||
else:
|
||||
# New style: Npz file
|
||||
count = stats["count"]
|
||||
sum_v = stats["sum"]
|
||||
sum_square_v = stats["sum_square"]
|
||||
mean = sum_v / count
|
||||
var = sum_square_v / count - mean * mean
|
||||
std = np.sqrt(np.maximum(var, eps))
|
||||
|
||||
self.register_buffer("mean", torch.from_numpy(mean))
|
||||
self.register_buffer("std", torch.from_numpy(std))
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"stats_file={self.stats_file}, "
|
||||
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward function
|
||||
|
||||
Args:
|
||||
x: (B, L, ...)
|
||||
ilens: (B,)
|
||||
"""
|
||||
if ilens is None:
|
||||
ilens = x.new_full([x.size(0)], x.size(1))
|
||||
norm_means = self.norm_means
|
||||
norm_vars = self.norm_vars
|
||||
self.mean = self.mean.to(x.device, x.dtype)
|
||||
self.std = self.std.to(x.device, x.dtype)
|
||||
mask = make_pad_mask(ilens, x, 1)
|
||||
|
||||
# feat: (B, T, D)
|
||||
if norm_means:
|
||||
if x.requires_grad:
|
||||
x = x - self.mean
|
||||
else:
|
||||
x -= self.mean
|
||||
if x.requires_grad:
|
||||
x = x.masked_fill(mask, 0.0)
|
||||
else:
|
||||
x.masked_fill_(mask, 0.0)
|
||||
|
||||
if norm_vars:
|
||||
x /= self.std
|
||||
|
||||
return x, ilens
|
||||
|
||||
def inverse(
|
||||
self, x: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if ilens is None:
|
||||
ilens = x.new_full([x.size(0)], x.size(1))
|
||||
norm_means = self.norm_means
|
||||
norm_vars = self.norm_vars
|
||||
self.mean = self.mean.to(x.device, x.dtype)
|
||||
self.std = self.std.to(x.device, x.dtype)
|
||||
mask = make_pad_mask(ilens, x, 1)
|
||||
|
||||
if x.requires_grad:
|
||||
x = x.masked_fill(mask, 0.0)
|
||||
else:
|
||||
x.masked_fill_(mask, 0.0)
|
||||
|
||||
if norm_vars:
|
||||
x *= self.std
|
||||
|
||||
# feat: (B, T, D)
|
||||
if norm_means:
|
||||
x += self.mean
|
||||
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
|
||||
return x, ilens
|
||||
14
funasr_local/layers/inversible_interface.py
Normal file
14
funasr_local/layers/inversible_interface.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class InversibleInterface(ABC):
|
||||
@abstractmethod
|
||||
def inverse(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return output, output_lengths
|
||||
raise NotImplementedError
|
||||
82
funasr_local/layers/label_aggregation.py
Normal file
82
funasr_local/layers/label_aggregation.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class LabelAggregate(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
win_length: int = 512,
|
||||
hop_length: int = 128,
|
||||
center: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"win_length={self.win_length}, "
|
||||
f"hop_length={self.hop_length}, "
|
||||
f"center={self.center}, "
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""LabelAggregate forward function.
|
||||
|
||||
Args:
|
||||
input: (Batch, Nsamples, Label_dim)
|
||||
ilens: (Batch)
|
||||
Returns:
|
||||
output: (Batch, Frames, Label_dim)
|
||||
|
||||
"""
|
||||
bs = input.size(0)
|
||||
max_length = input.size(1)
|
||||
label_dim = input.size(2)
|
||||
|
||||
# NOTE(jiatong):
|
||||
# The default behaviour of label aggregation is compatible with
|
||||
# torch.stft about framing and padding.
|
||||
|
||||
# Step1: center padding
|
||||
if self.center:
|
||||
pad = self.win_length // 2
|
||||
max_length = max_length + 2 * pad
|
||||
input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0)
|
||||
input[:, :pad, :] = input[:, pad : (2 * pad), :]
|
||||
input[:, (max_length - pad) : max_length, :] = input[
|
||||
:, (max_length - 2 * pad) : (max_length - pad), :
|
||||
]
|
||||
nframe = (max_length - self.win_length) // self.hop_length + 1
|
||||
|
||||
# Step2: framing
|
||||
output = input.as_strided(
|
||||
(bs, nframe, self.win_length, label_dim),
|
||||
(max_length * label_dim, self.hop_length * label_dim, label_dim, 1),
|
||||
)
|
||||
|
||||
# Step3: aggregate label
|
||||
output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2)
|
||||
output = output.float()
|
||||
|
||||
# Step4: process lengths
|
||||
if ilens is not None:
|
||||
if self.center:
|
||||
pad = self.win_length // 2
|
||||
ilens = ilens + 2 * pad
|
||||
|
||||
olens = (ilens - self.win_length) // self.hop_length + 1
|
||||
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
return output.to(input.dtype), olens
|
||||
83
funasr_local/layers/log_mel.py
Normal file
83
funasr_local/layers/log_mel.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import librosa
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class LogMel(torch.nn.Module):
|
||||
"""Convert STFT to fbank feats
|
||||
|
||||
The arguments is same as librosa.filters.mel
|
||||
|
||||
Args:
|
||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||
n_fft: int > 0 [scalar] number of FFT components
|
||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||
If `None`, use `fmax = fs / 2.0`
|
||||
htk: use HTK formula instead of Slaney
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = None,
|
||||
fmax: float = None,
|
||||
htk: bool = False,
|
||||
log_base: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
_mel_options = dict(
|
||||
sr=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.mel_options = _mel_options
|
||||
self.log_base = log_base
|
||||
|
||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||
melmat = librosa.filters.mel(**_mel_options)
|
||||
# melmat: (D2, D1) -> (D1, D2)
|
||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||
|
||||
def extra_repr(self):
|
||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||
mel_feat = torch.matmul(feat, self.melmat)
|
||||
mel_feat = torch.clamp(mel_feat, min=1e-10)
|
||||
|
||||
if self.log_base is None:
|
||||
logmel_feat = mel_feat.log()
|
||||
elif self.log_base == 2.0:
|
||||
logmel_feat = mel_feat.log2()
|
||||
elif self.log_base == 10.0:
|
||||
logmel_feat = mel_feat.log10()
|
||||
else:
|
||||
logmel_feat = mel_feat.log() / torch.log(self.log_base)
|
||||
|
||||
# Zero padding
|
||||
if ilens is not None:
|
||||
logmel_feat = logmel_feat.masked_fill(
|
||||
make_pad_mask(ilens, logmel_feat, 1), 0.0
|
||||
)
|
||||
else:
|
||||
ilens = feat.new_full(
|
||||
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
|
||||
)
|
||||
return logmel_feat, ilens
|
||||
340
funasr_local/layers/mask_along_axis.py
Normal file
340
funasr_local/layers/mask_along_axis.py
Normal file
@@ -0,0 +1,340 @@
|
||||
import math
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Sequence
|
||||
from typing import Union
|
||||
|
||||
|
||||
def mask_along_axis(
|
||||
spec: torch.Tensor,
|
||||
spec_lengths: torch.Tensor,
|
||||
mask_width_range: Sequence[int] = (0, 30),
|
||||
dim: int = 1,
|
||||
num_mask: int = 2,
|
||||
replace_with_zero: bool = True,
|
||||
):
|
||||
"""Apply mask along the specified direction.
|
||||
|
||||
Args:
|
||||
spec: (Batch, Length, Freq)
|
||||
spec_lengths: (Length): Not using lengths in this implementation
|
||||
mask_width_range: Select the width randomly between this range
|
||||
"""
|
||||
|
||||
org_size = spec.size()
|
||||
if spec.dim() == 4:
|
||||
# spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
|
||||
spec = spec.view(-1, spec.size(2), spec.size(3))
|
||||
|
||||
B = spec.shape[0]
|
||||
# D = Length or Freq
|
||||
D = spec.shape[dim]
|
||||
# mask_length: (B, num_mask, 1)
|
||||
mask_length = torch.randint(
|
||||
mask_width_range[0],
|
||||
mask_width_range[1],
|
||||
(B, num_mask),
|
||||
device=spec.device,
|
||||
).unsqueeze(2)
|
||||
|
||||
# mask_pos: (B, num_mask, 1)
|
||||
mask_pos = torch.randint(
|
||||
0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
|
||||
).unsqueeze(2)
|
||||
|
||||
# aran: (1, 1, D)
|
||||
aran = torch.arange(D, device=spec.device)[None, None, :]
|
||||
# mask: (Batch, num_mask, D)
|
||||
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
|
||||
# Multiply masks: (Batch, num_mask, D) -> (Batch, D)
|
||||
mask = mask.any(dim=1)
|
||||
if dim == 1:
|
||||
# mask: (Batch, Length, 1)
|
||||
mask = mask.unsqueeze(2)
|
||||
elif dim == 2:
|
||||
# mask: (Batch, 1, Freq)
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if replace_with_zero:
|
||||
value = 0.0
|
||||
else:
|
||||
value = spec.mean()
|
||||
|
||||
if spec.requires_grad:
|
||||
spec = spec.masked_fill(mask, value)
|
||||
else:
|
||||
spec = spec.masked_fill_(mask, value)
|
||||
spec = spec.view(*org_size)
|
||||
return spec, spec_lengths
|
||||
|
||||
def mask_along_axis_lfr(
|
||||
spec: torch.Tensor,
|
||||
spec_lengths: torch.Tensor,
|
||||
mask_width_range: Sequence[int] = (0, 30),
|
||||
dim: int = 1,
|
||||
num_mask: int = 2,
|
||||
replace_with_zero: bool = True,
|
||||
lfr_rate: int = 1,
|
||||
):
|
||||
"""Apply mask along the specified direction.
|
||||
|
||||
Args:
|
||||
spec: (Batch, Length, Freq)
|
||||
spec_lengths: (Length): Not using lengths in this implementation
|
||||
mask_width_range: Select the width randomly between this range
|
||||
lfr_rate:low frame rate
|
||||
"""
|
||||
|
||||
org_size = spec.size()
|
||||
if spec.dim() == 4:
|
||||
# spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
|
||||
spec = spec.view(-1, spec.size(2), spec.size(3))
|
||||
|
||||
B = spec.shape[0]
|
||||
# D = Length or Freq
|
||||
D = spec.shape[dim] // lfr_rate
|
||||
# mask_length: (B, num_mask, 1)
|
||||
mask_length = torch.randint(
|
||||
mask_width_range[0],
|
||||
mask_width_range[1],
|
||||
(B, num_mask),
|
||||
device=spec.device,
|
||||
).unsqueeze(2)
|
||||
if lfr_rate > 1:
|
||||
mask_length = mask_length.repeat(1, lfr_rate, 1)
|
||||
# mask_pos: (B, num_mask, 1)
|
||||
mask_pos = torch.randint(
|
||||
0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
|
||||
).unsqueeze(2)
|
||||
if lfr_rate > 1:
|
||||
mask_pos_raw = mask_pos.clone()
|
||||
mask_pos = torch.zeros((B, 0, 1), device=spec.device, dtype=torch.int32)
|
||||
for i in range(lfr_rate):
|
||||
mask_pos_i = mask_pos_raw + D * i
|
||||
mask_pos = torch.cat((mask_pos, mask_pos_i), dim=1)
|
||||
# aran: (1, 1, D)
|
||||
D = spec.shape[dim]
|
||||
aran = torch.arange(D, device=spec.device)[None, None, :]
|
||||
# mask: (Batch, num_mask, D)
|
||||
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
|
||||
# Multiply masks: (Batch, num_mask, D) -> (Batch, D)
|
||||
mask = mask.any(dim=1)
|
||||
if dim == 1:
|
||||
# mask: (Batch, Length, 1)
|
||||
mask = mask.unsqueeze(2)
|
||||
elif dim == 2:
|
||||
# mask: (Batch, 1, Freq)
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if replace_with_zero:
|
||||
value = 0.0
|
||||
else:
|
||||
value = spec.mean()
|
||||
|
||||
if spec.requires_grad:
|
||||
spec = spec.masked_fill(mask, value)
|
||||
else:
|
||||
spec = spec.masked_fill_(mask, value)
|
||||
spec = spec.view(*org_size)
|
||||
return spec, spec_lengths
|
||||
|
||||
|
||||
class MaskAlongAxis(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mask_width_range: Union[int, Sequence[int]] = (0, 30),
|
||||
num_mask: int = 2,
|
||||
dim: Union[int, str] = "time",
|
||||
replace_with_zero: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if isinstance(mask_width_range, int):
|
||||
mask_width_range = (0, mask_width_range)
|
||||
if len(mask_width_range) != 2:
|
||||
raise TypeError(
|
||||
f"mask_width_range must be a tuple of int and int values: "
|
||||
f"{mask_width_range}",
|
||||
)
|
||||
|
||||
assert mask_width_range[1] > mask_width_range[0]
|
||||
if isinstance(dim, str):
|
||||
if dim == "time":
|
||||
dim = 1
|
||||
elif dim == "freq":
|
||||
dim = 2
|
||||
else:
|
||||
raise ValueError("dim must be int, 'time' or 'freq'")
|
||||
if dim == 1:
|
||||
self.mask_axis = "time"
|
||||
elif dim == 2:
|
||||
self.mask_axis = "freq"
|
||||
else:
|
||||
self.mask_axis = "unknown"
|
||||
|
||||
super().__init__()
|
||||
self.mask_width_range = mask_width_range
|
||||
self.num_mask = num_mask
|
||||
self.dim = dim
|
||||
self.replace_with_zero = replace_with_zero
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"mask_width_range={self.mask_width_range}, "
|
||||
f"num_mask={self.num_mask}, axis={self.mask_axis}"
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
spec: (Batch, Length, Freq)
|
||||
"""
|
||||
|
||||
return mask_along_axis(
|
||||
spec,
|
||||
spec_lengths,
|
||||
mask_width_range=self.mask_width_range,
|
||||
dim=self.dim,
|
||||
num_mask=self.num_mask,
|
||||
replace_with_zero=self.replace_with_zero,
|
||||
)
|
||||
|
||||
|
||||
class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
|
||||
"""Mask input spec along a specified axis with variable maximum width.
|
||||
|
||||
Formula:
|
||||
max_width = max_width_ratio * seq_len
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
|
||||
num_mask: int = 2,
|
||||
dim: Union[int, str] = "time",
|
||||
replace_with_zero: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if isinstance(mask_width_ratio_range, float):
|
||||
mask_width_ratio_range = (0.0, mask_width_ratio_range)
|
||||
if len(mask_width_ratio_range) != 2:
|
||||
raise TypeError(
|
||||
f"mask_width_ratio_range must be a tuple of float and float values: "
|
||||
f"{mask_width_ratio_range}",
|
||||
)
|
||||
|
||||
assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
|
||||
if isinstance(dim, str):
|
||||
if dim == "time":
|
||||
dim = 1
|
||||
elif dim == "freq":
|
||||
dim = 2
|
||||
else:
|
||||
raise ValueError("dim must be int, 'time' or 'freq'")
|
||||
if dim == 1:
|
||||
self.mask_axis = "time"
|
||||
elif dim == 2:
|
||||
self.mask_axis = "freq"
|
||||
else:
|
||||
self.mask_axis = "unknown"
|
||||
|
||||
super().__init__()
|
||||
self.mask_width_ratio_range = mask_width_ratio_range
|
||||
self.num_mask = num_mask
|
||||
self.dim = dim
|
||||
self.replace_with_zero = replace_with_zero
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"mask_width_ratio_range={self.mask_width_ratio_range}, "
|
||||
f"num_mask={self.num_mask}, axis={self.mask_axis}"
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
spec: (Batch, Length, Freq)
|
||||
"""
|
||||
|
||||
max_seq_len = spec.shape[self.dim]
|
||||
min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
|
||||
min_mask_width = max([0, min_mask_width])
|
||||
max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
|
||||
max_mask_width = min([max_seq_len, max_mask_width])
|
||||
|
||||
if max_mask_width > min_mask_width:
|
||||
return mask_along_axis(
|
||||
spec,
|
||||
spec_lengths,
|
||||
mask_width_range=(min_mask_width, max_mask_width),
|
||||
dim=self.dim,
|
||||
num_mask=self.num_mask,
|
||||
replace_with_zero=self.replace_with_zero,
|
||||
)
|
||||
return spec, spec_lengths
|
||||
|
||||
class MaskAlongAxisLFR(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mask_width_range: Union[int, Sequence[int]] = (0, 30),
|
||||
num_mask: int = 2,
|
||||
dim: Union[int, str] = "time",
|
||||
replace_with_zero: bool = True,
|
||||
lfr_rate: int = 1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if isinstance(mask_width_range, int):
|
||||
mask_width_range = (0, mask_width_range)
|
||||
if len(mask_width_range) != 2:
|
||||
raise TypeError(
|
||||
f"mask_width_range must be a tuple of int and int values: "
|
||||
f"{mask_width_range}",
|
||||
)
|
||||
|
||||
assert mask_width_range[1] > mask_width_range[0]
|
||||
if isinstance(dim, str):
|
||||
if dim == "time":
|
||||
dim = 1
|
||||
lfr_rate = 1
|
||||
elif dim == "freq":
|
||||
dim = 2
|
||||
else:
|
||||
raise ValueError("dim must be int, 'time' or 'freq'")
|
||||
if dim == 1:
|
||||
self.mask_axis = "time"
|
||||
lfr_rate = 1
|
||||
elif dim == 2:
|
||||
self.mask_axis = "freq"
|
||||
else:
|
||||
self.mask_axis = "unknown"
|
||||
|
||||
super().__init__()
|
||||
self.mask_width_range = mask_width_range
|
||||
self.num_mask = num_mask
|
||||
self.dim = dim
|
||||
self.replace_with_zero = replace_with_zero
|
||||
self.lfr_rate = lfr_rate
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"mask_width_range={self.mask_width_range}, "
|
||||
f"num_mask={self.num_mask}, axis={self.mask_axis}"
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
spec: (Batch, Length, Freq)
|
||||
"""
|
||||
|
||||
return mask_along_axis_lfr(
|
||||
spec,
|
||||
spec_lengths,
|
||||
mask_width_range=self.mask_width_range,
|
||||
dim=self.dim,
|
||||
num_mask=self.num_mask,
|
||||
replace_with_zero=self.replace_with_zero,
|
||||
lfr_rate=self.lfr_rate,
|
||||
)
|
||||
273
funasr_local/layers/sinc_conv.py
Normal file
273
funasr_local/layers/sinc_conv.py
Normal file
@@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python3
|
||||
# 2020, Technische Universität München; Ludwig Kürzinger
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Sinc convolutions."""
|
||||
import math
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Union
|
||||
|
||||
|
||||
class LogCompression(torch.nn.Module):
|
||||
"""Log Compression Activation.
|
||||
|
||||
Activation function `log(abs(x) + 1)`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize."""
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward.
|
||||
|
||||
Applies the Log Compression function elementwise on tensor x.
|
||||
"""
|
||||
return torch.log(torch.abs(x) + 1)
|
||||
|
||||
|
||||
class SincConv(torch.nn.Module):
|
||||
"""Sinc Convolution.
|
||||
|
||||
This module performs a convolution using Sinc filters in time domain as kernel.
|
||||
Sinc filters function as band passes in spectral domain.
|
||||
The filtering is done as a convolution in time domain, and no transformation
|
||||
to spectral domain is necessary.
|
||||
|
||||
This implementation of the Sinc convolution is heavily inspired
|
||||
by Ravanelli et al. https://github.com/mravanelli/SincNet,
|
||||
and adapted for the ESpnet toolkit.
|
||||
Combine Sinc convolutions with a log compression activation function, as in:
|
||||
https://arxiv.org/abs/2010.07597
|
||||
|
||||
Notes:
|
||||
Currently, the same filters are applied to all input channels.
|
||||
The windowing function is applied on the kernel to obtained a smoother filter,
|
||||
and not on the input values, which is different to traditional ASR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
window_func: str = "hamming",
|
||||
scale_type: str = "mel",
|
||||
fs: Union[int, float] = 16000,
|
||||
):
|
||||
"""Initialize Sinc convolutions.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
kernel_size: Sinc filter kernel size (needs to be an odd number).
|
||||
stride: See torch.nn.functional.conv1d.
|
||||
padding: See torch.nn.functional.conv1d.
|
||||
dilation: See torch.nn.functional.conv1d.
|
||||
window_func: Window function on the filter, one of ["hamming", "none"].
|
||||
fs (str, int, float): Sample rate of the input data
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
window_funcs = {
|
||||
"none": self.none_window,
|
||||
"hamming": self.hamming_window,
|
||||
}
|
||||
if window_func not in window_funcs:
|
||||
raise NotImplementedError(
|
||||
f"Window function has to be one of {list(window_funcs.keys())}",
|
||||
)
|
||||
self.window_func = window_funcs[window_func]
|
||||
scale_choices = {
|
||||
"mel": MelScale,
|
||||
"bark": BarkScale,
|
||||
}
|
||||
if scale_type not in scale_choices:
|
||||
raise NotImplementedError(
|
||||
f"Scale has to be one of {list(scale_choices.keys())}",
|
||||
)
|
||||
self.scale = scale_choices[scale_type]
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.stride = stride
|
||||
self.fs = float(fs)
|
||||
if self.kernel_size % 2 == 0:
|
||||
raise ValueError("SincConv: Kernel size must be odd.")
|
||||
self.f = None
|
||||
N = self.kernel_size // 2
|
||||
self._x = 2 * math.pi * torch.linspace(1, N, N)
|
||||
self._window = self.window_func(torch.linspace(1, N, N))
|
||||
# init may get overwritten by E2E network,
|
||||
# but is still required to calculate output dim
|
||||
self.init_filters()
|
||||
|
||||
@staticmethod
|
||||
def sinc(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Sinc function."""
|
||||
x2 = x + 1e-6
|
||||
return torch.sin(x2) / x2
|
||||
|
||||
@staticmethod
|
||||
def none_window(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Identity-like windowing function."""
|
||||
return torch.ones_like(x)
|
||||
|
||||
@staticmethod
|
||||
def hamming_window(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Hamming Windowing function."""
|
||||
L = 2 * x.size(0) + 1
|
||||
x = x.flip(0)
|
||||
return 0.54 - 0.46 * torch.cos(2.0 * math.pi * x / L)
|
||||
|
||||
def init_filters(self):
|
||||
"""Initialize filters with filterbank values."""
|
||||
f = self.scale.bank(self.out_channels, self.fs)
|
||||
f = torch.div(f, self.fs)
|
||||
self.f = torch.nn.Parameter(f, requires_grad=True)
|
||||
|
||||
def _create_filters(self, device: str):
|
||||
"""Calculate coefficients.
|
||||
|
||||
This function (re-)calculates the filter convolutions coefficients.
|
||||
"""
|
||||
f_mins = torch.abs(self.f[:, 0])
|
||||
f_maxs = torch.abs(self.f[:, 0]) + torch.abs(self.f[:, 1] - self.f[:, 0])
|
||||
|
||||
self._x = self._x.to(device)
|
||||
self._window = self._window.to(device)
|
||||
|
||||
f_mins_x = torch.matmul(f_mins.view(-1, 1), self._x.view(1, -1))
|
||||
f_maxs_x = torch.matmul(f_maxs.view(-1, 1), self._x.view(1, -1))
|
||||
|
||||
kernel = (torch.sin(f_maxs_x) - torch.sin(f_mins_x)) / (0.5 * self._x)
|
||||
kernel = kernel * self._window
|
||||
|
||||
kernel_left = kernel.flip(1)
|
||||
kernel_center = (2 * f_maxs - 2 * f_mins).unsqueeze(1)
|
||||
filters = torch.cat([kernel_left, kernel_center, kernel], dim=1)
|
||||
|
||||
filters = filters.view(filters.size(0), 1, filters.size(1))
|
||||
self.sinc_filters = filters
|
||||
|
||||
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
||||
"""Sinc convolution forward function.
|
||||
|
||||
Args:
|
||||
xs: Batch in form of torch.Tensor (B, C_in, D_in).
|
||||
|
||||
Returns:
|
||||
xs: Batch in form of torch.Tensor (B, C_out, D_out).
|
||||
"""
|
||||
self._create_filters(xs.device)
|
||||
xs = torch.nn.functional.conv1d(
|
||||
xs,
|
||||
self.sinc_filters,
|
||||
padding=self.padding,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
groups=self.in_channels,
|
||||
)
|
||||
return xs
|
||||
|
||||
def get_odim(self, idim: int) -> int:
|
||||
"""Obtain the output dimension of the filter."""
|
||||
D_out = idim + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1
|
||||
D_out = (D_out // self.stride) + 1
|
||||
return D_out
|
||||
|
||||
|
||||
class MelScale:
|
||||
"""Mel frequency scale."""
|
||||
|
||||
@staticmethod
|
||||
def convert(f):
|
||||
"""Convert Hz to mel."""
|
||||
return 1125.0 * torch.log(torch.div(f, 700.0) + 1.0)
|
||||
|
||||
@staticmethod
|
||||
def invert(x):
|
||||
"""Convert mel to Hz."""
|
||||
return 700.0 * (torch.exp(torch.div(x, 1125.0)) - 1.0)
|
||||
|
||||
@classmethod
|
||||
def bank(cls, channels: int, fs: float) -> torch.Tensor:
|
||||
"""Obtain initialization values for the mel scale.
|
||||
|
||||
Args:
|
||||
channels: Number of channels.
|
||||
fs: Sample rate.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Filter start frequencíes.
|
||||
torch.Tensor: Filter stop frequencies.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
# min and max bandpass edge frequencies
|
||||
min_frequency = torch.tensor(30.0)
|
||||
max_frequency = torch.tensor(fs * 0.5)
|
||||
frequencies = torch.linspace(
|
||||
cls.convert(min_frequency), cls.convert(max_frequency), channels + 2
|
||||
)
|
||||
frequencies = cls.invert(frequencies)
|
||||
f1, f2 = frequencies[:-2], frequencies[2:]
|
||||
return torch.stack([f1, f2], dim=1)
|
||||
|
||||
|
||||
class BarkScale:
|
||||
"""Bark frequency scale.
|
||||
|
||||
Has wider bandwidths at lower frequencies, see:
|
||||
Critical bandwidth: BARK
|
||||
Zwicker and Terhardt, 1980
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def convert(f):
|
||||
"""Convert Hz to Bark."""
|
||||
b = torch.div(f, 1000.0)
|
||||
b = torch.pow(b, 2.0) * 1.4
|
||||
b = torch.pow(b + 1.0, 0.69)
|
||||
return b * 75.0 + 25.0
|
||||
|
||||
@staticmethod
|
||||
def invert(x):
|
||||
"""Convert Bark to Hz."""
|
||||
f = torch.div(x - 25.0, 75.0)
|
||||
f = torch.pow(f, (1.0 / 0.69))
|
||||
f = torch.div(f - 1.0, 1.4)
|
||||
f = torch.pow(f, 0.5)
|
||||
return f * 1000.0
|
||||
|
||||
@classmethod
|
||||
def bank(cls, channels: int, fs: float) -> torch.Tensor:
|
||||
"""Obtain initialization values for the Bark scale.
|
||||
|
||||
Args:
|
||||
channels: Number of channels.
|
||||
fs: Sample rate.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Filter start frequencíes.
|
||||
torch.Tensor: Filter stop frequencíes.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
# min and max BARK center frequencies by approximation
|
||||
min_center_frequency = torch.tensor(70.0)
|
||||
max_center_frequency = torch.tensor(fs * 0.45)
|
||||
center_frequencies = torch.linspace(
|
||||
cls.convert(min_center_frequency),
|
||||
cls.convert(max_center_frequency),
|
||||
channels,
|
||||
)
|
||||
center_frequencies = cls.invert(center_frequencies)
|
||||
|
||||
f1 = center_frequencies - torch.div(cls.convert(center_frequencies), 2)
|
||||
f2 = center_frequencies + torch.div(cls.convert(center_frequencies), 2)
|
||||
return torch.stack([f1, f2], dim=1)
|
||||
234
funasr_local/layers/stft.py
Normal file
234
funasr_local/layers/stft.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.layers.complex_utils import is_complex
|
||||
from funasr_local.layers.inversible_interface import InversibleInterface
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
|
||||
|
||||
|
||||
is_torch_1_7_plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
|
||||
|
||||
|
||||
class Stft(torch.nn.Module, InversibleInterface):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
if win_length is None:
|
||||
self.win_length = n_fft
|
||||
else:
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.normalized = normalized
|
||||
self.onesided = onesided
|
||||
if window is not None and not hasattr(torch, f"{window}_window"):
|
||||
if window.lower() != "povey":
|
||||
raise ValueError(f"{window} window is not implemented")
|
||||
self.window = window
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"n_fft={self.n_fft}, "
|
||||
f"win_length={self.win_length}, "
|
||||
f"hop_length={self.hop_length}, "
|
||||
f"center={self.center}, "
|
||||
f"normalized={self.normalized}, "
|
||||
f"onesided={self.onesided}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""STFT forward function.
|
||||
|
||||
Args:
|
||||
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
||||
ilens: (Batch)
|
||||
Returns:
|
||||
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
||||
|
||||
"""
|
||||
bs = input.size(0)
|
||||
if input.dim() == 3:
|
||||
multi_channel = True
|
||||
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
||||
input = input.transpose(1, 2).reshape(-1, input.size(1))
|
||||
else:
|
||||
multi_channel = False
|
||||
|
||||
# NOTE(kamo):
|
||||
# The default behaviour of torch.stft is compatible with librosa.stft
|
||||
# about padding and scaling.
|
||||
# Note that it's different from scipy.signal.stft
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
||||
if self.window is not None:
|
||||
if self.window.lower() == "povey":
|
||||
window = torch.hann_window(self.win_length, periodic=False,
|
||||
device=input.device, dtype=input.dtype).pow(0.85)
|
||||
else:
|
||||
window_func = getattr(torch, f"{self.window}_window")
|
||||
window = window_func(
|
||||
self.win_length, dtype=input.dtype, device=input.device
|
||||
)
|
||||
else:
|
||||
window = None
|
||||
|
||||
# For the compatibility of ARM devices, which do not support
|
||||
# torch.stft() due to the lake of MKL.
|
||||
if input.is_cuda or torch.backends.mkl.is_available():
|
||||
stft_kwargs = dict(
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=self.center,
|
||||
window=window,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
)
|
||||
if is_torch_1_7_plus:
|
||||
stft_kwargs["return_complex"] = False
|
||||
output = torch.stft(input, **stft_kwargs)
|
||||
else:
|
||||
if self.training:
|
||||
raise NotImplementedError(
|
||||
"stft is implemented with librosa on this device, which does not "
|
||||
"support the training mode."
|
||||
)
|
||||
|
||||
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
|
||||
stft_kwargs = dict(
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=self.center,
|
||||
window=window,
|
||||
)
|
||||
|
||||
if window is not None:
|
||||
# pad the given window to n_fft
|
||||
n_pad_left = (self.n_fft - window.shape[0]) // 2
|
||||
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
|
||||
stft_kwargs["window"] = torch.cat(
|
||||
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
|
||||
).numpy()
|
||||
else:
|
||||
win_length = (
|
||||
self.win_length if self.win_length is not None else self.n_fft
|
||||
)
|
||||
stft_kwargs["window"] = torch.ones(win_length)
|
||||
|
||||
output = []
|
||||
# iterate over istances in a batch
|
||||
for i, instance in enumerate(input):
|
||||
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
|
||||
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
|
||||
output = torch.stack(output, 0)
|
||||
if not self.onesided:
|
||||
len_conj = self.n_fft - output.shape[1]
|
||||
conj = output[:, 1 : 1 + len_conj].flip(1)
|
||||
conj[:, :, :, -1].data *= -1
|
||||
output = torch.cat([output, conj], 1)
|
||||
if self.normalized:
|
||||
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# -> (Batch, Frames, Freq, 2=real_imag)
|
||||
output = output.transpose(1, 2)
|
||||
if multi_channel:
|
||||
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
||||
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
||||
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
|
||||
1, 2
|
||||
)
|
||||
|
||||
if ilens is not None:
|
||||
if self.center:
|
||||
pad = self.n_fft // 2
|
||||
ilens = ilens + 2 * pad
|
||||
|
||||
olens = (ilens - self.n_fft) // self.hop_length + 1
|
||||
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
return output, olens
|
||||
|
||||
def inverse(
|
||||
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Inverse STFT.
|
||||
|
||||
Args:
|
||||
input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
|
||||
ilens: (batch,)
|
||||
Returns:
|
||||
wavs: (batch, samples)
|
||||
ilens: (batch,)
|
||||
"""
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
istft = torch.functional.istft
|
||||
else:
|
||||
try:
|
||||
import torchaudio
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install torchaudio>=0.3.0 or use torch>=1.6.0"
|
||||
)
|
||||
|
||||
if not hasattr(torchaudio.functional, "istft"):
|
||||
raise ImportError(
|
||||
"Please install torchaudio>=0.3.0 or use torch>=1.6.0"
|
||||
)
|
||||
istft = torchaudio.functional.istft
|
||||
|
||||
if self.window is not None:
|
||||
window_func = getattr(torch, f"{self.window}_window")
|
||||
if is_complex(input):
|
||||
datatype = input.real.dtype
|
||||
else:
|
||||
datatype = input.dtype
|
||||
window = window_func(self.win_length, dtype=datatype, device=input.device)
|
||||
else:
|
||||
window = None
|
||||
|
||||
if is_complex(input):
|
||||
input = torch.stack([input.real, input.imag], dim=-1)
|
||||
elif input.shape[-1] != 2:
|
||||
raise TypeError("Invalid input type")
|
||||
input = input.transpose(1, 2)
|
||||
|
||||
wavs = istft(
|
||||
input,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=window,
|
||||
center=self.center,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
length=ilens.max() if ilens is not None else ilens,
|
||||
)
|
||||
|
||||
return wavs, ilens
|
||||
88
funasr_local/layers/time_warp.py
Normal file
88
funasr_local/layers/time_warp.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Time warp module."""
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.nets_utils import pad_list
|
||||
|
||||
DEFAULT_TIME_WARP_MODE = "bicubic"
|
||||
|
||||
|
||||
def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
|
||||
"""Time warping using torch.interpolate.
|
||||
|
||||
Args:
|
||||
x: (Batch, Time, Freq)
|
||||
window: time warp parameter
|
||||
mode: Interpolate mode
|
||||
"""
|
||||
|
||||
# bicubic supports 4D or more dimension tensor
|
||||
org_size = x.size()
|
||||
if x.dim() == 3:
|
||||
# x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
|
||||
x = x[:, None]
|
||||
|
||||
t = x.shape[2]
|
||||
if t - window <= window:
|
||||
return x.view(*org_size)
|
||||
|
||||
center = torch.randint(window, t - window, (1,))[0]
|
||||
warped = torch.randint(center - window, center + window, (1,))[0] + 1
|
||||
|
||||
# left: (Batch, Channel, warped, Freq)
|
||||
# right: (Batch, Channel, time - warped, Freq)
|
||||
left = torch.nn.functional.interpolate(
|
||||
x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
|
||||
)
|
||||
right = torch.nn.functional.interpolate(
|
||||
x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
|
||||
)
|
||||
|
||||
if x.requires_grad:
|
||||
x = torch.cat([left, right], dim=-2)
|
||||
else:
|
||||
x[:, :, :warped] = left
|
||||
x[:, :, warped:] = right
|
||||
|
||||
return x.view(*org_size)
|
||||
|
||||
|
||||
class TimeWarp(torch.nn.Module):
|
||||
"""Time warping using torch.interpolate.
|
||||
|
||||
Args:
|
||||
window: time warp parameter
|
||||
mode: Interpolate mode
|
||||
"""
|
||||
|
||||
def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
|
||||
super().__init__()
|
||||
self.window = window
|
||||
self.mode = mode
|
||||
|
||||
def extra_repr(self):
|
||||
return f"window={self.window}, mode={self.mode}"
|
||||
|
||||
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x: (Batch, Time, Freq)
|
||||
x_lengths: (Batch,)
|
||||
"""
|
||||
|
||||
if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
|
||||
# Note that applying same warping for each sample
|
||||
y = time_warp(x, window=self.window, mode=self.mode)
|
||||
else:
|
||||
# FIXME(kamo): I have no idea to batchify Timewarp
|
||||
ys = []
|
||||
for i in range(x.size(0)):
|
||||
_y = time_warp(
|
||||
x[i][None, : x_lengths[i]],
|
||||
window=self.window,
|
||||
mode=self.mode,
|
||||
)[0]
|
||||
ys.append(_y)
|
||||
y = pad_list(ys, 0.0)
|
||||
|
||||
return y, x_lengths
|
||||
88
funasr_local/layers/utterance_mvn.py
Normal file
88
funasr_local/layers/utterance_mvn.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.layers.abs_normalize import AbsNormalize
|
||||
|
||||
|
||||
class UtteranceMVN(AbsNormalize):
|
||||
def __init__(
|
||||
self,
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = False,
|
||||
eps: float = 1.0e-20,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.eps = eps
|
||||
|
||||
def extra_repr(self):
|
||||
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward function
|
||||
|
||||
Args:
|
||||
x: (B, L, ...)
|
||||
ilens: (B,)
|
||||
|
||||
"""
|
||||
return utterance_mvn(
|
||||
x,
|
||||
ilens,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
|
||||
def utterance_mvn(
|
||||
x: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = False,
|
||||
eps: float = 1.0e-20,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply utterance mean and variance normalization
|
||||
|
||||
Args:
|
||||
x: (B, T, D), assumed zero padded
|
||||
ilens: (B,)
|
||||
norm_means:
|
||||
norm_vars:
|
||||
eps:
|
||||
|
||||
"""
|
||||
if ilens is None:
|
||||
ilens = x.new_full([x.size(0)], x.size(1))
|
||||
ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
|
||||
# Zero padding
|
||||
if x.requires_grad:
|
||||
x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
|
||||
else:
|
||||
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
|
||||
# mean: (B, 1, D)
|
||||
mean = x.sum(dim=1, keepdim=True) / ilens_
|
||||
|
||||
if norm_means:
|
||||
x -= mean
|
||||
|
||||
if norm_vars:
|
||||
var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||
std = torch.clamp(var.sqrt(), min=eps)
|
||||
x = x / std.sqrt()
|
||||
return x, ilens
|
||||
else:
|
||||
if norm_vars:
|
||||
y = x - mean
|
||||
y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
|
||||
var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||
std = torch.clamp(var.sqrt(), min=eps)
|
||||
x /= std
|
||||
return x, ilens
|
||||
Reference in New Issue
Block a user