mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 09:49:21 +08:00
- Fixed several bugs. Thanks @shivammehta25 for the suggestions
This commit is contained in:
@@ -21,8 +21,8 @@ def validate_args(args):
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def write_wavs(model, inputs, output_dir, vocoder=None):
|
def write_wavs(model, inputs, output_dir, external_vocoder=None):
|
||||||
if vocoder is None:
|
if external_vocoder is None:
|
||||||
print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
|
print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
|
||||||
t0 = perf_counter()
|
t0 = perf_counter()
|
||||||
wavs, wav_lengths = model.run(None, inputs)
|
wavs, wav_lengths = model.run(None, inputs)
|
||||||
@@ -33,10 +33,10 @@ def write_wavs(model, inputs, output_dir, vocoder=None):
|
|||||||
mel_t0 = perf_counter()
|
mel_t0 = perf_counter()
|
||||||
mels, mel_lengths = model.run(None, inputs)
|
mels, mel_lengths = model.run(None, inputs)
|
||||||
mel_infer_secs = perf_counter() - mel_t0
|
mel_infer_secs = perf_counter() - mel_t0
|
||||||
print("Generating waveform from mel using provided vocoder")
|
print("Generating waveform from mel using external vocoder")
|
||||||
vocoder_inputs = {vocoder.get_inputs()[0].name: mels}
|
vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
|
||||||
vocoder_t0 = perf_counter()
|
vocoder_t0 = perf_counter()
|
||||||
wavs = vocoder.run(None, vocoder_inputs)[0]
|
wavs = external_vocoder.run(None, vocoder_inputs)[0]
|
||||||
vocoder_infer_secs = perf_counter() - vocoder_t0
|
vocoder_infer_secs = perf_counter() - vocoder_t0
|
||||||
wavs = wavs.squeeze(1)
|
wavs = wavs.squeeze(1)
|
||||||
wav_lengths = mel_lengths * 256
|
wav_lengths = mel_lengths * 256
|
||||||
@@ -156,10 +156,10 @@ def main():
|
|||||||
if has_vocoder_embedded:
|
if has_vocoder_embedded:
|
||||||
write_wavs(model, inputs, args.output_dir)
|
write_wavs(model, inputs, args.output_dir)
|
||||||
elif args.vocoder:
|
elif args.vocoder:
|
||||||
vocoder = ort.InferenceSession(args.vocoder, providers=providers)
|
external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
|
||||||
write_wavs(model, inputs, args.output_dir, vocoder)
|
write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
|
||||||
else:
|
else:
|
||||||
warn = "[!] Vocoder model not embedded nor provided. The mel output will be written to *.numpy files in the output directory"
|
warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
|
||||||
warnings.warn(warn, UserWarning)
|
warnings.warn(warn, UserWarning)
|
||||||
write_mels(model, inputs, args.output_dir)
|
write_mels(model, inputs, args.output_dir)
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,12 @@ def sequence_mask(length, max_length=None):
|
|||||||
|
|
||||||
|
|
||||||
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
||||||
factor = torch.scalar_tensor(num_downsamplings_in_unet).square()
|
factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
|
||||||
return (length / factor).ceil() * factor
|
length = (length / factor).ceil() * factor
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
return length
|
||||||
|
else:
|
||||||
|
return length.int().item()
|
||||||
|
|
||||||
|
|
||||||
def convert_pad_shape(pad_shape):
|
def convert_pad_shape(pad_shape):
|
||||||
|
|||||||
Reference in New Issue
Block a user