mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
84
funasr_local/modules/frontends/beamformer.py
Normal file
84
funasr_local/modules/frontends/beamformer.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch_complex import functional as FC
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
|
||||
def get_power_spectral_density_matrix(
|
||||
xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
|
||||
) -> ComplexTensor:
|
||||
"""Return cross-channel power spectral density (PSD) matrix
|
||||
|
||||
Args:
|
||||
xs (ComplexTensor): (..., F, C, T)
|
||||
mask (torch.Tensor): (..., F, C, T)
|
||||
normalization (bool):
|
||||
eps (float):
|
||||
Returns
|
||||
psd (ComplexTensor): (..., F, C, C)
|
||||
|
||||
"""
|
||||
# outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
|
||||
psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
|
||||
|
||||
# Averaging mask along C: (..., C, T) -> (..., T)
|
||||
mask = mask.mean(dim=-2)
|
||||
|
||||
# Normalized mask along T: (..., T)
|
||||
if normalization:
|
||||
# If assuming the tensor is padded with zero, the summation along
|
||||
# the time axis is same regardless of the padding length.
|
||||
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
|
||||
|
||||
# psd: (..., T, C, C)
|
||||
psd = psd_Y * mask[..., None, None]
|
||||
# (..., T, C, C) -> (..., C, C)
|
||||
psd = psd.sum(dim=-3)
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def get_mvdr_vector(
|
||||
psd_s: ComplexTensor,
|
||||
psd_n: ComplexTensor,
|
||||
reference_vector: torch.Tensor,
|
||||
eps: float = 1e-15,
|
||||
) -> ComplexTensor:
|
||||
"""Return the MVDR(Minimum Variance Distortionless Response) vector:
|
||||
|
||||
h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
|
||||
|
||||
Reference:
|
||||
On optimal frequency-domain multichannel linear filtering
|
||||
for noise reduction; M. Souden et al., 2010;
|
||||
https://ieeexplore.ieee.org/document/5089420
|
||||
|
||||
Args:
|
||||
psd_s (ComplexTensor): (..., F, C, C)
|
||||
psd_n (ComplexTensor): (..., F, C, C)
|
||||
reference_vector (torch.Tensor): (..., C)
|
||||
eps (float):
|
||||
Returns:
|
||||
beamform_vector (ComplexTensor)r: (..., F, C)
|
||||
"""
|
||||
# Add eps
|
||||
C = psd_n.size(-1)
|
||||
eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
|
||||
shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
|
||||
eye = eye.view(*shape)
|
||||
psd_n += eps * eye
|
||||
|
||||
# numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
|
||||
numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
|
||||
# ws: (..., C, C) / (...,) -> (..., C, C)
|
||||
ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
|
||||
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
|
||||
beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
|
||||
return beamform_vector
|
||||
|
||||
|
||||
def apply_beamforming_vector(
|
||||
beamform_vector: ComplexTensor, mix: ComplexTensor
|
||||
) -> ComplexTensor:
|
||||
# (..., C) x (..., C, T) -> (..., T)
|
||||
es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
|
||||
return es
|
||||
Reference in New Issue
Block a user