mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add llm export script
This commit is contained in:
@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x):
|
||||
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
torch.Tensor: Output tensor.
|
||||
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x_padded = x_padded.view(x.size()[0],
|
||||
x.size()[1],
|
||||
x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)[
|
||||
:, :, :, : x.size(-1) // 2 + 1
|
||||
] # only keep the positions from 0 to time2
|
||||
|
||||
@@ -174,7 +174,7 @@ class TransformerDecoder(torch.nn.Module):
|
||||
memory_mask)
|
||||
return x
|
||||
|
||||
@torch.jit.ignore(drop=True)
|
||||
@torch.jit.unused
|
||||
def forward_layers_checkpointed(self, x: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
|
||||
@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(EspnetRelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
@@ -221,7 +221,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x):
|
||||
def extend_pe(self, x: torch.Tensor):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
@@ -253,7 +253,8 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
|
||||
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -169,7 +169,7 @@ class BaseEncoder(torch.nn.Module):
|
||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||
return xs
|
||||
|
||||
@torch.jit.ignore(drop=True)
|
||||
@torch.jit.unused
|
||||
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
||||
chunk_masks: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
@@ -180,6 +180,7 @@ class BaseEncoder(torch.nn.Module):
|
||||
mask_pad)
|
||||
return xs
|
||||
|
||||
@torch.jit.export
|
||||
def forward_chunk(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
@@ -270,6 +271,7 @@ class BaseEncoder(torch.nn.Module):
|
||||
|
||||
return (xs, r_att_cache, r_cnn_cache)
|
||||
|
||||
@torch.jit.unused
|
||||
def forward_chunk_by_chunk(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user