diff --git a/matcha/onnx/infer.py b/matcha/onnx/infer.py index 362099b..89ca925 100644 --- a/matcha/onnx/infer.py +++ b/matcha/onnx/infer.py @@ -21,8 +21,8 @@ def validate_args(args): return args -def write_wavs(model, inputs, output_dir, vocoder=None): - if vocoder is None: +def write_wavs(model, inputs, output_dir, external_vocoder=None): + if external_vocoder is None: print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") t0 = perf_counter() wavs, wav_lengths = model.run(None, inputs) @@ -33,10 +33,10 @@ def write_wavs(model, inputs, output_dir, vocoder=None): mel_t0 = perf_counter() mels, mel_lengths = model.run(None, inputs) mel_infer_secs = perf_counter() - mel_t0 - print("Generating waveform from mel using provided vocoder") - vocoder_inputs = {vocoder.get_inputs()[0].name: mels} + print("Generating waveform from mel using external vocoder") + vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} vocoder_t0 = perf_counter() - wavs = vocoder.run(None, vocoder_inputs)[0] + wavs = external_vocoder.run(None, vocoder_inputs)[0] vocoder_infer_secs = perf_counter() - vocoder_t0 wavs = wavs.squeeze(1) wav_lengths = mel_lengths * 256 @@ -156,10 +156,10 @@ def main(): if has_vocoder_embedded: write_wavs(model, inputs, args.output_dir) elif args.vocoder: - vocoder = ort.InferenceSession(args.vocoder, providers=providers) - write_wavs(model, inputs, args.output_dir, vocoder) + external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) + write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) else: - warn = "[!] Vocoder model not embedded nor provided. The mel output will be written to *.numpy files in the output directory" + warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" warnings.warn(warn, UserWarning) write_mels(model, inputs, args.output_dir) diff --git a/matcha/utils/model.py b/matcha/utils/model.py index 0579e46..7909f87 100644 --- a/matcha/utils/model.py +++ b/matcha/utils/model.py @@ -12,8 +12,12 @@ def sequence_mask(length, max_length=None): def fix_len_compatibility(length, num_downsamplings_in_unet=2): - factor = torch.scalar_tensor(num_downsamplings_in_unet).square() - return (length / factor).ceil() * factor + factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) + length = (length / factor).ceil() * factor + if torch.jit.is_tracing(): + return length + else: + return length.int().item() def convert_pad_shape(pad_shape):