ONNX export and inference. Complete and tested implmentation.

This commit is contained in:
mush42
2023-09-24 01:57:35 +02:00
parent 2cd057187b
commit 1b204ed42c
6 changed files with 396 additions and 6 deletions

View File

@@ -7,15 +7,13 @@ import torch
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
while True:
if length % (2**num_downsamplings_in_unet) == 0:
return length
length += 1
factor = torch.scalar_tensor(num_downsamplings_in_unet).square()
return (length / factor).ceil() * factor
def convert_pad_shape(pad_shape):