mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
83
funasr_local/layers/log_mel.py
Normal file
83
funasr_local/layers/log_mel.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import librosa
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class LogMel(torch.nn.Module):
|
||||
"""Convert STFT to fbank feats
|
||||
|
||||
The arguments is same as librosa.filters.mel
|
||||
|
||||
Args:
|
||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||
n_fft: int > 0 [scalar] number of FFT components
|
||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||
If `None`, use `fmax = fs / 2.0`
|
||||
htk: use HTK formula instead of Slaney
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = None,
|
||||
fmax: float = None,
|
||||
htk: bool = False,
|
||||
log_base: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
_mel_options = dict(
|
||||
sr=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.mel_options = _mel_options
|
||||
self.log_base = log_base
|
||||
|
||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||
melmat = librosa.filters.mel(**_mel_options)
|
||||
# melmat: (D2, D1) -> (D1, D2)
|
||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||
|
||||
def extra_repr(self):
|
||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||
mel_feat = torch.matmul(feat, self.melmat)
|
||||
mel_feat = torch.clamp(mel_feat, min=1e-10)
|
||||
|
||||
if self.log_base is None:
|
||||
logmel_feat = mel_feat.log()
|
||||
elif self.log_base == 2.0:
|
||||
logmel_feat = mel_feat.log2()
|
||||
elif self.log_base == 10.0:
|
||||
logmel_feat = mel_feat.log10()
|
||||
else:
|
||||
logmel_feat = mel_feat.log() / torch.log(self.log_base)
|
||||
|
||||
# Zero padding
|
||||
if ilens is not None:
|
||||
logmel_feat = logmel_feat.masked_fill(
|
||||
make_pad_mask(ilens, logmel_feat, 1), 0.0
|
||||
)
|
||||
else:
|
||||
ilens = feat.new_full(
|
||||
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
|
||||
)
|
||||
return logmel_feat, ilens
|
||||
Reference in New Issue
Block a user