mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
38
funasr_local/models/preencoder/linear.py
Normal file
38
funasr_local/models/preencoder/linear.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# 2021, Carnegie Mellon University; Xuankai Chang
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Linear Projection."""
|
||||
|
||||
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LinearProjection(AbsPreEncoder):
|
||||
"""Linear Projection Preencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
):
|
||||
"""Initialize the module."""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
self.output_dim = output_size
|
||||
self.linear_out = torch.nn.Linear(input_size, output_size)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward."""
|
||||
output = self.linear_out(input)
|
||||
return output, input_lengths # no state in this layer
|
||||
|
||||
def output_size(self) -> int:
|
||||
"""Get the output size."""
|
||||
return self.output_dim
|
||||
Reference in New Issue
Block a user