mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 09:49:21 +08:00
- Fixed several bugs. Thanks @shivammehta25 for the suggestions
This commit is contained in:
@@ -21,8 +21,8 @@ def validate_args(args):
|
||||
return args
|
||||
|
||||
|
||||
def write_wavs(model, inputs, output_dir, vocoder=None):
|
||||
if vocoder is None:
|
||||
def write_wavs(model, inputs, output_dir, external_vocoder=None):
|
||||
if external_vocoder is None:
|
||||
print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
|
||||
t0 = perf_counter()
|
||||
wavs, wav_lengths = model.run(None, inputs)
|
||||
@@ -33,10 +33,10 @@ def write_wavs(model, inputs, output_dir, vocoder=None):
|
||||
mel_t0 = perf_counter()
|
||||
mels, mel_lengths = model.run(None, inputs)
|
||||
mel_infer_secs = perf_counter() - mel_t0
|
||||
print("Generating waveform from mel using provided vocoder")
|
||||
vocoder_inputs = {vocoder.get_inputs()[0].name: mels}
|
||||
print("Generating waveform from mel using external vocoder")
|
||||
vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
|
||||
vocoder_t0 = perf_counter()
|
||||
wavs = vocoder.run(None, vocoder_inputs)[0]
|
||||
wavs = external_vocoder.run(None, vocoder_inputs)[0]
|
||||
vocoder_infer_secs = perf_counter() - vocoder_t0
|
||||
wavs = wavs.squeeze(1)
|
||||
wav_lengths = mel_lengths * 256
|
||||
@@ -156,10 +156,10 @@ def main():
|
||||
if has_vocoder_embedded:
|
||||
write_wavs(model, inputs, args.output_dir)
|
||||
elif args.vocoder:
|
||||
vocoder = ort.InferenceSession(args.vocoder, providers=providers)
|
||||
write_wavs(model, inputs, args.output_dir, vocoder)
|
||||
external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
|
||||
write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
|
||||
else:
|
||||
warn = "[!] Vocoder model not embedded nor provided. The mel output will be written to *.numpy files in the output directory"
|
||||
warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
|
||||
warnings.warn(warn, UserWarning)
|
||||
write_mels(model, inputs, args.output_dir)
|
||||
|
||||
|
||||
@@ -12,8 +12,12 @@ def sequence_mask(length, max_length=None):
|
||||
|
||||
|
||||
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
||||
factor = torch.scalar_tensor(num_downsamplings_in_unet).square()
|
||||
return (length / factor).ceil() * factor
|
||||
factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
|
||||
length = (length / factor).ceil() * factor
|
||||
if torch.jit.is_tracing():
|
||||
return length
|
||||
else:
|
||||
return length.int().item()
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
|
||||
Reference in New Issue
Block a user