Initial commit

This commit is contained in:
Shivam Mehta
2023-09-16 17:51:36 +00:00
parent b189c1983a
commit f016784049
100 changed files with 6416 additions and 0 deletions

6
.env.example Normal file
View File

@@ -0,0 +1,6 @@
# example of file for storing private and user specific environment variables, like keys or system paths
# rename it to ".env" (excluded from version control by default)
# .env is loaded by train.py automatically
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
MY_VAR="/home/user/my/system/path"

22
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@@ -0,0 +1,22 @@
## What does this PR do?
<!--
Please include a summary of the change and which issue is fixed.
Please also include relevant motivation and context.
List any dependencies that are required for this change.
List all the breaking changes introduced by this pull request.
-->
Fixes #\<issue_number>
## Before submitting
- [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**?
- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together?
- [ ] Did you list all the **breaking changes** introduced by this pull request?
- [ ] Did you **test your PR locally** with `pytest` command?
- [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command?
## Did you have fun?
Make sure you had fun coding 🙃

15
.github/codecov.yml vendored Normal file
View File

@@ -0,0 +1,15 @@
coverage:
status:
# measures overall project coverage
project:
default:
threshold: 100% # how much decrease in coverage is needed to not consider success
# measures PR or single commit coverage
patch:
default:
threshold: 100% # how much decrease in coverage is needed to not consider success
# project: off
# patch: off

16
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,16 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "daily"
ignore:
- dependency-name: "pytorch-lightning"
update-types: ["version-update:semver-patch"]
- dependency-name: "torchmetrics"
update-types: ["version-update:semver-patch"]

44
.github/release-drafter.yml vendored Normal file
View File

@@ -0,0 +1,44 @@
name-template: "v$RESOLVED_VERSION"
tag-template: "v$RESOLVED_VERSION"
categories:
- title: "🚀 Features"
labels:
- "feature"
- "enhancement"
- title: "🐛 Bug Fixes"
labels:
- "fix"
- "bugfix"
- "bug"
- title: "🧹 Maintenance"
labels:
- "maintenance"
- "dependencies"
- "refactoring"
- "cosmetic"
- "chore"
- title: "📝️ Documentation"
labels:
- "documentation"
- "docs"
change-template: "- $TITLE @$AUTHOR (#$NUMBER)"
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions
version-resolver:
major:
labels:
- "major"
minor:
labels:
- "minor"
patch:
labels:
- "patch"
default: patch
template: |
## Changes
$CHANGES

163
.gitignore vendored Normal file
View File

@@ -0,0 +1,163 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
### VisualStudioCode
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace
**/.vscode
# JetBrains
.idea/
# Data & Models
*.h5
*.tar
*.tar.gz
# Lightning-Hydra-Template
configs/local/default.yaml
/data/
/logs/
.env
# Aim logging
.aim
# Cython complied files
matcha/utils/monotonic_align/core.c
# Ignoring hifigan checkpoint
generator_v1
g_02500000
gradio_cached_examples/
synth_output/

59
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,59 @@
default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
- id: end-of-file-fixer
# - id: check-docstring-first
- id: check-yaml
- id: debug-statements
- id: detect-private-key
- id: check-toml
- id: check-case-conflict
- id: check-added-large-files
# python code formatting
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
args: [--line-length, "120"]
# python import sorting
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py38-plus]
# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args:
[
"--max-line-length", "120",
"--extend-ignore",
"E203,E402,E501,F401,F841,RST2,RST301",
"--exclude",
"logs/*,data/*,matcha/hifigan/*",
]
additional_dependencies: [flake8-rst-docstrings==0.3.0]
# pylint
- repo: https://github.com/pycqa/pylint
rev: v2.8.2
hooks:
- id: pylint

2
.project-root Normal file
View File

@@ -0,0 +1,2 @@
# this file is required for inferring the project root directory
# do not delete

603
.pylintrc Normal file
View File

@@ -0,0 +1,603 @@
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=missing-docstring,
too-many-public-methods,
too-many-lines,
bare-except,
## for avoiding weird p3.6 CI linter error
## TODO: see later if we can remove this
assigning-non-slot,
unsupported-assignment-operation,
## end
line-too-long,
fixme,
wrong-import-order,
ungrouped-imports,
wrong-import-position,
import-error,
invalid-name,
too-many-instance-attributes,
arguments-differ,
arguments-renamed,
no-name-in-module,
no-member,
unsubscriptable-object,
print-statement,
parameter-unpacking,
unpacking-in-except,
old-raise-syntax,
backtick,
long-suffix,
old-ne-operator,
old-octal-literal,
import-star-module-level,
non-ascii-bytes-literal,
raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
useless-object-inheritance,
too-few-public-methods,
too-many-branches,
too-many-arguments,
too-many-locals,
too-many-statements,
apply-builtin,
basestring-builtin,
buffer-builtin,
cmp-builtin,
coerce-builtin,
execfile-builtin,
file-builtin,
long-builtin,
raw_input-builtin,
reduce-builtin,
standarderror-builtin,
unicode-builtin,
xrange-builtin,
coerce-method,
delslice-method,
getslice-method,
setslice-method,
no-absolute-import,
old-division,
dict-iter-method,
dict-view-method,
next-method-called,
metaclass-assignment,
indexing-exception,
raising-string,
reload-builtin,
oct-method,
hex-method,
nonzero-method,
cmp-method,
input-builtin,
round-builtin,
intern-builtin,
unichr-builtin,
map-builtin-not-iterating,
zip-builtin-not-iterating,
range-builtin-not-iterating,
filter-builtin-not-iterating,
using-cmp-argument,
eq-without-hash,
div-method,
idiv-method,
rdiv-method,
exception-message-attribute,
invalid-str-codec,
sys-max-int,
bad-python3-import,
deprecated-string-function,
deprecated-str-translate-call,
deprecated-itertools-function,
deprecated-types-field,
next-method-defined,
dict-items-not-iterating,
dict-keys-not-iterating,
dict-values-not-iterating,
deprecated-operator-function,
deprecated-urllib-function,
xreadlines-attribute,
deprecated-sys-function,
exception-escape,
comprehension-escape,
duplicate-code,
not-callable,
import-outside-toplevel,
logging-fstring-interpolation,
logging-not-lazy,
unused-argument,
no-else-return,
chained-comparison,
redefined-outer-name
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit
[LOGGING]
# Format style used to check logging format string. `old` means using %
# formatting, while `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package..
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=numpy.*,torch.*
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=120
# Maximum number of lines in a module.
max-module-lines=1000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style.
argument-rgx=[a-z_][a-z0-9_]{0,30}$
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
x,
ex,
Run,
_
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style.
variable-rgx=[a-z_][a-z0-9_]{0,30}$
[STRING]
# This flag controls whether the implicit-str-concat-in-sequence should
# generate a warning on implicit string concatenation in sequences defined over
# several lines.
check-str-concat-over-line-jumps=no
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
[DESIGN]
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=15
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
Exception

13
MANIFEST.in Normal file
View File

@@ -0,0 +1,13 @@
include README.md
include LICENSE.txt
include requirements.*.txt
include *.cff
include requirements.txt
recursive-include matcha *.json
recursive-include matcha *.html
recursive-include matcha *.png
recursive-include matcha *.md
recursive-include matcha *.py
recursive-include matcha *.pyx
recursive-exclude tests *
prune tests*

33
Makefile Normal file
View File

@@ -0,0 +1,33 @@
help: ## Show help
@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
clean: ## Clean autogenerated files
rm -rf dist
find . -type f -name "*.DS_Store" -ls -delete
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
find . | grep -E ".pytest_cache" | xargs rm -rf
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
rm -f .coverage
clean-logs: ## Clean logs
rm -rf logs/**
format: ## Run pre-commit hooks
pre-commit run -a
sync: ## Merge changes from main branch to your current branch
git pull
git pull origin main
test: ## Run not slow tests
pytest -k "not slow"
test-full: ## Run all tests
pytest
train-ljspeech: ## Train the model
python matcha/train.py experiment=ljspeech
train-ljspeech-min: ## Train the model with minimum memory
python matcha/train.py experiment=ljspeech_min_memory

171
README.md Normal file
View File

@@ -0,0 +1,171 @@
<div align="center">
# Matcha-TTS: A fast TTS architecture with conditional flow matching
### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
[![python](https://img.shields.io/badge/-Python_3.10-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit)
[![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
[![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/)
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
</div>
<p style="text-align: center;">
<img src="https://shivammehta25.github.io/Matcha-TTS/images/logo.png" height="128"/>
</p>
> This is the official code implementation of 🍵 Matcha-TTS.
We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method:
- Is probabilistic
- Has compact memory footprint
- Sounds highly natural
- Is very fast to synthesise from
Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199).
<br>
## Installation
1. Create an environment (suggested but optional)
```
conda create -n matcha_tts python=3.10 -y
conda activate matcha_tts
```
2. Install Matcha TTS using pip from source
(in future we plan to add it to PyPI)
```bash
pip install git+https://github.com/shivammehta25/Matcha-TTS.git
```
3. Run CLI / gradio app / jupyter notebook
```bash
# This will download the required models and list available arguments
matcha_tts --help
```
or
```bash
matcha_tts_app
```
or open `synthesis.ipynb` on jupyter notebook
### CLI Arguments
- To synthesise from given text, run:
```bash
match_tts --text "<INPUT TEXT>"
```
- To synthesise from a file, run:
```bash
match_tts --file <PATH TO FILE>
```
- To batch synthesise from a file, run:
```bash
match_tts --file <PATH TO FILE> --batched
```
Additional arguments
- Speaking rate
```bash
match_tts --text "<INPUT TEXT>" --speaking_rate 1.0
```
- Sampling temperature
```bash
match_tts --text "<INPUT TEXT>" --temperature 0.667
```
- Euler ODE solver steps
```bash
match_tts --text "<INPUT TEXT>" --steps 10
```
## Citation information
If you find this work useful, please cite our paper:
```text
@article{mehta2023matcha,
title={Matcha-TTS: A fast TTS architecture with conditional flow matching},
author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje},
journal={arXiv preprint arXiv:2309.03199},
year={2023}
}
```
## Train with your own dataset
Let's assume we are training with LJSpeech
1. Download the dataset from [here](https://keithito.com/LJ-Speech-Dataset/), extract it to `data/LJSpeech-1.1`, and prepare the filelists to point to the extracted data like the [5th point of setup in Tacotron2 repo](https://github.com/NVIDIA/tacotron2#setup).
2. Clone and enter this repository
```bash
git clone https://github.com/shivammehta25/Matcha-TTS.git
cd Matcha-TTS
```
3. Install the package from source
```bash
pip install -e .
```
4. Go to `configs/data/ljspeech.yaml` and change
```yaml
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt
```
to the paths of your train and validation filelists.
5. Run the training script
```bash
make train-ljspeech
```
or
```bash
python matcha/train.py experiment=ljspeech
```
for multi-gpu training, run
```bash
python matcha/train.py experiment=ljspeech trainer.devices=[0,1]
```
## Acknowledgements
Since this code uses: [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers attached to it.
Other source codes I would like to acknowledge:
- [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev)
- [Grad-TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS)
- [torchdyn](https://github.com/DiffEqML/torchdyn)

1
configs/__init__.py Normal file
View File

@@ -0,0 +1 @@
# this file is needed here to include configs when building project as a package

View File

@@ -0,0 +1,5 @@
defaults:
- model_checkpoint.yaml
- model_summary.yaml
- rich_progress_bar.yaml
- _self_

View File

@@ -0,0 +1,17 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${paths.output_dir}/checkpoints # directory to save the model file
filename: checkpoint_{epoch:03d} # checkpoint filename
monitor: epoch # name of the logged metric which determines when model is improving
verbose: False # verbosity mode
save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 10 # save k best models (determined by above metric)
mode: "max" # "max" means higher metric value is better, can be also "min"
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
save_weights_only: False # if True, then only the models weights will be saved
every_n_train_steps: null # number of training steps between checkpoints
train_time_interval: null # checkpoints are monitored at the specified time interval
every_n_epochs: 100 # number of epochs between checkpoints
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation

View File

@@ -0,0 +1,5 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
model_summary:
_target_: lightning.pytorch.callbacks.RichModelSummary
max_depth: 3 # the maximum depth of layer nesting that the summary will include

View File

View File

@@ -0,0 +1,4 @@
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
rich_progress_bar:
_target_: lightning.pytorch.callbacks.RichProgressBar

View File

@@ -0,0 +1,21 @@
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: ljspeech
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt
batch_size: 32
num_workers: 20
pin_memory: True
cleaners: [english_cleaners2]
add_blank: True
n_spks: 1
n_fft: 1024
n_feats: 80
sample_rate: 22050
hop_length: 256
win_length: 1024
f_min: 0
f_max: 8000
data_statistics: # Computed for ljspeech dataset
mel_mean: -5.536622
mel_std: 2.116101
seed: ${seed}

14
configs/data/vctk.yaml Normal file
View File

@@ -0,0 +1,14 @@
defaults:
- ljspeech
- _self_
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: vctk
train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt
valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt
batch_size: 32
add_blank: True
n_spks: 109
data_statistics: # Computed for vctk dataset
mel_mean: -6.630575
mel_std: 2.482914

View File

@@ -0,0 +1,35 @@
# @package _global_
# default debugging setup, runs 1 full epoch
# other debugging configs can inherit from this one
# overwrite task name so debugging logs are stored in separate folder
task_name: "debug"
# disable callbacks and loggers during debugging
callbacks: null
logger: null
extras:
ignore_warnings: False
enforce_tags: False
# sets level of all command line loggers to 'DEBUG'
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
hydra:
job_logging:
root:
level: DEBUG
# use this to also set hydra loggers to 'DEBUG'
# verbose: True
trainer:
max_epochs: 1
accelerator: cpu # debuggers don't like gpus
devices: 1 # debuggers don't like multiprocessing
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
data:
num_workers: 0 # debuggers don't like multiprocessing
pin_memory: False # disable gpu memory pin

9
configs/debug/fdr.yaml Normal file
View File

@@ -0,0 +1,9 @@
# @package _global_
# runs 1 train, 1 validation and 1 test step
defaults:
- default
trainer:
fast_dev_run: true

12
configs/debug/limit.yaml Normal file
View File

@@ -0,0 +1,12 @@
# @package _global_
# uses only 1% of the training data and 5% of validation/test data
defaults:
- default
trainer:
max_epochs: 3
limit_train_batches: 0.01
limit_val_batches: 0.05
limit_test_batches: 0.05

View File

@@ -0,0 +1,13 @@
# @package _global_
# overfits to 3 batches
defaults:
- default
trainer:
max_epochs: 20
overfit_batches: 3
# model ckpt and early stopping need to be disabled during overfitting
callbacks: null

View File

@@ -0,0 +1,12 @@
# @package _global_
# runs with execution time profiling
defaults:
- default
trainer:
max_epochs: 1
profiler: "simple"
# profiler: "advanced"
# profiler: "pytorch"

18
configs/eval.yaml Normal file
View File

@@ -0,0 +1,18 @@
# @package _global_
defaults:
- _self_
- data: mnist # choose datamodule with `test_dataloader()` for evaluation
- model: mnist
- logger: null
- trainer: default
- paths: default
- extras: default
- hydra: default
task_name: "eval"
tags: ["dev"]
# passing checkpoint path is necessary for evaluation
ckpt_path: ???

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ljspeech.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech

View File

@@ -0,0 +1,18 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ljspeech.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech_min
model:
out_size: 172

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: vctk.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["multispeaker"]
run_name: multispeaker

View File

@@ -0,0 +1,8 @@
# disable python warnings if they annoy you
ignore_warnings: False
# ask user for tags if none are provided in the config
enforce_tags: True
# pretty print config tree at the start of the run using Rich library
print_config: True

View File

@@ -0,0 +1,52 @@
# @package _global_
# example hyperparameter optimization of some experiment with Optuna:
# python train.py -m hparams_search=mnist_optuna experiment=example
defaults:
- override /hydra/sweeper: optuna
# choose metric which will be optimized by Optuna
# make sure this is the correct name of some metric logged in lightning module!
optimized_metric: "val/acc_best"
# here we define Optuna hyperparameter search
# it optimizes for value returned from function with @hydra.main decorator
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
hydra:
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
sweeper:
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
# storage URL to persist optimization results
# for example, you can use SQLite if you set 'sqlite:///example.db'
storage: null
# name of the study to persist optimization results
study_name: null
# number of parallel workers
n_jobs: 1
# 'minimize' or 'maximize' the objective
direction: maximize
# total number of runs that will be executed
n_trials: 20
# choose Optuna hyperparameter sampler
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
sampler:
_target_: optuna.samplers.TPESampler
seed: 1234
n_startup_trials: 10 # number of random sampling runs before optimization starts
# define hyperparameter search space
params:
model.optimizer.lr: interval(0.0001, 0.1)
data.batch_size: choice(32, 64, 128, 256)
model.net.lin1_size: choice(64, 128, 256)
model.net.lin2_size: choice(64, 128, 256)
model.net.lin3_size: choice(32, 64, 128, 256)

View File

@@ -0,0 +1,19 @@
# https://hydra.cc/docs/configure_hydra/intro/
# enable color logging
defaults:
- override hydra_logging: colorlog
- override job_logging: colorlog
# output directory, generated dynamically on each run
run:
dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
sweep:
dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
subdir: ${hydra.job.num}
job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log

0
configs/local/.gitkeep Normal file
View File

28
configs/logger/aim.yaml Normal file
View File

@@ -0,0 +1,28 @@
# https://aimstack.io/
# example usage in lightning module:
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
# `aim up`
aim:
_target_: aim.pytorch_lightning.AimLogger
repo: ${paths.root_dir} # .aim folder will be created here
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
# aim allows to group runs under experiment name
experiment: null # any string, set to "default" if not specified
train_metric_prefix: "train/"
val_metric_prefix: "val/"
test_metric_prefix: "test/"
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
system_tracking_interval: 10 # set to null to disable system metrics tracking
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
log_system_params: true
# enable/disable tracking console logs (default value is true)
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550

12
configs/logger/comet.yaml Normal file
View File

@@ -0,0 +1,12 @@
# https://www.comet.ml
comet:
_target_: lightning.pytorch.loggers.comet.CometLogger
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
save_dir: "${paths.output_dir}"
project_name: "lightning-hydra-template"
rest_api_key: null
# experiment_name: ""
experiment_key: null # set to resume experiment
offline: False
prefix: ""

7
configs/logger/csv.yaml Normal file
View File

@@ -0,0 +1,7 @@
# csv logger built in lightning
csv:
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
save_dir: "${paths.output_dir}"
name: "csv/"
prefix: ""

View File

@@ -0,0 +1,9 @@
# train with many loggers at once
defaults:
# - comet
- csv
# - mlflow
# - neptune
- tensorboard
- wandb

View File

@@ -0,0 +1,12 @@
# https://mlflow.org
mlflow:
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
# experiment_name: ""
# run_name: ""
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
tags: null
# save_dir: "./mlruns"
prefix: ""
artifact_location: null
# run_id: ""

View File

@@ -0,0 +1,9 @@
# https://neptune.ai
neptune:
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
project: username/lightning-hydra-template
# name: ""
log_model_checkpoints: True
prefix: ""

View File

@@ -0,0 +1,10 @@
# https://www.tensorflow.org/tensorboard/
tensorboard:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${paths.output_dir}/tensorboard/"
name: null
log_graph: False
default_hp_metric: True
prefix: ""
# version: ""

16
configs/logger/wandb.yaml Normal file
View File

@@ -0,0 +1,16 @@
# https://wandb.ai
wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: False
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "lightning-hydra-template"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
# entity: "" # set to name of your wandb team
group: ""
tags: []
job_type: ""

View File

@@ -0,0 +1,3 @@
name: CFM
solver: euler
sigma_min: 1e-4

View File

@@ -0,0 +1,7 @@
channels: [256, 256]
dropout: 0.05
attention_head_dim: 64
n_blocks: 1
num_mid_blocks: 2
num_heads: 2
act_fn: snakebeta

View File

@@ -0,0 +1,18 @@
encoder_type: RoPE Encoder
encoder_params:
n_feats: ${model.n_feats}
n_channels: 192
filter_channels: 768
filter_channels_dp: 256
n_heads: 2
n_layers: 6
kernel_size: 3
p_dropout: 0.1
spk_emb_dim: 64
n_spks: 1
prenet: true
duration_predictor_params:
filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp}
kernel_size: 3
p_dropout: ${model.encoder.encoder_params.p_dropout}

14
configs/model/matcha.yaml Normal file
View File

@@ -0,0 +1,14 @@
defaults:
- _self_
- encoder: default.yaml
- decoder: default.yaml
- cfm: default.yaml
- optimizer: adam.yaml
_target_: matcha.models.matcha_tts.MatchaTTS
n_vocab: 178
n_spks: ${data.n_spks}
spk_emb_dim: 64
n_feats: 80
data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4

View File

@@ -0,0 +1,4 @@
_target_: torch.optim.Adam
_partial_: true
lr: 1e-4
weight_decay: 0.0

View File

@@ -0,0 +1,18 @@
# path to root directory
# this requires PROJECT_ROOT environment variable to exist
# you can replace it with "." if you want the root to be the current working directory
root_dir: ${oc.env:PROJECT_ROOT}
# path to data directory
data_dir: ${paths.root_dir}/data/
# path to logging directory
log_dir: ${paths.root_dir}/logs/
# path to output directory, created dynamically by hydra
# path generation pattern is specified in `configs/hydra/default.yaml`
# use it to store all files generated during the run, like ckpts and metrics
output_dir: ${hydra:runtime.output_dir}
# path to working directory
work_dir: ${hydra:runtime.cwd}

51
configs/train.yaml Normal file
View File

@@ -0,0 +1,51 @@
# @package _global_
# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: ljspeech
- model: matcha
- callbacks: default
- logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
- extras: default
- hydra: default
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null
# config for hyperparameter optimization
- hparams_search: null
# optional local config for machine/user specific settings
# it's optional since it doesn't need to exist and is excluded from version control
- optional local: default
# debugging config (enable through command line, e.g. `python train.py debug=default)
- debug: null
# task name, determines output directory path
task_name: "train"
run_name: ???
# tags to help you identify your experiments
# you can overwrite this in experiment configs
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
tags: ["dev"]
# set False to skip model training
train: True
# evaluate on test set, using best model weights achieved during training
# lightning chooses best weights based on the metric specified in checkpoint callback
test: True
# simply provide checkpoint path to resume training
ckpt_path: null
# seed for random number generators in pytorch, numpy and python.random
seed: 1234

5
configs/trainer/cpu.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- default
accelerator: cpu
devices: 1

9
configs/trainer/ddp.yaml Normal file
View File

@@ -0,0 +1,9 @@
defaults:
- default
strategy: ddp
accelerator: gpu
devices: [0,1]
num_nodes: 1
sync_batchnorm: True

View File

@@ -0,0 +1,7 @@
defaults:
- default
# simulate DDP on CPU, useful for debugging
accelerator: cpu
devices: 2
strategy: ddp_spawn

View File

@@ -0,0 +1,20 @@
_target_: lightning.pytorch.trainer.Trainer
default_root_dir: ${paths.output_dir}
max_epochs: -1
accelerator: gpu
devices: [0]
# mixed precision for extra speed-up
precision: 16-mixed
# perform a validation loop every N training epochs
check_val_every_n_epoch: 1
# set True to to ensure deterministic results
# makes training slower but gives more reproducibility than just setting seeds
deterministic: False
gradient_clip_val: 5.0

5
configs/trainer/gpu.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- default
accelerator: gpu
devices: 1

5
configs/trainer/mps.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- default
accelerator: mps
devices: 1

1
data Symbolic link
View File

@@ -0,0 +1 @@
/home/smehta/Projects/Speech-Backbones/Grad-TTS/data

0
matcha/__init__.py Normal file
View File

207
matcha/app.py Normal file
View File

@@ -0,0 +1,207 @@
import tempfile
from argparse import Namespace
from pathlib import Path
import gradio as gr
import soundfile as sf
import torch
from matcha.cli import (MATCHA_URLS, VOCODER_URL, assert_model_downloaded,
get_device, load_matcha, load_vocoder, process_text,
to_waveform)
from matcha.utils.utils import get_user_data_dir, plot_tensor
LOCATION = Path(get_user_data_dir())
args = Namespace(
cpu=False,
model="matcha_ljspeech",
vocoder="hifigan_T2_v1",
spk=None,
)
MATCHA_TTS_LOC = LOCATION / f"{args.model}.ckpt"
VOCODER_LOC = LOCATION / f"{args.vocoder}"
LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
assert_model_downloaded(MATCHA_TTS_LOC, MATCHA_URLS[args.model], use_wget=True)
assert_model_downloaded(VOCODER_LOC, VOCODER_URL[args.vocoder])
device = get_device(args)
model = load_matcha(args.model, MATCHA_TTS_LOC, device)
vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC, device)
@torch.inference_mode()
def process_text_gradio(text):
output = process_text(1, text, device)
return output["x_phones"][1::2], output["x"], output["x_lengths"]
@torch.inference_mode()
def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale):
output = model.synthesise(
text,
text_length,
n_timesteps=n_timesteps,
temperature=temperature,
spks=args.spk,
length_scale=length_scale,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
sf.write(fp.name, output["waveform"], 22050, "PCM_24")
return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
def run_full_synthesis(text, n_timesteps, mel_temp, length_scale):
phones, text, text_lengths = process_text_gradio(text)
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale)
return phones, audio, mel_spectrogram
def main():
description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method:
* Is probabilistic
* Has compact memory footprint
* Sounds highly natural
* Is very fast to synthesise from
Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199).
Code is available in our [GitHub repository](https://github.com/shivammehta25/Matcha-TTS), along with pre-trained models.
Cached examples are available at the bottom of the page.
"""
with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
processed_text = gr.State(value=None)
processed_text_len = gr.State(value=None)
mel_variable = gr.State(value=None)
with gr.Box():
with gr.Row():
gr.Markdown(description, scale=3)
gr.Image(LOGO_URL, label="Matcha-TTS logo", height=150, width=150, scale=1, show_label=False)
with gr.Box():
with gr.Row():
gr.Markdown("# Text Input")
with gr.Row():
text = gr.Textbox(value="", lines=2, label="Text to synthesise")
with gr.Row():
gr.Markdown("### Hyper parameters")
with gr.Row():
n_timesteps = gr.Slider(
label="Number of ODE steps",
minimum=0,
maximum=100,
step=1,
value=10,
interactive=True,
)
length_scale = gr.Slider(
label="Length scale (Speaking rate)",
minimum=0.01,
maximum=3.0,
step=0.05,
value=1.0,
interactive=True,
)
mel_temp = gr.Slider(
label="Sampling temperature",
minimum=0.00,
maximum=2.001,
step=0.16675,
value=0.667,
interactive=True,
)
synth_btn = gr.Button("Synthesise")
with gr.Box():
with gr.Row():
gr.Markdown("### Phonetised text")
phonetised_text = gr.Textbox(interactive=False, scale=10, label="Phonetised text")
with gr.Box():
with gr.Row():
mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram")
# with gr.Row():
audio = gr.Audio(interactive=False, label="Audio")
with gr.Row():
examples = gr.Examples(
examples=[
[
"We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
50,
0.677,
1.0,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
2,
0.677,
1.0,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
4,
0.677,
1.0,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
10,
0.677,
1.0,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
50,
0.677,
1.0,
],
[
"The narrative of these events is based largely on the recollections of the participants.",
10,
0.677,
1.0,
],
[
"The jury did not believe him, and the verdict was for the defendants.",
10,
0.677,
1.0,
],
],
fn=run_full_synthesis,
inputs=[text, n_timesteps, mel_temp, length_scale],
outputs=[phonetised_text, audio, mel_spectrogram],
cache_examples=True,
)
synth_btn.click(
fn=process_text_gradio,
inputs=[
text,
],
outputs=[phonetised_text, processed_text, processed_text_len],
api_name="matcha_tts",
queue=True,
).then(
fn=synthesise_mel,
inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale],
outputs=[audio, mel_spectrogram],
)
demo.queue(concurrency_count=5).launch(share=True)
if __name__ == "__main__":
main()

339
matcha/cli.py Normal file
View File

@@ -0,0 +1,339 @@
import argparse
import datetime as dt
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import torch
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.env import AttrDict
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.models.matcha_tts import MatchaTTS
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import (assert_model_downloaded, get_user_data_dir,
intersperse)
MATCHA_URLS = {"matcha_ljspeech": ""} # , "matcha_vctk": ""} # Coming soon
MULTISPEAKER_MODEL = {"matcha_vctk"}
SINGLESPEAKER_MODEL = {"matcha_ljspeech"}
VOCODER_URL = {
"hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link",
}
def plot_spectrogram_to_numpy(spectrogram, filename):
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.title("Synthesised Mel-Spectrogram")
fig.canvas.draw()
plt.savefig(filename)
def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}")
x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"]), 0),
dtype=torch.long,
device=device,
)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
x_phones = sequence_to_text(x.squeeze(0).tolist())
print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
def get_texts(args):
if args.text:
texts = [args.text]
else:
with open(args.file) as f:
texts = f.readlines()
return texts
def assert_required_models_available(args):
save_dir = get_user_data_dir()
model_path = save_dir / f"{args.model}.ckpt"
vocoder_path = save_dir / f"{args.vocoder}"
assert_model_downloaded(model_path, MATCHA_URLS[args.model], use_wget=True)
assert_model_downloaded(vocoder_path, VOCODER_URL[args.vocoder])
return {"matcha": model_path, "vocoder": vocoder_path}
def load_hifigan(checkpoint_path, device):
h = AttrDict(v1)
hifigan = HiFiGAN(h).to(device)
hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def load_vocoder(vocoder_name, checkpoint_path, device):
print(f"[!] Loading {vocoder_name}!")
vocoder = None
if vocoder_name == "hifigan_T2_v1":
vocoder = load_hifigan(checkpoint_path, device)
else:
raise NotImplementedError(
f"Vocoder {vocoder_name} not implemented! define a load_<<vocoder_name>> method for it"
)
denoiser = Denoiser(vocoder, mode="zeros")
print(f"[+] {vocoder_name} loaded!")
return vocoder, denoiser
def load_matcha(model_name, checkpoint_path, device):
print(f"[!] Loading {model_name}!")
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
_ = model.eval()
print(f"[+] {model_name} loaded!")
return model
def to_waveform(mel, vocoder, denoiser=None):
audio = vocoder(mel).clamp(-1, 1)
if denoiser is not None:
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
def save_to_folder(filename: str, output: dict, folder: str):
folder = Path(folder)
folder.mkdir(exist_ok=True, parents=True)
plot_spectrogram_to_numpy(np.array(output["mel"].squeeze().float().cpu()), f"{filename}.png")
np.save(folder / f"{filename}", output["mel"].cpu().numpy())
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
return folder.resolve() / f"{filename}.wav"
def validate_args(args):
assert (
args.text or args.file
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
assert args.temperature >= 0, "Sampling temperature cannot be negative"
assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
assert args.steps > 0, "Number of ODE steps must be greater than 0"
if args.model in SINGLESPEAKER_MODEL:
assert args.spk is None, f"Speaker ID is not supported for {args.model}"
if args.spk is not None:
assert args.spk >= 0 and args.spk < 109, "Speaker ID must be between 0 and 108"
assert args.model in MULTISPEAKER_MODEL, "Speaker ID is only supported for multispeaker model"
if args.model in MULTISPEAKER_MODEL:
if args.spk is None:
print("[!] Speaker ID not provided! Using speaker ID 0")
args.spk = 0
if args.batched:
assert args.batch_size > 0, "Batch size must be greater than 0"
return args
@torch.inference_mode()
def cli():
parser = argparse.ArgumentParser(description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching")
parser.add_argument(
"--model",
type=str,
default="matcha_ljspeech",
help="Model to use",
choices=MATCHA_URLS.keys(),
)
parser.add_argument(
"--vocoder",
type=str,
default="hifigan_T2_v1",
help="Vocoder to use",
choices=VOCODER_URL.keys(),
)
parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
parser.add_argument(
"--temperature",
type=float,
default=0.667,
help="Variance of the x0 noise (default: 0.667)",
)
parser.add_argument(
"--speaking_rate",
type=float,
default=1.0,
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
)
parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)")
parser.add_argument("--cpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
parser.add_argument(
"--denoiser_strength",
type=float,
default=0.00025,
help="Strength of the vocoder bias denoiser (default: 0.00025)",
)
parser.add_argument(
"--output_folder",
type=str,
default=os.getcwd(),
help="Output folder to save results (default: current dir)",
)
parser.add_argument("--batched", action="store_true")
parser.add_argument("--batch_size", type=int, default=32)
args = parser.parse_args()
args = validate_args(args)
device = get_device(args)
print_config(args)
paths = assert_required_models_available(args)
model = load_matcha(args.model, paths["matcha"], device)
vocoder, denoiser = load_vocoder(args.vocoder, paths["vocoder"], device)
texts = get_texts(args)
spk = torch.tensor([args.spk], device=device, dtype=torch.long) if args.spk is not None else None
if len(texts) == 1 or not args.batched:
unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
else:
batched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
class BatchedSynthesisDataset(torch.utils.data.Dataset):
def __init__(self, processed_texts):
self.processed_texts = processed_texts
def __len__(self):
return len(self.processed_texts)
def __getitem__(self, idx):
return self.processed_texts[idx]
def batched_collate_fn(batch):
x = []
x_lengths = []
for b in batch:
x.append(b["x"].squeeze(0))
x_lengths.append(b["x_lengths"])
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
x_lengths = torch.concat(x_lengths, dim=0)
return {"x": x, "x_lengths": x_lengths}
def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
total_rtf = []
total_rtf_w = []
processed_text = [process_text(i, text, "cpu") for i, text in enumerate(texts)]
dataloader = torch.utils.data.DataLoader(
BatchedSynthesisDataset(processed_text),
batch_size=args.batch_size,
collate_fn=batched_collate_fn,
num_workers=8,
)
for i, batch in enumerate(dataloader):
i = i + 1
start_t = dt.datetime.now()
output = model.synthesise(
batch["x"].to(device),
batch["x_lengths"].to(device),
n_timesteps=args.steps,
temperature=args.temperature,
spks=spk,
length_scale=args.speaking_rate,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])
print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}")
print(f"[🍵-Batch: {i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
total_rtf.append(output["rtf"])
total_rtf_w.append(rtf_w)
for j in range(output["mel"].shape[0]):
base_name = f"utterance_{j:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{j:03d}"
length = output["mel_lengths"][j]
new_dict = {"mel": output["mel"][j][:, :length], "waveform": output["waveform"][j][: length * 256]}
location = save_to_folder(base_name, new_dict, args.output_folder)
print(f"[🍵-{j}] Waveform saved: {location}")
print("".join(["="] * 100))
print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
total_rtf = []
total_rtf_w = []
for i, text in enumerate(texts):
i = i + 1
base_name = f"utterance_{i:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{i:03d}"
print("".join(["="] * 100))
text = text.strip()
text_processed = process_text(i, text, device)
print(f"[🍵] Whisking Matcha-T(ea)TS for: {i}")
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=args.steps,
temperature=args.temperature,
spks=spk,
length_scale=args.speaking_rate,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
# RTF with HiFiGAN
t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])
print(f"[🍵-{i}] Matcha-TTS RTF: {output['rtf']:.4f}")
print(f"[🍵-{i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
total_rtf.append(output["rtf"])
total_rtf_w.append(rtf_w)
location = save_to_folder(base_name, output, args.output_folder)
print(f"[+] Waveform saved: {location}")
print("".join(["="] * 100))
print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
def print_config(args):
print("[!] Configurations: ")
print(f"\t- Temperature: {args.temperature}")
print(f"\t- Speaking rate: {args.speaking_rate}")
print(f"\t- Number of ODE steps: {args.steps}")
print(f"\t- Speaker: {args.spk}")
def get_device(args):
if torch.cuda.is_available() and not args.cpu:
print("[+] GPU Available! Using GPU")
device = torch.device("cuda")
else:
print("[-] GPU not available or forced CPU run! Using CPU")
device = torch.device("cpu")
return device
if __name__ == "__main__":
cli()

0
matcha/data/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,231 @@
import random
from typing import Any, Dict, Optional
import torch
import torchaudio as ta
from lightning import LightningDataModule
from torch.utils.data.dataloader import DataLoader
from matcha.text import text_to_sequence
from matcha.utils.audio import mel_spectrogram
from matcha.utils.model import fix_len_compatibility, normalize
from matcha.utils.utils import intersperse
def parse_filelist(filelist_path, split_char="|"):
with open(filelist_path, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split_char) for line in f]
return filepaths_and_text
class TextMelDataModule(LightningDataModule):
def __init__( # pylint: disable=unused-argument
self,
name,
train_filelist_path,
valid_filelist_path,
batch_size,
num_workers,
pin_memory,
cleaners,
add_blank,
n_spks,
n_fft,
n_feats,
sample_rate,
hop_length,
win_length,
f_min,
f_max,
data_statistics,
seed,
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)
def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute things like random split twice!
"""
# load and split datasets only if not loaded already
self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.train_filelist_path,
self.hparams.n_spks,
self.hparams.cleaners,
self.hparams.add_blank,
self.hparams.n_fft,
self.hparams.n_feats,
self.hparams.sample_rate,
self.hparams.hop_length,
self.hparams.win_length,
self.hparams.f_min,
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
)
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.valid_filelist_path,
self.hparams.n_spks,
self.hparams.cleaners,
self.hparams.add_blank,
self.hparams.n_fft,
self.hparams.n_feats,
self.hparams.sample_rate,
self.hparams.hop_length,
self.hparams.win_length,
self.hparams.f_min,
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
)
def train_dataloader(self):
return DataLoader(
dataset=self.trainset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
collate_fn=TextMelBatchCollate(self.hparams.n_spks),
)
def val_dataloader(self):
return DataLoader(
dataset=self.validset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
collate_fn=TextMelBatchCollate(self.hparams.n_spks),
)
def teardown(self, stage: Optional[str] = None):
"""Clean up after fit or test."""
pass # pylint: disable=unnecessary-pass
def state_dict(self): # pylint: disable=no-self-use
"""Extra things to save to checkpoint."""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""Things to do when loading checkpoint."""
pass # pylint: disable=unnecessary-pass
class TextMelDataset(torch.utils.data.Dataset):
def __init__(
self,
filelist_path,
n_spks,
cleaners,
add_blank=True,
n_fft=1024,
n_mels=80,
sample_rate=22050,
hop_length=256,
win_length=1024,
f_min=0.0,
f_max=8000,
data_parameters=None,
seed=None,
):
self.filepaths_and_text = parse_filelist(filelist_path)
self.n_spks = n_spks
self.cleaners = cleaners
self.add_blank = add_blank
self.n_fft = n_fft
self.n_mels = n_mels
self.sample_rate = sample_rate
self.hop_length = hop_length
self.win_length = win_length
self.f_min = f_min
self.f_max = f_max
if data_parameters is not None:
self.data_parameters = data_parameters
else:
self.data_parameters = {"mel_mean": 0, "mel_std": 1}
random.seed(seed)
random.shuffle(self.filepaths_and_text)
def get_datapoint(self, filepath_and_text):
if self.n_spks > 1:
filepath, spk, text = (
filepath_and_text[0],
int(filepath_and_text[1]),
filepath_and_text[2],
)
else:
filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None
text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk}
def get_mel(self, filepath):
audio, sr = ta.load(filepath)
assert sr == self.sample_rate
mel = mel_spectrogram(
audio,
self.n_fft,
self.n_mels,
self.sample_rate,
self.hop_length,
self.win_length,
self.f_min,
self.f_max,
center=False,
).squeeze()
mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"])
return mel
def get_text(self, text, add_blank=True):
text_norm = text_to_sequence(text, self.cleaners)
if self.add_blank:
text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm)
return text_norm
def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index])
return datapoint
def __len__(self):
return len(self.filepaths_and_text)
class TextMelBatchCollate:
def __init__(self, n_spks):
self.n_spks = n_spks
def __call__(self, batch):
B = len(batch)
y_max_length = max([item["y"].shape[-1] for item in batch])
y_max_length = fix_len_compatibility(y_max_length)
x_max_length = max([item["x"].shape[-1] for item in batch])
n_feats = batch[0]["y"].shape[-2]
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], []
spks = []
for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1])
x_lengths.append(x_.shape[-1])
y[i, :, : y_.shape[-1]] = y_
x[i, : x_.shape[-1]] = x_
spks.append(item["spk"])
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks}

21
matcha/hifigan/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020 Jungil Kong
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

101
matcha/hifigan/README.md Normal file
View File

@@ -0,0 +1,101 @@
# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
In our [paper](https://arxiv.org/abs/2010.05646),
we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
We provide our implementation and pretrained models as open source in this repository.
**Abstract :**
Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
Although such methods improve the sampling efficiency and memory usage,
their sample quality has not yet reached that of autoregressive and flow-based generative models.
In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
As speech audio consists of sinusoidal signals with various periods,
we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
faster than real-time on CPU with comparable quality to an autoregressive counterpart.
Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
## Pre-requisites
1. Python >= 3.6
2. Clone this repository.
3. Install python requirements. Please refer [requirements.txt](requirements.txt)
4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
And move all wav files to `LJSpeech-1.1/wavs`
## Training
```
python train.py --config config_v1.json
```
To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
You can change the path by adding `--checkpoint_path` option.
Validation loss during training with V1 generator.<br>
![validation loss](./validation_loss.png)
## Pretrained Model
You can also use pretrained models we provide.<br/>
[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
Details of each folder are as in follows:
| Folder Name | Generator | Dataset | Fine-Tuned |
| ------------ | --------- | --------- | ------------------------------------------------------ |
| LJ_V1 | V1 | LJSpeech | No |
| LJ_V2 | V2 | LJSpeech | No |
| LJ_V3 | V3 | LJSpeech | No |
| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
| VCTK_V1 | V1 | VCTK | No |
| VCTK_V2 | V2 | VCTK | No |
| VCTK_V3 | V3 | VCTK | No |
| UNIVERSAL_V1 | V1 | Universal | No |
We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
## Fine-Tuning
1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
Example:
` Audio File : LJ001-0001.wav
Mel-Spectrogram File : LJ001-0001.npy`
2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
3. Run the following command.
```
python train.py --fine_tuning True --config config_v1.json
```
For other command line options, please refer to the training section.
## Inference from wav file
1. Make `test_files` directory and copy wav files into the directory.
2. Run the following command.
` python inference.py --checkpoint_file [generator checkpoint file path]`
Generated wav files are saved in `generated_files` by default.<br>
You can change the path by adding `--output_dir` option.
## Inference for end-to-end speech synthesis
1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
[Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
2. Run the following command.
` python inference_e2e.py --checkpoint_file [generator checkpoint file path]`
Generated wav files are saved in `generated_files_from_mel` by default.<br>
You can change the path by adding `--output_dir` option.
## Acknowledgements
We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.

View File

28
matcha/hifigan/config.py Normal file
View File

@@ -0,0 +1,28 @@
v1 = {
"resblock": "1",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.0004,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [8, 8, 2, 2],
"upsample_kernel_sizes": [16, 16, 4, 4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"resblock_initial_channel": 256,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_loss": None,
"num_workers": 4,
"dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
}

View File

@@ -0,0 +1,64 @@
# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
import torch
class Denoiser(torch.nn.Module):
"""Removes model bias from audio produced with waveglow"""
def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
super().__init__()
self.filter_length = filter_length
self.hop_length = int(filter_length / n_overlap)
self.win_length = win_length
dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
self.device = device
if mode == "zeros":
mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
elif mode == "normal":
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
else:
raise Exception(f"Mode {mode} if not supported")
def stft_fn(audio, n_fft, hop_length, win_length, window):
spec = torch.stft(
audio,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
return_complex=True,
)
spec = torch.view_as_real(spec)
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
self.stft = lambda x: stft_fn(
audio=x,
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
window=torch.hann_window(self.win_length, device=device),
)
self.istft = lambda x, y: torch.istft(
torch.complex(x * torch.cos(y), x * torch.sin(y)),
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
window=torch.hann_window(self.win_length, device=device),
)
with torch.no_grad():
bias_audio = vocoder(mel_input).float().squeeze(0)
bias_spec, _ = self.stft(bias_audio)
self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
@torch.inference_mode()
def forward(self, audio, strength=0.0005):
audio_spec, audio_angles = self.stft(audio)
audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
audio_denoised = self.istft(audio_spec_denoised, audio_angles)
return audio_denoised

17
matcha/hifigan/env.py Normal file
View File

@@ -0,0 +1,17 @@
""" from https://github.com/jik876/hifi-gan """
import os
import shutil
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
def build_env(config, config_name, path):
t_path = os.path.join(path, config_name)
if config != t_path:
os.makedirs(path, exist_ok=True)
shutil.copyfile(config, os.path.join(path, config_name))

View File

@@ -0,0 +1,217 @@
""" from https://github.com/jik876/hifi-gan """
import math
import os
import random
import numpy as np
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from librosa.util import normalize
from scipy.io.wavfile import read
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[str(y.device)],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
def get_dataset_filelist(a):
with open(a.input_training_file, encoding="utf-8") as fi:
training_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
with open(a.input_validation_file, encoding="utf-8") as fi:
validation_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
return training_files, validation_files
class MelDataset(torch.utils.data.Dataset):
def __init__(
self,
training_files,
segment_size,
n_fft,
num_mels,
hop_size,
win_size,
sampling_rate,
fmin,
fmax,
split=True,
shuffle=True,
n_cache_reuse=1,
device=None,
fmax_loss=None,
fine_tuning=False,
base_mels_path=None,
):
self.audio_files = training_files
random.seed(1234)
if shuffle:
random.shuffle(self.audio_files)
self.segment_size = segment_size
self.sampling_rate = sampling_rate
self.split = split
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.fmax_loss = fmax_loss
self.cached_wav = None
self.n_cache_reuse = n_cache_reuse
self._cache_ref_count = 0
self.device = device
self.fine_tuning = fine_tuning
self.base_mels_path = base_mels_path
def __getitem__(self, index):
filename = self.audio_files[index]
if self._cache_ref_count == 0:
audio, sampling_rate = load_wav(filename)
audio = audio / MAX_WAV_VALUE
if not self.fine_tuning:
audio = normalize(audio) * 0.95
self.cached_wav = audio
if sampling_rate != self.sampling_rate:
raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
self._cache_ref_count = self.n_cache_reuse
else:
audio = self.cached_wav
self._cache_ref_count -= 1
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)
if not self.fine_tuning:
if self.split:
if audio.size(1) >= self.segment_size:
max_audio_start = audio.size(1) - self.segment_size
audio_start = random.randint(0, max_audio_start)
audio = audio[:, audio_start : audio_start + self.segment_size]
else:
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
mel = mel_spectrogram(
audio,
self.n_fft,
self.num_mels,
self.sampling_rate,
self.hop_size,
self.win_size,
self.fmin,
self.fmax,
center=False,
)
else:
mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
mel = torch.from_numpy(mel)
if len(mel.shape) < 3:
mel = mel.unsqueeze(0)
if self.split:
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
if audio.size(1) >= self.segment_size:
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
else:
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
mel_loss = mel_spectrogram(
audio,
self.n_fft,
self.num_mels,
self.sampling_rate,
self.hop_size,
self.win_size,
self.fmin,
self.fmax_loss,
center=False,
)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
def __len__(self):
return len(self.audio_files)

368
matcha/hifigan/models.py Normal file
View File

@@ -0,0 +1,368 @@
""" from https://github.com/jik876/hifi-gan """
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from .xutils import get_padding, init_weights
LRELU_SLOPE = 0.1
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.h = h
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.h = h
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Generator(torch.nn.Module):
def __init__(self, h):
super().__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super().__init__()
self.period = period
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2),
DiscriminatorP(3),
DiscriminatorP(5),
DiscriminatorP(7),
DiscriminatorP(11),
]
)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for _, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(torch.nn.Module):
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
]
)
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i - 1](y)
y_hat = self.meanpools[i - 1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses

60
matcha/hifigan/xutils.py Normal file
View File

@@ -0,0 +1,60 @@
""" from https://github.com/jik876/hifi-gan """
import glob
import os
import matplotlib
import torch
from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def apply_weight_norm(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
weight_norm(m)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def save_checkpoint(filepath, obj):
print(f"Saving checkpoint to {filepath}")
torch.save(obj, filepath)
print("Complete.")
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + "????????")
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return None
return sorted(cp_list)[-1]

View File

View File

@@ -0,0 +1,209 @@
"""
This is a base lightning module that can be used to train a model.
The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
"""
import inspect
from abc import ABC
from typing import Any, Dict
import torch
from lightning import LightningModule
from lightning.pytorch.utilities import grad_norm
from matcha import utils
from matcha.utils.utils import plot_tensor
log = utils.get_pylogger(__name__)
class BaseLightningClass(LightningModule, ABC):
def update_data_statistics(self, data_statistics):
if data_statistics is None:
data_statistics = {
"mel_mean": 0.0,
"mel_std": 1.0,
}
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
def configure_optimizers(self) -> Any:
optimizer = self.hparams.optimizer(params=self.parameters())
if self.hparams.scheduler not in (None, {}):
scheduler_args = {}
# Manage last epoch for exponential schedulers
if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
if hasattr(self, "ckpt_loaded_epoch"):
current_epoch = self.ckpt_loaded_epoch - 1
else:
current_epoch = -1
scheduler_args.update({"optimizer": optimizer})
scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
scheduler.last_epoch = current_epoch
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": self.hparams.scheduler.lightning_args.interval,
"frequency": self.hparams.scheduler.lightning_args.frequency,
"name": "learning_rate",
},
}
return {"optimizer": optimizer}
def get_losses(self, batch):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
dur_loss, prior_loss, diff_loss = self(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
out_size=self.out_size,
)
return {
"dur_loss": dur_loss,
"prior_loss": prior_loss,
"diff_loss": diff_loss,
}
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
def training_step(self, batch: Any, batch_idx: int):
loss_dict = self.get_losses(batch)
self.log(
"step",
float(self.global_step),
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/train_dur_loss",
loss_dict["dur_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/train_prior_loss",
loss_dict["prior_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/train_diff_loss",
loss_dict["diff_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
total_loss = sum(loss_dict.values())
self.log(
"loss/train",
total_loss,
on_step=True,
on_epoch=True,
logger=True,
prog_bar=True,
sync_dist=True,
)
return {"loss": total_loss, "log": loss_dict}
def validation_step(self, batch: Any, batch_idx: int):
loss_dict = self.get_losses(batch)
self.log(
"sub_loss/val_dur_loss",
loss_dict["dur_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/val_prior_loss",
loss_dict["prior_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/val_diff_loss",
loss_dict["diff_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
total_loss = sum(loss_dict.values())
self.log(
"loss/val",
total_loss,
on_step=True,
on_epoch=True,
logger=True,
prog_bar=True,
sync_dist=True,
)
return total_loss
def on_validation_end(self) -> None:
if self.trainer.is_global_zero:
one_batch = next(iter(self.trainer.val_dataloaders))
if self.current_epoch == 0:
log.debug("Plotting original samples")
for i in range(2):
y = one_batch["y"][i].unsqueeze(0).to(self.device)
self.logger.experiment.add_image(
f"original/{i}",
plot_tensor(y.squeeze().cpu()),
self.current_epoch,
dataformats="HWC",
)
log.debug("Synthesising...")
for i in range(2):
x = one_batch["x"][i].unsqueeze(0).to(self.device)
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
attn = output["attn"]
self.logger.experiment.add_image(
f"generated_enc/{i}",
plot_tensor(y_enc.squeeze().cpu()),
self.current_epoch,
dataformats="HWC",
)
self.logger.experiment.add_image(
f"generated_dec/{i}",
plot_tensor(y_dec.squeeze().cpu()),
self.current_epoch,
dataformats="HWC",
)
self.logger.experiment.add_image(
f"alignment/{i}",
plot_tensor(attn.squeeze().cpu()),
self.current_epoch,
dataformats="HWC",
)
def on_before_optimizer_step(self, optimizer):
self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})

View File

View File

@@ -0,0 +1,394 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from conformer import ConformerBlock
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.embeddings import TimestepEmbedding
from einops import pack, rearrange, repeat
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Block1D(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
torch.nn.GroupNorm(groups, dim_out),
nn.Mish(),
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class ResnetBlock1D(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super().__init__()
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.block1 = Block1D(dim, dim_out, groups=groups)
self.block2 = Block1D(dim_out, dim_out, groups=groups)
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
class Downsample1D(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs):
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
return outputs
class ConformerWrapper(ConformerBlock):
def __init__( # pylint: disable=useless-super-delegation
self,
*,
dim,
dim_head=64,
heads=8,
ff_mult=4,
conv_expansion_factor=2,
conv_kernel_size=31,
attn_dropout=0,
ff_dropout=0,
conv_dropout=0,
conv_causal=False,
):
super().__init__(
dim=dim,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
conv_dropout=conv_dropout,
conv_causal=conv_causal,
)
def forward(
self,
hidden_states,
attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
):
return super().forward(x=hidden_states, mask=attention_mask.bool())
class Decoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
down_block_type="transformer",
mid_block_type="transformer",
up_block_type="transformer",
):
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
down_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for i in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
mid_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i]
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
dim=2 * input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
self.get_block(
up_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
# nn.init.normal_(self.final_proj.weight)
@staticmethod
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
if block_type == "conformer":
block = ConformerWrapper(
dim=dim,
dim_head=attention_head_dim,
heads=num_heads,
ff_mult=1,
conv_expansion_factor=2,
ff_dropout=dropout,
attn_dropout=dropout,
conv_dropout=dropout,
conv_kernel_size=31,
)
elif block_type == "transformer":
block = BasicTransformerBlock(
dim=dim,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
else:
raise ValueError(f"Unknown block type {block_type}")
return block
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c")
mask_down = rearrange(mask_down, "b 1 t -> b t")
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=mask_down,
timestep=t,
)
x = rearrange(x, "b t c -> b c t")
mask_down = rearrange(mask_down, "b t -> b 1 t")
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c")
mask_mid = rearrange(mask_mid, "b 1 t -> b t")
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=mask_mid,
timestep=t,
)
x = rearrange(x, "b t c -> b c t")
mask_mid = rearrange(mask_mid, "b t -> b 1 t")
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
x = rearrange(x, "b c t -> b t c")
mask_up = rearrange(mask_up, "b 1 t -> b t")
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=mask_up,
timestep=t,
)
x = rearrange(x, "b t c -> b c t")
mask_up = rearrange(mask_up, "b t -> b 1 t")
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask

View File

@@ -0,0 +1,114 @@
from abc import ABC
import torch
import torch.nn.functional as F
from matcha.models.components.decoder import Decoder
from matcha.utils.pylogger import get_pylogger
log = get_pylogger(__name__)
class BASECFM(torch.nn.Module, ABC):
def __init__(
self,
n_feats,
cfm_params,
n_spks=1,
spk_emb_dim=128,
):
super().__init__()
self.n_feats = n_feats
self.n_spks = n_spks
self.spk_emb_dim = spk_emb_dim
self.solver = cfm_params.solver
if hasattr(cfm_params, "sigma_min"):
self.sigma_min = cfm_params.sigma_min
else:
self.sigma_min = 1e-4
self.estimator = None
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
z (_type_): mu + noise (we don't need this in this formulation), we will sample the noise again
mask (_type_): output_mask
mu (_type_): output of encoder
n_timesteps (_type_): number of diffusion steps
stoc (bool, optional): _description_. Defaults to False.
spks (_type_, optional): _description_. Defaults to None.
Returns:
sample: _description_
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (_type_): _description_
t (_type_): _description_
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
sol = []
steps = 1
while steps <= len(t_span) - 1:
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if steps < len(t_span) - 1:
dt = t_span[steps + 1] - t
steps += 1
return sol[-1]
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (_type_): Target
mask (_type_): target mask
mu (_type_): output of encoder
spks (_type_, optional): speaker embedding. Defaults to None.
Returns:
loss: diffusion loss
y: conditional flow
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss, y
class CFM(BASECFM):
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
# Just change the architecture of the estimator here
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)

View File

@@ -0,0 +1,392 @@
""" from https://github.com/jaywalnut310/glow-tts """
import math
import torch
import torch.nn as nn
from einops import rearrange
import matcha.utils as utils
from matcha.utils.model import sequence_mask
log = utils.get_pylogger(__name__)
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-4):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = torch.nn.Parameter(torch.ones(channels))
self.beta = torch.nn.Parameter(torch.zeros(channels))
def forward(self, x):
n_dims = len(x.shape)
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
shape = [1, -1] + [1] * (n_dims - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.conv_layers = torch.nn.ModuleList()
self.norm_layers = torch.nn.ModuleList()
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class RotaryPositionalEmbeddings(nn.Module):
"""
## RoPE module
Rotary encoding transforms pairs of features by rotating in the 2D plane.
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
by an angle depending on the position of the token.
"""
def __init__(self, d: int, base: int = 10_000):
r"""
* `d` is the number of features $d$
* `base` is the constant used for calculating $\Theta$
"""
super().__init__()
self.base = base
self.d = int(d)
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
r"""
Cache $\cos$ and $\sin$ values
"""
# Return if cache is already built
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
# Get sequence length
seq_len = x.shape[0]
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
# Concatenate so that for row $m$ we have
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
# Cache them
self.cos_cached = idx_theta2.cos()[:, None, None, :]
self.sin_cached = idx_theta2.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor):
# $\frac{d}{2}$
d_2 = self.d // 2
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
# Cache $\cos$ and $\sin$ values
x = rearrange(x, "b h t d -> t b h d")
self._build_cache(x)
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
# Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope)
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
heads_share=True,
p_dropout=0.0,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.heads_share = heads_share
self.proximal_bias = proximal_bias
self.p_dropout = p_dropout
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
# from https://nn.labml.ai/transformers/rope/index.html
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout)
torch.nn.init.xavier_uniform_(self.conv_q.weight)
torch.nn.init.xavier_uniform_(self.conv_k.weight)
if proximal_init:
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
torch.nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
b, d, t_s, t_t = (*key.size(), query.size(2))
query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
query = self.query_rotary_pe(query)
key = self.key_rotary_pe(key)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
p_attn = torch.nn.functional.softmax(scores, dim=-1)
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
return output, p_attn
@staticmethod
def _attention_bias_proximal(length):
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.drop = torch.nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.attn_layers = torch.nn.ModuleList()
self.norm_layers_1 = torch.nn.ModuleList()
self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList()
for _ in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.n_layers):
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class TextEncoder(nn.Module):
def __init__(
self,
encoder_type,
encoder_params,
duration_predictor_params,
n_vocab,
n_spks=1,
spk_emb_dim=128,
):
super().__init__()
self.encoder_type = encoder_type
self.n_vocab = n_vocab
self.n_feats = encoder_params.n_feats
self.n_channels = encoder_params.n_channels
self.spk_emb_dim = spk_emb_dim
self.n_spks = n_spks
self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
if encoder_params.prenet:
self.prenet = ConvReluNorm(
self.n_channels,
self.n_channels,
self.n_channels,
kernel_size=5,
n_layers=3,
p_dropout=0.5,
)
else:
self.prenet = lambda x, x_mask: x
self.encoder = Encoder(
encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
encoder_params.filter_channels,
encoder_params.n_heads,
encoder_params.n_layers,
encoder_params.kernel_size,
encoder_params.p_dropout,
)
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
self.proj_w = DurationPredictor(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
duration_predictor_params.filter_channels_dp,
duration_predictor_params.kernel_size,
duration_predictor_params.p_dropout,
)
def forward(self, x, x_lengths, spks=None):
x = self.emb(x) * math.sqrt(self.n_channels)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.prenet(x, x_mask)
if self.n_spks > 1:
x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
x = self.encoder(x, x_mask)
mu = self.proj_m(x) * x_mask
x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask

211
matcha/models/matcha_tts.py Normal file
View File

@@ -0,0 +1,211 @@
import datetime as dt
import math
import random
import torch
import matcha.utils.monotonic_align as monotonic_align
from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder
from matcha.utils.model import (
denormalize,
duration_loss,
fix_len_compatibility,
generate_path,
sequence_mask,
)
log = utils.get_pylogger(__name__)
class MatchaTTS(BaseLightningClass): # 🍵
def __init__(
self,
n_vocab,
n_spks,
spk_emb_dim,
n_feats,
encoder,
decoder,
cfm,
data_statistics,
out_size,
optimizer=None,
scheduler=None,
):
super().__init__()
self.save_hyperparameters(logger=False)
self.n_vocab = n_vocab
self.n_spks = n_spks
self.spk_emb_dim = spk_emb_dim
self.n_feats = n_feats
self.out_size = out_size
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
self.encoder = TextEncoder(
encoder.encoder_type,
encoder.encoder_params,
encoder.duration_predictor_params,
n_vocab,
n_spks,
spk_emb_dim,
)
self.decoder = CFM(
in_channels=2 * encoder.encoder_params.n_feats,
out_channel=encoder.encoder_params.n_feats,
cfm_params=cfm,
decoder_params=decoder,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
self.update_data_statistics(data_statistics)
@torch.inference_mode()
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
"""
Generates mel-spectrogram from text. Returns:
1. encoder outputs
2. decoder outputs
3. generated alignment
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
x_lengths (torch.Tensor): lengths of texts in batch.
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
temperature (float, optional): controls variance of terminal distribution.
stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
Usually, does not provide synthesis improvements.
length_scale (float, optional): controls speech pace.
Increase value to slow down generated speech and vice versa.
"""
# For RTF computation
t = dt.datetime.now()
if self.n_spks > 1:
# Get speaker embedding
spks = self.spk_emb(spks.long())
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = int(y_lengths.max())
y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length]
# Generate sample by performing reverse dynamics
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
decoder_outputs = decoder_outputs[:, :, :y_max_length]
t = (dt.datetime.now() - t).total_seconds()
rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
return {
"encoder_outputs": encoder_outputs,
"decoder_outputs": decoder_outputs,
"attn": attn[:, :, :y_max_length],
"mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
"mel_lengths": y_lengths,
"rtf": rtf,
}
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs.
3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
x_lengths (torch.Tensor): lengths of texts in batch.
y (torch.Tensor): batch of corresponding mel-spectrograms.
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
"""
if self.n_spks > 1:
# Get speaker embedding
spks = self.spk_emb(spks)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
y_max_length = y.shape[-1]
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach()
# Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
# Cut a small segment of mel-spectrogram in order to increase batch size
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
# - Do not need this hack for Matcha-TTS, but it works with it as well
if not isinstance(out_size, type(None)):
max_offset = (y_lengths - out_size).clamp(0)
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
out_offset = torch.LongTensor(
[torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
).to(y_lengths)
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
y_cut_lengths = []
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
y_cut_lengths.append(y_cut_length)
cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
y_cut_lengths = torch.LongTensor(y_cut_lengths)
y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
attn = attn_cut
y = y_cut
y_mask = y_cut_mask
# Align encoded text with mel-spectrogram and get mu_y segment
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
# Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
return dur_loss, prior_loss, diff_loss

53
matcha/text/__init__.py Normal file
View File

@@ -0,0 +1,53 @@
""" from https://github.com/keithito/tacotron """
from matcha.text import cleaners
from matcha.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
def text_to_sequence(text, cleaner_names):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
"""
sequence = []
clean_text = _clean_text(text, cleaner_names)
for symbol in clean_text:
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]
return sequence
def cleaned_text_to_sequence(cleaned_text):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
"""
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text]
return sequence
def sequence_to_text(sequence):
"""Converts a sequence of IDs back to a string"""
result = ""
for symbol_id in sequence:
s = _id_to_symbol[symbol_id]
result += s
return result
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text)
return text

105
matcha/text/cleaners.py Normal file
View File

@@ -0,0 +1,105 @@
""" from https://github.com/keithito/tacotron
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
"""
import logging
import re
import phonemizer
from unidecode import unidecode
# To avoid excessive logging we set the log level of the phonemizer package to Critical
critical_logger = logging.getLogger("phonemizer")
critical_logger.setLevel(logging.CRITICAL)
# Intializing the phonemizer globally significantly reduces the speed
# now the phonemizer is not initialising at every call
# Might be less flexible, but it is much-much faster
global_phonemizer = phonemizer.backend.EspeakBackend(
language="en-us",
preserve_punctuation=True,
with_stress=True,
language_switch="remove-flags",
logger=critical_logger,
)
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners2(text):
"""Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
phonemes = collapse_whitespace(phonemes)
return phonemes

71
matcha/text/numbers.py Normal file
View File

@@ -0,0 +1,71 @@
""" from https://github.com/keithito/tacotron """
import re
import inflect
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars"
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return f"{dollars} {dollar_unit}, {cents} {cent_unit}"
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return f"{dollars} {dollar_unit}"
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return f"{cents} {cent_unit}"
else:
return "zero dollars"
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text

17
matcha/text/symbols.py Normal file
View File

@@ -0,0 +1,17 @@
""" from https://github.com/keithito/tacotron
Defines the set of symbols used in text input to the model.
"""
_pad = "_"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = (
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
# Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
# Special symbol ids
SPACE_ID = symbols.index(" ")

122
matcha/train.py Normal file
View File

@@ -0,0 +1,122 @@
from typing import Any, Dict, List, Optional, Tuple
import hydra
import lightning as L
import rootutils
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from matcha import utils
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #
log = utils.get_pylogger(__name__)
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: A DictConfig configuration composed by Hydra.
:return: A tuple with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
train_metrics = trainer.callback_metrics
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
"""Main entry point for training.
:param cfg: DictConfig configuration composed by Hydra.
:return: Optional[float] with optimized metric value.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
# train the model
metric_dict, _ = train(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric"))
# return optimized metric
return metric_value
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter

5
matcha/utils/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
from matcha.utils.logging_utils import log_hyperparameters
from matcha.utils.pylogger import get_pylogger
from matcha.utils.rich_utils import enforce_tags, print_config_tree
from matcha.utils.utils import extras, get_metric_value, task_wrapper

82
matcha/utils/audio.py Normal file
View File

@@ -0,0 +1,82 @@
import numpy as np
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from scipy.io.wavfile import read
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[str(y.device)],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec

View File

@@ -0,0 +1,111 @@
r"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
import sys
from pathlib import Path
import rootutils
import torch
from hydra import compose, initialize
from omegaconf import open_dict
from tqdm.auto import tqdm
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.utils.logging_utils import pylogger
log = pylogger.get_pylogger(__name__)
def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
"""Generate data mean and standard deviation helpful in data normalisation
Args:
data_loader (torch.utils.data.Dataloader): _description_
out_channels (int): mel spectrogram channels
"""
total_mel_sum = 0
total_mel_sq_sum = 0
total_mel_len = 0
for batch in tqdm(data_loader, leave=False):
mels = batch["y"]
mel_lengths = batch["y_lengths"]
total_mel_len += torch.sum(mel_lengths)
total_mel_sum += torch.sum(mels)
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
data_mean = total_mel_sum / (total_mel_len * out_channels)
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-config",
type=str,
default="vctk.yaml",
help="The name of the yaml config file under configs/data",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default="256",
help="Can have increased batch size for faster computation",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
required=False,
help="force overwrite the file",
)
args = parser.parse_args()
output_file = Path(args.input_config).with_suffix(".json")
if os.path.exists(output_file) and not args.force:
print("File already exists. Use -f to force overwrite")
sys.exit(1)
with initialize(version_base="1.3", config_path="../../configs/data"):
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
with open_dict(cfg):
del cfg["hydra"]
del cfg["_target_"]
cfg["data_statistics"] = None
cfg["seed"] = 1234
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup()
data_loader = text_mel_datamodule.train_dataloader()
log.info("Dataloader loaded! Now computing stats...")
params = compute_data_statistics(data_loader, cfg["n_feats"])
print(params)
json.dump(
params,
open(output_file, "w"),
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,56 @@
from typing import List
import hydra
from lightning import Callback
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from matcha.utils import pylogger
log = pylogger.get_pylogger(__name__)
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config.
:param callbacks_cfg: A DictConfig object containing callback configurations.
:return: A list of instantiated callbacks.
"""
callbacks: List[Callback] = []
if not callbacks_cfg:
log.warning("No callback configs found! Skipping..")
return callbacks
if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiates loggers from config.
:param logger_cfg: A DictConfig object containing logger configurations.
:return: A list of instantiated loggers.
"""
logger: List[Logger] = []
if not logger_cfg:
log.warning("No logger configs found! Skipping...")
return logger
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
logger.append(hydra.utils.instantiate(lg_conf))
return logger

View File

@@ -0,0 +1,53 @@
from typing import Any, Dict
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import OmegaConf
from matcha.utils import pylogger
log = pylogger.get_pylogger(__name__)
@rank_zero_only
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
"""Controls which config parts are saved by Lightning loggers.
Additionally saves:
- Number of model parameters
:param object_dict: A dictionary containing the following objects:
- `"cfg"`: A DictConfig object containing the main config.
- `"model"`: The Lightning model.
- `"trainer"`: The Lightning trainer.
"""
hparams = {}
cfg = OmegaConf.to_container(object_dict["cfg"])
model = object_dict["model"]
trainer = object_dict["trainer"]
if not trainer.logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
hparams["model"] = cfg["model"]
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
hparams["data"] = cfg["data"]
hparams["trainer"] = cfg["trainer"]
hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")
hparams["task_name"] = cfg.get("task_name")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")
# send hparams to all loggers
for logger in trainer.loggers:
logger.log_hyperparams(hparams)

88
matcha/utils/model.py Normal file
View File

@@ -0,0 +1,88 @@
""" from https://github.com/jaywalnut310/glow-tts """
import numpy as np
import torch
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
while True:
if length % (2**num_downsamplings_in_unet) == 0:
return length
length += 1
def convert_pad_shape(pad_shape):
inverted_shape = pad_shape[::-1]
pad_shape = [item for sublist in inverted_shape for item in sublist]
return pad_shape
def generate_path(duration, mask):
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path
def duration_loss(logw, logw_, lengths):
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
return loss
def normalize(data, mu, std):
if not isinstance(mu, (float, int)):
if isinstance(mu, list):
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
elif isinstance(mu, torch.Tensor):
mu = mu.to(data.device)
elif isinstance(mu, np.ndarray):
mu = torch.from_numpy(mu).to(data.device)
mu = mu.unsqueeze(-1)
if not isinstance(std, (float, int)):
if isinstance(std, list):
std = torch.tensor(std, dtype=data.dtype, device=data.device)
elif isinstance(std, torch.Tensor):
std = std.to(data.device)
elif isinstance(std, np.ndarray):
std = torch.from_numpy(std).to(data.device)
std = std.unsqueeze(-1)
return (data - mu) / std
def denormalize(data, mu, std):
if not isinstance(mu, float):
if isinstance(mu, list):
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
elif isinstance(mu, torch.Tensor):
mu = mu.to(data.device)
elif isinstance(mu, np.ndarray):
mu = torch.from_numpy(mu).to(data.device)
mu = mu.unsqueeze(-1)
if not isinstance(std, float):
if isinstance(std, list):
std = torch.tensor(std, dtype=data.dtype, device=data.device)
elif isinstance(std, torch.Tensor):
std = std.to(data.device)
elif isinstance(std, np.ndarray):
std = torch.from_numpy(std).to(data.device)
std = std.unsqueeze(-1)
return data * std + mu

View File

@@ -0,0 +1,22 @@
import numpy as np
import torch
from matcha.utils.monotonic_align.core import maximum_path_c
def maximum_path(value, mask):
"""Cython optimised version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
value = value * mask
device = value.device
dtype = value.dtype
value = value.data.cpu().numpy().astype(np.float32)
path = np.zeros_like(value).astype(np.int32)
mask = mask.data.cpu().numpy()
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
maximum_path_c(path, value, t_x_max, t_y_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)

View File

@@ -0,0 +1,47 @@
import numpy as np
cimport cython
cimport numpy as np
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
cdef int x
cdef int y
cdef float v_prev
cdef float v_cur
cdef float tmp
cdef int index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[x, y-1]
if x == 0:
if y == 0:
v_prev = 0.
else:
v_prev = max_neg_val
else:
v_prev = value[x-1, y-1]
value[x, y] = max(v_cur, v_prev) + value[x, y]
for y in range(t_y - 1, -1, -1):
path[index, y] = 1
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
cdef int b = values.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)

View File

@@ -0,0 +1,7 @@
# from distutils.core import setup
# from Cython.Build import cythonize
# import numpy
# setup(name='monotonic_align',
# ext_modules=cythonize("core.pyx"),
# include_dirs=[numpy.get_include()])

21
matcha/utils/pylogger.py Normal file
View File

@@ -0,0 +1,21 @@
import logging
from lightning.pytorch.utilities import rank_zero_only
def get_pylogger(name: str = __name__) -> logging.Logger:
"""Initializes a multi-GPU-friendly python command line logger.
:param name: The name of the logger, defaults to ``__name__``.
:return: A logger object.
"""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger

101
matcha/utils/rich_utils.py Normal file
View File

@@ -0,0 +1,101 @@
from pathlib import Path
from typing import Sequence
import rich
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt
from matcha.utils import pylogger
log = pylogger.get_pylogger(__name__)
@rank_zero_only
def print_config_tree(
cfg: DictConfig,
print_order: Sequence[str] = (
"data",
"model",
"callbacks",
"logger",
"trainer",
"paths",
"extras",
),
resolve: bool = False,
save_to_file: bool = False,
) -> None:
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
:param cfg: A DictConfig composed by Hydra.
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
"callbacks", "logger", "trainer", "paths", "extras")``.
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
queue = []
# add fields from `print_order` to queue
for field in print_order:
_ = (
queue.append(field)
if field in cfg
else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
)
# add all the other fields to queue (not specified in `print_order`)
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, style=style, guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
# print config tree
rich.print(tree)
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
rich.print(tree, file=file)
@rank_zero_only
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
"""Prompts user to input tags from command line if no tags are provided in config.
:param cfg: A DictConfig composed by Hydra.
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
"""
if not cfg.get("tags"):
if "id" in HydraConfig().cfg.hydra.job:
raise ValueError("Specify tags before launching a multirun!")
log.warning("No tags provided in config. Prompting user to input tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
tags = [t.strip() for t in tags.split(",") if t != ""]
with open_dict(cfg):
cfg.tags = tags
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
rich.print(cfg.tags, file=file)

214
matcha/utils/utils.py Normal file
View File

@@ -0,0 +1,214 @@
import os
import sys
import warnings
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import gdown
import matplotlib.pyplot as plt
import numpy as np
import torch
import wget
from omegaconf import DictConfig
from matcha.utils import pylogger, rich_utils
log = pylogger.get_pylogger(__name__)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
:param cfg: A DictConfig object containing the config tree.
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
rich_utils.enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return metric_dict, object_dict
```
:param task_func: The task function to be wrapped.
:return: The wrapped task function.
"""
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.output_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
"""Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
:param metric_name: The name of the metric to retrieve.
:return: The value of the metric.
"""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
def intersperse(lst, item):
# Adds blank symbol
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def save_figure_to_numpy(fig):
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def plot_tensor(tensor):
plt.style.use("default")
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
def save_plot(tensor, savepath):
plt.style.use("default")
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
plt.savefig(savepath)
plt.close()
def to_numpy(tensor):
if isinstance(tensor, np.ndarray):
return tensor
elif isinstance(tensor, torch.Tensor):
return tensor.detach().cpu().numpy()
elif isinstance(tensor, list):
return np.array(tensor)
else:
raise TypeError("Unsupported type for conversion to numpy array")
def get_user_data_dir(appname="matcha_tts"):
"""
Args:
appname (str): Name of application
Returns:
Path: path to user data directory
"""
MATCHA_HOME = os.environ.get("MATCHA_HOME")
if MATCHA_HOME is not None:
ans = Path(MATCHA_HOME).expanduser().resolve(strict=False)
elif sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel
key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == "darwin":
ans = Path("~/Library/Application Support/").expanduser()
else:
ans = Path.home().joinpath(".local/share")
return ans.joinpath(appname)
def assert_model_downloaded(checkpoint_path, url, use_wget=False):
if Path(checkpoint_path).exists():
log.debug(f"[+] Model already present at {checkpoint_path}!")
return
log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
checkpoint_path = str(checkpoint_path)
if not use_wget:
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else:
wget.download(url=url, out=checkpoint_path)

0
notebooks/.gitkeep Normal file
View File

51
pyproject.toml Normal file
View File

@@ -0,0 +1,51 @@
[build-system]
requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"]
[tool.black]
line-length = 120
target-version = ['py310']
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
| foo.py # also separately exclude a file named foo.py in
# the root of the project
)
'''
[tool.pytest.ini_options]
addopts = [
"--color=yes",
"--durations=0",
"--strict-markers",
"--doctest-modules",
]
filterwarnings = [
"ignore::DeprecationWarning",
"ignore::UserWarning",
]
log_cli = "True"
markers = [
"slow: slow tests",
]
minversion = "6.0"
testpaths = "tests/"
[tool.coverage.report]
exclude_lines = [
"pragma: nocover",
"raise NotImplementedError",
"raise NotImplementedError()",
"if __name__ == .__main__.:",
]

45
requirements.txt Normal file
View File

@@ -0,0 +1,45 @@
# --------- pytorch --------- #
torch>=2.0.0
torchvision>=0.15.0
lightning>=2.0.0
torchmetrics>=0.11.4
# --------- hydra --------- #
hydra-core==1.3.2
hydra-colorlog==1.2.0
hydra-optuna-sweeper==1.2.0
# --------- loggers --------- #
# wandb
# neptune-client
# mlflow
# comet-ml
# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550
# --------- others --------- #
rootutils # standardizing the project root setup
pre-commit # hooks for applying linters on commit
rich # beautiful text formatting in terminal
pytest # tests
# sh # for running bash commands in some tests (linux/macos only)
phonemizer # phonemization of text
tensorboard
librosa
Cython
numpy
einops
inflect
Unidecode
scipy
torchaudio
matplotlib
pandas
conformer
diffusers @ git+https://github.com/shivammehta25/diffusers.git@matcha_tts_version
notebook
ipywidgets
gradio
gdown
wget
seaborn
gradio

7
scripts/schedule.sh Normal file
View File

@@ -0,0 +1,7 @@
#!/bin/bash
# Schedule execution of many runs
# Run from root folder with: bash scripts/schedule.sh
python src/train.py trainer.max_epochs=5 logger=csv
python src/train.py trainer.max_epochs=10 logger=csv

38
setup.py Normal file
View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python
import os
import numpy
import pkg_resources
from Cython.Build import cythonize
from setuptools import Extension, find_packages, setup
exts = [
Extension(
name="matcha.utils.monotonic_align.core",
sources=["matcha/utils/monotonic_align/core.pyx"],
)
]
setup(
name="matcha",
version="0.1.0",
description="A fast TTS architecture with conditional flow matching",
author="Shivam Mehta",
author_email="shivam.mehta25@gmail.com",
url="https://shivammehta25.github.io/Matcha-TTS",
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(open(os.path.join(os.path.dirname(__file__), "requirements.txt")))
],
include_dirs=[numpy.get_include()],
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
# use this to customize global commands available in the terminal after installing the package
entry_points={
"console_scripts": [
"matcha-data-stats=matcha.utils.generate_data_statistics:main",
"matcha_tts=matcha.cli:cli",
"matcha_tts_app=matcha.app:main",
]
},
ext_modules=cythonize(exts, language_level=3),
)

419
synthesis.ipynb Normal file

File diff suppressed because one or more lines are too long