diff --git a/README.md b/README.md index 99df7bb..dd7cdb3 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,55 @@ python matcha/train.py experiment=ljspeech trainer.devices=[0,1] matcha-tts --text "" --checkpoint_path ``` +## ONNX support + +It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph. + +### ONNX export + +To export a checkpoint to ONNX, run the following: + +```bash +python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5 +``` + +Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint** arguments. This enables you to embed the vocoder in the exported graph and generate waveforms in a single run (similar to end-to-end TTS systems). + +**Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**. + +**Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release. + +### ONNX Inference + +To run inference on the exported model, use the following: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs +``` + +You can also control synthesis parameters: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0 +``` + +To run inference on **GPU**, make sure to install **onnxruntime-gpu** package, and then pass `--gpu` to the inference command: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --gpu +``` + +If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory. +If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory. + +If you exported only Matcha to ONNX, and you want to run a full TTS pipeline, you can pass a path to a vocoder model in `ONNX` format: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vocoder hifigan.small.onnx +``` + +This will write `.wav` audio files to the output directory. + ## Citation information If you use our code or otherwise find this work useful, please cite our paper: diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index bc5ed06..6feb9e7 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -116,7 +116,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 w = torch.exp(logw) * x_mask w_ceil = torch.ceil(w) * length_scale y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_max_length = int(y_lengths.max()) + y_max_length = y_lengths.max() y_max_length_ = fix_len_compatibility(y_max_length) # Using obtained durations `w` construct alignment map `attn` diff --git a/matcha/onnx/__init__.py b/matcha/onnx/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/matcha/onnx/export.py b/matcha/onnx/export.py new file mode 100644 index 0000000..9b79508 --- /dev/null +++ b/matcha/onnx/export.py @@ -0,0 +1,181 @@ +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +from lightning import LightningModule + +from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder + +DEFAULT_OPSET = 15 + +SEED = 1234 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +class MatchaWithVocoder(LightningModule): + def __init__(self, matcha, vocoder): + super().__init__() + self.matcha = matcha + self.vocoder = vocoder + + def forward(self, x, x_lengths, scales, spks=None): + mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) + wavs = self.vocoder(mel).clamp(-1, 1) + lengths = mel_lengths * 256 + return wavs.squeeze(1), lengths + + +def get_exportable_module(matcha, vocoder, n_timesteps): + """ + Return an appropriate `LighteningModule` and output-node names + based on whether the vocoder is embedded in the final graph + """ + + def onnx_forward_func(x, x_lengths, scales, spks=None): + """ + Custom forward function for accepting + scaler parameters as tensors + """ + # Extract scaler parameters from tensors + temperature = scales[0] + length_scale = scales[1] + output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) + return output["mel"], output["mel_lengths"] + + # Monkey-patch Matcha's forward function + matcha.forward = onnx_forward_func + + if vocoder is None: + model, output_names = matcha, ["mel", "mel_lengths"] + else: + model = MatchaWithVocoder(matcha, vocoder) + output_names = ["wav", "wav_lengths"] + return model, output_names + + +def get_inputs(is_multi_speaker): + """ + Create dummy inputs for tracing + """ + dummy_input_length = 50 + x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) + x_lengths = torch.LongTensor([dummy_input_length]) + + # Scales + temperature = 0.667 + length_scale = 1.0 + scales = torch.Tensor([temperature, length_scale]) + + model_inputs = [x, x_lengths, scales] + input_names = [ + "x", + "x_lengths", + "scales", + ] + + if is_multi_speaker: + spks = torch.LongTensor([1]) + model_inputs.append(spks) + input_names.append("spks") + + return tuple(model_inputs), input_names + + +def main(): + parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") + + parser.add_argument( + "checkpoint_path", + type=str, + help="Path to the model checkpoint", + ) + parser.add_argument("output", type=str, help="Path to output `.onnx` file") + parser.add_argument( + "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" + ) + parser.add_argument( + "--vocoder-name", + type=str, + choices=list(VOCODER_URLS.keys()), + default=None, + help="Name of the vocoder to embed in the ONNX graph", + ) + parser.add_argument( + "--vocoder-checkpoint-path", + type=str, + default=None, + help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", + ) + parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") + + args = parser.parse_args() + + print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") + print(f"Setting n_timesteps to {args.n_timesteps}") + + checkpoint_path = Path(args.checkpoint_path) + matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") + + if args.vocoder_name or args.vocoder_checkpoint_path: + assert ( + args.vocoder_name and args.vocoder_checkpoint_path + ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." + vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") + else: + vocoder = None + + is_multi_speaker = matcha.n_spks > 1 + + dummy_input, input_names = get_inputs(is_multi_speaker) + model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) + + # Set dynamic shape for inputs/outputs + dynamic_axes = { + "x": {0: "batch_size", 1: "time"}, + "x_lengths": {0: "batch_size"}, + } + + if vocoder is None: + dynamic_axes.update( + { + "mel": {0: "batch_size", 2: "time"}, + "mel_lengths": {0: "batch_size"}, + } + ) + else: + print("Embedding the vocoder in the ONNX graph") + dynamic_axes.update( + { + "wav": {0: "batch_size", 1: "time"}, + "wav_lengths": {0: "batch_size"}, + } + ) + + if is_multi_speaker: + dynamic_axes["spks"] = {0: "batch_size"} + + # Create the output directory (if not exists) + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + + model.to_onnx( + args.output, + dummy_input, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=args.opset, + export_params=True, + do_constant_folding=True, + ) + print(f"[🍵] ONNX model exported to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/matcha/onnx/infer.py b/matcha/onnx/infer.py new file mode 100644 index 0000000..89ca925 --- /dev/null +++ b/matcha/onnx/infer.py @@ -0,0 +1,168 @@ +import argparse +import os +import warnings +from pathlib import Path +from time import perf_counter + +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + +from matcha.cli import plot_spectrogram_to_numpy, process_text + + +def validate_args(args): + assert ( + args.text or args.file + ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." + assert args.temperature >= 0, "Sampling temperature cannot be negative" + assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" + return args + + +def write_wavs(model, inputs, output_dir, external_vocoder=None): + if external_vocoder is None: + print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") + t0 = perf_counter() + wavs, wav_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + mel_infer_secs = vocoder_infer_secs = None + else: + print("[🍵] Generating mel using Matcha") + mel_t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + mel_infer_secs = perf_counter() - mel_t0 + print("Generating waveform from mel using external vocoder") + vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} + vocoder_t0 = perf_counter() + wavs = external_vocoder.run(None, vocoder_inputs)[0] + vocoder_infer_secs = perf_counter() - vocoder_t0 + wavs = wavs.squeeze(1) + wav_lengths = mel_lengths * 256 + infer_secs = mel_infer_secs + vocoder_infer_secs + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): + output_filename = output_dir.joinpath(f"output_{i + 1}.wav") + audio = wav[:wav_length] + print(f"Writing audio to {output_filename}") + sf.write(output_filename, audio, 22050, "PCM_24") + + wav_secs = wav_lengths.sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + if mel_infer_secs is not None: + mel_rtf = mel_infer_secs / wav_secs + print(f"Matcha RTF: {mel_rtf}") + if vocoder_infer_secs is not None: + vocoder_rtf = vocoder_infer_secs / wav_secs + print(f"Vocoder RTF: {vocoder_rtf}") + print(f"Overall RTF: {rtf}") + + +def write_mels(model, inputs, output_dir): + t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, mel in enumerate(mels): + output_stem = output_dir.joinpath(f"output_{i + 1}") + plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) + np.save(output_stem.with_suffix(".numpy"), mel) + + wav_secs = (mel_lengths * 256).sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + print(f"RTF: {rtf}") + + +def main(): + parser = argparse.ArgumentParser( + description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" + ) + parser.add_argument( + "model", + type=str, + help="ONNX model to use", + ) + parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") + parser.add_argument("--text", type=str, default=None, help="Text to synthesize") + parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") + parser.add_argument("--spk", type=int, default=None, help="Speaker ID") + parser.add_argument( + "--temperature", + type=float, + default=0.667, + help="Variance of the x0 noise (default: 0.667)", + ) + parser.add_argument( + "--speaking-rate", + type=float, + default=1.0, + help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", + ) + parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") + parser.add_argument( + "--output-dir", + type=str, + default=os.getcwd(), + help="Output folder to save results (default: current dir)", + ) + + args = parser.parse_args() + args = validate_args(args) + + if args.gpu: + providers = ["GPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + model = ort.InferenceSession(args.model, providers=providers) + + model_inputs = model.get_inputs() + model_outputs = list(model.get_outputs()) + + if args.text: + text_lines = args.text.splitlines() + else: + with open(args.file, encoding="utf-8") as file: + text_lines = file.read().splitlines() + + processed_lines = [process_text(0, line, "cpu") for line in text_lines] + x = [line["x"].squeeze() for line in processed_lines] + # Pad + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) + x = x.detach().cpu().numpy() + x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) + inputs = { + "x": x, + "x_lengths": x_lengths, + "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), + } + is_multi_speaker = len(model_inputs) == 4 + if is_multi_speaker: + if args.spk is None: + args.spk = 0 + warn = "[!] Speaker ID not provided! Using speaker ID 0" + warnings.warn(warn, UserWarning) + inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) + + has_vocoder_embedded = model_outputs[0].name == "wav" + if has_vocoder_embedded: + write_wavs(model, inputs, args.output_dir) + elif args.vocoder: + external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) + write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) + else: + warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" + warnings.warn(warn, UserWarning) + write_mels(model, inputs, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/matcha/utils/model.py b/matcha/utils/model.py index d76b4ba..869cc60 100644 --- a/matcha/utils/model.py +++ b/matcha/utils/model.py @@ -7,15 +7,17 @@ import torch def sequence_mask(length, max_length=None): if max_length is None: max_length = length.max() - x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) + x = torch.arange(max_length, dtype=length.dtype, device=length.device) return x.unsqueeze(0) < length.unsqueeze(1) def fix_len_compatibility(length, num_downsamplings_in_unet=2): - while True: - if length % (2**num_downsamplings_in_unet) == 0: - return length - length += 1 + factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) + length = (length / factor).ceil() * factor + if not torch.onnx.is_in_onnx_export(): + return length.int().item() + else: + return length def convert_pad_shape(pad_shape):