Use torch.onnx.is_in_onnx_export() instead of torch.jit.is_scripting() since the former is dedicated to this use case.

This commit is contained in:
mush42
2023-09-26 15:28:15 +02:00
parent 01c99161c4
commit 336dd20d5b

View File

@@ -14,10 +14,10 @@ def sequence_mask(length, max_length=None):
def fix_len_compatibility(length, num_downsamplings_in_unet=2): def fix_len_compatibility(length, num_downsamplings_in_unet=2):
factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
length = (length / factor).ceil() * factor length = (length / factor).ceil() * factor
if torch.jit.is_tracing(): if not torch.onnx.is_in_onnx_export():
return length
else:
return length.int().item() return length.int().item()
else:
return length
def convert_pad_shape(pad_shape): def convert_pad_shape(pad_shape):