mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 01:49:23 +08:00
feat: Initial commit
This commit is contained in:
87
models/encoder/wavlm.py
Normal file
87
models/encoder/wavlm.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import WavLMModel
|
||||
from transformers.modeling_outputs import Wav2Vec2BaseModelOutput
|
||||
from typing import Optional, Tuple, Union
|
||||
import torch.nn.functional as F
|
||||
|
||||
def linear_interpolation(features, output_len: int):
|
||||
features = features.transpose(1, 2)
|
||||
output_features = F.interpolate(
|
||||
features, size=output_len, align_corners=True, mode='linear')
|
||||
return output_features.transpose(1, 2)
|
||||
|
||||
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
|
||||
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
||||
|
||||
|
||||
class WavLMModel(WavLMModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = (not do_freeze)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
mask_time_indices: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
frame_num=None,
|
||||
interpolate_pos: int = 0,
|
||||
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
extract_features = self.feature_extractor(input_values)
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
|
||||
if interpolate_pos == 0:
|
||||
extract_features = linear_interpolation(
|
||||
extract_features, output_len=frame_num)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask corresponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(
|
||||
extract_features.shape[1], attention_mask, add_adapter=False
|
||||
)
|
||||
|
||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||
hidden_states = self._mask_hidden_states(
|
||||
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if interpolate_pos == 1:
|
||||
hidden_states = linear_interpolation(
|
||||
hidden_states, output_len=frame_num)
|
||||
|
||||
if self.adapter is not None:
|
||||
hidden_states = self.adapter(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states, extract_features) + encoder_outputs[1:]
|
||||
|
||||
return Wav2Vec2BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
extract_features=extract_features,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
Reference in New Issue
Block a user