mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 17:59:19 +08:00
Use torch.onnx.is_in_onnx_export() instead of torch.jit.is_scripting() since the former is dedicated to this use case.
This commit is contained in:
@@ -14,10 +14,10 @@ def sequence_mask(length, max_length=None):
|
|||||||
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
||||||
factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
|
factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
|
||||||
length = (length / factor).ceil() * factor
|
length = (length / factor).ceil() * factor
|
||||||
if torch.jit.is_tracing():
|
if not torch.onnx.is_in_onnx_export():
|
||||||
return length
|
|
||||||
else:
|
|
||||||
return length.int().item()
|
return length.int().item()
|
||||||
|
else:
|
||||||
|
return length
|
||||||
|
|
||||||
|
|
||||||
def convert_pad_shape(pad_shape):
|
def convert_pad_shape(pad_shape):
|
||||||
|
|||||||
Reference in New Issue
Block a user