mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
151
funasr_local/modules/frontends/frontend.py
Normal file
151
funasr_local/modules/frontends/frontend.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr_local.modules.frontends.dnn_beamformer import DNN_Beamformer
|
||||
# from funasr_local.modules.frontends.dnn_wpe import DNN_WPE
|
||||
|
||||
|
||||
class Frontend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
idim: int,
|
||||
# WPE options
|
||||
use_wpe: bool = False,
|
||||
wtype: str = "blstmp",
|
||||
wlayers: int = 3,
|
||||
wunits: int = 300,
|
||||
wprojs: int = 320,
|
||||
wdropout_rate: float = 0.0,
|
||||
taps: int = 5,
|
||||
delay: int = 3,
|
||||
use_dnn_mask_for_wpe: bool = True,
|
||||
# Beamformer options
|
||||
use_beamformer: bool = False,
|
||||
btype: str = "blstmp",
|
||||
blayers: int = 3,
|
||||
bunits: int = 300,
|
||||
bprojs: int = 320,
|
||||
bnmask: int = 2,
|
||||
badim: int = 320,
|
||||
ref_channel: int = -1,
|
||||
bdropout_rate=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_beamformer = use_beamformer
|
||||
self.use_wpe = use_wpe
|
||||
self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
|
||||
# use frontend for all the data,
|
||||
# e.g. in the case of multi-speaker speech separation
|
||||
self.use_frontend_for_all = bnmask > 2
|
||||
|
||||
if self.use_wpe:
|
||||
if self.use_dnn_mask_for_wpe:
|
||||
# Use DNN for power estimation
|
||||
# (Not observed significant gains)
|
||||
iterations = 1
|
||||
else:
|
||||
# Performing as conventional WPE, without DNN Estimator
|
||||
iterations = 2
|
||||
|
||||
self.wpe = DNN_WPE(
|
||||
wtype=wtype,
|
||||
widim=idim,
|
||||
wunits=wunits,
|
||||
wprojs=wprojs,
|
||||
wlayers=wlayers,
|
||||
taps=taps,
|
||||
delay=delay,
|
||||
dropout_rate=wdropout_rate,
|
||||
iterations=iterations,
|
||||
use_dnn_mask=use_dnn_mask_for_wpe,
|
||||
)
|
||||
else:
|
||||
self.wpe = None
|
||||
|
||||
if self.use_beamformer:
|
||||
self.beamformer = DNN_Beamformer(
|
||||
btype=btype,
|
||||
bidim=idim,
|
||||
bunits=bunits,
|
||||
bprojs=bprojs,
|
||||
blayers=blayers,
|
||||
bnmask=bnmask,
|
||||
dropout_rate=bdropout_rate,
|
||||
badim=badim,
|
||||
ref_channel=ref_channel,
|
||||
)
|
||||
else:
|
||||
self.beamformer = None
|
||||
|
||||
def forward(
|
||||
self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
|
||||
) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
|
||||
assert len(x) == len(ilens), (len(x), len(ilens))
|
||||
# (B, T, F) or (B, T, C, F)
|
||||
if x.dim() not in (3, 4):
|
||||
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
|
||||
if not torch.is_tensor(ilens):
|
||||
ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
|
||||
|
||||
mask = None
|
||||
h = x
|
||||
if h.dim() == 4:
|
||||
if self.training:
|
||||
choices = [(False, False)] if not self.use_frontend_for_all else []
|
||||
if self.use_wpe:
|
||||
choices.append((True, False))
|
||||
|
||||
if self.use_beamformer:
|
||||
choices.append((False, True))
|
||||
|
||||
use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
|
||||
|
||||
else:
|
||||
use_wpe = self.use_wpe
|
||||
use_beamformer = self.use_beamformer
|
||||
|
||||
# 1. WPE
|
||||
if use_wpe:
|
||||
# h: (B, T, C, F) -> h: (B, T, C, F)
|
||||
h, ilens, mask = self.wpe(h, ilens)
|
||||
|
||||
# 2. Beamformer
|
||||
if use_beamformer:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
h, ilens, mask = self.beamformer(h, ilens)
|
||||
|
||||
return h, ilens, mask
|
||||
|
||||
|
||||
def frontend_for(args, idim):
|
||||
return Frontend(
|
||||
idim=idim,
|
||||
# WPE options
|
||||
use_wpe=args.use_wpe,
|
||||
wtype=args.wtype,
|
||||
wlayers=args.wlayers,
|
||||
wunits=args.wunits,
|
||||
wprojs=args.wprojs,
|
||||
wdropout_rate=args.wdropout_rate,
|
||||
taps=args.wpe_taps,
|
||||
delay=args.wpe_delay,
|
||||
use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
|
||||
# Beamformer options
|
||||
use_beamformer=args.use_beamformer,
|
||||
btype=args.btype,
|
||||
blayers=args.blayers,
|
||||
bunits=args.bunits,
|
||||
bprojs=args.bprojs,
|
||||
bnmask=args.bnmask,
|
||||
badim=args.badim,
|
||||
ref_channel=args.ref_channel,
|
||||
bdropout_rate=args.bdropout_rate,
|
||||
)
|
||||
Reference in New Issue
Block a user