mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 18:29:19 +08:00
ONNX export and inference. Complete and tested implmentation.
This commit is contained in:
@@ -7,15 +7,13 @@ import torch
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
||||
while True:
|
||||
if length % (2**num_downsamplings_in_unet) == 0:
|
||||
return length
|
||||
length += 1
|
||||
factor = torch.scalar_tensor(num_downsamplings_in_unet).square()
|
||||
return (length / factor).ceil() * factor
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
|
||||
Reference in New Issue
Block a user