update model inference

This commit is contained in:
lyuxiang.lx
2024-07-24 19:18:09 +08:00
parent a13411c561
commit 02f941d348
5 changed files with 85 additions and 64 deletions

View File

@@ -43,7 +43,7 @@ class InterpolateRegulator(nn.Module):
def forward(self, x, ylens=None):
# x in (B, T, D)
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
out = self.model(x).transpose(1, 2).contiguous()
olens = ylens
return out * mask, olens