mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-04 17:39:19 +08:00
add files
This commit is contained in:
37
funasr_local/models/decoder/sv_decoder.py
Normal file
37
funasr_local/models/decoder/sv_decoder.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
|
||||
|
||||
class DenseDecoder(AbsDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
encoder_output_size,
|
||||
num_nodes_resnet1: int = 256,
|
||||
num_nodes_last_layer: int = 256,
|
||||
batchnorm_momentum: float = 0.5,
|
||||
):
|
||||
super(DenseDecoder, self).__init__()
|
||||
self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1)
|
||||
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
|
||||
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.output_dense = torch.nn.Linear(num_nodes_last_layer, vocab_size, bias=False)
|
||||
|
||||
def forward(self, features):
|
||||
embeddings = {}
|
||||
features = self.resnet1_dense(features)
|
||||
embeddings["resnet1_dense"] = features
|
||||
features = F.relu(features)
|
||||
features = self.resnet1_bn(features)
|
||||
|
||||
features = self.resnet2_dense(features)
|
||||
embeddings["resnet2_dense"] = features
|
||||
features = F.relu(features)
|
||||
features = self.resnet2_bn(features)
|
||||
|
||||
features = self.output_dense(features)
|
||||
return features, embeddings
|
||||
Reference in New Issue
Block a user