mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add llm export script
This commit is contained in:
@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(EspnetRelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
@@ -221,7 +221,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x):
|
||||
def extend_pe(self, x: torch.Tensor):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
@@ -253,7 +253,8 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
|
||||
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user