add llm export script

This commit is contained in:
lyuxiang.lx
2024-08-28 18:21:11 +08:00
parent f4e70e222c
commit 9ab298dd49
9 changed files with 34 additions and 17 deletions

View File

@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
"""Compute relative positional encoding.
Args:
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2

View File

@@ -174,7 +174,7 @@ class TransformerDecoder(torch.nn.Module):
memory_mask)
return x
@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward_layers_checkpointed(self, x: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,

View File

@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
"""Construct an PositionalEncoding object."""
super(EspnetRelPositionalEncoding, self).__init__()
self.d_model = d_model
@@ -221,7 +221,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
def extend_pe(self, x: torch.Tensor):
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
@@ -253,7 +253,8 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:

View File

@@ -169,7 +169,7 @@ class BaseEncoder(torch.nn.Module):
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs
@torch.jit.ignore(drop=True)
@torch.jit.unused
def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
@@ -180,6 +180,7 @@ class BaseEncoder(torch.nn.Module):
mask_pad)
return xs
@torch.jit.export
def forward_chunk(
self,
xs: torch.Tensor,
@@ -270,6 +271,7 @@ class BaseEncoder(torch.nn.Module):
return (xs, r_att_cache, r_cnn_cache)
@torch.jit.unused
def forward_chunk_by_chunk(
self,
xs: torch.Tensor,