mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
93
funasr_local/modules/frontends/dnn_wpe.py
Normal file
93
funasr_local/modules/frontends/dnn_wpe.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Tuple
|
||||
|
||||
from pytorch_wpe import wpe_one_iteration
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr_local.modules.frontends.mask_estimator import MaskEstimator
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class DNN_WPE(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
wtype: str = "blstmp",
|
||||
widim: int = 257,
|
||||
wlayers: int = 3,
|
||||
wunits: int = 300,
|
||||
wprojs: int = 320,
|
||||
dropout_rate: float = 0.0,
|
||||
taps: int = 5,
|
||||
delay: int = 3,
|
||||
use_dnn_mask: bool = True,
|
||||
iterations: int = 1,
|
||||
normalization: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.iterations = iterations
|
||||
self.taps = taps
|
||||
self.delay = delay
|
||||
|
||||
self.normalization = normalization
|
||||
self.use_dnn_mask = use_dnn_mask
|
||||
|
||||
self.inverse_power = True
|
||||
|
||||
if self.use_dnn_mask:
|
||||
self.mask_est = MaskEstimator(
|
||||
wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, data: ComplexTensor, ilens: torch.LongTensor
|
||||
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
|
||||
"""The forward function
|
||||
|
||||
Notation:
|
||||
B: Batch
|
||||
C: Channel
|
||||
T: Time or Sequence length
|
||||
F: Freq or Some dimension of the feature vector
|
||||
|
||||
Args:
|
||||
data: (B, C, T, F)
|
||||
ilens: (B,)
|
||||
Returns:
|
||||
data: (B, C, T, F)
|
||||
ilens: (B,)
|
||||
"""
|
||||
# (B, T, C, F) -> (B, F, C, T)
|
||||
enhanced = data = data.permute(0, 3, 2, 1)
|
||||
mask = None
|
||||
|
||||
for i in range(self.iterations):
|
||||
# Calculate power: (..., C, T)
|
||||
power = enhanced.real**2 + enhanced.imag**2
|
||||
if i == 0 and self.use_dnn_mask:
|
||||
# mask: (B, F, C, T)
|
||||
(mask,), _ = self.mask_est(enhanced, ilens)
|
||||
if self.normalization:
|
||||
# Normalize along T
|
||||
mask = mask / mask.sum(dim=-1)[..., None]
|
||||
# (..., C, T) * (..., C, T) -> (..., C, T)
|
||||
power = power * mask
|
||||
|
||||
# Averaging along the channel axis: (..., C, T) -> (..., T)
|
||||
power = power.mean(dim=-2)
|
||||
|
||||
# enhanced: (..., C, T) -> (..., C, T)
|
||||
enhanced = wpe_one_iteration(
|
||||
data.contiguous(),
|
||||
power,
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
inverse_power=self.inverse_power,
|
||||
)
|
||||
|
||||
enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
|
||||
|
||||
# (B, F, C, T) -> (B, T, C, F)
|
||||
enhanced = enhanced.permute(0, 3, 2, 1)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -3)
|
||||
return enhanced, ilens, mask
|
||||
Reference in New Issue
Block a user