- Fixed several bugs. Thanks @shivammehta25 for the suggestions

This commit is contained in:
mush42
2023-09-26 14:21:17 +02:00
parent 2c21a0edac
commit 01c99161c4
2 changed files with 14 additions and 10 deletions

View File

@@ -21,8 +21,8 @@ def validate_args(args):
return args return args
def write_wavs(model, inputs, output_dir, vocoder=None): def write_wavs(model, inputs, output_dir, external_vocoder=None):
if vocoder is None: if external_vocoder is None:
print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
t0 = perf_counter() t0 = perf_counter()
wavs, wav_lengths = model.run(None, inputs) wavs, wav_lengths = model.run(None, inputs)
@@ -33,10 +33,10 @@ def write_wavs(model, inputs, output_dir, vocoder=None):
mel_t0 = perf_counter() mel_t0 = perf_counter()
mels, mel_lengths = model.run(None, inputs) mels, mel_lengths = model.run(None, inputs)
mel_infer_secs = perf_counter() - mel_t0 mel_infer_secs = perf_counter() - mel_t0
print("Generating waveform from mel using provided vocoder") print("Generating waveform from mel using external vocoder")
vocoder_inputs = {vocoder.get_inputs()[0].name: mels} vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
vocoder_t0 = perf_counter() vocoder_t0 = perf_counter()
wavs = vocoder.run(None, vocoder_inputs)[0] wavs = external_vocoder.run(None, vocoder_inputs)[0]
vocoder_infer_secs = perf_counter() - vocoder_t0 vocoder_infer_secs = perf_counter() - vocoder_t0
wavs = wavs.squeeze(1) wavs = wavs.squeeze(1)
wav_lengths = mel_lengths * 256 wav_lengths = mel_lengths * 256
@@ -156,10 +156,10 @@ def main():
if has_vocoder_embedded: if has_vocoder_embedded:
write_wavs(model, inputs, args.output_dir) write_wavs(model, inputs, args.output_dir)
elif args.vocoder: elif args.vocoder:
vocoder = ort.InferenceSession(args.vocoder, providers=providers) external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
write_wavs(model, inputs, args.output_dir, vocoder) write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
else: else:
warn = "[!] Vocoder model not embedded nor provided. The mel output will be written to *.numpy files in the output directory" warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
warnings.warn(warn, UserWarning) warnings.warn(warn, UserWarning)
write_mels(model, inputs, args.output_dir) write_mels(model, inputs, args.output_dir)

View File

@@ -12,8 +12,12 @@ def sequence_mask(length, max_length=None):
def fix_len_compatibility(length, num_downsamplings_in_unet=2): def fix_len_compatibility(length, num_downsamplings_in_unet=2):
factor = torch.scalar_tensor(num_downsamplings_in_unet).square() factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
return (length / factor).ceil() * factor length = (length / factor).ceil() * factor
if torch.jit.is_tracing():
return length
else:
return length.int().item()
def convert_pad_shape(pad_shape): def convert_pad_shape(pad_shape):