68 Commits

Author SHA1 Message Date
Shivam Mehta
4c35836fa5 minor fix 2024-03-02 12:48:54 +00:00
Shivam Mehta
294c6b1327 Adding saving phones while getting durations from matcha 2024-03-02 12:47:08 +00:00
Shivam Mehta
ad76016916 Fixing configs 2024-02-26 09:11:22 +00:00
Shivam Mehta
05c8f9b4a8 updating configs and experiments 2024-02-25 22:02:36 +00:00
Shivam Mehta
4d5b62cea9 Adding a bit of comments 2024-02-24 15:20:13 +00:00
Shivam Mehta
8e87111a98 Adding possibility of getting durations out 2024-02-24 15:10:19 +00:00
Shivam Mehta
def0855608 Adding other experiment configs 2024-01-22 11:46:08 +00:00
Shivam Mehta
6976a91348 Merge branch 'main' into stoc_dur 2024-01-12 11:58:41 +00:00
Shivam Mehta
256adc55d3 Adding ICASSP 2024 2024-01-12 11:31:01 +00:00
Shivam Mehta
bfcbdbc82e Merge pull request #43 from shivammehta25/dev
Removing gdown for HifiGAN checkpoints too
2024-01-12 12:29:03 +01:00
Shivam Mehta
fb7b954de5 Updating different url for hifigan as well 2024-01-12 11:21:51 +00:00
Shivam Mehta
5a52a67cf7 Version bump 2024-01-12 11:11:41 +00:00
Shivam Mehta
39cbd85236 Using Wget for new ckpt downloadsA 2024-01-12 11:09:25 +00:00
Shivam Mehta
47a629f128 Merge pull request #42 from shivammehta25/dev
Merging dev adding another dataset, piper phonemizer and refractoring
2024-01-12 11:49:53 +01:00
Shivam Mehta
95ec24b599 Version bump 2024-01-12 10:48:52 +00:00
Shivam Mehta
5a2a893750 Merge pull request #19 from shivammehta25/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2024-01-12 11:47:10 +01:00
Shivam Mehta
13ca33fbe5 Merge pull request #37 from shivammehta25/dependabot/pip/dev/diffusers-0.25.0
Bump diffusers from 0.21.3 to 0.25.0
2024-01-12 11:46:40 +01:00
Shivam Mehta
19bea20928 Merge branch 'main' into dev 2024-01-12 10:37:17 +00:00
Shivam Mehta
8268360674 Update download urls 2024-01-12 10:32:59 +00:00
Shivam Mehta
a0bf4e9e9a Merge pull request #40 from shivammehta25/ghenter-readme-update-1
Update README.md with ICASSP acceptance
2024-01-12 10:13:23 +01:00
Shivam Mehta
458e9df236 Adding synthesis 2024-01-10 11:05:22 +00:00
Shivam Mehta
d03bba82bb In the middle of adding discrete nf based duration predictor 2024-01-10 11:04:46 +00:00
Gustav Eje Henter
f1e8efdec2 Update README.md
Add back full stop that erroneously went missing in the shuffle.
2024-01-09 22:53:09 +01:00
Gustav Eje Henter
4ec245e61e Update README.md with ICASSP acceptance
Added ICASSP acceptance to the README and made some tiny tweaks to the text
2024-01-09 22:48:16 +01:00
pre-commit-ci[bot]
dc035a09f2 [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/pre-commit/pre-commit-hooks: v4.4.0 → v4.5.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.4.0...v4.5.0)
- [github.com/psf/black: 23.9.1 → 23.12.1](https://github.com/psf/black/compare/23.9.1...23.12.1)
- [github.com/PyCQA/isort: 5.12.0 → 5.13.2](https://github.com/PyCQA/isort/compare/5.12.0...5.13.2)
- [github.com/asottile/pyupgrade: v3.14.0 → v3.15.0](https://github.com/asottile/pyupgrade/compare/v3.14.0...v3.15.0)
- [github.com/PyCQA/flake8: 6.1.0 → 7.0.0](https://github.com/PyCQA/flake8/compare/6.1.0...7.0.0)
- [github.com/pycqa/pylint: v3.0.0 → v3.0.3](https://github.com/pycqa/pylint/compare/v3.0.0...v3.0.3)
2024-01-08 21:15:26 +00:00
Shivam Mehta
a58bab5403 Adding option to do flow matching based duration prediction 2024-01-05 11:13:07 +00:00
dependabot[bot]
254a8e05ce Bump diffusers from 0.21.3 to 0.25.0
Bumps [diffusers](https://github.com/huggingface/diffusers) from 0.21.3 to 0.25.0.
- [Release notes](https://github.com/huggingface/diffusers/releases)
- [Commits](https://github.com/huggingface/diffusers/compare/v0.21.3...v0.25.0)

---
updated-dependencies:
- dependency-name: diffusers
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-12-28 13:20:11 +00:00
Shivam Mehta
0ed9290c31 Logging global step while training 2023-12-06 10:39:54 +00:00
Shivam Mehta
f39ee6cf3b Changing while to for for more readibility 2023-12-05 12:10:52 +00:00
Shivam Mehta
6e71dc8b8f adding prior loss as a configuration 2023-12-05 09:57:37 +00:00
Shivam Mehta
ae2417c175 Merge pull request #34 from shivammehta25/piper_phonemize
Piper phonemize
2023-12-04 11:16:24 +01:00
Shivam Mehta
6c7a82a516 Adding dataset information 2023-12-04 10:15:13 +00:00
Shivam Mehta
009b09a8b2 Removing unwanted configs 2023-12-04 10:13:44 +00:00
Shivam Mehta
a18db17330 Removing the option for configuring prior loss, the durations predicted are not so good then 2023-12-04 10:12:39 +00:00
Shivam Mehta
263d5c4d4e Adding piper phonemizer with different dataset 2023-12-01 12:06:26 +00:00
Shivam Mehta
df896301ca Minor changes moving option to disable prior loss in config 2023-12-01 10:44:49 +00:00
Shivam Mehta
c8d0d60f87 Merge pull request #16 from shivammehta25/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2023-10-06 05:44:02 +02:00
pre-commit-ci[bot]
e540794e7e [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/psf/black: 23.1.0 → 23.9.1](https://github.com/psf/black/compare/23.1.0...23.9.1)
- [github.com/asottile/pyupgrade: v3.3.1 → v3.14.0](https://github.com/asottile/pyupgrade/compare/v3.3.1...v3.14.0)
- [github.com/PyCQA/flake8: 6.0.0 → 6.1.0](https://github.com/PyCQA/flake8/compare/6.0.0...6.1.0)
- [github.com/pycqa/pylint: v2.8.2 → v3.0.0](https://github.com/pycqa/pylint/compare/v2.8.2...v3.0.0)
2023-10-03 13:14:20 +00:00
Shivam Mehta
b756809a32 Merge pull request #13 from shivammehta25/dev
Merging dev to main | adding ONNX support
2023-09-29 16:54:09 +02:00
Shivam Mehta
1ead4303f3 Version Bump 2023-09-29 14:50:46 +00:00
Shivam Mehta
7a29fef719 Merge pull request #12 from shivammehta25/dependabot/pip/dev/diffusers-0.21.3
Bump diffusers from 0.21.2 to 0.21.3
2023-09-29 16:48:13 +02:00
Shivam Mehta
9ace522249 Update README.md 2023-09-29 16:46:38 +02:00
Shivam Mehta
ed6e6bbf6c Merge branch 'ONNX_BRANCH' into dev 2023-09-29 14:43:52 +00:00
Shivam Mehta
51ea36d271 Merge pull request #8 from mush42/onnx
ONNX export and inference
2023-09-29 16:43:19 +02:00
Shivam Mehta
269609003b Adding onnx installation command in the README 2023-09-29 14:38:57 +00:00
dependabot[bot]
2a81800825 Bump diffusers from 0.21.2 to 0.21.3
Bumps [diffusers](https://github.com/huggingface/diffusers) from 0.21.2 to 0.21.3.
- [Release notes](https://github.com/huggingface/diffusers/releases)
- [Commits](https://github.com/huggingface/diffusers/compare/v0.21.2...v0.21.3)

---
updated-dependencies:
- dependency-name: diffusers
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-09-28 13:23:02 +00:00
mush42
336dd20d5b Use torch.onnx.is_in_onnx_export() instead of torch.jit.is_scripting() since the former is dedicated to this use case. 2023-09-26 15:28:15 +02:00
mush42
01c99161c4 - Fixed several bugs. Thanks @shivammehta25 for the suggestions 2023-09-26 14:21:17 +02:00
mush42
2c21a0edac Fixed an error encountered when loading the vocoder during export. 2023-09-24 20:28:59 +02:00
mush42
25767f76a8 Readme: added a note about GPU inference with onnxruntime. 2023-09-24 02:13:27 +02:00
mush42
1b204ed42c ONNX export and inference. Complete and tested implmentation. 2023-09-24 01:57:35 +02:00
Shivam Mehta
2cd057187b Update README.md
Add information about installation and compilation of monotonic alignment
2023-09-23 17:39:36 +02:00
Shivam Mehta
d373e9a5b1 Bumping it to an increased version 2023-09-21 13:43:20 +00:00
Shivam Mehta
f12be190a4 ADding video teaser to readme 2023-09-21 13:41:21 +00:00
Shivam Mehta
be998ae31f Merge pull request #4 from shivammehta25/dev
Another version bump because I am new to twine
2023-09-21 15:26:54 +02:00
Shivam Mehta
a49af19f48 Another version bump because I am new to twine 2023-09-21 13:25:40 +00:00
Shivam Mehta
7850aa0910 Merge pull request #3 from shivammehta25/dev
Adding multispeaker 🍵 Matcha-TTS
2023-09-21 15:23:15 +02:00
Shivam Mehta
309739b7cc Version bump 2023-09-21 13:20:17 +00:00
Shivam Mehta
0aaabf90c4 Minor UI fixes 2023-09-21 13:18:32 +00:00
Shivam Mehta
c5dab67d9f Adding teaser url 2023-09-21 13:15:36 +00:00
Shivam Mehta
d098f32730 bumping diffusers, changing depandabot to open PR to dev branch, adding url for multispeaker matcha checkpoint 2023-09-21 12:32:27 +00:00
Shivam Mehta
281a098337 better default speaking rate 2023-09-20 15:23:46 +00:00
Shivam Mehta
db95158043 Will have to load all models to enable multiple synthesis at the same time 2023-09-20 10:40:01 +00:00
Shivam Mehta
267bf96651 Adding multispeaker model in UI 2023-09-20 10:28:48 +00:00
Shivam Mehta
72635012b0 Adding matcha vctk 2023-09-20 07:08:11 +00:00
Gustav Eje Henter
9ceee279f0 Minor improvements to README.md 2023-09-18 18:44:13 +02:00
Shivam Mehta
d7b9a37359 adding more validation to multispeaker CLI 2023-09-18 11:37:58 +00:00
Shivam Mehta
ec43ef0732 Keeping ODE step slider to be greater than 0 always in gradio 2023-09-17 22:08:29 +00:00
45 changed files with 2971 additions and 283 deletions

View File

@@ -7,6 +7,7 @@ version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
target-branch: "dev"
schedule:
interval: "daily"
ignore:

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
@@ -18,28 +18,28 @@ repos:
# python code formatting
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.12.1
hooks:
- id: black
args: [--line-length, "120"]
# python import sorting
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py38-plus]
# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8
args:
@@ -54,6 +54,6 @@ repos:
# pylint
- repo: https://github.com/pycqa/pylint
rev: v2.8.2
rev: v3.0.3
hooks:
- id: pylint

View File

@@ -82,16 +82,6 @@ disable=missing-docstring,
no-name-in-module,
no-member,
unsubscriptable-object,
print-statement,
parameter-unpacking,
unpacking-in-except,
old-raise-syntax,
backtick,
long-suffix,
old-ne-operator,
old-octal-literal,
import-star-module-level,
non-ascii-bytes-literal,
raw-checker-failed,
bad-inline-option,
locally-disabled,
@@ -106,67 +96,6 @@ disable=missing-docstring,
too-many-arguments,
too-many-locals,
too-many-statements,
apply-builtin,
basestring-builtin,
buffer-builtin,
cmp-builtin,
coerce-builtin,
execfile-builtin,
file-builtin,
long-builtin,
raw_input-builtin,
reduce-builtin,
standarderror-builtin,
unicode-builtin,
xrange-builtin,
coerce-method,
delslice-method,
getslice-method,
setslice-method,
no-absolute-import,
old-division,
dict-iter-method,
dict-view-method,
next-method-called,
metaclass-assignment,
indexing-exception,
raising-string,
reload-builtin,
oct-method,
hex-method,
nonzero-method,
cmp-method,
input-builtin,
round-builtin,
intern-builtin,
unichr-builtin,
map-builtin-not-iterating,
zip-builtin-not-iterating,
range-builtin-not-iterating,
filter-builtin-not-iterating,
using-cmp-argument,
eq-without-hash,
div-method,
idiv-method,
rdiv-method,
exception-message-attribute,
invalid-str-codec,
sys-max-int,
bad-python3-import,
deprecated-string-function,
deprecated-str-translate-call,
deprecated-itertools-function,
deprecated-types-field,
next-method-defined,
dict-items-not-iterating,
dict-keys-not-iterating,
dict-values-not-iterating,
deprecated-operator-function,
deprecated-urllib-function,
xreadlines-attribute,
deprecated-sys-function,
exception-escape,
comprehension-escape,
duplicate-code,
not-callable,
import-outside-toplevel,
@@ -363,13 +292,6 @@ max-line-length=120
# Maximum number of lines in a module.
max-module-lines=1000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
@@ -599,5 +521,5 @@ min-public-methods=2
# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
Exception
overgeneral-exceptions=builtins.BaseException,
builtins.Exception

View File

@@ -17,7 +17,7 @@ create-package: ## Create wheel and tar gz
rm -rf dist/
python setup.py bdist_wheel --plat-name=manylinux1_x86_64
python setup.py sdist
python -m twine upload dist/* --verbose
python -m twine upload dist/* --verbose --skip-existing
format: ## Run pre-commit hooks
pre-commit run -a

126
README.md
View File

@@ -17,22 +17,24 @@
</div>
> This is the official code implementation of 🍵 Matcha-TTS.
> This is the official code implementation of 🍵 Matcha-TTS [ICASSP 2024].
We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method:
We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses [conditional flow matching](https://arxiv.org/abs/2210.02747) (similar to [rectified flows](https://arxiv.org/abs/2209.03003)) to speed up ODE-based speech synthesis. Our method:
- Is probabilistic
- Has compact memory footprint
- Sounds highly natural
- Is very fast to synthesise from
Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199).
Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS) and read [our ICASSP 2024 paper](https://arxiv.org/abs/2309.03199) for more details.
[Pretrained models](https://drive.google.com/drive/folders/17C_gYgEHOxI5ZypcfE_k1piKCtyR0isJ?usp=sharing) will be auto downloaded with the CLI or gradio interface.
[Pre-trained models](https://drive.google.com/drive/folders/17C_gYgEHOxI5ZypcfE_k1piKCtyR0isJ?usp=sharing) will be automatically downloaded with the CLI or gradio interface.
[Try 🍵 Matcha-TTS on HuggingFace 🤗 spaces!](https://huggingface.co/spaces/shivammehta25/Matcha-TTS)
You can also [try 🍵 Matcha-TTS in your browser on HuggingFace 🤗 spaces](https://huggingface.co/spaces/shivammehta25/Matcha-TTS).
<br>
## Teaser video
[![Watch the video](https://img.youtube.com/vi/xmvJkz3bqw0/hqdefault.jpg)](https://youtu.be/xmvJkz3bqw0)
## Installation
@@ -53,6 +55,8 @@ from source
```bash
pip install git+https://github.com/shivammehta25/Matcha-TTS.git
cd Matcha-TTS
pip install -e .
```
3. Run CLI / gradio app / jupyter notebook
@@ -110,26 +114,13 @@ matcha-tts --text "<INPUT TEXT>" --temperature 0.667
matcha-tts --text "<INPUT TEXT>" --steps 10
```
## Citation information
If you find this work useful, please cite our paper:
```text
@article{mehta2023matcha,
title={Matcha-TTS: A fast TTS architecture with conditional flow matching},
author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje},
journal={arXiv preprint arXiv:2309.03199},
year={2023}
}
```
## Train with your own dataset
Let's assume we are training with LJSpeech
Let's assume we are training with LJ Speech
1. Download the dataset from [here](https://keithito.com/LJ-Speech-Dataset/), extract it to `data/LJSpeech-1.1`, and prepare the filelists to point to the extracted data like the [5th point of setup in Tacotron2 repo](https://github.com/NVIDIA/tacotron2#setup).
1. Download the dataset from [here](https://keithito.com/LJ-Speech-Dataset/), extract it to `data/LJSpeech-1.1`, and prepare the file lists to point to the extracted data like for [item 5 in the setup of the NVIDIA Tacotron 2 repo](https://github.com/NVIDIA/tacotron2#setup).
2. Clone and enter this repository
2. Clone and enter the Matcha-TTS repository
```bash
git clone https://github.com/shivammehta25/Matcha-TTS.git
@@ -167,7 +158,7 @@ data_statistics: # Computed for ljspeech dataset
to the paths of your train and validation filelists.
5. Run the training script
6. Run the training script
```bash
make train-ljspeech
@@ -191,20 +182,97 @@ python matcha/train.py experiment=ljspeech_min_memory
python matcha/train.py experiment=ljspeech trainer.devices=[0,1]
```
6. Synthesise from the custom trained model
7. Synthesise from the custom trained model
```bash
matcha-tts --text "<INPUT TEXT>" --checkpoint_path <PATH TO CHECKPOINT>
```
## ONNX support
> Special thanks to [@mush42](https://github.com/mush42) for implementing ONNX export and inference support.
It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph.
### ONNX export
To export a checkpoint to ONNX, first install ONNX with
```bash
pip install onnx
```
then run the following:
```bash
python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5
```
Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint** arguments. This enables you to embed the vocoder in the exported graph and generate waveforms in a single run (similar to end-to-end TTS systems).
**Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**.
**Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release.
### ONNX Inference
To run inference on the exported model, first install `onnxruntime` using
```bash
pip install onnxruntime
pip install onnxruntime-gpu # for GPU inference
```
then use the following:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs
```
You can also control synthesis parameters:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0
```
To run inference on **GPU**, make sure to install **onnxruntime-gpu** package, and then pass `--gpu` to the inference command:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --gpu
```
If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory.
If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory.
If you exported only Matcha to ONNX, and you want to run a full TTS pipeline, you can pass a path to a vocoder model in `ONNX` format:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vocoder hifigan.small.onnx
```
This will write `.wav` audio files to the output directory.
## Citation information
If you use our code or otherwise find this work useful, please cite our paper:
```text
@inproceedings{mehta2024matcha,
title={Matcha-{TTS}: A fast {TTS} architecture with conditional flow matching},
author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje},
booktitle={Proc. ICASSP},
year={2024}
}
```
## Acknowledgements
Since this code uses: [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that comes with it.
Since this code uses [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that come with it.
Other source codes I would like to acknowledge:
Other source code we would like to acknowledge:
- [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev) :For helping me figure out how to make cython binaries pip installable and encouragement
- [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev): For helping me figure out how to make cython binaries pip installable and encouragement
- [Hugging Face Diffusers](https://huggingface.co/): For their awesome diffusers library and its components
- [Grad-TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): For source code of MAS
- [Grad-TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): For the monotonic alignment search source code
- [torchdyn](https://github.com/DiffEqML/torchdyn): Useful for trying other ODE solvers during research and development
- [labml.ai](https://nn.labml.ai/transformers/rope/index.html): For RoPE implementation
- [labml.ai](https://nn.labml.ai/transformers/rope/index.html): For the RoPE implementation

View File

@@ -0,0 +1,14 @@
defaults:
- ljspeech
- _self_
# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: hi-fi_en-US_female
train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt
valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt
batch_size: 32
cleaners: [english_cleaners_piper]
data_statistics: # Computed for this dataset
mel_mean: -6.38385
mel_std: 2.541796

View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: joe_spont_only
train_filelist_path: data/filelists/joe_spontonly_train.txt
valid_filelist_path: data/filelists/joe_spontonly_val.txt
data_statistics:
mel_mean: -5.882903
mel_std: 2.458284

10
configs/data/ryan.yaml Normal file
View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: ryan
train_filelist_path: data/filelists/ryan_train.csv
valid_filelist_path: data/filelists/ryan_val.csv
data_statistics:
mel_mean: -4.715779
mel_std: 2.124502

10
configs/data/tsg2.yaml Normal file
View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: tsg2
train_filelist_path: data/filelists/cormac_train.txt
valid_filelist_path: data/filelists/cormac_val.txt
data_statistics:
mel_mean: -5.536622
mel_std: 2.116101

View File

@@ -7,8 +7,8 @@
task_name: "debug"
# disable callbacks and loggers during debugging
callbacks: null
logger: null
# callbacks: null
# logger: null
extras:
ignore_warnings: False

View File

@@ -7,6 +7,9 @@ defaults:
trainer:
max_epochs: 1
profiler: "simple"
# profiler: "advanced"
# profiler: "simple"
profiler: "advanced"
# profiler: "pytorch"
accelerator: gpu
limit_train_batches: 0.02

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: hi-fi_en-US_female.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"]
run_name: hi-fi_en-US_female_piper_phonemizer

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_det_dur

View File

@@ -0,0 +1,20 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_stoc_dur
model:
duration_predictor:
p_dropout: 0.2

View File

@@ -0,0 +1,16 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ljspeech.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech

View File

@@ -0,0 +1,18 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_det_dur
trainer:
max_epochs: 3000

View File

@@ -0,0 +1,24 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_stoc_dur
model:
duration_predictor:
p_dropout: 0.2
trainer:
max_epochs: 3000

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_det_dur

View File

@@ -0,0 +1,20 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_stoc_dur
model:
duration_predictor:
p_dropout: 0.5

View File

@@ -0,0 +1,7 @@
name: deterministic
n_spks: ${model.n_spks}
spk_emb_dim: ${model.spk_emb_dim}
filter_channels: 256
kernel_size: 3
n_channels: ${model.encoder.encoder_params.n_channels}
p_dropout: ${model.encoder.encoder_params.p_dropout}

View File

@@ -0,0 +1,7 @@
defaults:
- deterministic.yaml
- _self_
sigma_min: 1e-4
n_steps: 10
name: flow_matching

View File

@@ -3,16 +3,8 @@ encoder_params:
n_feats: ${model.n_feats}
n_channels: 192
filter_channels: 768
filter_channels_dp: 256
n_heads: 2
n_layers: 6
kernel_size: 3
p_dropout: 0.1
spk_emb_dim: 64
n_spks: 1
prenet: true
duration_predictor_params:
filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp}
kernel_size: 3
p_dropout: ${model.encoder.encoder_params.p_dropout}

View File

@@ -1,6 +1,7 @@
defaults:
- _self_
- encoder: default.yaml
- duration_predictor: deterministic.yaml
- decoder: default.yaml
- cfm: default.yaml
- optimizer: adam.yaml
@@ -12,3 +13,4 @@ spk_emb_dim: 64
n_feats: 80
data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4
prior_loss: true

View File

@@ -1 +1 @@
0.0.1.dev4
0.0.5.1

View File

@@ -8,7 +8,7 @@ import torch
from matcha.cli import (
MATCHA_URLS,
VOCODER_URL,
VOCODER_URLS,
assert_model_downloaded,
get_device,
load_matcha,
@@ -22,20 +22,80 @@ LOCATION = Path(get_user_data_dir())
args = Namespace(
cpu=False,
model="matcha_ljspeech",
vocoder="hifigan_T2_v1",
spk=None,
model="matcha_vctk",
vocoder="hifigan_univ_v1",
spk=0,
)
MATCHA_TTS_LOC = LOCATION / f"{args.model}.ckpt"
VOCODER_LOC = LOCATION / f"{args.vocoder}"
CURRENTLY_LOADED_MODEL = args.model
def MATCHA_TTS_LOC(x):
return LOCATION / f"{x}.ckpt"
def VOCODER_LOC(x):
return LOCATION / f"{x}"
LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
assert_model_downloaded(MATCHA_TTS_LOC, MATCHA_URLS[args.model])
assert_model_downloaded(VOCODER_LOC, VOCODER_URL[args.vocoder])
RADIO_OPTIONS = {
"Multi Speaker (VCTK)": {
"model": "matcha_vctk",
"vocoder": "hifigan_univ_v1",
},
"Single Speaker (LJ Speech)": {
"model": "matcha_ljspeech",
"vocoder": "hifigan_T2_v1",
},
}
# Ensure all the required models are downloaded
assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"])
assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"])
assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
device = get_device(args)
model = load_matcha(args.model, MATCHA_TTS_LOC, device)
vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC, device)
# Load default model
model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device)
vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device)
def load_model(model_name, vocoder_name):
model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device)
vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device)
return model, vocoder, denoiser
def load_model_ui(model_type, textbox):
model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"]
global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
if CURRENTLY_LOADED_MODEL != model_name:
model, vocoder, denoiser = load_model(model_name, vocoder_name)
CURRENTLY_LOADED_MODEL = model_name
if model_name == "matcha_ljspeech":
spk_slider = gr.update(visible=False, value=-1)
single_speaker_examples = gr.update(visible=True)
multi_speaker_examples = gr.update(visible=False)
length_scale = gr.update(value=0.95)
else:
spk_slider = gr.update(visible=True, value=0)
single_speaker_examples = gr.update(visible=False)
multi_speaker_examples = gr.update(visible=True)
length_scale = gr.update(value=0.85)
return (
textbox,
gr.update(interactive=True),
spk_slider,
single_speaker_examples,
multi_speaker_examples,
length_scale,
)
@torch.inference_mode()
@@ -45,13 +105,14 @@ def process_text_gradio(text):
@torch.inference_mode()
def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale):
def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk):
spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
output = model.synthesise(
text,
text_length,
n_timesteps=n_timesteps,
temperature=temperature,
spks=args.spk,
spks=spk,
length_scale=length_scale,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
@@ -61,9 +122,27 @@ def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale):
return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
def run_full_synthesis(text, n_timesteps, mel_temp, length_scale):
def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk):
global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
if CURRENTLY_LOADED_MODEL != "matcha_vctk":
global model, vocoder, denoiser # pylint: disable=global-statement
model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1")
CURRENTLY_LOADED_MODEL = "matcha_vctk"
phones, text, text_lengths = process_text_gradio(text)
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale)
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
return phones, audio, mel_spectrogram
def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1):
global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
if CURRENTLY_LOADED_MODEL != "matcha_ljspeech":
global model, vocoder, denoiser # pylint: disable=global-statement
model, vocoder, denoiser = load_model("matcha_ljspeech", "hifigan_T2_v1")
CURRENTLY_LOADED_MODEL = "matcha_ljspeech"
phones, text, text_lengths = process_text_gradio(text)
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
return phones, audio, mel_spectrogram
@@ -92,20 +171,31 @@ def main():
with gr.Box():
with gr.Row():
gr.Markdown(description, scale=3)
gr.Image(LOGO_URL, label="Matcha-TTS logo", height=150, width=150, scale=1, show_label=False)
with gr.Column():
gr.Image(LOGO_URL, label="Matcha-TTS logo", height=50, width=50, scale=1, show_label=False)
html = '<br><iframe width="560" height="315" src="https://www.youtube.com/embed/xmvJkz3bqw0?si=jN7ILyDsbPwJCGoa" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>'
gr.HTML(html)
with gr.Box():
radio_options = list(RADIO_OPTIONS.keys())
model_type = gr.Radio(
radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False
)
with gr.Row():
gr.Markdown("# Text Input")
with gr.Row():
text = gr.Textbox(value="", lines=2, label="Text to synthesise")
text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3)
spk_slider = gr.Slider(
minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1
)
with gr.Row():
gr.Markdown("### Hyper parameters")
with gr.Row():
n_timesteps = gr.Slider(
label="Number of ODE steps",
minimum=0,
minimum=1,
maximum=100,
step=1,
value=10,
@@ -142,58 +232,110 @@ def main():
# with gr.Row():
audio = gr.Audio(interactive=False, label="Audio")
with gr.Row():
with gr.Row(visible=False) as example_row_lj_speech:
examples = gr.Examples( # pylint: disable=unused-variable
examples=[
[
"We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
50,
0.677,
1.0,
0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
2,
0.677,
1.0,
0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
4,
0.677,
1.0,
0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
10,
0.677,
1.0,
0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
50,
0.677,
1.0,
0.95,
],
[
"The narrative of these events is based largely on the recollections of the participants.",
10,
0.677,
1.0,
0.95,
],
[
"The jury did not believe him, and the verdict was for the defendants.",
10,
0.677,
1.0,
0.95,
],
],
fn=run_full_synthesis,
fn=ljspeech_example_cacher,
inputs=[text, n_timesteps, mel_temp, length_scale],
outputs=[phonetised_text, audio, mel_spectrogram],
cache_examples=True,
)
with gr.Row() as example_row_multispeaker:
multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable
examples=[
[
"Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!",
10,
0.677,
0.85,
0,
],
[
"Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!",
10,
0.677,
0.85,
16,
],
[
"Hello everyone! I am speaker 44 and I am here to tell you that Matcha-TTS is amazing!",
50,
0.677,
0.85,
44,
],
[
"Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!",
50,
0.677,
0.85,
45,
],
[
"Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!",
4,
0.677,
0.85,
58,
],
],
fn=multispeaker_example_cacher,
inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider],
outputs=[phonetised_text, audio, mel_spectrogram],
cache_examples=True,
label="Multi Speaker Examples",
)
model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then(
load_model_ui,
inputs=[model_type, text],
outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale],
)
synth_btn.click(
fn=process_text_gradio,
inputs=[
@@ -204,11 +346,11 @@ def main():
queue=True,
).then(
fn=synthesise_mel,
inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale],
inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider],
outputs=[audio, mel_spectrogram],
)
demo.queue(concurrency_count=5).launch(share=True)
demo.queue().launch(share=True)
if __name__ == "__main__":

View File

@@ -1,6 +1,7 @@
import argparse
import datetime as dt
import os
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
@@ -17,13 +18,20 @@ from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
MATCHA_URLS = {
"matcha_ljspeech": "https://drive.google.com/file/d/1BBzmMU7k3a_WetDfaFblMoN18GqQeHCg/view?usp=drive_link"
} # , "matcha_vctk": ""} # Coming soon
"matcha_ljspeech": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt",
"matcha_vctk": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_vctk.ckpt",
}
MULTISPEAKER_MODEL = {"matcha_vctk"}
SINGLESPEAKER_MODEL = {"matcha_ljspeech"}
VOCODER_URLS = {
"hifigan_T2_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1", # Old url: https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link
"hifigan_univ_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/g_02500000", # Old url: https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link
}
VOCODER_URL = {"hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link"}
MULTISPEAKER_MODEL = {
"matcha_vctk": {"vocoder": "hifigan_univ_v1", "speaking_rate": 0.85, "spk": 0, "spk_range": (0, 107)}
}
SINGLESPEAKER_MODEL = {"matcha_ljspeech": {"vocoder": "hifigan_T2_v1", "speaking_rate": 0.95, "spk": None}}
def plot_spectrogram_to_numpy(spectrogram, filename):
@@ -55,17 +63,21 @@ def get_texts(args):
if args.text:
texts = [args.text]
else:
with open(args.file) as f:
with open(args.file, encoding="utf-8") as f:
texts = f.readlines()
return texts
def assert_required_models_available(args):
save_dir = get_user_data_dir()
if not hasattr(args, "checkpoint_path") and args.checkpoint_path is None:
model_path = args.checkpoint_path
else:
model_path = save_dir / f"{args.model}.ckpt"
vocoder_path = save_dir / f"{args.vocoder}"
assert_model_downloaded(model_path, MATCHA_URLS[args.model])
assert_model_downloaded(vocoder_path, VOCODER_URL[args.vocoder])
vocoder_path = save_dir / f"{args.vocoder}"
assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder])
return {"matcha": model_path, "vocoder": vocoder_path}
@@ -81,7 +93,7 @@ def load_hifigan(checkpoint_path, device):
def load_vocoder(vocoder_name, checkpoint_path, device):
print(f"[!] Loading {vocoder_name}!")
vocoder = None
if vocoder_name == "hifigan_T2_v1":
if vocoder_name in ("hifigan_T2_v1", "hifigan_univ_v1"):
vocoder = load_hifigan(checkpoint_path, device)
else:
raise NotImplementedError(
@@ -124,21 +136,70 @@ def validate_args(args):
args.text or args.file
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
assert args.temperature >= 0, "Sampling temperature cannot be negative"
assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
assert args.steps > 0, "Number of ODE steps must be greater than 0"
if args.checkpoint_path is None:
# When using pretrained models
if args.model in SINGLESPEAKER_MODEL:
assert args.spk is None, f"Speaker ID is not supported for {args.model}"
if args.spk is not None:
assert args.spk >= 0 and args.spk < 109, "Speaker ID must be between 0 and 108"
assert args.model in MULTISPEAKER_MODEL, "Speaker ID is only supported for multispeaker model"
args = validate_args_for_single_speaker_model(args)
if args.model in MULTISPEAKER_MODEL:
if args.spk is None:
print("[!] Speaker ID not provided! Using speaker ID 0")
args.spk = 0
args = validate_args_for_multispeaker_model(args)
else:
# When using a custom model
if args.vocoder != "hifigan_univ_v1":
warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech."
warnings.warn(warn_, UserWarning)
if args.speaking_rate is None:
args.speaking_rate = 1.0
if args.batched:
assert args.batch_size > 0, "Batch size must be greater than 0"
assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
return args
def validate_args_for_multispeaker_model(args):
if args.vocoder is not None:
if args.vocoder != MULTISPEAKER_MODEL[args.model]["vocoder"]:
warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {MULTISPEAKER_MODEL[args.model]['vocoder']}"
warnings.warn(warn_, UserWarning)
else:
args.vocoder = MULTISPEAKER_MODEL[args.model]["vocoder"]
if args.speaking_rate is None:
args.speaking_rate = MULTISPEAKER_MODEL[args.model]["speaking_rate"]
spk_range = MULTISPEAKER_MODEL[args.model]["spk_range"]
if args.spk is not None:
assert (
args.spk >= spk_range[0] and args.spk <= spk_range[-1]
), f"Speaker ID must be between {spk_range} for this model."
else:
available_spk_id = MULTISPEAKER_MODEL[args.model]["spk"]
warn_ = f"[!] Speaker ID not provided! Using speaker ID {available_spk_id}"
warnings.warn(warn_, UserWarning)
args.spk = available_spk_id
return args
def validate_args_for_single_speaker_model(args):
if args.vocoder is not None:
if args.vocoder != SINGLESPEAKER_MODEL[args.model]["vocoder"]:
warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {SINGLESPEAKER_MODEL[args.model]['vocoder']}"
warnings.warn(warn_, UserWarning)
else:
args.vocoder = SINGLESPEAKER_MODEL[args.model]["vocoder"]
if args.speaking_rate is None:
args.speaking_rate = SINGLESPEAKER_MODEL[args.model]["speaking_rate"]
if args.spk != SINGLESPEAKER_MODEL[args.model]["spk"]:
warn_ = f"[-] Ignoring speaker id {args.spk} for {args.model}"
warnings.warn(warn_, UserWarning)
args.spk = SINGLESPEAKER_MODEL[args.model]["spk"]
return args
@@ -166,9 +227,9 @@ def cli():
parser.add_argument(
"--vocoder",
type=str,
default="hifigan_T2_v1",
help="Vocoder to use",
choices=VOCODER_URL.keys(),
default="hifigan_univ_v1",
help="Vocoder to use (default: will use the one suggested with the pretrained model))",
choices=VOCODER_URLS.keys(),
)
parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
@@ -182,7 +243,7 @@ def cli():
parser.add_argument(
"--speaking_rate",
type=float,
default=1.0,
default=None,
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
)
parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)")
@@ -199,8 +260,10 @@ def cli():
default=os.getcwd(),
help="Output folder to save results (default: current dir)",
)
parser.add_argument("--batched", action="store_true")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--batched", action="store_true", help="Batched inference (default: False)")
parser.add_argument(
"--batch_size", type=int, default=32, help="Batch size only useful when --batched (default: 32)"
)
args = parser.parse_args()
@@ -333,6 +396,8 @@ def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
def print_config(args):
print("[!] Configurations: ")
print(f"\t- Model: {args.model}")
print(f"\t- Vocoder: {args.vocoder}")
print(f"\t- Temperature: {args.temperature}")
print(f"\t- Speaking rate: {args.speaking_rate}")
print(f"\t- Number of ODE steps: {args.steps}")

View File

@@ -109,7 +109,7 @@ class TextMelDataModule(LightningDataModule):
"""Clean up after fit or test."""
pass # pylint: disable=unnecessary-pass
def state_dict(self): # pylint: disable=no-self-use
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
@@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset):
filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None
text = self.get_text(text, add_blank=self.add_blank)
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk}
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
def get_mel(self, filepath):
audio, sr = ta.load(filepath)
@@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset):
return mel
def get_text(self, text, add_blank=True):
text_norm = text_to_sequence(text, self.cleaners)
text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
if self.add_blank:
text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm)
return text_norm
return text_norm, cleaned_text
def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index])
@@ -207,15 +207,16 @@ class TextMelBatchCollate:
def __call__(self, batch):
B = len(batch)
y_max_length = max([item["y"].shape[-1] for item in batch])
y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
y_max_length = fix_len_compatibility(y_max_length)
x_max_length = max([item["x"].shape[-1] for item in batch])
x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
n_feats = batch[0]["y"].shape[-2]
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], []
spks = []
filepaths, x_texts = [], []
for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1])
@@ -223,9 +224,19 @@ class TextMelBatchCollate:
y[i, :, : y_.shape[-1]] = y_
x[i, : x_.shape[-1]] = x_
spks.append(item["spk"])
filepaths.append(item["filepath"])
x_texts.append(item["x_text"])
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks}
return {
"x": x,
"x_lengths": x_lengths,
"y": y,
"y_lengths": y_lengths,
"spks": spks,
"filepaths": filepaths,
"x_texts": x_texts,
}

View File

@@ -58,7 +58,7 @@ class BaseLightningClass(LightningModule, ABC):
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
dur_loss, prior_loss, diff_loss = self(
dur_loss, prior_loss, diff_loss, *_ = self(
x=x,
x_lengths=x_lengths,
y=y,
@@ -81,7 +81,7 @@ class BaseLightningClass(LightningModule, ABC):
"step",
float(self.global_step),
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)

View File

@@ -0,0 +1,448 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack
from matcha.models.components.decoder import SinusoidalPosEmb, TimestepEmbedding
from matcha.models.components.text_encoder import LayerNorm
# Define available networks
class DurationPredictorNetwork(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class DurationPredictorNetworkWithTimeStep(nn.Module):
"""Similar architecture but with a time embedding support"""
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.time_embeddings = SinusoidalPosEmb(filter_channels)
self.time_mlp = TimestepEmbedding(
in_channels=filter_channels,
time_embed_dim=filter_channels,
act_fn="silu",
)
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask, enc_outputs, t):
t = self.time_embeddings(t)
t = self.time_mlp(t).unsqueeze(-1)
x = pack([x, enc_outputs], "b * t")[0]
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
# Define available methods to compute loss
# Simple MSE deterministic
class DeterministicDurationPredictor(nn.Module):
def __init__(self, params):
super().__init__()
self.estimator = DurationPredictorNetwork(
params.n_channels + (params.spk_emb_dim if params.n_spks > 1 else 0),
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
@torch.inference_mode()
def forward(self, x, x_mask):
return self.estimator(x, x_mask)
def compute_loss(self, durations, enc_outputs, x_mask):
return F.mse_loss(self.estimator(enc_outputs, x_mask), durations, reduction="sum") / torch.sum(x_mask)
# Flow Matching duration predictor
class FlowMatchingDurationPrediction(nn.Module):
def __init__(self, params) -> None:
super().__init__()
self.estimator = DurationPredictorNetworkWithTimeStep(
1
+ params.n_channels
+ (
params.spk_emb_dim if params.n_spks > 1 else 0
), # 1 for the durations and n_channels for encoder outputs
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
self.sigma_min = params.sigma_min
self.n_steps = params.n_steps
@torch.inference_mode()
def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
if n_timesteps is None:
n_timesteps = self.n_steps
b, _, t = enc_outputs.shape
z = torch.randn((b, 1, t), device=enc_outputs.device, dtype=enc_outputs.dtype) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=enc_outputs.device)
return self.solve_euler(z, t_span=t_span, enc_outputs=enc_outputs, mask=mask)
def solve_euler(self, x, t_span, enc_outputs, mask):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, enc_outputs, t)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, enc_outputs, mask):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
enc_outputs = enc_outputs.detach() # don't update encoder from the duration predictor
b, _, t = enc_outputs.shape
# random timestep
t = torch.rand([b, 1, 1], device=enc_outputs.device, dtype=enc_outputs.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, enc_outputs, t.squeeze()), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss
# VITS discrete normalising flow based duration predictor
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = Log()
self.flows = nn.ModuleList()
self.flows.append(ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
# Meta class to wrap all duration predictors
class DP(nn.Module):
def __init__(self, params):
super().__init__()
self.name = params.name
if params.name == "deterministic":
self.dp = DeterministicDurationPredictor(
params,
)
elif params.name == "flow_matching":
self.dp = FlowMatchingDurationPrediction(
params,
)
else:
raise ValueError(f"Invalid duration predictor configuration: {params.name}")
@torch.inference_mode()
def forward(self, enc_outputs, mask):
return self.dp(enc_outputs, mask)
def compute_loss(self, durations, enc_outputs, mask):
return self.dp.compute_loss(durations, enc_outputs, mask)

View File

@@ -73,16 +73,14 @@ class BASECFM(torch.nn.Module, ABC):
# Or in future might add like a return_all_steps flag
sol = []
steps = 1
while steps <= len(t_span) - 1:
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if steps < len(t_span) - 1:
dt = t_span[steps + 1] - t
steps += 1
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]

View File

@@ -67,33 +67,6 @@ class ConvReluNorm(nn.Module):
return x * x_mask
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class RotaryPositionalEmbeddings(nn.Module):
"""
## RoPE module
@@ -330,7 +303,6 @@ class TextEncoder(nn.Module):
self,
encoder_type,
encoder_params,
duration_predictor_params,
n_vocab,
n_spks=1,
spk_emb_dim=128,
@@ -368,12 +340,6 @@ class TextEncoder(nn.Module):
)
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
self.proj_w = DurationPredictor(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
duration_predictor_params.filter_channels_dp,
duration_predictor_params.kernel_size,
duration_predictor_params.p_dropout,
)
def forward(self, x, x_lengths, spks=None):
"""Run forward pass to the transformer based encoder and duration predictor
@@ -404,7 +370,7 @@ class TextEncoder(nn.Module):
x = self.encoder(x, x_mask)
mu = self.proj_m(x) * x_mask
x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask)
# x_dp = torch.detach(x)
# logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask
return mu, x, x_mask

View File

@@ -4,14 +4,14 @@ import random
import torch
import matcha.utils.monotonic_align as monotonic_align
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.duration_predictors import DP
from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder
from matcha.utils.model import (
denormalize,
duration_loss,
fix_len_compatibility,
generate_path,
sequence_mask,
@@ -28,12 +28,14 @@ class MatchaTTS(BaseLightningClass): # 🍵
spk_emb_dim,
n_feats,
encoder,
duration_predictor,
decoder,
cfm,
data_statistics,
out_size,
optimizer=None,
scheduler=None,
prior_loss=True,
):
super().__init__()
@@ -44,6 +46,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.spk_emb_dim = spk_emb_dim
self.n_feats = n_feats
self.out_size = out_size
self.prior_loss = prior_loss
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
@@ -51,12 +54,13 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.encoder = TextEncoder(
encoder.encoder_type,
encoder.encoder_params,
encoder.duration_predictor_params,
n_vocab,
n_spks,
spk_emb_dim,
)
self.dp = DP(duration_predictor)
self.decoder = CFM(
in_channels=2 * encoder.encoder_params.n_feats,
out_channel=encoder.encoder_params.n_feats,
@@ -110,13 +114,17 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Get speaker embedding
spks = self.spk_emb(spks.long())
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
# Get encoder_outputs `mu_x` and encoded text `enc_output`
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
# Get log-scaled token durations `logw`
logw = self.dp(enc_output, x_mask)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
# print(w_ceil)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = int(y_lengths.max())
y_max_length = y_lengths.max()
y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn`
@@ -171,7 +179,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
spks = self.spk_emb(spks)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
y_max_length = y.shape[-1]
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
@@ -190,9 +198,8 @@ class MatchaTTS(BaseLightningClass): # 🍵
attn = attn.detach()
# Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
dur_loss = self.dp.compute_loss(logw_, enc_output, x_mask)
# Cut a small segment of mel-spectrogram in order to increase batch size
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
@@ -228,7 +235,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
if self.prior_loss:
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss
return dur_loss, prior_loss, diff_loss, attn

0
matcha/onnx/__init__.py Normal file
View File

181
matcha/onnx/export.py Normal file
View File

@@ -0,0 +1,181 @@
import argparse
import random
from pathlib import Path
import numpy as np
import torch
from lightning import LightningModule
from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder
DEFAULT_OPSET = 15
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class MatchaWithVocoder(LightningModule):
def __init__(self, matcha, vocoder):
super().__init__()
self.matcha = matcha
self.vocoder = vocoder
def forward(self, x, x_lengths, scales, spks=None):
mel, mel_lengths = self.matcha(x, x_lengths, scales, spks)
wavs = self.vocoder(mel).clamp(-1, 1)
lengths = mel_lengths * 256
return wavs.squeeze(1), lengths
def get_exportable_module(matcha, vocoder, n_timesteps):
"""
Return an appropriate `LighteningModule` and output-node names
based on whether the vocoder is embedded in the final graph
"""
def onnx_forward_func(x, x_lengths, scales, spks=None):
"""
Custom forward function for accepting
scaler parameters as tensors
"""
# Extract scaler parameters from tensors
temperature = scales[0]
length_scale = scales[1]
output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale)
return output["mel"], output["mel_lengths"]
# Monkey-patch Matcha's forward function
matcha.forward = onnx_forward_func
if vocoder is None:
model, output_names = matcha, ["mel", "mel_lengths"]
else:
model = MatchaWithVocoder(matcha, vocoder)
output_names = ["wav", "wav_lengths"]
return model, output_names
def get_inputs(is_multi_speaker):
"""
Create dummy inputs for tracing
"""
dummy_input_length = 50
x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long)
x_lengths = torch.LongTensor([dummy_input_length])
# Scales
temperature = 0.667
length_scale = 1.0
scales = torch.Tensor([temperature, length_scale])
model_inputs = [x, x_lengths, scales]
input_names = [
"x",
"x_lengths",
"scales",
]
if is_multi_speaker:
spks = torch.LongTensor([1])
model_inputs.append(spks)
input_names.append("spks")
return tuple(model_inputs), input_names
def main():
parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX")
parser.add_argument(
"checkpoint_path",
type=str,
help="Path to the model checkpoint",
)
parser.add_argument("output", type=str, help="Path to output `.onnx` file")
parser.add_argument(
"--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)"
)
parser.add_argument(
"--vocoder-name",
type=str,
choices=list(VOCODER_URLS.keys()),
default=None,
help="Name of the vocoder to embed in the ONNX graph",
)
parser.add_argument(
"--vocoder-checkpoint-path",
type=str,
default=None,
help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience",
)
parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15")
args = parser.parse_args()
print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}")
print(f"Setting n_timesteps to {args.n_timesteps}")
checkpoint_path = Path(args.checkpoint_path)
matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu")
if args.vocoder_name or args.vocoder_checkpoint_path:
assert (
args.vocoder_name and args.vocoder_checkpoint_path
), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph."
vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu")
else:
vocoder = None
is_multi_speaker = matcha.n_spks > 1
dummy_input, input_names = get_inputs(is_multi_speaker)
model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps)
# Set dynamic shape for inputs/outputs
dynamic_axes = {
"x": {0: "batch_size", 1: "time"},
"x_lengths": {0: "batch_size"},
}
if vocoder is None:
dynamic_axes.update(
{
"mel": {0: "batch_size", 2: "time"},
"mel_lengths": {0: "batch_size"},
}
)
else:
print("Embedding the vocoder in the ONNX graph")
dynamic_axes.update(
{
"wav": {0: "batch_size", 1: "time"},
"wav_lengths": {0: "batch_size"},
}
)
if is_multi_speaker:
dynamic_axes["spks"] = {0: "batch_size"}
# Create the output directory (if not exists)
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
model.to_onnx(
args.output,
dummy_input,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=args.opset,
export_params=True,
do_constant_folding=True,
)
print(f"[🍵] ONNX model exported to {args.output}")
if __name__ == "__main__":
main()

168
matcha/onnx/infer.py Normal file
View File

@@ -0,0 +1,168 @@
import argparse
import os
import warnings
from pathlib import Path
from time import perf_counter
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
from matcha.cli import plot_spectrogram_to_numpy, process_text
def validate_args(args):
assert (
args.text or args.file
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
assert args.temperature >= 0, "Sampling temperature cannot be negative"
assert args.speaking_rate >= 0, "Speaking rate must be greater than 0"
return args
def write_wavs(model, inputs, output_dir, external_vocoder=None):
if external_vocoder is None:
print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
t0 = perf_counter()
wavs, wav_lengths = model.run(None, inputs)
infer_secs = perf_counter() - t0
mel_infer_secs = vocoder_infer_secs = None
else:
print("[🍵] Generating mel using Matcha")
mel_t0 = perf_counter()
mels, mel_lengths = model.run(None, inputs)
mel_infer_secs = perf_counter() - mel_t0
print("Generating waveform from mel using external vocoder")
vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
vocoder_t0 = perf_counter()
wavs = external_vocoder.run(None, vocoder_inputs)[0]
vocoder_infer_secs = perf_counter() - vocoder_t0
wavs = wavs.squeeze(1)
wav_lengths = mel_lengths * 256
infer_secs = mel_infer_secs + vocoder_infer_secs
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)):
output_filename = output_dir.joinpath(f"output_{i + 1}.wav")
audio = wav[:wav_length]
print(f"Writing audio to {output_filename}")
sf.write(output_filename, audio, 22050, "PCM_24")
wav_secs = wav_lengths.sum() / 22050
print(f"Inference seconds: {infer_secs}")
print(f"Generated wav seconds: {wav_secs}")
rtf = infer_secs / wav_secs
if mel_infer_secs is not None:
mel_rtf = mel_infer_secs / wav_secs
print(f"Matcha RTF: {mel_rtf}")
if vocoder_infer_secs is not None:
vocoder_rtf = vocoder_infer_secs / wav_secs
print(f"Vocoder RTF: {vocoder_rtf}")
print(f"Overall RTF: {rtf}")
def write_mels(model, inputs, output_dir):
t0 = perf_counter()
mels, mel_lengths = model.run(None, inputs)
infer_secs = perf_counter() - t0
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for i, mel in enumerate(mels):
output_stem = output_dir.joinpath(f"output_{i + 1}")
plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png"))
np.save(output_stem.with_suffix(".numpy"), mel)
wav_secs = (mel_lengths * 256).sum() / 22050
print(f"Inference seconds: {infer_secs}")
print(f"Generated wav seconds: {wav_secs}")
rtf = infer_secs / wav_secs
print(f"RTF: {rtf}")
def main():
parser = argparse.ArgumentParser(
description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
)
parser.add_argument(
"model",
type=str,
help="ONNX model to use",
)
parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)")
parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
parser.add_argument(
"--temperature",
type=float,
default=0.667,
help="Variance of the x0 noise (default: 0.667)",
)
parser.add_argument(
"--speaking-rate",
type=float,
default=1.0,
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
)
parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
parser.add_argument(
"--output-dir",
type=str,
default=os.getcwd(),
help="Output folder to save results (default: current dir)",
)
args = parser.parse_args()
args = validate_args(args)
if args.gpu:
providers = ["GPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
model = ort.InferenceSession(args.model, providers=providers)
model_inputs = model.get_inputs()
model_outputs = list(model.get_outputs())
if args.text:
text_lines = args.text.splitlines()
else:
with open(args.file, encoding="utf-8") as file:
text_lines = file.read().splitlines()
processed_lines = [process_text(0, line, "cpu") for line in text_lines]
x = [line["x"].squeeze() for line in processed_lines]
# Pad
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
x = x.detach().cpu().numpy()
x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64)
inputs = {
"x": x,
"x_lengths": x_lengths,
"scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32),
}
is_multi_speaker = len(model_inputs) == 4
if is_multi_speaker:
if args.spk is None:
args.spk = 0
warn = "[!] Speaker ID not provided! Using speaker ID 0"
warnings.warn(warn, UserWarning)
inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64)
has_vocoder_embedded = model_outputs[0].name == "wav"
if has_vocoder_embedded:
write_wavs(model, inputs, args.output_dir)
elif args.vocoder:
external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
else:
warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
warnings.warn(warn, UserWarning)
write_mels(model, inputs, args.output_dir)
if __name__ == "__main__":
main()

View File

@@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names):
for symbol in clean_text:
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]
return sequence
return sequence, clean_text
def cleaned_text_to_sequence(cleaned_text):

View File

@@ -15,6 +15,7 @@ import logging
import re
import phonemizer
import piper_phonemize
from unidecode import unidecode
# To avoid excessive logging we set the log level of the phonemizer package to Critical
@@ -103,3 +104,13 @@ def english_cleaners2(text):
phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
phonemes = collapse_whitespace(phonemes)
return phonemes
def english_cleaners_piper(text):
"""Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
phonemes = collapse_whitespace(phonemes)
return phonemes

View File

@@ -0,0 +1,192 @@
r"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
import sys
from pathlib import Path
import lightning
import numpy as np
import rootutils
import torch
from hydra import compose, initialize
from omegaconf import open_dict
from torch import nn
from tqdm.auto import tqdm
from matcha.cli import get_device
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS
from matcha.utils.logging_utils import pylogger
from matcha.utils.utils import get_phoneme_durations
log = pylogger.get_pylogger(__name__)
def save_durations_to_folder(
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
):
durations = attn.squeeze().sum(1)[:x_length].numpy()
durations_json = get_phoneme_durations(durations, text)
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
json.dump(durations_json, f, indent=4, ensure_ascii=False)
np.save(output, durations)
@torch.inference_mode()
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
"""Generate durations from the model for each datapoint and save it in a folder
Args:
data_loader (torch.utils.data.DataLoader): Dataloader
model (nn.Module): MatchaTTS model
device (torch.device): GPU or CPU
"""
for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
x = x.to(device)
y = y.to(device)
x_lengths = x_lengths.to(device)
y_lengths = y_lengths.to(device)
spks = spks.to(device) if spks is not None else None
_, _, _, attn = model(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
)
attn = attn.cpu()
for i in range(attn.shape[0]):
save_durations_to_folder(
attn[i],
x_lengths[i].item(),
y_lengths[i].item(),
batch["filepaths"][i],
output_folder,
batch["x_texts"][i],
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-config",
type=str,
default="vctk.yaml",
help="The name of the yaml config file under configs/data",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default="32",
help="Can have increased batch size for faster computation",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
required=False,
help="force overwrite the file",
)
parser.add_argument(
"-c",
"--checkpoint_path",
type=str,
required=True,
help="Path to the checkpoint file to load the model from",
)
parser.add_argument(
"-o",
"--output-folder",
type=str,
default=None,
help="Output folder to save the data statistics",
)
parser.add_argument(
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
)
args = parser.parse_args()
with initialize(version_base="1.3", config_path="../../configs/data"):
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
with open_dict(cfg):
del cfg["hydra"]
del cfg["_target_"]
cfg["seed"] = 1234
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
if args.output_folder is not None:
output_folder = Path(args.output_folder)
else:
output_folder = Path("data") / "processed_data" / cfg["name"] / "durations"
if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite")
sys.exit(1)
output_folder.mkdir(parents=True, exist_ok=True)
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
print("Loading model...")
device = get_device(args)
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup()
try:
print("Computing stats for training set if exists...")
train_dataloader = text_mel_datamodule.train_dataloader()
compute_durations(train_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No training set found")
try:
print("Computing stats for validation set if exists...")
val_dataloader = text_mel_datamodule.val_dataloader()
compute_durations(val_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No validation set found")
try:
print("Computing stats for test set if exists...")
test_dataloader = text_mel_datamodule.test_dataloader()
compute_durations(test_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No test set found")
print(f"[+] Done! Data statistics saved to: {output_folder}")
if __name__ == "__main__":
# Helps with generating durations for the dataset to train other architectures
# that cannot learn to align due to limited size of dataset
# Example usage:
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
# This will create a folder in data/processed_data/durations/ljspeech with the durations
main()

View File

@@ -7,15 +7,17 @@ import torch
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
while True:
if length % (2**num_downsamplings_in_unet) == 0:
factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
length = (length / factor).ceil() * factor
if not torch.onnx.is_in_onnx_export():
return length.int().item()
else:
return length
length += 1
def convert_pad_shape(pad_shape):

View File

@@ -2,6 +2,7 @@ import os
import sys
import warnings
from importlib.util import find_spec
from math import ceil
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
@@ -115,7 +116,7 @@ def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
return None
if metric_name not in metric_dict:
raise Exception(
raise ValueError(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
@@ -205,13 +206,54 @@ def get_user_data_dir(appname="matcha_tts"):
return final_path
def assert_model_downloaded(checkpoint_path, url, use_wget=False):
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
if Path(checkpoint_path).exists():
log.debug(f"[+] Model already present at {checkpoint_path}!")
print(f"[+] Model already present at {checkpoint_path}!")
return
log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
print(f"[-] Model not found at {checkpoint_path}! Will download it")
checkpoint_path = str(checkpoint_path)
if not use_wget:
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else:
wget.download(url=url, out=checkpoint_path)
def get_phoneme_durations(durations, phones):
prev = durations[0]
merged_durations = []
# Convolve with stride 2
for i in range(1, len(durations), 2):
if i == len(durations) - 2:
# if it is last take full value
next_half = durations[i + 1]
else:
next_half = ceil(durations[i + 1] / 2)
curr = prev + durations[i] + next_half
prev = durations[i + 1] - next_half
merged_durations.append(curr)
assert len(phones) == len(merged_durations)
assert len(merged_durations) == (len(durations) - 1) // 2
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
start = torch.tensor(0)
duration_json = []
for i, duration in enumerate(merged_durations):
duration_json.append(
{
phones[i]: {
"starttime": start.item(),
"endtime": duration.item(),
"duration": duration.item() - start.item(),
}
}
)
start = duration
assert list(duration_json[-1].values())[0]["endtime"] == sum(
durations
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
return duration_json

View File

@@ -35,10 +35,11 @@ torchaudio
matplotlib
pandas
conformer==0.3.2
diffusers==0.21.1
diffusers==0.25.0
notebook
ipywidgets
gradio
gdown
wget
seaborn
piper_phonemize

15
scripts/get_durations.sh Normal file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
echo "Starting script"
echo "Getting LJ Speech durations"
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c logs/train/lj_det/runs/2024-01-12_12-05-00/checkpoints/last.ckpt -f
echo "Getting TSG2 durations"
python matcha/utils/get_durations_from_trained_model.py -i tsg2.yaml -c logs/train/tsg2_det_dur/runs/2024-01-05_12-33-35/checkpoints/last.ckpt -f
echo "Getting Joe Spont durations"
python matcha/utils/get_durations_from_trained_model.py -i joe_spont_only.yaml -c logs/train/joe_det_dur/runs/2024-02-20_14-01-01/checkpoints/last.ckpt -f
echo "Getting Ryan durations"
python matcha/utils/get_durations_from_trained_model.py -i ryan.yaml -c logs/train/matcha_ryan_det/runs/2024-02-26_09-28-09/checkpoints/last.ckpt -f

7
scripts/transcribe.sh Normal file
View File

@@ -0,0 +1,7 @@
echo "Transcribing"
whispertranscriber -i lj_det_output -o lj_det_output_transcriptions -f
whispertranscriber -i lj_fm_output -o lj_fm_output_transcriptions -f
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_det_output_transcriptions
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_fm_output_transcriptions

30
scripts/wer_computer.sh Normal file
View File

@@ -0,0 +1,30 @@
#!/bin/bash
# Run from root folder with: bash scripts/wer_computer.sh
root_folder=${1:-"dur_wer_computation"}
echo "Running WER computation for Duration predictors"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_fm_output_transcriptions/"
# echo $cmd
echo "LJ"
echo "==================================="
echo "Flow Matching"
$cmd
echo "-----------------------------------"
echo "LJ Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_det_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac"
echo "==================================="
echo "Cormac Flow Matching"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/fm_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/det_output_transcriptions/"
$cmd
echo "-----------------------------------"

File diff suppressed because one or more lines are too long