mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
172
funasr_local/modules/frontends/dnn_beamformer.py
Normal file
172
funasr_local/modules/frontends/dnn_beamformer.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""DNN beamformer module."""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from funasr_local.modules.frontends.beamformer import apply_beamforming_vector
|
||||
from funasr_local.modules.frontends.beamformer import get_mvdr_vector
|
||||
from funasr_local.modules.frontends.beamformer import (
|
||||
get_power_spectral_density_matrix, # noqa: H301
|
||||
)
|
||||
from funasr_local.modules.frontends.mask_estimator import MaskEstimator
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
|
||||
class DNN_Beamformer(torch.nn.Module):
|
||||
"""DNN mask based Beamformer
|
||||
|
||||
Citation:
|
||||
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
|
||||
https://arxiv.org/abs/1703.04783
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bidim,
|
||||
btype="blstmp",
|
||||
blayers=3,
|
||||
bunits=300,
|
||||
bprojs=320,
|
||||
bnmask=2,
|
||||
dropout_rate=0.0,
|
||||
badim=320,
|
||||
ref_channel: int = -1,
|
||||
beamformer_type="mvdr",
|
||||
):
|
||||
super().__init__()
|
||||
self.mask = MaskEstimator(
|
||||
btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
|
||||
)
|
||||
self.ref = AttentionReference(bidim, badim)
|
||||
self.ref_channel = ref_channel
|
||||
|
||||
self.nmask = bnmask
|
||||
|
||||
if beamformer_type != "mvdr":
|
||||
raise ValueError(
|
||||
"Not supporting beamformer_type={}".format(beamformer_type)
|
||||
)
|
||||
self.beamformer_type = beamformer_type
|
||||
|
||||
def forward(
|
||||
self, data: ComplexTensor, ilens: torch.LongTensor
|
||||
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
|
||||
"""The forward function
|
||||
|
||||
Notation:
|
||||
B: Batch
|
||||
C: Channel
|
||||
T: Time or Sequence length
|
||||
F: Freq
|
||||
|
||||
Args:
|
||||
data (ComplexTensor): (B, T, C, F)
|
||||
ilens (torch.Tensor): (B,)
|
||||
Returns:
|
||||
enhanced (ComplexTensor): (B, T, F)
|
||||
ilens (torch.Tensor): (B,)
|
||||
|
||||
"""
|
||||
|
||||
def apply_beamforming(data, ilens, psd_speech, psd_noise):
|
||||
# u: (B, C)
|
||||
if self.ref_channel < 0:
|
||||
u, _ = self.ref(psd_speech, ilens)
|
||||
else:
|
||||
# (optional) Create onehot vector for fixed reference microphone
|
||||
u = torch.zeros(
|
||||
*(data.size()[:-3] + (data.size(-2),)), device=data.device
|
||||
)
|
||||
u[..., self.ref_channel].fill_(1)
|
||||
|
||||
ws = get_mvdr_vector(psd_speech, psd_noise, u)
|
||||
enhanced = apply_beamforming_vector(ws, data)
|
||||
|
||||
return enhanced, ws
|
||||
|
||||
# data (B, T, C, F) -> (B, F, C, T)
|
||||
data = data.permute(0, 3, 2, 1)
|
||||
|
||||
# mask: (B, F, C, T)
|
||||
masks, _ = self.mask(data, ilens)
|
||||
assert self.nmask == len(masks)
|
||||
|
||||
if self.nmask == 2: # (mask_speech, mask_noise)
|
||||
mask_speech, mask_noise = masks
|
||||
|
||||
psd_speech = get_power_spectral_density_matrix(data, mask_speech)
|
||||
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
|
||||
|
||||
enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
|
||||
|
||||
# (..., F, T) -> (..., T, F)
|
||||
enhanced = enhanced.transpose(-1, -2)
|
||||
mask_speech = mask_speech.transpose(-1, -3)
|
||||
else: # multi-speaker case: (mask_speech1, ..., mask_noise)
|
||||
mask_speech = list(masks[:-1])
|
||||
mask_noise = masks[-1]
|
||||
|
||||
psd_speeches = [
|
||||
get_power_spectral_density_matrix(data, mask) for mask in mask_speech
|
||||
]
|
||||
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
|
||||
|
||||
enhanced = []
|
||||
ws = []
|
||||
for i in range(self.nmask - 1):
|
||||
psd_speech = psd_speeches.pop(i)
|
||||
# treat all other speakers' psd_speech as noises
|
||||
enh, w = apply_beamforming(
|
||||
data, ilens, psd_speech, sum(psd_speeches) + psd_noise
|
||||
)
|
||||
psd_speeches.insert(i, psd_speech)
|
||||
|
||||
# (..., F, T) -> (..., T, F)
|
||||
enh = enh.transpose(-1, -2)
|
||||
mask_speech[i] = mask_speech[i].transpose(-1, -3)
|
||||
|
||||
enhanced.append(enh)
|
||||
ws.append(w)
|
||||
|
||||
return enhanced, ilens, mask_speech
|
||||
|
||||
|
||||
class AttentionReference(torch.nn.Module):
|
||||
def __init__(self, bidim, att_dim):
|
||||
super().__init__()
|
||||
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
|
||||
self.gvec = torch.nn.Linear(att_dim, 1)
|
||||
|
||||
def forward(
|
||||
self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
"""The forward function
|
||||
|
||||
Args:
|
||||
psd_in (ComplexTensor): (B, F, C, C)
|
||||
ilens (torch.Tensor): (B,)
|
||||
scaling (float):
|
||||
Returns:
|
||||
u (torch.Tensor): (B, C)
|
||||
ilens (torch.Tensor): (B,)
|
||||
"""
|
||||
B, _, C = psd_in.size()[:3]
|
||||
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
|
||||
# psd_in: (B, F, C, C)
|
||||
psd = psd_in.masked_fill(
|
||||
torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
|
||||
)
|
||||
# psd: (B, F, C, C) -> (B, C, F)
|
||||
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
|
||||
|
||||
# Calculate amplitude
|
||||
psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
|
||||
|
||||
# (B, C, F) -> (B, C, F2)
|
||||
mlp_psd = self.mlp_psd(psd_feat)
|
||||
# (B, C, F2) -> (B, C, 1) -> (B, C)
|
||||
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
|
||||
u = F.softmax(scaling * e, dim=-1)
|
||||
return u, ilens
|
||||
Reference in New Issue
Block a user