81 Commits
0.0.2 ... main

Author SHA1 Message Date
Shivam Mehta
bd4d90d932 Update README.md 2025-09-17 08:49:57 -07:00
Shivam Mehta
108906c603 Merge pull request #121 from jimregan/english-data
ljspeech/hificaptain from #99
2024-12-02 09:02:41 -06:00
Shivam Mehta
354f5dc69f Merge pull request #123 from jimregan/patch-1
Fix a typo
2024-12-02 08:26:00 -06:00
Jim O’Regan
8e5f98476e Fix a typo 2024-12-02 15:21:31 +01:00
Jim O'Regan
7e499df0b2 ljspeech/hificaptain from #99 2024-12-02 11:01:04 +00:00
Shivam Mehta
0735e653fc Merge pull request #103 from jimregan/mmconv-cleaner
add a cleaner for IPA data (pre-phonetised)
2024-11-13 22:15:47 -08:00
Shivam Mehta
f9843cfca4 Merge pull request #101 from jimregan/pylint
Make pylint happy
2024-11-13 22:13:36 -08:00
Shivam Mehta
289ef51578 Fixing thhe usage of denoiser_strength from the command line. 2024-11-14 06:55:51 +01:00
Shivam Mehta
7a65f83b17 Updating the version 2024-11-14 06:42:06 +01:00
Shivam Mehta
7275764a48 Fixing espeak not removing brackets in some cases 2024-11-14 06:39:58 +01:00
Jim O'Regan
863bfbdd8b rename method, it's more generic than the previous name suggested 2024-10-03 18:51:47 +00:00
Jim O'Regan
4bc541705a add a cleaner for the mmconv data
Different versions of espeak represent things differently, it seems
(also, there are some distinctions none of our speakers make, so
normalising those away reduces perplexity a tiny amount).
2024-10-03 17:18:58 +00:00
pre-commit-ci[bot]
a3fea22988 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-10-02 14:31:11 +00:00
Jim O'Regan
d56f40765c disable consider-using-from-import instead (missed one) 2024-10-02 14:30:18 +00:00
Jim O'Regan
b0ba920dc1 disable consider-using-from-import instead 2024-10-02 14:29:06 +00:00
Jim O'Regan
a220f283e3 disable consider-using-generator 2024-10-02 13:57:12 +00:00
Jim O'Regan
1df73ef43e disable global-variable-not-assigned 2024-10-02 13:55:44 +00:00
Jim O'Regan
404b045b65 add dummy exception (W0719) 2024-10-02 13:51:17 +00:00
Jim O'Regan
7cfae6bed4 add dummy exception (W0719) 2024-10-02 13:49:47 +00:00
Jim O'Regan
a83fd29829 C0209 2024-10-02 13:45:27 +00:00
pre-commit-ci[bot]
c8178bf2cd [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-10-02 13:32:45 +00:00
Jim O'Regan
8b1284993a W1514 + R1732 2024-10-02 13:31:57 +00:00
Jim O'Regan
0000f93021 R1732 + W1514 2024-10-02 13:25:02 +00:00
pre-commit-ci[bot]
c2569a1018 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-10-02 13:21:37 +00:00
Jim O'Regan
bd058a68f7 R0402 2024-10-02 13:21:00 +00:00
Jim O'Regan
362ba2dce7 C0209 2024-10-02 08:38:28 +00:00
Shivam Mehta
77804265f8 removing diffuser versioning 2024-08-09 18:28:56 +02:00
Shivam Mehta
d31cd92a61 Merge pull request #75 from shivammehta25/dev
Adding alginment information to readme
2024-05-27 13:57:49 +02:00
Shivam Mehta
068d135e20 Adding alginment information to readme 2024-05-27 13:57:10 +02:00
Shivam Mehta
bd37d03b62 Merge pull request #74 from shivammehta25/dev
Adding the possibility to use Matcha-TTS as an aligner and train from pretrained extracted alignments.
2024-05-27 13:54:27 +02:00
Shivam Mehta
ac0b258f80 Adding configuration for training from durations 2024-05-27 13:50:21 +02:00
Shivam Mehta
de910380bc Fixing batched synthesis for multispeaker model 2024-05-27 13:40:02 +02:00
Shivam Mehta
aa496aa13f Adding the possibility to train with durations 2024-05-27 13:24:21 +02:00
Shivam Mehta
e658aee6a5 Pinning gradio 2024-05-25 20:15:17 +02:00
Shivam Mehta
d816c40e3d Updating the notebook to adjust to the change 2024-05-24 11:46:03 +02:00
Shivam Mehta
4b39f6cad0 Adding the possibility of get durations out of pretrained model 2024-05-24 11:34:51 +02:00
Shivam Mehta
dd9105b34b Merge pull request #60 from jimregan/patch-1
Pin gradio to 3.43.2
2024-02-27 13:29:42 +01:00
Jim O’Regan
7d9d4cfd40 Pin gradio to 3.43.2
Fixes #59
2024-02-27 13:25:08 +01:00
Shivam Mehta
256adc55d3 Adding ICASSP 2024 2024-01-12 11:31:01 +00:00
Shivam Mehta
bfcbdbc82e Merge pull request #43 from shivammehta25/dev
Removing gdown for HifiGAN checkpoints too
2024-01-12 12:29:03 +01:00
Shivam Mehta
fb7b954de5 Updating different url for hifigan as well 2024-01-12 11:21:51 +00:00
Shivam Mehta
5a52a67cf7 Version bump 2024-01-12 11:11:41 +00:00
Shivam Mehta
39cbd85236 Using Wget for new ckpt downloadsA 2024-01-12 11:09:25 +00:00
Shivam Mehta
47a629f128 Merge pull request #42 from shivammehta25/dev
Merging dev adding another dataset, piper phonemizer and refractoring
2024-01-12 11:49:53 +01:00
Shivam Mehta
95ec24b599 Version bump 2024-01-12 10:48:52 +00:00
Shivam Mehta
5a2a893750 Merge pull request #19 from shivammehta25/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2024-01-12 11:47:10 +01:00
Shivam Mehta
13ca33fbe5 Merge pull request #37 from shivammehta25/dependabot/pip/dev/diffusers-0.25.0
Bump diffusers from 0.21.3 to 0.25.0
2024-01-12 11:46:40 +01:00
Shivam Mehta
19bea20928 Merge branch 'main' into dev 2024-01-12 10:37:17 +00:00
Shivam Mehta
8268360674 Update download urls 2024-01-12 10:32:59 +00:00
Shivam Mehta
a0bf4e9e9a Merge pull request #40 from shivammehta25/ghenter-readme-update-1
Update README.md with ICASSP acceptance
2024-01-12 10:13:23 +01:00
Gustav Eje Henter
f1e8efdec2 Update README.md
Add back full stop that erroneously went missing in the shuffle.
2024-01-09 22:53:09 +01:00
Gustav Eje Henter
4ec245e61e Update README.md with ICASSP acceptance
Added ICASSP acceptance to the README and made some tiny tweaks to the text
2024-01-09 22:48:16 +01:00
pre-commit-ci[bot]
dc035a09f2 [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/pre-commit/pre-commit-hooks: v4.4.0 → v4.5.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.4.0...v4.5.0)
- [github.com/psf/black: 23.9.1 → 23.12.1](https://github.com/psf/black/compare/23.9.1...23.12.1)
- [github.com/PyCQA/isort: 5.12.0 → 5.13.2](https://github.com/PyCQA/isort/compare/5.12.0...5.13.2)
- [github.com/asottile/pyupgrade: v3.14.0 → v3.15.0](https://github.com/asottile/pyupgrade/compare/v3.14.0...v3.15.0)
- [github.com/PyCQA/flake8: 6.1.0 → 7.0.0](https://github.com/PyCQA/flake8/compare/6.1.0...7.0.0)
- [github.com/pycqa/pylint: v3.0.0 → v3.0.3](https://github.com/pycqa/pylint/compare/v3.0.0...v3.0.3)
2024-01-08 21:15:26 +00:00
dependabot[bot]
254a8e05ce Bump diffusers from 0.21.3 to 0.25.0
Bumps [diffusers](https://github.com/huggingface/diffusers) from 0.21.3 to 0.25.0.
- [Release notes](https://github.com/huggingface/diffusers/releases)
- [Commits](https://github.com/huggingface/diffusers/compare/v0.21.3...v0.25.0)

---
updated-dependencies:
- dependency-name: diffusers
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-12-28 13:20:11 +00:00
Shivam Mehta
0ed9290c31 Logging global step while training 2023-12-06 10:39:54 +00:00
Shivam Mehta
f39ee6cf3b Changing while to for for more readibility 2023-12-05 12:10:52 +00:00
Shivam Mehta
6e71dc8b8f adding prior loss as a configuration 2023-12-05 09:57:37 +00:00
Shivam Mehta
ae2417c175 Merge pull request #34 from shivammehta25/piper_phonemize
Piper phonemize
2023-12-04 11:16:24 +01:00
Shivam Mehta
6c7a82a516 Adding dataset information 2023-12-04 10:15:13 +00:00
Shivam Mehta
009b09a8b2 Removing unwanted configs 2023-12-04 10:13:44 +00:00
Shivam Mehta
a18db17330 Removing the option for configuring prior loss, the durations predicted are not so good then 2023-12-04 10:12:39 +00:00
Shivam Mehta
263d5c4d4e Adding piper phonemizer with different dataset 2023-12-01 12:06:26 +00:00
Shivam Mehta
df896301ca Minor changes moving option to disable prior loss in config 2023-12-01 10:44:49 +00:00
Shivam Mehta
c8d0d60f87 Merge pull request #16 from shivammehta25/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2023-10-06 05:44:02 +02:00
pre-commit-ci[bot]
e540794e7e [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/psf/black: 23.1.0 → 23.9.1](https://github.com/psf/black/compare/23.1.0...23.9.1)
- [github.com/asottile/pyupgrade: v3.3.1 → v3.14.0](https://github.com/asottile/pyupgrade/compare/v3.3.1...v3.14.0)
- [github.com/PyCQA/flake8: 6.0.0 → 6.1.0](https://github.com/PyCQA/flake8/compare/6.0.0...6.1.0)
- [github.com/pycqa/pylint: v2.8.2 → v3.0.0](https://github.com/pycqa/pylint/compare/v2.8.2...v3.0.0)
2023-10-03 13:14:20 +00:00
Shivam Mehta
b756809a32 Merge pull request #13 from shivammehta25/dev
Merging dev to main | adding ONNX support
2023-09-29 16:54:09 +02:00
Shivam Mehta
1ead4303f3 Version Bump 2023-09-29 14:50:46 +00:00
Shivam Mehta
7a29fef719 Merge pull request #12 from shivammehta25/dependabot/pip/dev/diffusers-0.21.3
Bump diffusers from 0.21.2 to 0.21.3
2023-09-29 16:48:13 +02:00
Shivam Mehta
9ace522249 Update README.md 2023-09-29 16:46:38 +02:00
Shivam Mehta
ed6e6bbf6c Merge branch 'ONNX_BRANCH' into dev 2023-09-29 14:43:52 +00:00
Shivam Mehta
51ea36d271 Merge pull request #8 from mush42/onnx
ONNX export and inference
2023-09-29 16:43:19 +02:00
Shivam Mehta
269609003b Adding onnx installation command in the README 2023-09-29 14:38:57 +00:00
dependabot[bot]
2a81800825 Bump diffusers from 0.21.2 to 0.21.3
Bumps [diffusers](https://github.com/huggingface/diffusers) from 0.21.2 to 0.21.3.
- [Release notes](https://github.com/huggingface/diffusers/releases)
- [Commits](https://github.com/huggingface/diffusers/compare/v0.21.2...v0.21.3)

---
updated-dependencies:
- dependency-name: diffusers
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-09-28 13:23:02 +00:00
mush42
336dd20d5b Use torch.onnx.is_in_onnx_export() instead of torch.jit.is_scripting() since the former is dedicated to this use case. 2023-09-26 15:28:15 +02:00
mush42
01c99161c4 - Fixed several bugs. Thanks @shivammehta25 for the suggestions 2023-09-26 14:21:17 +02:00
mush42
2c21a0edac Fixed an error encountered when loading the vocoder during export. 2023-09-24 20:28:59 +02:00
mush42
25767f76a8 Readme: added a note about GPU inference with onnxruntime. 2023-09-24 02:13:27 +02:00
mush42
1b204ed42c ONNX export and inference. Complete and tested implmentation. 2023-09-24 01:57:35 +02:00
Shivam Mehta
2cd057187b Update README.md
Add information about installation and compilation of monotonic alignment
2023-09-23 17:39:36 +02:00
Shivam Mehta
d373e9a5b1 Bumping it to an increased version 2023-09-21 13:43:20 +00:00
Shivam Mehta
f12be190a4 ADding video teaser to readme 2023-09-21 13:41:21 +00:00
41 changed files with 1257 additions and 177 deletions

1
.gitignore vendored
View File

@@ -161,3 +161,4 @@ generator_v1
g_02500000 g_02500000
gradio_cached_examples/ gradio_cached_examples/
synth_output/ synth_output/
/data

View File

@@ -1,9 +1,9 @@
default_language_version: default_language_version:
python: python3.10 python: python3.11
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v4.5.0
hooks: hooks:
# list of supported hooks: https://pre-commit.com/hooks.html # list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace - id: trailing-whitespace
@@ -18,28 +18,28 @@ repos:
# python code formatting # python code formatting
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 23.1.0 rev: 23.12.1
hooks: hooks:
- id: black - id: black
args: [--line-length, "120"] args: [--line-length, "120"]
# python import sorting # python import sorting
- repo: https://github.com/PyCQA/isort - repo: https://github.com/PyCQA/isort
rev: 5.12.0 rev: 5.13.2
hooks: hooks:
- id: isort - id: isort
args: ["--profile", "black", "--filter-files"] args: ["--profile", "black", "--filter-files"]
# python upgrading syntax to newer version # python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.3.1 rev: v3.15.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade
args: [--py38-plus] args: [--py38-plus]
# python check (PEP8), programming errors and code complexity # python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 6.0.0 rev: 7.0.0
hooks: hooks:
- id: flake8 - id: flake8
args: args:
@@ -54,6 +54,6 @@ repos:
# pylint # pylint
- repo: https://github.com/pycqa/pylint - repo: https://github.com/pycqa/pylint
rev: v2.8.2 rev: v3.0.3
hooks: hooks:
- id: pylint - id: pylint

View File

@@ -82,16 +82,6 @@ disable=missing-docstring,
no-name-in-module, no-name-in-module,
no-member, no-member,
unsubscriptable-object, unsubscriptable-object,
print-statement,
parameter-unpacking,
unpacking-in-except,
old-raise-syntax,
backtick,
long-suffix,
old-ne-operator,
old-octal-literal,
import-star-module-level,
non-ascii-bytes-literal,
raw-checker-failed, raw-checker-failed,
bad-inline-option, bad-inline-option,
locally-disabled, locally-disabled,
@@ -106,67 +96,6 @@ disable=missing-docstring,
too-many-arguments, too-many-arguments,
too-many-locals, too-many-locals,
too-many-statements, too-many-statements,
apply-builtin,
basestring-builtin,
buffer-builtin,
cmp-builtin,
coerce-builtin,
execfile-builtin,
file-builtin,
long-builtin,
raw_input-builtin,
reduce-builtin,
standarderror-builtin,
unicode-builtin,
xrange-builtin,
coerce-method,
delslice-method,
getslice-method,
setslice-method,
no-absolute-import,
old-division,
dict-iter-method,
dict-view-method,
next-method-called,
metaclass-assignment,
indexing-exception,
raising-string,
reload-builtin,
oct-method,
hex-method,
nonzero-method,
cmp-method,
input-builtin,
round-builtin,
intern-builtin,
unichr-builtin,
map-builtin-not-iterating,
zip-builtin-not-iterating,
range-builtin-not-iterating,
filter-builtin-not-iterating,
using-cmp-argument,
eq-without-hash,
div-method,
idiv-method,
rdiv-method,
exception-message-attribute,
invalid-str-codec,
sys-max-int,
bad-python3-import,
deprecated-string-function,
deprecated-str-translate-call,
deprecated-itertools-function,
deprecated-types-field,
next-method-defined,
dict-items-not-iterating,
dict-keys-not-iterating,
dict-values-not-iterating,
deprecated-operator-function,
deprecated-urllib-function,
xreadlines-attribute,
deprecated-sys-function,
exception-escape,
comprehension-escape,
duplicate-code, duplicate-code,
not-callable, not-callable,
import-outside-toplevel, import-outside-toplevel,
@@ -363,13 +292,6 @@ max-line-length=120
# Maximum number of lines in a module. # Maximum number of lines in a module.
max-module-lines=1000 max-module-lines=1000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body # Allow the body of a class to be on the same line as the declaration if body
# contains single statement. # contains single statement.
single-line-class-stmt=no single-line-class-stmt=no
@@ -599,5 +521,5 @@ min-public-methods=2
# Exceptions that will emit a warning when being caught. Defaults to # Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception". # "BaseException, Exception".
overgeneral-exceptions=BaseException, overgeneral-exceptions=builtins.BaseException,
Exception builtins.Exception

127
README.md
View File

@@ -10,14 +10,14 @@
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/) [![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
[![PyPI Downloads](https://static.pepy.tech/personalized-badge/matcha-tts?period=total&units=INTERNATIONAL_SYSTEM&left_color=BLACK&right_color=GREEN&left_text=downloads)](https://pepy.tech/projects/matcha-tts)
<p style="text-align: center;"> <p style="text-align: center;">
<img src="https://shivammehta25.github.io/Matcha-TTS/images/logo.png" height="128"/> <img src="https://shivammehta25.github.io/Matcha-TTS/images/logo.png" height="128"/>
</p> </p>
</div> </div>
> This is the official code implementation of 🍵 Matcha-TTS. > This is the official code implementation of 🍵 Matcha-TTS [ICASSP 2024].
We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses [conditional flow matching](https://arxiv.org/abs/2210.02747) (similar to [rectified flows](https://arxiv.org/abs/2209.03003)) to speed up ODE-based speech synthesis. Our method: We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses [conditional flow matching](https://arxiv.org/abs/2210.02747) (similar to [rectified flows](https://arxiv.org/abs/2209.03003)) to speed up ODE-based speech synthesis. Our method:
@@ -26,11 +26,15 @@ We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, tha
- Sounds highly natural - Sounds highly natural
- Is very fast to synthesise from - Is very fast to synthesise from
Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS) and read [our arXiv preprint](https://arxiv.org/abs/2309.03199) for more details. Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS) and read [our ICASSP 2024 paper](https://arxiv.org/abs/2309.03199) for more details.
[Pre-trained models](https://drive.google.com/drive/folders/17C_gYgEHOxI5ZypcfE_k1piKCtyR0isJ?usp=sharing) will be automatically downloaded with the CLI or gradio interface. [Pre-trained models](https://drive.google.com/drive/folders/17C_gYgEHOxI5ZypcfE_k1piKCtyR0isJ?usp=sharing) will be automatically downloaded with the CLI or gradio interface.
[Try 🍵 Matcha-TTS on HuggingFace 🤗 spaces!](https://huggingface.co/spaces/shivammehta25/Matcha-TTS) You can also [try 🍵 Matcha-TTS in your browser on HuggingFace 🤗 spaces](https://huggingface.co/spaces/shivammehta25/Matcha-TTS).
## Teaser video
[![Watch the video](https://img.youtube.com/vi/xmvJkz3bqw0/hqdefault.jpg)](https://youtu.be/xmvJkz3bqw0)
## Installation ## Installation
@@ -41,7 +45,7 @@ conda create -n matcha-tts python=3.10 -y
conda activate matcha-tts conda activate matcha-tts
``` ```
2. Install Matcha TTS using pip or from source 2. Install Matcha TTS using pip or from source
```bash ```bash
pip install matcha-tts pip install matcha-tts
@@ -51,6 +55,8 @@ from source
```bash ```bash
pip install git+https://github.com/shivammehta25/Matcha-TTS.git pip install git+https://github.com/shivammehta25/Matcha-TTS.git
cd Matcha-TTS
pip install -e .
``` ```
3. Run CLI / gradio app / jupyter notebook 3. Run CLI / gradio app / jupyter notebook
@@ -182,16 +188,117 @@ python matcha/train.py experiment=ljspeech trainer.devices=[0,1]
matcha-tts --text "<INPUT TEXT>" --checkpoint_path <PATH TO CHECKPOINT> matcha-tts --text "<INPUT TEXT>" --checkpoint_path <PATH TO CHECKPOINT>
``` ```
## ONNX support
> Special thanks to [@mush42](https://github.com/mush42) for implementing ONNX export and inference support.
It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph.
### ONNX export
To export a checkpoint to ONNX, first install ONNX with
```bash
pip install onnx
```
then run the following:
```bash
python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5
```
Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint** arguments. This enables you to embed the vocoder in the exported graph and generate waveforms in a single run (similar to end-to-end TTS systems).
**Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**.
**Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release.
### ONNX Inference
To run inference on the exported model, first install `onnxruntime` using
```bash
pip install onnxruntime
pip install onnxruntime-gpu # for GPU inference
```
then use the following:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs
```
You can also control synthesis parameters:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0
```
To run inference on **GPU**, make sure to install **onnxruntime-gpu** package, and then pass `--gpu` to the inference command:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --gpu
```
If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory.
If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory.
If you exported only Matcha to ONNX, and you want to run a full TTS pipeline, you can pass a path to a vocoder model in `ONNX` format:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vocoder hifigan.small.onnx
```
This will write `.wav` audio files to the output directory.
## Extract phoneme alignments from Matcha-TTS
If the dataset is structured as
```bash
data/
└── LJSpeech-1.1
├── metadata.csv
├── README
├── test.txt
├── train.txt
├── val.txt
└── wavs
```
Then you can extract the phoneme level alignments from a Trained Matcha-TTS model using:
```bash
python matcha/utils/get_durations_from_trained_model.py -i dataset_yaml -c <checkpoint>
```
Example:
```bash
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c matcha_ljspeech.ckpt
```
or simply:
```bash
matcha-tts-get-durations -i ljspeech.yaml -c matcha_ljspeech.ckpt
```
---
## Train using extracted alignments
In the datasetconfig turn on load duration.
Example: `ljspeech.yaml`
```
load_durations: True
```
or see an examples in configs/experiment/ljspeech_from_durations.yaml
## Citation information ## Citation information
If you use our code or otherwise find this work useful, please cite our paper: If you use our code or otherwise find this work useful, please cite our paper:
```text ```text
@article{mehta2023matcha, @inproceedings{mehta2024matcha,
title={Matcha-TTS: A fast TTS architecture with conditional flow matching}, title={Matcha-{TTS}: A fast {TTS} architecture with conditional flow matching},
author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje}, author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje},
journal={arXiv preprint arXiv:2309.03199}, booktitle={Proc. ICASSP},
year={2023} year={2024}
} }
``` ```
@@ -199,7 +306,7 @@ If you use our code or otherwise find this work useful, please cite our paper:
Since this code uses [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that come with it. Since this code uses [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that come with it.
Other source code I would like to acknowledge: Other source code we would like to acknowledge:
- [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev): For helping me figure out how to make cython binaries pip installable and encouragement - [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev): For helping me figure out how to make cython binaries pip installable and encouragement
- [Hugging Face Diffusers](https://huggingface.co/): For their awesome diffusers library and its components - [Hugging Face Diffusers](https://huggingface.co/): For their awesome diffusers library and its components

View File

@@ -0,0 +1,14 @@
defaults:
- ljspeech
- _self_
# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: hi-fi_en-US_female
train_filelist_path: data/hi-fi_en-US_female/train.txt
valid_filelist_path: data/hi-fi_en-US_female/val.txt
batch_size: 32
cleaners: [english_cleaners_piper]
data_statistics: # Computed for this dataset
mel_mean: -6.38385
mel_std: 2.541796

View File

@@ -1,7 +1,7 @@
_target_: matcha.data.text_mel_datamodule.TextMelDataModule _target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: ljspeech name: ljspeech
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt train_filelist_path: data/LJSpeech-1.1/train.txt
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt valid_filelist_path: data/LJSpeech-1.1/val.txt
batch_size: 32 batch_size: 32
num_workers: 20 num_workers: 20
pin_memory: True pin_memory: True
@@ -19,3 +19,4 @@ data_statistics: # Computed for ljspeech dataset
mel_mean: -5.536622 mel_mean: -5.536622
mel_std: 2.116101 mel_std: 2.116101
seed: ${seed} seed: ${seed}
load_durations: false

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: hi-fi_en-US_female.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"]
run_name: hi-fi_en-US_female_piper_phonemizer

View File

@@ -0,0 +1,19 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ljspeech.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech
data:
load_durations: True
batch_size: 64

View File

@@ -12,3 +12,5 @@ spk_emb_dim: 64
n_feats: 80 n_feats: 80
data_statistics: ${data.data_statistics} data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4 out_size: null # Must be divisible by 4
prior_loss: true
use_precomputed_durations: ${data.load_durations}

1
data
View File

@@ -1 +0,0 @@
/home/smehta/Projects/Speech-Backbones/Grad-TTS/data

View File

@@ -1 +1 @@
0.0.2 0.0.7.2

View File

@@ -29,8 +29,15 @@ args = Namespace(
CURRENTLY_LOADED_MODEL = args.model CURRENTLY_LOADED_MODEL = args.model
MATCHA_TTS_LOC = lambda x: LOCATION / f"{x}.ckpt" # noqa: E731
VOCODER_LOC = lambda x: LOCATION / f"{x}" # noqa: E731 def MATCHA_TTS_LOC(x):
return LOCATION / f"{x}.ckpt"
def VOCODER_LOC(x):
return LOCATION / f"{x}"
LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png" LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
RADIO_OPTIONS = { RADIO_OPTIONS = {
"Multi Speaker (VCTK)": { "Multi Speaker (VCTK)": {

View File

@@ -18,13 +18,13 @@ from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
MATCHA_URLS = { MATCHA_URLS = {
"matcha_ljspeech": "https://drive.google.com/file/d/1BBzmMU7k3a_WetDfaFblMoN18GqQeHCg/view?usp=drive_link", "matcha_ljspeech": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt",
"matcha_vctk": "https://drive.google.com/file/d/1enuxmfslZciWGAl63WGh2ekVo00FYuQ9/view?usp=drive_link", "matcha_vctk": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_vctk.ckpt",
} }
VOCODER_URLS = { VOCODER_URLS = {
"hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link", "hifigan_T2_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1", # Old url: https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link
"hifigan_univ_v1": "https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link", "hifigan_univ_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/g_02500000", # Old url: https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link
} }
MULTISPEAKER_MODEL = { MULTISPEAKER_MODEL = {
@@ -48,7 +48,7 @@ def plot_spectrogram_to_numpy(spectrogram, filename):
def process_text(i: int, text: str, device: torch.device): def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}") print(f"[{i}] - Input text: {text}")
x = torch.tensor( x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"]), 0), intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
dtype=torch.long, dtype=torch.long,
device=device, device=device,
)[None] )[None]
@@ -63,7 +63,7 @@ def get_texts(args):
if args.text: if args.text:
texts = [args.text] texts = [args.text]
else: else:
with open(args.file) as f: with open(args.file, encoding="utf-8") as f:
texts = f.readlines() texts = f.readlines()
return texts return texts
@@ -114,10 +114,10 @@ def load_matcha(model_name, checkpoint_path, device):
return model return model
def to_waveform(mel, vocoder, denoiser=None): def to_waveform(mel, vocoder, denoiser=None, denoiser_strength=0.00025):
audio = vocoder(mel).clamp(-1, 1) audio = vocoder(mel).clamp(-1, 1)
if denoiser is not None: if denoiser is not None:
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze() audio = denoiser(audio.squeeze(), strength=denoiser_strength).cpu().squeeze()
return audio.cpu().squeeze() return audio.cpu().squeeze()
@@ -140,7 +140,7 @@ def validate_args(args):
if args.checkpoint_path is None: if args.checkpoint_path is None:
# When using pretrained models # When using pretrained models
if args.model in SINGLESPEAKER_MODEL.keys(): if args.model in SINGLESPEAKER_MODEL:
args = validate_args_for_single_speaker_model(args) args = validate_args_for_single_speaker_model(args)
if args.model in MULTISPEAKER_MODEL: if args.model in MULTISPEAKER_MODEL:
@@ -326,16 +326,17 @@ def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
for i, batch in enumerate(dataloader): for i, batch in enumerate(dataloader):
i = i + 1 i = i + 1
start_t = dt.datetime.now() start_t = dt.datetime.now()
b = batch["x"].shape[0]
output = model.synthesise( output = model.synthesise(
batch["x"].to(device), batch["x"].to(device),
batch["x_lengths"].to(device), batch["x_lengths"].to(device),
n_timesteps=args.steps, n_timesteps=args.steps,
temperature=args.temperature, temperature=args.temperature,
spks=spk, spks=spk.expand(b) if spk is not None else spk,
length_scale=args.speaking_rate, length_scale=args.speaking_rate,
) )
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength)
t = (dt.datetime.now() - start_t).total_seconds() t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1]) rtf_w = t * 22050 / (output["waveform"].shape[-1])
print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}") print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}")
@@ -376,7 +377,7 @@ def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
spks=spk, spks=spk,
length_scale=args.speaking_rate, length_scale=args.speaking_rate,
) )
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength)
# RTF with HiFiGAN # RTF with HiFiGAN
t = (dt.datetime.now() - start_t).total_seconds() t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1]) rtf_w = t * 22050 / (output["waveform"].shape[-1])

View File

@@ -1,6 +1,8 @@
import random import random
from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import numpy as np
import torch import torch
import torchaudio as ta import torchaudio as ta
from lightning import LightningDataModule from lightning import LightningDataModule
@@ -39,6 +41,7 @@ class TextMelDataModule(LightningDataModule):
f_max, f_max,
data_statistics, data_statistics,
seed, seed,
load_durations,
): ):
super().__init__() super().__init__()
@@ -68,6 +71,7 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max, self.hparams.f_max,
self.hparams.data_statistics, self.hparams.data_statistics,
self.hparams.seed, self.hparams.seed,
self.hparams.load_durations,
) )
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.valid_filelist_path, self.hparams.valid_filelist_path,
@@ -83,6 +87,7 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max, self.hparams.f_max,
self.hparams.data_statistics, self.hparams.data_statistics,
self.hparams.seed, self.hparams.seed,
self.hparams.load_durations,
) )
def train_dataloader(self): def train_dataloader(self):
@@ -109,7 +114,7 @@ class TextMelDataModule(LightningDataModule):
"""Clean up after fit or test.""" """Clean up after fit or test."""
pass # pylint: disable=unnecessary-pass pass # pylint: disable=unnecessary-pass
def state_dict(self): # pylint: disable=no-self-use def state_dict(self):
"""Extra things to save to checkpoint.""" """Extra things to save to checkpoint."""
return {} return {}
@@ -134,6 +139,7 @@ class TextMelDataset(torch.utils.data.Dataset):
f_max=8000, f_max=8000,
data_parameters=None, data_parameters=None,
seed=None, seed=None,
load_durations=False,
): ):
self.filepaths_and_text = parse_filelist(filelist_path) self.filepaths_and_text = parse_filelist(filelist_path)
self.n_spks = n_spks self.n_spks = n_spks
@@ -146,6 +152,8 @@ class TextMelDataset(torch.utils.data.Dataset):
self.win_length = win_length self.win_length = win_length
self.f_min = f_min self.f_min = f_min
self.f_max = f_max self.f_max = f_max
self.load_durations = load_durations
if data_parameters is not None: if data_parameters is not None:
self.data_parameters = data_parameters self.data_parameters = data_parameters
else: else:
@@ -164,10 +172,29 @@ class TextMelDataset(torch.utils.data.Dataset):
filepath, text = filepath_and_text[0], filepath_and_text[1] filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None spk = None
text = self.get_text(text, add_blank=self.add_blank) text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath) mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk} durations = self.get_durations(filepath, text) if self.load_durations else None
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
def get_durations(self, filepath, text):
filepath = Path(filepath)
data_dir, name = filepath.parent.parent, filepath.stem
try:
dur_loc = data_dir / "durations" / f"{name}.npy"
durs = torch.from_numpy(np.load(dur_loc).astype(int))
except FileNotFoundError as e:
raise FileNotFoundError(
f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
) from e
assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
return durs
def get_mel(self, filepath): def get_mel(self, filepath):
audio, sr = ta.load(filepath) audio, sr = ta.load(filepath)
@@ -187,11 +214,11 @@ class TextMelDataset(torch.utils.data.Dataset):
return mel return mel
def get_text(self, text, add_blank=True): def get_text(self, text, add_blank=True):
text_norm = text_to_sequence(text, self.cleaners) text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
if self.add_blank: if self.add_blank:
text_norm = intersperse(text_norm, 0) text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm) text_norm = torch.IntTensor(text_norm)
return text_norm return text_norm, cleaned_text
def __getitem__(self, index): def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index]) datapoint = self.get_datapoint(self.filepaths_and_text[index])
@@ -207,15 +234,18 @@ class TextMelBatchCollate:
def __call__(self, batch): def __call__(self, batch):
B = len(batch) B = len(batch)
y_max_length = max([item["y"].shape[-1] for item in batch]) y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
y_max_length = fix_len_compatibility(y_max_length) y_max_length = fix_len_compatibility(y_max_length)
x_max_length = max([item["x"].shape[-1] for item in batch]) x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
n_feats = batch[0]["y"].shape[-2] n_feats = batch[0]["y"].shape[-2]
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long) x = torch.zeros((B, x_max_length), dtype=torch.long)
durations = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], [] y_lengths, x_lengths = [], []
spks = [] spks = []
filepaths, x_texts = [], []
for i, item in enumerate(batch): for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"] y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1]) y_lengths.append(y_.shape[-1])
@@ -223,9 +253,22 @@ class TextMelBatchCollate:
y[i, :, : y_.shape[-1]] = y_ y[i, :, : y_.shape[-1]] = y_
x[i, : x_.shape[-1]] = x_ x[i, : x_.shape[-1]] = x_
spks.append(item["spk"]) spks.append(item["spk"])
filepaths.append(item["filepath"])
x_texts.append(item["x_text"])
if item["durations"] is not None:
durations[i, : item["durations"].shape[-1]] = item["durations"]
y_lengths = torch.tensor(y_lengths, dtype=torch.long) y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks} return {
"x": x,
"x_lengths": x_lengths,
"y": y,
"y_lengths": y_lengths,
"spks": spks,
"filepaths": filepaths,
"x_texts": x_texts,
"durations": durations if not torch.eq(durations, 0).all() else None,
}

View File

@@ -4,6 +4,10 @@
import torch import torch
class ModeException(Exception):
pass
class Denoiser(torch.nn.Module): class Denoiser(torch.nn.Module):
"""Removes model bias from audio produced with waveglow""" """Removes model bias from audio produced with waveglow"""
@@ -20,7 +24,7 @@ class Denoiser(torch.nn.Module):
elif mode == "normal": elif mode == "normal":
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
else: else:
raise Exception(f"Mode {mode} if not supported") raise ModeException(f"Mode {mode} if not supported")
def stft_fn(audio, n_fft, hop_length, win_length, window): def stft_fn(audio, n_fft, hop_length, win_length, window):
spec = torch.stft( spec = torch.stft(

View File

@@ -55,7 +55,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
if fmax not in mel_basis: if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

@@ -1,7 +1,7 @@
""" from https://github.com/jik876/hifi-gan """ """ from https://github.com/jik876/hifi-gan """
import torch import torch
import torch.nn as nn import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm

View File

@@ -58,13 +58,14 @@ class BaseLightningClass(LightningModule, ABC):
y, y_lengths = batch["y"], batch["y_lengths"] y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"] spks = batch["spks"]
dur_loss, prior_loss, diff_loss = self( dur_loss, prior_loss, diff_loss, *_ = self(
x=x, x=x,
x_lengths=x_lengths, x_lengths=x_lengths,
y=y, y=y,
y_lengths=y_lengths, y_lengths=y_lengths,
spks=spks, spks=spks,
out_size=self.out_size, out_size=self.out_size,
durations=batch["durations"],
) )
return { return {
"dur_loss": dur_loss, "dur_loss": dur_loss,
@@ -81,7 +82,7 @@ class BaseLightningClass(LightningModule, ABC):
"step", "step",
float(self.global_step), float(self.global_step),
on_step=True, on_step=True,
on_epoch=True, prog_bar=True,
logger=True, logger=True,
sync_dist=True, sync_dist=True,
) )

View File

@@ -2,7 +2,7 @@ import math
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F import torch.nn.functional as F
from conformer import ConformerBlock from conformer import ConformerBlock
from diffusers.models.activations import get_activation from diffusers.models.activations import get_activation

View File

@@ -73,16 +73,14 @@ class BASECFM(torch.nn.Module, ABC):
# Or in future might add like a return_all_steps flag # Or in future might add like a return_all_steps flag
sol = [] sol = []
steps = 1 for step in range(1, len(t_span)):
while steps <= len(t_span) - 1:
dphi_dt = self.estimator(x, mask, mu, t, spks, cond) dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
x = x + dt * dphi_dt x = x + dt * dphi_dt
t = t + dt t = t + dt
sol.append(x) sol.append(x)
if steps < len(t_span) - 1: if step < len(t_span) - 1:
dt = t_span[steps + 1] - t dt = t_span[step + 1] - t
steps += 1
return sol[-1] return sol[-1]

View File

@@ -3,10 +3,10 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn # pylint: disable=consider-using-from-import
from einops import rearrange from einops import rearrange
import matcha.utils as utils import matcha.utils as utils # pylint: disable=consider-using-from-import
from matcha.utils.model import sequence_mask from matcha.utils.model import sequence_mask
log = utils.get_pylogger(__name__) log = utils.get_pylogger(__name__)

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn # pylint: disable=consider-using-from-import
from diffusers.models.attention import ( from diffusers.models.attention import (
GEGLU, GEGLU,
GELU, GELU,

View File

@@ -4,7 +4,7 @@ import random
import torch import torch
import matcha.utils.monotonic_align as monotonic_align import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
from matcha import utils from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.flow_matching import CFM from matcha.models.components.flow_matching import CFM
@@ -34,6 +34,8 @@ class MatchaTTS(BaseLightningClass): # 🍵
out_size, out_size,
optimizer=None, optimizer=None,
scheduler=None, scheduler=None,
prior_loss=True,
use_precomputed_durations=False,
): ):
super().__init__() super().__init__()
@@ -44,6 +46,8 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.spk_emb_dim = spk_emb_dim self.spk_emb_dim = spk_emb_dim
self.n_feats = n_feats self.n_feats = n_feats
self.out_size = out_size self.out_size = out_size
self.prior_loss = prior_loss
self.use_precomputed_durations = use_precomputed_durations
if n_spks > 1: if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
@@ -102,6 +106,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Lengths of mel spectrograms # Lengths of mel spectrograms
"rtf": float, "rtf": float,
# Real-time factor # Real-time factor
}
""" """
# For RTF computation # For RTF computation
t = dt.datetime.now() t = dt.datetime.now()
@@ -116,7 +121,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
w = torch.exp(logw) * x_mask w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = int(y_lengths.max()) y_max_length = y_lengths.max()
y_max_length_ = fix_len_compatibility(y_max_length) y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn` # Using obtained durations `w` construct alignment map `attn`
@@ -145,10 +150,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
"rtf": rtf, "rtf": rtf,
} }
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None): def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
""" """
Computes 3 losses: Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). 1. duration loss: loss between predicted token durations and those extracted by Monotonic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs. 2. prior loss: loss between mel-spectrogram and encoder outputs.
3. flow matching loss: loss between mel-spectrogram and decoder outputs. 3. flow matching loss: loss between mel-spectrogram and decoder outputs.
@@ -177,17 +182,20 @@ class MatchaTTS(BaseLightningClass): # 🍵
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram if self.use_precomputed_durations:
with torch.no_grad(): attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
const = -0.5 * math.log(2 * math.pi) * self.n_feats else:
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
y_square = torch.matmul(factor.transpose(1, 2), y**2) with torch.no_grad():
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) const = -0.5 * math.log(2 * math.pi) * self.n_feats
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
log_prior = y_square - y_mu_double + mu_square + const y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach() attn = attn.detach() # b, t_text, T_mel
# Compute loss between predicted log-scaled durations and those obtained from MAS # Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper # refered to as prior loss in the paper
@@ -228,7 +236,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Compute loss of the decoder # Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) if self.prior_loss:
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss return dur_loss, prior_loss, diff_loss, attn

0
matcha/onnx/__init__.py Normal file
View File

181
matcha/onnx/export.py Normal file
View File

@@ -0,0 +1,181 @@
import argparse
import random
from pathlib import Path
import numpy as np
import torch
from lightning import LightningModule
from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder
DEFAULT_OPSET = 15
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class MatchaWithVocoder(LightningModule):
def __init__(self, matcha, vocoder):
super().__init__()
self.matcha = matcha
self.vocoder = vocoder
def forward(self, x, x_lengths, scales, spks=None):
mel, mel_lengths = self.matcha(x, x_lengths, scales, spks)
wavs = self.vocoder(mel).clamp(-1, 1)
lengths = mel_lengths * 256
return wavs.squeeze(1), lengths
def get_exportable_module(matcha, vocoder, n_timesteps):
"""
Return an appropriate `LighteningModule` and output-node names
based on whether the vocoder is embedded in the final graph
"""
def onnx_forward_func(x, x_lengths, scales, spks=None):
"""
Custom forward function for accepting
scaler parameters as tensors
"""
# Extract scaler parameters from tensors
temperature = scales[0]
length_scale = scales[1]
output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale)
return output["mel"], output["mel_lengths"]
# Monkey-patch Matcha's forward function
matcha.forward = onnx_forward_func
if vocoder is None:
model, output_names = matcha, ["mel", "mel_lengths"]
else:
model = MatchaWithVocoder(matcha, vocoder)
output_names = ["wav", "wav_lengths"]
return model, output_names
def get_inputs(is_multi_speaker):
"""
Create dummy inputs for tracing
"""
dummy_input_length = 50
x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long)
x_lengths = torch.LongTensor([dummy_input_length])
# Scales
temperature = 0.667
length_scale = 1.0
scales = torch.Tensor([temperature, length_scale])
model_inputs = [x, x_lengths, scales]
input_names = [
"x",
"x_lengths",
"scales",
]
if is_multi_speaker:
spks = torch.LongTensor([1])
model_inputs.append(spks)
input_names.append("spks")
return tuple(model_inputs), input_names
def main():
parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX")
parser.add_argument(
"checkpoint_path",
type=str,
help="Path to the model checkpoint",
)
parser.add_argument("output", type=str, help="Path to output `.onnx` file")
parser.add_argument(
"--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)"
)
parser.add_argument(
"--vocoder-name",
type=str,
choices=list(VOCODER_URLS.keys()),
default=None,
help="Name of the vocoder to embed in the ONNX graph",
)
parser.add_argument(
"--vocoder-checkpoint-path",
type=str,
default=None,
help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience",
)
parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15")
args = parser.parse_args()
print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}")
print(f"Setting n_timesteps to {args.n_timesteps}")
checkpoint_path = Path(args.checkpoint_path)
matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu")
if args.vocoder_name or args.vocoder_checkpoint_path:
assert (
args.vocoder_name and args.vocoder_checkpoint_path
), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph."
vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu")
else:
vocoder = None
is_multi_speaker = matcha.n_spks > 1
dummy_input, input_names = get_inputs(is_multi_speaker)
model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps)
# Set dynamic shape for inputs/outputs
dynamic_axes = {
"x": {0: "batch_size", 1: "time"},
"x_lengths": {0: "batch_size"},
}
if vocoder is None:
dynamic_axes.update(
{
"mel": {0: "batch_size", 2: "time"},
"mel_lengths": {0: "batch_size"},
}
)
else:
print("Embedding the vocoder in the ONNX graph")
dynamic_axes.update(
{
"wav": {0: "batch_size", 1: "time"},
"wav_lengths": {0: "batch_size"},
}
)
if is_multi_speaker:
dynamic_axes["spks"] = {0: "batch_size"}
# Create the output directory (if not exists)
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
model.to_onnx(
args.output,
dummy_input,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=args.opset,
export_params=True,
do_constant_folding=True,
)
print(f"[🍵] ONNX model exported to {args.output}")
if __name__ == "__main__":
main()

168
matcha/onnx/infer.py Normal file
View File

@@ -0,0 +1,168 @@
import argparse
import os
import warnings
from pathlib import Path
from time import perf_counter
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
from matcha.cli import plot_spectrogram_to_numpy, process_text
def validate_args(args):
assert (
args.text or args.file
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
assert args.temperature >= 0, "Sampling temperature cannot be negative"
assert args.speaking_rate >= 0, "Speaking rate must be greater than 0"
return args
def write_wavs(model, inputs, output_dir, external_vocoder=None):
if external_vocoder is None:
print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
t0 = perf_counter()
wavs, wav_lengths = model.run(None, inputs)
infer_secs = perf_counter() - t0
mel_infer_secs = vocoder_infer_secs = None
else:
print("[🍵] Generating mel using Matcha")
mel_t0 = perf_counter()
mels, mel_lengths = model.run(None, inputs)
mel_infer_secs = perf_counter() - mel_t0
print("Generating waveform from mel using external vocoder")
vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
vocoder_t0 = perf_counter()
wavs = external_vocoder.run(None, vocoder_inputs)[0]
vocoder_infer_secs = perf_counter() - vocoder_t0
wavs = wavs.squeeze(1)
wav_lengths = mel_lengths * 256
infer_secs = mel_infer_secs + vocoder_infer_secs
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)):
output_filename = output_dir.joinpath(f"output_{i + 1}.wav")
audio = wav[:wav_length]
print(f"Writing audio to {output_filename}")
sf.write(output_filename, audio, 22050, "PCM_24")
wav_secs = wav_lengths.sum() / 22050
print(f"Inference seconds: {infer_secs}")
print(f"Generated wav seconds: {wav_secs}")
rtf = infer_secs / wav_secs
if mel_infer_secs is not None:
mel_rtf = mel_infer_secs / wav_secs
print(f"Matcha RTF: {mel_rtf}")
if vocoder_infer_secs is not None:
vocoder_rtf = vocoder_infer_secs / wav_secs
print(f"Vocoder RTF: {vocoder_rtf}")
print(f"Overall RTF: {rtf}")
def write_mels(model, inputs, output_dir):
t0 = perf_counter()
mels, mel_lengths = model.run(None, inputs)
infer_secs = perf_counter() - t0
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for i, mel in enumerate(mels):
output_stem = output_dir.joinpath(f"output_{i + 1}")
plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png"))
np.save(output_stem.with_suffix(".numpy"), mel)
wav_secs = (mel_lengths * 256).sum() / 22050
print(f"Inference seconds: {infer_secs}")
print(f"Generated wav seconds: {wav_secs}")
rtf = infer_secs / wav_secs
print(f"RTF: {rtf}")
def main():
parser = argparse.ArgumentParser(
description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
)
parser.add_argument(
"model",
type=str,
help="ONNX model to use",
)
parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)")
parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
parser.add_argument(
"--temperature",
type=float,
default=0.667,
help="Variance of the x0 noise (default: 0.667)",
)
parser.add_argument(
"--speaking-rate",
type=float,
default=1.0,
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
)
parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
parser.add_argument(
"--output-dir",
type=str,
default=os.getcwd(),
help="Output folder to save results (default: current dir)",
)
args = parser.parse_args()
args = validate_args(args)
if args.gpu:
providers = ["GPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
model = ort.InferenceSession(args.model, providers=providers)
model_inputs = model.get_inputs()
model_outputs = list(model.get_outputs())
if args.text:
text_lines = args.text.splitlines()
else:
with open(args.file, encoding="utf-8") as file:
text_lines = file.read().splitlines()
processed_lines = [process_text(0, line, "cpu") for line in text_lines]
x = [line["x"].squeeze() for line in processed_lines]
# Pad
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
x = x.detach().cpu().numpy()
x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64)
inputs = {
"x": x,
"x_lengths": x_lengths,
"scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32),
}
is_multi_speaker = len(model_inputs) == 4
if is_multi_speaker:
if args.spk is None:
args.spk = 0
warn = "[!] Speaker ID not provided! Using speaker ID 0"
warnings.warn(warn, UserWarning)
inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64)
has_vocoder_embedded = model_outputs[0].name == "wav"
if has_vocoder_embedded:
write_wavs(model, inputs, args.output_dir)
elif args.vocoder:
external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
else:
warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
warnings.warn(warn, UserWarning)
write_mels(model, inputs, args.output_dir)
if __name__ == "__main__":
main()

View File

@@ -7,6 +7,10 @@ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
class UnknownCleanerException(Exception):
pass
def text_to_sequence(text, cleaner_names): def text_to_sequence(text, cleaner_names):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text. """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args: Args:
@@ -21,7 +25,7 @@ def text_to_sequence(text, cleaner_names):
for symbol in clean_text: for symbol in clean_text:
symbol_id = _symbol_to_id[symbol] symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id] sequence += [symbol_id]
return sequence return sequence, clean_text
def cleaned_text_to_sequence(cleaned_text): def cleaned_text_to_sequence(cleaned_text):
@@ -48,6 +52,6 @@ def _clean_text(text, cleaner_names):
for name in cleaner_names: for name in cleaner_names:
cleaner = getattr(cleaners, name) cleaner = getattr(cleaners, name)
if not cleaner: if not cleaner:
raise Exception("Unknown cleaner: %s" % name) raise UnknownCleanerException(f"Unknown cleaner: {name}")
text = cleaner(text) text = cleaner(text)
return text return text

View File

@@ -36,9 +36,12 @@ global_phonemizer = phonemizer.backend.EspeakBackend(
# Regular expression matching whitespace: # Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+") _whitespace_re = re.compile(r"\s+")
# Remove brackets
_brackets_re = re.compile(r"[\[\]\(\)\{\}]")
# List of (regular expression, replacement) pairs for abbreviations: # List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [ _abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [ for x in [
("mrs", "misess"), ("mrs", "misess"),
("mr", "mister"), ("mr", "mister"),
@@ -72,6 +75,10 @@ def lowercase(text):
return text.lower() return text.lower()
def remove_brackets(text):
return re.sub(_brackets_re, "", text)
def collapse_whitespace(text): def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text) return re.sub(_whitespace_re, " ", text)
@@ -101,5 +108,37 @@ def english_cleaners2(text):
text = lowercase(text) text = lowercase(text)
text = expand_abbreviations(text) text = expand_abbreviations(text)
phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
# Added in some cases espeak is not removing brackets
phonemes = remove_brackets(phonemes)
phonemes = collapse_whitespace(phonemes) phonemes = collapse_whitespace(phonemes)
return phonemes return phonemes
def ipa_simplifier(text):
replacements = [
("ɐ", "ə"),
("ˈə", "ə"),
("ʤ", ""),
("ʧ", ""),
("", "ɪ"),
]
for replacement in replacements:
text = text.replace(replacement[0], replacement[1])
phonemes = collapse_whitespace(text)
return phonemes
# I am removing this due to incompatibility with several version of python
# However, if you want to use it, you can uncomment it
# and install piper-phonemize with the following command:
# pip install piper-phonemize
# import piper_phonemize
# def english_cleaners_piper(text):
# """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
# text = convert_to_ascii(text)
# text = lowercase(text)
# text = expand_abbreviations(text)
# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
# phonemes = collapse_whitespace(phonemes)
# return phonemes

View File

@@ -48,7 +48,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
if f"{str(fmax)}_{str(y.device)}" not in mel_basis: if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python
import argparse
import os
import sys
import tempfile
from pathlib import Path
import torchaudio
from torch.hub import download_url_to_file
from tqdm import tqdm
from matcha.utils.data.utils import _extract_zip
URLS = {
"en-US": {
"female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_F.zip",
"male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_M.zip",
},
"ja-JP": {
"female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_F.zip",
"male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_M.zip",
},
}
INFO_PAGE = "https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/"
# On their website they say "We NICT open-sourced Hi-Fi-CAPTAIN",
# but they use this very-much-not-open-source licence.
# Dunno if this is open washing or stupidity.
LICENCE = "CC BY-NC-SA 4.0"
# I'd normally put the citation here. It's on their website.
# Boo to non-open-source stuff.
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files")
parser.add_argument(
"-r",
"--skip-resampling",
action="store_true",
default=False,
help="Skip resampling the data (from 48 to 22.05)",
)
parser.add_argument(
"-l", "--language", type=str, choices=["en-US", "ja-JP"], default="en-US", help="The language to download"
)
parser.add_argument(
"-g",
"--gender",
type=str,
choices=["male", "female"],
default="female",
help="The gender of the speaker to download",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
default="data",
help="Place to store the converted data. Top-level only, the subdirectory will be created",
)
return parser.parse_args()
def process_text(infile, outpath: Path):
outmode = "w"
if infile.endswith("dev.txt"):
outfile = outpath / "valid.txt"
elif infile.endswith("eval.txt"):
outfile = outpath / "test.txt"
else:
outfile = outpath / "train.txt"
if outfile.exists():
outmode = "a"
with (
open(infile, encoding="utf-8") as inf,
open(outfile, outmode, encoding="utf-8") as of,
):
for line in inf.readlines():
line = line.strip()
fileid, rest = line.split(" ", maxsplit=1)
outfile = str(outpath / f"{fileid}.wav")
of.write(f"{outfile}|{rest}\n")
def process_files(zipfile, outpath, resample=True):
with tempfile.TemporaryDirectory() as tmpdirname:
for filename in tqdm(_extract_zip(zipfile, tmpdirname)):
if not filename.startswith(tmpdirname):
filename = os.path.join(tmpdirname, filename)
if filename.endswith(".txt"):
process_text(filename, outpath)
elif filename.endswith(".wav"):
filepart = filename.rsplit("/", maxsplit=1)[-1]
outfile = str(outpath / filepart)
arr, sr = torchaudio.load(filename)
if resample:
arr = torchaudio.functional.resample(arr, orig_freq=sr, new_freq=22050)
torchaudio.save(outfile, arr, 22050)
else:
continue
def main():
args = get_args()
save_dir = None
if args.save_dir:
save_dir = Path(args.save_dir)
if not save_dir.is_dir():
save_dir.mkdir()
if not args.output_dir:
print("output directory not specified, exiting")
sys.exit(1)
URL = URLS[args.language][args.gender]
dirname = f"hi-fi_{args.language}_{args.gender}"
outbasepath = Path(args.output_dir)
if not outbasepath.is_dir():
outbasepath.mkdir()
outpath = outbasepath / dirname
if not outpath.is_dir():
outpath.mkdir()
resample = True
if args.skip_resampling:
resample = False
if save_dir:
zipname = URL.rsplit("/", maxsplit=1)[-1]
zipfile = save_dir / zipname
if not zipfile.exists():
download_url_to_file(URL, zipfile, progress=True)
process_files(zipfile, outpath, resample)
else:
with tempfile.NamedTemporaryFile(suffix=".zip", delete=True) as zf:
download_url_to_file(URL, zf.name, progress=True)
process_files(zf.name, outpath, resample)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,97 @@
#!/usr/bin/env python
import argparse
import random
import tempfile
from pathlib import Path
from torch.hub import download_url_to_file
from matcha.utils.data.utils import _extract_tar
URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
INFO_PAGE = "https://keithito.com/LJ-Speech-Dataset/"
LICENCE = "Public domain (LibriVox copyright disclaimer)"
CITATION = """
@misc{ljspeech17,
author = {Keith Ito and Linda Johnson},
title = {The LJ Speech Dataset},
howpublished = {\\url{https://keithito.com/LJ-Speech-Dataset/}},
year = 2017
}
"""
def decision():
return random.random() < 0.98
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files")
parser.add_argument(
"output_dir",
type=str,
nargs="?",
default="data",
help="Place to store the converted data (subdirectory LJSpeech-1.1 will be created)",
)
return parser.parse_args()
def process_csv(ljpath: Path):
if (ljpath / "metadata.csv").exists():
basepath = ljpath
elif (ljpath / "LJSpeech-1.1" / "metadata.csv").exists():
basepath = ljpath / "LJSpeech-1.1"
csvpath = basepath / "metadata.csv"
wavpath = basepath / "wavs"
with (
open(csvpath, encoding="utf-8") as csvf,
open(basepath / "train.txt", "w", encoding="utf-8") as tf,
open(basepath / "val.txt", "w", encoding="utf-8") as vf,
):
for line in csvf.readlines():
line = line.strip()
parts = line.split("|")
wavfile = str(wavpath / f"{parts[0]}.wav")
if decision():
tf.write(f"{wavfile}|{parts[1]}\n")
else:
vf.write(f"{wavfile}|{parts[1]}\n")
def main():
args = get_args()
save_dir = None
if args.save_dir:
save_dir = Path(args.save_dir)
if not save_dir.is_dir():
save_dir.mkdir()
outpath = Path(args.output_dir)
if not outpath.is_dir():
outpath.mkdir()
if save_dir:
tarname = URL.rsplit("/", maxsplit=1)[-1]
tarfile = save_dir / tarname
if not tarfile.exists():
download_url_to_file(URL, str(tarfile), progress=True)
_extract_tar(tarfile, outpath)
process_csv(outpath)
else:
with tempfile.NamedTemporaryFile(suffix=".tar.bz2", delete=True) as zf:
download_url_to_file(URL, zf.name, progress=True)
_extract_tar(zf.name, outpath)
process_csv(outpath)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
# taken from https://github.com/pytorch/audio/blob/main/src/torchaudio/datasets/utils.py
# Copyright (c) 2017 Facebook Inc. (Soumith Chintala)
# Licence: BSD 2-Clause
# pylint: disable=C0123
import logging
import os
import tarfile
import zipfile
from pathlib import Path
from typing import Any, List, Optional, Union
_LG = logging.getLogger(__name__)
def _extract_tar(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if type(from_path) is Path:
from_path = str(Path)
if to_path is None:
to_path = os.path.dirname(from_path)
with tarfile.open(from_path, "r") as tar:
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
return files
def _extract_zip(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if type(from_path) is Path:
from_path = str(Path)
if to_path is None:
to_path = os.path.dirname(from_path)
with zipfile.ZipFile(from_path, "r") as zfile:
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
return files

View File

@@ -94,6 +94,7 @@ def main():
cfg["batch_size"] = args.batch_size cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
text_mel_datamodule = TextMelDataModule(**cfg) text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup() text_mel_datamodule.setup()
@@ -101,10 +102,8 @@ def main():
log.info("Dataloader loaded! Now computing stats...") log.info("Dataloader loaded! Now computing stats...")
params = compute_data_statistics(data_loader, cfg["n_feats"]) params = compute_data_statistics(data_loader, cfg["n_feats"])
print(params) print(params)
json.dump( with open(output_file, "w", encoding="utf-8") as dumpfile:
params, json.dump(params, dumpfile)
open(output_file, "w"),
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1,195 @@
r"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
import sys
from pathlib import Path
import lightning
import numpy as np
import rootutils
import torch
from hydra import compose, initialize
from omegaconf import open_dict
from torch import nn
from tqdm.auto import tqdm
from matcha.cli import get_device
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS
from matcha.utils.logging_utils import pylogger
from matcha.utils.utils import get_phoneme_durations
log = pylogger.get_pylogger(__name__)
def save_durations_to_folder(
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
):
durations = attn.squeeze().sum(1)[:x_length].numpy()
durations_json = get_phoneme_durations(durations, text)
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
json.dump(durations_json, f, indent=4, ensure_ascii=False)
np.save(output, durations)
@torch.inference_mode()
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
"""Generate durations from the model for each datapoint and save it in a folder
Args:
data_loader (torch.utils.data.DataLoader): Dataloader
model (nn.Module): MatchaTTS model
device (torch.device): GPU or CPU
"""
for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
x = x.to(device)
y = y.to(device)
x_lengths = x_lengths.to(device)
y_lengths = y_lengths.to(device)
spks = spks.to(device) if spks is not None else None
_, _, _, attn = model(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
)
attn = attn.cpu()
for i in range(attn.shape[0]):
save_durations_to_folder(
attn[i],
x_lengths[i].item(),
y_lengths[i].item(),
batch["filepaths"][i],
output_folder,
batch["x_texts"][i],
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-config",
type=str,
default="ljspeech.yaml",
help="The name of the yaml config file under configs/data",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default="32",
help="Can have increased batch size for faster computation",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
required=False,
help="force overwrite the file",
)
parser.add_argument(
"-c",
"--checkpoint_path",
type=str,
required=True,
help="Path to the checkpoint file to load the model from",
)
parser.add_argument(
"-o",
"--output-folder",
type=str,
default=None,
help="Output folder to save the data statistics",
)
parser.add_argument(
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
)
args = parser.parse_args()
with initialize(version_base="1.3", config_path="../../configs/data"):
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
with open_dict(cfg):
del cfg["hydra"]
del cfg["_target_"]
cfg["seed"] = 1234
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
if args.output_folder is not None:
output_folder = Path(args.output_folder)
else:
output_folder = Path(cfg["train_filelist_path"]).parent / "durations"
print(f"Output folder set to: {output_folder}")
if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite")
sys.exit(1)
output_folder.mkdir(parents=True, exist_ok=True)
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
print("Loading model...")
device = get_device(args)
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup()
try:
print("Computing stats for training set if exists...")
train_dataloader = text_mel_datamodule.train_dataloader()
compute_durations(train_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No training set found")
try:
print("Computing stats for validation set if exists...")
val_dataloader = text_mel_datamodule.val_dataloader()
compute_durations(val_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No validation set found")
try:
print("Computing stats for test set if exists...")
test_dataloader = text_mel_datamodule.test_dataloader()
compute_durations(test_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No test set found")
print(f"[+] Done! Data statistics saved to: {output_folder}")
if __name__ == "__main__":
# Helps with generating durations for the dataset to train other architectures
# that cannot learn to align due to limited size of dataset
# Example usage:
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
# This will create a folder in data/processed_data/durations/ljspeech with the durations
main()

View File

@@ -7,15 +7,17 @@ import torch
def sequence_mask(length, max_length=None): def sequence_mask(length, max_length=None):
if max_length is None: if max_length is None:
max_length = length.max() max_length = length.max()
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1) return x.unsqueeze(0) < length.unsqueeze(1)
def fix_len_compatibility(length, num_downsamplings_in_unet=2): def fix_len_compatibility(length, num_downsamplings_in_unet=2):
while True: factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
if length % (2**num_downsamplings_in_unet) == 0: length = (length / factor).ceil() * factor
return length if not torch.onnx.is_in_onnx_export():
length += 1 return length.int().item()
else:
return length
def convert_pad_shape(pad_shape): def convert_pad_shape(pad_shape):

View File

@@ -72,7 +72,7 @@ def print_config_tree(
# save config tree to file # save config tree to file
if save_to_file: if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file:
rich.print(tree, file=file) rich.print(tree, file=file)
@@ -97,5 +97,5 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
log.info(f"Tags: {cfg.tags}") log.info(f"Tags: {cfg.tags}")
if save_to_file: if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file:
rich.print(cfg.tags, file=file) rich.print(cfg.tags, file=file)

View File

@@ -2,6 +2,7 @@ import os
import sys import sys
import warnings import warnings
from importlib.util import find_spec from importlib.util import find_spec
from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable, Dict, Tuple
@@ -115,7 +116,7 @@ def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
return None return None
if metric_name not in metric_dict: if metric_name not in metric_dict:
raise Exception( raise ValueError(
f"Metric value not found! <metric_name={metric_name}>\n" f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n" "Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!" "Make sure `optimized_metric` name in `hparams_search` config is correct!"
@@ -205,13 +206,54 @@ def get_user_data_dir(appname="matcha_tts"):
return final_path return final_path
def assert_model_downloaded(checkpoint_path, url, use_wget=False): def assert_model_downloaded(checkpoint_path, url, use_wget=True):
if Path(checkpoint_path).exists(): if Path(checkpoint_path).exists():
log.debug(f"[+] Model already present at {checkpoint_path}!") log.debug(f"[+] Model already present at {checkpoint_path}!")
print(f"[+] Model already present at {checkpoint_path}!")
return return
log.info(f"[-] Model not found at {checkpoint_path}! Will download it") log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
print(f"[-] Model not found at {checkpoint_path}! Will download it")
checkpoint_path = str(checkpoint_path) checkpoint_path = str(checkpoint_path)
if not use_wget: if not use_wget:
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else: else:
wget.download(url=url, out=checkpoint_path) wget.download(url=url, out=checkpoint_path)
def get_phoneme_durations(durations, phones):
prev = durations[0]
merged_durations = []
# Convolve with stride 2
for i in range(1, len(durations), 2):
if i == len(durations) - 2:
# if it is last take full value
next_half = durations[i + 1]
else:
next_half = ceil(durations[i + 1] / 2)
curr = prev + durations[i] + next_half
prev = durations[i + 1] - next_half
merged_durations.append(curr)
assert len(phones) == len(merged_durations)
assert len(merged_durations) == (len(durations) - 1) // 2
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
start = torch.tensor(0)
duration_json = []
for i, duration in enumerate(merged_durations):
duration_json.append(
{
phones[i]: {
"starttime": start.item(),
"endtime": duration.item(),
"duration": duration.item() - start.item(),
}
}
)
start = duration
assert list(duration_json[-1].values())[0]["endtime"] == sum(
durations
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
return duration_json

View File

@@ -35,10 +35,10 @@ torchaudio
matplotlib matplotlib
pandas pandas
conformer==0.3.2 conformer==0.3.2
diffusers==0.21.2 diffusers # developed using version ==0.25.0
notebook notebook
ipywidgets ipywidgets
gradio gradio==3.43.2
gdown gdown
wget wget
seaborn seaborn

View File

@@ -16,9 +16,16 @@ with open("README.md", encoding="utf-8") as readme_file:
README = readme_file.read() README = readme_file.read()
cwd = os.path.dirname(os.path.abspath(__file__)) cwd = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(cwd, "matcha", "VERSION")) as fin: with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin:
version = fin.read().strip() version = fin.read().strip()
def get_requires():
requirements = os.path.join(os.path.dirname(__file__), "requirements.txt")
with open(requirements, encoding="utf-8") as reqfile:
return [str(r).strip() for r in reqfile]
setup( setup(
name="matcha-tts", name="matcha-tts",
version=version, version=version,
@@ -28,7 +35,7 @@ setup(
author="Shivam Mehta", author="Shivam Mehta",
author_email="shivam.mehta25@gmail.com", author_email="shivam.mehta25@gmail.com",
url="https://shivammehta25.github.io/Matcha-TTS", url="https://shivammehta25.github.io/Matcha-TTS",
install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))], install_requires=get_requires(),
include_dirs=[numpy.get_include()], include_dirs=[numpy.get_include()],
include_package_data=True, include_package_data=True,
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
@@ -38,6 +45,7 @@ setup(
"matcha-data-stats=matcha.utils.generate_data_statistics:main", "matcha-data-stats=matcha.utils.generate_data_statistics:main",
"matcha-tts=matcha.cli:cli", "matcha-tts=matcha.cli:cli",
"matcha-tts-app=matcha.app:main", "matcha-tts-app=matcha.app:main",
"matcha-tts-get-durations=matcha.utils.get_durations_from_trained_model:main",
] ]
}, },
ext_modules=cythonize(exts, language_level=3), ext_modules=cythonize(exts, language_level=3),

View File

@@ -19,7 +19,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"id": "148f4bc0-c28e-4670-9a5e-4c7928ab8992", "id": "148f4bc0-c28e-4670-9a5e-4c7928ab8992",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -192,7 +192,7 @@
"source": [ "source": [
"@torch.inference_mode()\n", "@torch.inference_mode()\n",
"def process_text(text: str):\n", "def process_text(text: str):\n",
" x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2']), 0),dtype=torch.long, device=device)[None]\n", " x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2'])[0], 0),dtype=torch.long, device=device)[None]\n",
" x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)\n", " x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)\n",
" x_phones = sequence_to_text(x.squeeze(0).tolist())\n", " x_phones = sequence_to_text(x.squeeze(0).tolist())\n",
" return {\n", " return {\n",