mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add llm export script
This commit is contained in:
@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x):
|
||||
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
torch.Tensor: Output tensor.
|
||||
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x_padded = x_padded.view(x.size()[0],
|
||||
x.size()[1],
|
||||
x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)[
|
||||
:, :, :, : x.size(-1) // 2 + 1
|
||||
] # only keep the positions from 0 to time2
|
||||
|
||||
Reference in New Issue
Block a user