mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/models/postencoder/__init__.py
Normal file
0
funasr_local/models/postencoder/__init__.py
Normal file
17
funasr_local/models/postencoder/abs_postencoder.py
Normal file
17
funasr_local/models/postencoder/abs_postencoder.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsPostEncoder(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def output_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
# 2021, University of Stuttgart; Pavel Denisov
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Hugging Face Transformers PostEncoder."""
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import torch
|
||||
|
||||
try:
|
||||
from transformers import AutoModel
|
||||
|
||||
is_transformers_available = True
|
||||
except ImportError:
|
||||
is_transformers_available = False
|
||||
|
||||
|
||||
class HuggingFaceTransformersPostEncoder(AbsPostEncoder):
|
||||
"""Hugging Face Transformers PostEncoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
model_name_or_path: str,
|
||||
):
|
||||
"""Initialize the module."""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
if not is_transformers_available:
|
||||
raise ImportError(
|
||||
"`transformers` is not available. Please install it via `pip install"
|
||||
" transformers` or `cd /path/to/espnet/tools && . ./activate_python.sh"
|
||||
" && ./installers/install_transformers.sh`."
|
||||
)
|
||||
|
||||
model = AutoModel.from_pretrained(model_name_or_path)
|
||||
|
||||
if hasattr(model, "encoder"):
|
||||
self.transformer = model.encoder
|
||||
else:
|
||||
self.transformer = model
|
||||
|
||||
if hasattr(self.transformer, "embed_tokens"):
|
||||
del self.transformer.embed_tokens
|
||||
if hasattr(self.transformer, "wte"):
|
||||
del self.transformer.wte
|
||||
if hasattr(self.transformer, "word_embedding"):
|
||||
del self.transformer.word_embedding
|
||||
|
||||
self.pretrained_params = copy.deepcopy(self.transformer.state_dict())
|
||||
|
||||
if (
|
||||
self.transformer.config.is_encoder_decoder
|
||||
or self.transformer.config.model_type in ["xlnet", "t5"]
|
||||
):
|
||||
self.use_inputs_embeds = True
|
||||
self.extend_attention_mask = False
|
||||
elif self.transformer.config.model_type == "gpt2":
|
||||
self.use_inputs_embeds = True
|
||||
self.extend_attention_mask = True
|
||||
else:
|
||||
self.use_inputs_embeds = False
|
||||
self.extend_attention_mask = True
|
||||
|
||||
self.linear_in = torch.nn.Linear(
|
||||
input_size, self.transformer.config.hidden_size
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward."""
|
||||
input = self.linear_in(input)
|
||||
|
||||
args = {"return_dict": True}
|
||||
|
||||
mask = (~make_pad_mask(input_lengths)).to(input.device).float()
|
||||
|
||||
if self.extend_attention_mask:
|
||||
args["attention_mask"] = _extend_attention_mask(mask)
|
||||
else:
|
||||
args["attention_mask"] = mask
|
||||
|
||||
if self.use_inputs_embeds:
|
||||
args["inputs_embeds"] = input
|
||||
else:
|
||||
args["hidden_states"] = input
|
||||
|
||||
if self.transformer.config.model_type == "mpnet":
|
||||
args["head_mask"] = [None for _ in self.transformer.layer]
|
||||
|
||||
output = self.transformer(**args).last_hidden_state
|
||||
|
||||
return output, input_lengths
|
||||
|
||||
def reload_pretrained_parameters(self):
|
||||
self.transformer.load_state_dict(self.pretrained_params)
|
||||
logging.info("Pretrained Transformers model parameters reloaded!")
|
||||
|
||||
def output_size(self) -> int:
|
||||
"""Get the output size."""
|
||||
return self.transformer.config.hidden_size
|
||||
|
||||
|
||||
def _extend_attention_mask(mask: torch.Tensor) -> torch.Tensor:
|
||||
mask = mask[:, None, None, :]
|
||||
mask = (1.0 - mask) * -10000.0
|
||||
return mask
|
||||
Reference in New Issue
Block a user