mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
143
funasr_local/models/frontend/s3prl.py
Normal file
143
funasr_local/models/frontend/s3prl.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr_local.modules.frontends.frontend import Frontend
|
||||
from funasr_local.modules.nets_utils import pad_list
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
def base_s3prl_setup(args):
|
||||
args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
|
||||
args.upstream_model_config = getattr(args, "upstream_model_config", None)
|
||||
args.upstream_refresh = getattr(args, "upstream_refresh", False)
|
||||
args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
|
||||
args.init_ckpt = getattr(args, "init_ckpt", None)
|
||||
args.verbose = getattr(args, "verbose", False)
|
||||
args.tile_factor = getattr(args, "tile_factor", 1)
|
||||
return args
|
||||
|
||||
|
||||
class S3prlFrontend(AbsFrontend):
|
||||
"""Speech Pretrained Representation frontend structure for ASR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
download_dir: str = None,
|
||||
multilayer_feature: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
if download_dir is not None:
|
||||
torch.hub.set_dir(download_dir)
|
||||
|
||||
self.multilayer_feature = multilayer_feature
|
||||
self.upstream, self.featurizer = self._get_upstream(frontend_conf)
|
||||
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
|
||||
self.output_dim = self.featurizer.output_dim
|
||||
self.frontend_type = "s3prl"
|
||||
self.hop_length = self.upstream.get_downsample_rates("key")
|
||||
|
||||
def _get_upstream(self, frontend_conf):
|
||||
"""Get S3PRL upstream model."""
|
||||
s3prl_args = base_s3prl_setup(
|
||||
Namespace(**frontend_conf, device="cpu"),
|
||||
)
|
||||
self.args = s3prl_args
|
||||
|
||||
s3prl_path = None
|
||||
python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
|
||||
for p in python_path_list:
|
||||
if p.endswith("s3prl"):
|
||||
s3prl_path = p
|
||||
break
|
||||
assert s3prl_path is not None
|
||||
|
||||
s3prl_upstream = torch.hub.load(
|
||||
s3prl_path,
|
||||
s3prl_args.upstream,
|
||||
ckpt=s3prl_args.upstream_ckpt,
|
||||
model_config=s3prl_args.upstream_model_config,
|
||||
refresh=s3prl_args.upstream_refresh,
|
||||
source="local",
|
||||
).to("cpu")
|
||||
|
||||
if getattr(
|
||||
s3prl_upstream, "model", None
|
||||
) is not None and s3prl_upstream.model.__class__.__name__ in [
|
||||
"Wav2Vec2Model",
|
||||
"HubertModel",
|
||||
]:
|
||||
s3prl_upstream.model.encoder.layerdrop = 0.0
|
||||
|
||||
from s3prl.upstream.interfaces import Featurizer
|
||||
|
||||
if self.multilayer_feature is None:
|
||||
feature_selection = "last_hidden_state"
|
||||
else:
|
||||
feature_selection = "hidden_states"
|
||||
s3prl_featurizer = Featurizer(
|
||||
upstream=s3prl_upstream,
|
||||
feature_selection=feature_selection,
|
||||
upstream_device="cpu",
|
||||
)
|
||||
|
||||
return s3prl_upstream, s3prl_featurizer
|
||||
|
||||
def _tile_representations(self, feature):
|
||||
"""Tile up the representations by `tile_factor`.
|
||||
|
||||
Input - sequence of representations
|
||||
shape: (batch_size, seq_len, feature_dim)
|
||||
Output - sequence of tiled representations
|
||||
shape: (batch_size, seq_len * factor, feature_dim)
|
||||
"""
|
||||
assert (
|
||||
len(feature.shape) == 3
|
||||
), "Input argument `feature` has invalid shape: {}".format(feature.shape)
|
||||
tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
|
||||
tiled_feature = tiled_feature.reshape(
|
||||
feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
|
||||
)
|
||||
return tiled_feature
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.output_dim
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
|
||||
self.upstream.eval()
|
||||
with torch.no_grad():
|
||||
feats = self.upstream(wavs)
|
||||
feats = self.featurizer(wavs, feats)
|
||||
|
||||
if self.args.tile_factor != 1:
|
||||
feats = self._tile_representations(feats)
|
||||
|
||||
input_feats = pad_list(feats, 0.0)
|
||||
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
|
||||
|
||||
# Saving CUDA Memory
|
||||
del feats
|
||||
|
||||
return input_feats, feats_lens
|
||||
|
||||
def reload_pretrained_parameters(self):
|
||||
self.upstream.load_state_dict(self.pretrained_params)
|
||||
logging.info("Pretrained S3PRL frontend model parameters reloaded!")
|
||||
Reference in New Issue
Block a user