From 336dd20d5bc81c57dc2824fee0179fddc80a56e3 Mon Sep 17 00:00:00 2001 From: mush42 Date: Tue, 26 Sep 2023 15:28:15 +0200 Subject: [PATCH] Use torch.onnx.is_in_onnx_export() instead of torch.jit.is_scripting() since the former is dedicated to this use case. --- matcha/utils/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/matcha/utils/model.py b/matcha/utils/model.py index 7909f87..869cc60 100644 --- a/matcha/utils/model.py +++ b/matcha/utils/model.py @@ -14,10 +14,10 @@ def sequence_mask(length, max_length=None): def fix_len_compatibility(length, num_downsamplings_in_unet=2): factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) length = (length / factor).ceil() * factor - if torch.jit.is_tracing(): - return length - else: + if not torch.onnx.is_in_onnx_export(): return length.int().item() + else: + return length def convert_pad_shape(pad_shape):