From 1b204ed42c7f9d9092c5d2295e753de33e9417f8 Mon Sep 17 00:00:00 2001 From: mush42 Date: Sun, 24 Sep 2023 01:57:35 +0200 Subject: [PATCH 1/9] ONNX export and inference. Complete and tested implmentation. --- README.md | 43 +++++++++ matcha/models/matcha_tts.py | 2 +- matcha/onnx/__init__.py | 0 matcha/onnx/export.py | 181 ++++++++++++++++++++++++++++++++++++ matcha/onnx/infer.py | 168 +++++++++++++++++++++++++++++++++ matcha/utils/model.py | 8 +- 6 files changed, 396 insertions(+), 6 deletions(-) create mode 100644 matcha/onnx/__init__.py create mode 100644 matcha/onnx/export.py create mode 100644 matcha/onnx/infer.py diff --git a/README.md b/README.md index 99df7bb..a448004 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,49 @@ python matcha/train.py experiment=ljspeech trainer.devices=[0,1] matcha-tts --text "" --checkpoint_path ``` +## ONNX support + +It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph. + +### ONNX export + +To export a checkpoint to ONNX, run the following: + +```bash +python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5 +``` + +Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint** arguments. This enables you to embed the vocoder in the exported graph and generate waveforms in a single run (similar to end-to-end TTS systems). + +**Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**. + +**Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release. + +### ONNX Inference + +To run inference on the exported model, use the following: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs +``` + +You can also control synthesis parameters: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0 +``` + +If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory. +If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory. + +If you exported only Matcha to ONNX, and you want to run a full TTS pipeline, you can pass a path to a vocoder model in `ONNX` format: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vocoder hifigan.small.onnx +``` + +This will write `.wav` audio files to the output directory. + ## Citation information If you use our code or otherwise find this work useful, please cite our paper: diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index bc5ed06..6feb9e7 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -116,7 +116,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 w = torch.exp(logw) * x_mask w_ceil = torch.ceil(w) * length_scale y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_max_length = int(y_lengths.max()) + y_max_length = y_lengths.max() y_max_length_ = fix_len_compatibility(y_max_length) # Using obtained durations `w` construct alignment map `attn` diff --git a/matcha/onnx/__init__.py b/matcha/onnx/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/matcha/onnx/export.py b/matcha/onnx/export.py new file mode 100644 index 0000000..6fc0765 --- /dev/null +++ b/matcha/onnx/export.py @@ -0,0 +1,181 @@ +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +from lightning import LightningModule + +from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder + +DEFAULT_OPSET = 15 + +SEED = 1234 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +class MatchaWithVocoder(LightningModule): + def __init__(self, matcha, vocoder): + super().__init__() + self.matcha = matcha + self.vocoder = vocoder + + def forward(self, x, x_lengths, scales, spks=None): + mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) + wavs = self.vocoder(mel).clamp(-1, 1) + lengths = mel_lengths * 256 + return wavs.squeeze(1), lengths + + +def get_exportable_module(matcha, vocoder, n_timesteps): + """ + Return an appropriate `LighteningModule` and output-node names + based on whether the vocoder is embedded in the final graph + """ + + def onnx_forward_func(x, x_lengths, scales, spks=None): + """ + Custom forward function for accepting + scaler parameters as tensors + """ + # Extract scaler parameters from tensors + temperature = scales[0] + length_scale = scales[1] + output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) + return output["mel"], output["mel_lengths"] + + # Monkey-patch Matcha's forward function + matcha.forward = onnx_forward_func + + if vocoder is None: + model, output_names = matcha, ["mel", "mel_lengths"] + else: + model = MatchaWithVocoder(matcha, vocoder) + output_names = ["wav", "wav_lengths"] + return model, output_names + + +def get_inputs(is_multi_speaker): + """ + Create dummy inputs for tracing + """ + dummy_input_length = 50 + x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) + x_lengths = torch.LongTensor([dummy_input_length]) + + # Scales + temperature = 0.667 + length_scale = 1.0 + scales = torch.Tensor([temperature, length_scale]) + + model_inputs = [x, x_lengths, scales] + input_names = [ + "x", + "x_lengths", + "scales", + ] + + if is_multi_speaker: + spks = torch.LongTensor([1]) + model_inputs.append(spks) + input_names.append("spks") + + return tuple(model_inputs), input_names + + +def main(): + parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") + + parser.add_argument( + "checkpoint_path", + type=str, + help="Path to the model checkpoint", + ) + parser.add_argument("output", type=str, help="Path to output `.onnx` file") + parser.add_argument( + "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" + ) + parser.add_argument( + "--vocoder-name", + type=str, + choices=list(VOCODER_URLS.keys()), + default=None, + help="Name of the vocoder to embed in the ONNX graph", + ) + parser.add_argument( + "--vocoder-checkpoint-path", + type=str, + default=None, + help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", + ) + parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") + + args = parser.parse_args() + + print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") + print(f"Setting n_timesteps to {args.n_timesteps}") + + checkpoint_path = Path(args.checkpoint_path) + matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") + + if args.vocoder_name or args.vocoder_checkpoint_path: + assert ( + args.vocoder_name and args.vocoder_checkpoint_path + ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." + vocoder = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") + else: + vocoder = None + + is_multi_speaker = matcha.n_spks > 1 + + dummy_input, input_names = get_inputs(is_multi_speaker) + model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) + + # Set dynamic shape for inputs/outputs + dynamic_axes = { + "x": {0: "batch_size", 1: "time"}, + "x_lengths": {0: "batch_size"}, + } + + if vocoder is None: + dynamic_axes.update( + { + "mel": {0: "batch_size", 2: "time"}, + "mel_lengths": {0: "batch_size"}, + } + ) + else: + print("Embedding the vocoder in the ONNX graph") + dynamic_axes.update( + { + "wav": {0: "batch_size", 1: "time"}, + "wav_lengths": {0: "batch_size"}, + } + ) + + if is_multi_speaker: + dynamic_axes["spks"] = {0: "batch_size"} + + # Create the output directory (if not exists) + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + + model.to_onnx( + args.output, + dummy_input, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=args.opset, + export_params=True, + do_constant_folding=True, + ) + print(f"[🍵] ONNX model exported to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/matcha/onnx/infer.py b/matcha/onnx/infer.py new file mode 100644 index 0000000..362099b --- /dev/null +++ b/matcha/onnx/infer.py @@ -0,0 +1,168 @@ +import argparse +import os +import warnings +from pathlib import Path +from time import perf_counter + +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + +from matcha.cli import plot_spectrogram_to_numpy, process_text + + +def validate_args(args): + assert ( + args.text or args.file + ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." + assert args.temperature >= 0, "Sampling temperature cannot be negative" + assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" + return args + + +def write_wavs(model, inputs, output_dir, vocoder=None): + if vocoder is None: + print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") + t0 = perf_counter() + wavs, wav_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + mel_infer_secs = vocoder_infer_secs = None + else: + print("[🍵] Generating mel using Matcha") + mel_t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + mel_infer_secs = perf_counter() - mel_t0 + print("Generating waveform from mel using provided vocoder") + vocoder_inputs = {vocoder.get_inputs()[0].name: mels} + vocoder_t0 = perf_counter() + wavs = vocoder.run(None, vocoder_inputs)[0] + vocoder_infer_secs = perf_counter() - vocoder_t0 + wavs = wavs.squeeze(1) + wav_lengths = mel_lengths * 256 + infer_secs = mel_infer_secs + vocoder_infer_secs + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): + output_filename = output_dir.joinpath(f"output_{i + 1}.wav") + audio = wav[:wav_length] + print(f"Writing audio to {output_filename}") + sf.write(output_filename, audio, 22050, "PCM_24") + + wav_secs = wav_lengths.sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + if mel_infer_secs is not None: + mel_rtf = mel_infer_secs / wav_secs + print(f"Matcha RTF: {mel_rtf}") + if vocoder_infer_secs is not None: + vocoder_rtf = vocoder_infer_secs / wav_secs + print(f"Vocoder RTF: {vocoder_rtf}") + print(f"Overall RTF: {rtf}") + + +def write_mels(model, inputs, output_dir): + t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, mel in enumerate(mels): + output_stem = output_dir.joinpath(f"output_{i + 1}") + plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) + np.save(output_stem.with_suffix(".numpy"), mel) + + wav_secs = (mel_lengths * 256).sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + print(f"RTF: {rtf}") + + +def main(): + parser = argparse.ArgumentParser( + description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" + ) + parser.add_argument( + "model", + type=str, + help="ONNX model to use", + ) + parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") + parser.add_argument("--text", type=str, default=None, help="Text to synthesize") + parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") + parser.add_argument("--spk", type=int, default=None, help="Speaker ID") + parser.add_argument( + "--temperature", + type=float, + default=0.667, + help="Variance of the x0 noise (default: 0.667)", + ) + parser.add_argument( + "--speaking-rate", + type=float, + default=1.0, + help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", + ) + parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") + parser.add_argument( + "--output-dir", + type=str, + default=os.getcwd(), + help="Output folder to save results (default: current dir)", + ) + + args = parser.parse_args() + args = validate_args(args) + + if args.gpu: + providers = ["GPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + model = ort.InferenceSession(args.model, providers=providers) + + model_inputs = model.get_inputs() + model_outputs = list(model.get_outputs()) + + if args.text: + text_lines = args.text.splitlines() + else: + with open(args.file, encoding="utf-8") as file: + text_lines = file.read().splitlines() + + processed_lines = [process_text(0, line, "cpu") for line in text_lines] + x = [line["x"].squeeze() for line in processed_lines] + # Pad + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) + x = x.detach().cpu().numpy() + x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) + inputs = { + "x": x, + "x_lengths": x_lengths, + "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), + } + is_multi_speaker = len(model_inputs) == 4 + if is_multi_speaker: + if args.spk is None: + args.spk = 0 + warn = "[!] Speaker ID not provided! Using speaker ID 0" + warnings.warn(warn, UserWarning) + inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) + + has_vocoder_embedded = model_outputs[0].name == "wav" + if has_vocoder_embedded: + write_wavs(model, inputs, args.output_dir) + elif args.vocoder: + vocoder = ort.InferenceSession(args.vocoder, providers=providers) + write_wavs(model, inputs, args.output_dir, vocoder) + else: + warn = "[!] Vocoder model not embedded nor provided. The mel output will be written to *.numpy files in the output directory" + warnings.warn(warn, UserWarning) + write_mels(model, inputs, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/matcha/utils/model.py b/matcha/utils/model.py index d76b4ba..0579e46 100644 --- a/matcha/utils/model.py +++ b/matcha/utils/model.py @@ -7,15 +7,13 @@ import torch def sequence_mask(length, max_length=None): if max_length is None: max_length = length.max() - x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) + x = torch.arange(max_length, dtype=length.dtype, device=length.device) return x.unsqueeze(0) < length.unsqueeze(1) def fix_len_compatibility(length, num_downsamplings_in_unet=2): - while True: - if length % (2**num_downsamplings_in_unet) == 0: - return length - length += 1 + factor = torch.scalar_tensor(num_downsamplings_in_unet).square() + return (length / factor).ceil() * factor def convert_pad_shape(pad_shape): From 25767f76a8109642c6b52b753e250e024513d679 Mon Sep 17 00:00:00 2001 From: mush42 Date: Sun, 24 Sep 2023 02:13:27 +0200 Subject: [PATCH 2/9] Readme: added a note about GPU inference with onnxruntime. --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index a448004..dd7cdb3 100644 --- a/README.md +++ b/README.md @@ -221,6 +221,12 @@ You can also control synthesis parameters: python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0 ``` +To run inference on **GPU**, make sure to install **onnxruntime-gpu** package, and then pass `--gpu` to the inference command: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --gpu +``` + If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory. If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory. From 2c21a0edac9f57141967814f785c841eb95bcf5b Mon Sep 17 00:00:00 2001 From: mush42 Date: Sun, 24 Sep 2023 20:28:59 +0200 Subject: [PATCH 3/9] Fixed an error encountered when loading the vocoder during export. --- matcha/onnx/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matcha/onnx/export.py b/matcha/onnx/export.py index 6fc0765..9b79508 100644 --- a/matcha/onnx/export.py +++ b/matcha/onnx/export.py @@ -127,7 +127,7 @@ def main(): assert ( args.vocoder_name and args.vocoder_checkpoint_path ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." - vocoder = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") + vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") else: vocoder = None From 01c99161c4f7bf4e37b953300d45a0ff9c01b360 Mon Sep 17 00:00:00 2001 From: mush42 Date: Tue, 26 Sep 2023 14:21:17 +0200 Subject: [PATCH 4/9] - Fixed several bugs. Thanks @shivammehta25 for the suggestions --- matcha/onnx/infer.py | 16 ++++++++-------- matcha/utils/model.py | 8 ++++++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/matcha/onnx/infer.py b/matcha/onnx/infer.py index 362099b..89ca925 100644 --- a/matcha/onnx/infer.py +++ b/matcha/onnx/infer.py @@ -21,8 +21,8 @@ def validate_args(args): return args -def write_wavs(model, inputs, output_dir, vocoder=None): - if vocoder is None: +def write_wavs(model, inputs, output_dir, external_vocoder=None): + if external_vocoder is None: print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") t0 = perf_counter() wavs, wav_lengths = model.run(None, inputs) @@ -33,10 +33,10 @@ def write_wavs(model, inputs, output_dir, vocoder=None): mel_t0 = perf_counter() mels, mel_lengths = model.run(None, inputs) mel_infer_secs = perf_counter() - mel_t0 - print("Generating waveform from mel using provided vocoder") - vocoder_inputs = {vocoder.get_inputs()[0].name: mels} + print("Generating waveform from mel using external vocoder") + vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} vocoder_t0 = perf_counter() - wavs = vocoder.run(None, vocoder_inputs)[0] + wavs = external_vocoder.run(None, vocoder_inputs)[0] vocoder_infer_secs = perf_counter() - vocoder_t0 wavs = wavs.squeeze(1) wav_lengths = mel_lengths * 256 @@ -156,10 +156,10 @@ def main(): if has_vocoder_embedded: write_wavs(model, inputs, args.output_dir) elif args.vocoder: - vocoder = ort.InferenceSession(args.vocoder, providers=providers) - write_wavs(model, inputs, args.output_dir, vocoder) + external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) + write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) else: - warn = "[!] Vocoder model not embedded nor provided. The mel output will be written to *.numpy files in the output directory" + warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" warnings.warn(warn, UserWarning) write_mels(model, inputs, args.output_dir) diff --git a/matcha/utils/model.py b/matcha/utils/model.py index 0579e46..7909f87 100644 --- a/matcha/utils/model.py +++ b/matcha/utils/model.py @@ -12,8 +12,12 @@ def sequence_mask(length, max_length=None): def fix_len_compatibility(length, num_downsamplings_in_unet=2): - factor = torch.scalar_tensor(num_downsamplings_in_unet).square() - return (length / factor).ceil() * factor + factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) + length = (length / factor).ceil() * factor + if torch.jit.is_tracing(): + return length + else: + return length.int().item() def convert_pad_shape(pad_shape): From 336dd20d5bc81c57dc2824fee0179fddc80a56e3 Mon Sep 17 00:00:00 2001 From: mush42 Date: Tue, 26 Sep 2023 15:28:15 +0200 Subject: [PATCH 5/9] Use torch.onnx.is_in_onnx_export() instead of torch.jit.is_scripting() since the former is dedicated to this use case. --- matcha/utils/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/matcha/utils/model.py b/matcha/utils/model.py index 7909f87..869cc60 100644 --- a/matcha/utils/model.py +++ b/matcha/utils/model.py @@ -14,10 +14,10 @@ def sequence_mask(length, max_length=None): def fix_len_compatibility(length, num_downsamplings_in_unet=2): factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) length = (length / factor).ceil() * factor - if torch.jit.is_tracing(): - return length - else: + if not torch.onnx.is_in_onnx_export(): return length.int().item() + else: + return length def convert_pad_shape(pad_shape): From 2a81800825c38fbdbdc925c612a9d33e4a9043cb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 13:23:02 +0000 Subject: [PATCH 6/9] Bump diffusers from 0.21.2 to 0.21.3 Bumps [diffusers](https://github.com/huggingface/diffusers) from 0.21.2 to 0.21.3. - [Release notes](https://github.com/huggingface/diffusers/releases) - [Commits](https://github.com/huggingface/diffusers/compare/v0.21.2...v0.21.3) --- updated-dependencies: - dependency-name: diffusers dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c058372..c1be781 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,7 +35,7 @@ torchaudio matplotlib pandas conformer==0.3.2 -diffusers==0.21.2 +diffusers==0.21.3 notebook ipywidgets gradio From 269609003b6140b579be1ae302cb45a9eed7d551 Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Fri, 29 Sep 2023 14:38:57 +0000 Subject: [PATCH 7/9] Adding onnx installation command in the README --- README.md | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index dd7cdb3..696e9b9 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,6 @@ Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS) and read [ [![Watch the video](https://img.youtube.com/vi/xmvJkz3bqw0/hqdefault.jpg)](https://youtu.be/xmvJkz3bqw0) - ## Installation 1. Create an environment (suggested but optional) @@ -46,7 +45,7 @@ conda create -n matcha-tts python=3.10 -y conda activate matcha-tts ``` -2. Install Matcha TTS using pip or from source +2. Install Matcha TTS using pip or from source ```bash pip install matcha-tts @@ -191,11 +190,19 @@ matcha-tts --text "" --checkpoint_path ## ONNX support +> Special thanks to @mush42 for implementing ONNX export and inference support. + It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph. ### ONNX export -To export a checkpoint to ONNX, run the following: +To export a checkpoint to ONNX, first install ONNX with + +```bash +pip install onnx +``` + +then run the following: ```bash python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5 @@ -205,11 +212,18 @@ Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint* **Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**. -**Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release. +**Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release. ### ONNX Inference -To run inference on the exported model, use the following: +To run inference on the exported model, first install `onnxruntime` using + +```bash +pip install onnxruntime +pip install onnxruntime-gpu # for GPU inference +``` + +then use the following: ```bash python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs From 9ace522249590e12dedabde1ae072693af5a8bc2 Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Fri, 29 Sep 2023 16:46:38 +0200 Subject: [PATCH 8/9] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 696e9b9..2534da0 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ matcha-tts --text "" --checkpoint_path ## ONNX support -> Special thanks to @mush42 for implementing ONNX export and inference support. +> Special thanks to [@mush42](https://github.com/mush42) for implementing ONNX export and inference support. It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph. From 1ead4303f3cb7a0e07b3c7e5234b21a2b9056117 Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Fri, 29 Sep 2023 14:50:46 +0000 Subject: [PATCH 9/9] Version Bump --- matcha/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matcha/VERSION b/matcha/VERSION index bcab45a..81340c7 100644 --- a/matcha/VERSION +++ b/matcha/VERSION @@ -1 +1 @@ -0.0.3 +0.0.4