- Fixed several bugs. Thanks @shivammehta25 for the suggestions

This commit is contained in:
mush42
2023-09-26 14:21:17 +02:00
parent 2c21a0edac
commit 01c99161c4
2 changed files with 14 additions and 10 deletions

View File

@@ -12,8 +12,12 @@ def sequence_mask(length, max_length=None):
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
factor = torch.scalar_tensor(num_downsamplings_in_unet).square()
return (length / factor).ceil() * factor
factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
length = (length / factor).ceil() * factor
if torch.jit.is_tracing():
return length
else:
return length.int().item()
def convert_pad_shape(pad_shape):