mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
58
funasr_local/modules/positionwise_feed_forward.py
Normal file
58
funasr_local/modules/positionwise_feed_forward.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positionwise feed forward layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimenstion.
|
||||
hidden_units (int): The number of hidden units.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
||||
"""Construct an PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
||||
|
||||
|
||||
class PositionwiseFeedForwardDecoderSANM(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimenstion.
|
||||
hidden_units (int): The number of hidden units.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, adim=None, activation=torch.nn.ReLU()):
|
||||
"""Construct an PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForwardDecoderSANM, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim if adim is None else adim, bias=False)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.activation = activation
|
||||
self.norm = LayerNorm(hidden_units)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))
|
||||
Reference in New Issue
Block a user