23 Commits

Author SHA1 Message Date
Gustav Eje Henter
5a7f220662 Update demo page
Tweak the text, mention ICASSP acceptance, and update citation information.
2024-01-14 15:43:13 +01:00
Shivam Mehta
188cdb6844 Removing architecture from the docs 2024-01-10 11:06:48 +00:00
Shivam Mehta
30ae66af39 Update README.md 2023-09-28 10:15:09 +02:00
Shivam Mehta
158d95e278 Updating docs 2023-09-17 15:54:16 +00:00
Shivam Mehta
47b2fc9244 Update README.md 2023-09-14 15:07:04 +02:00
Shivam Mehta
00bf704393 centering the logo 2023-09-07 09:48:38 +00:00
Shivam Mehta
be07e9aed7 Trying to add logo 2023-09-07 09:38:41 +00:00
Shivam Mehta
87dc55646e Adding arXiv link 2023-09-07 08:17:54 +00:00
Shivam Mehta
3e297d9b25 Adding arXiv link 2023-09-07 08:15:38 +00:00
Shivam Mehta
5d6519b004 Truning on the hit counter 2023-09-06 18:23:02 +00:00
Shivam Mehta
7bd56fcdae Updating links 2023-09-06 17:45:04 +00:00
Shivam Mehta
859a6fe45c minor change to retrigger deployment 2023-09-06 16:19:57 +00:00
Shivam Mehta
a29515877b updating number of ode solver audios 2023-09-06 16:14:40 +00:00
Shivam Mehta
af5a398021 Changing the number of ode solver steps 2023-09-06 16:03:59 +00:00
Shivam Mehta
54bb4aca58 Adding stimuli for different number of ode solver steps 2023-09-06 15:51:06 +00:00
Shivam Mehta
94c39f85b2 Minor formatting fixes 2023-09-06 12:35:18 +00:00
Shivam Mehta
6b5936b616 minor changes 2023-09-06 12:28:57 +00:00
Shivam Mehta
82e33b4240 minor changes 2023-09-06 12:26:14 +00:00
Shivam Mehta
82a799f56e minor changes 2023-09-06 11:10:47 +00:00
Shivam Mehta
0e35178fbe adding hit counter 2023-09-06 02:19:18 +00:00
Shivam Mehta
683fa463bb adding hit counter 2023-09-06 02:15:42 +00:00
Shivam Mehta
d3208d86e3 Aligning image of play 2023-09-06 02:12:12 +00:00
Shivam Mehta
af03022f71 Adding first table 2023-09-06 02:04:14 +00:00
293 changed files with 501 additions and 9608 deletions

View File

@@ -1,6 +0,0 @@
# example of file for storing private and user specific environment variables, like keys or system paths
# rename it to ".env" (excluded from version control by default)
# .env is loaded by train.py automatically
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
MY_VAR="/home/user/my/system/path"

View File

@@ -1,22 +0,0 @@
## What does this PR do?
<!--
Please include a summary of the change and which issue is fixed.
Please also include relevant motivation and context.
List any dependencies that are required for this change.
List all the breaking changes introduced by this pull request.
-->
Fixes #\<issue_number>
## Before submitting
- [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**?
- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together?
- [ ] Did you list all the **breaking changes** introduced by this pull request?
- [ ] Did you **test your PR locally** with `pytest` command?
- [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command?
## Did you have fun?
Make sure you had fun coding 🙃

15
.github/codecov.yml vendored
View File

@@ -1,15 +0,0 @@
coverage:
status:
# measures overall project coverage
project:
default:
threshold: 100% # how much decrease in coverage is needed to not consider success
# measures PR or single commit coverage
patch:
default:
threshold: 100% # how much decrease in coverage is needed to not consider success
# project: off
# patch: off

View File

@@ -1,17 +0,0 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
target-branch: "dev"
schedule:
interval: "daily"
ignore:
- dependency-name: "pytorch-lightning"
update-types: ["version-update:semver-patch"]
- dependency-name: "torchmetrics"
update-types: ["version-update:semver-patch"]

View File

@@ -1,44 +0,0 @@
name-template: "v$RESOLVED_VERSION"
tag-template: "v$RESOLVED_VERSION"
categories:
- title: "🚀 Features"
labels:
- "feature"
- "enhancement"
- title: "🐛 Bug Fixes"
labels:
- "fix"
- "bugfix"
- "bug"
- title: "🧹 Maintenance"
labels:
- "maintenance"
- "dependencies"
- "refactoring"
- "cosmetic"
- "chore"
- title: "📝️ Documentation"
labels:
- "documentation"
- "docs"
change-template: "- $TITLE @$AUTHOR (#$NUMBER)"
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions
version-resolver:
major:
labels:
- "major"
minor:
labels:
- "minor"
patch:
labels:
- "patch"
default: patch
template: |
## Changes
$CHANGES

163
.gitignore vendored
View File

@@ -1,163 +0,0 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
### VisualStudioCode
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace
**/.vscode
# JetBrains
.idea/
# Data & Models
*.h5
*.tar
*.tar.gz
# Lightning-Hydra-Template
configs/local/default.yaml
/data/
/logs/
.env
# Aim logging
.aim
# Cython complied files
matcha/utils/monotonic_align/core.c
# Ignoring hifigan checkpoint
generator_v1
g_02500000
gradio_cached_examples/
synth_output/

View File

@@ -1,59 +0,0 @@
default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
- id: end-of-file-fixer
# - id: check-docstring-first
- id: check-yaml
- id: debug-statements
- id: detect-private-key
- id: check-toml
- id: check-case-conflict
- id: check-added-large-files
# python code formatting
- repo: https://github.com/psf/black
rev: 23.12.1
hooks:
- id: black
args: [--line-length, "120"]
# python import sorting
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py38-plus]
# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
args:
[
"--max-line-length", "120",
"--extend-ignore",
"E203,E402,E501,F401,F841,RST2,RST301",
"--exclude",
"logs/*,data/*,matcha/hifigan/*",
]
additional_dependencies: [flake8-rst-docstrings==0.3.0]
# pylint
- repo: https://github.com/pycqa/pylint
rev: v3.0.3
hooks:
- id: pylint

View File

@@ -1,2 +0,0 @@
# this file is required for inferring the project root directory
# do not delete

525
.pylintrc
View File

@@ -1,525 +0,0 @@
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=missing-docstring,
too-many-public-methods,
too-many-lines,
bare-except,
## for avoiding weird p3.6 CI linter error
## TODO: see later if we can remove this
assigning-non-slot,
unsupported-assignment-operation,
## end
line-too-long,
fixme,
wrong-import-order,
ungrouped-imports,
wrong-import-position,
import-error,
invalid-name,
too-many-instance-attributes,
arguments-differ,
arguments-renamed,
no-name-in-module,
no-member,
unsubscriptable-object,
raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
useless-object-inheritance,
too-few-public-methods,
too-many-branches,
too-many-arguments,
too-many-locals,
too-many-statements,
duplicate-code,
not-callable,
import-outside-toplevel,
logging-fstring-interpolation,
logging-not-lazy,
unused-argument,
no-else-return,
chained-comparison,
redefined-outer-name
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit
[LOGGING]
# Format style used to check logging format string. `old` means using %
# formatting, while `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package..
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=numpy.*,torch.*
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=120
# Maximum number of lines in a module.
max-module-lines=1000
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style.
argument-rgx=[a-z_][a-z0-9_]{0,30}$
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
x,
ex,
Run,
_
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style.
variable-rgx=[a-z_][a-z0-9_]{0,30}$
[STRING]
# This flag controls whether the implicit-str-concat-in-sequence should
# generate a warning on implicit string concatenation in sequences defined over
# several lines.
check-str-concat-over-line-jumps=no
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
[DESIGN]
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=15
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=builtins.BaseException,
builtins.Exception

21
LICENSE
View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2023 Shivam Mehta
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,14 +0,0 @@
include README.md
include LICENSE.txt
include requirements.*.txt
include *.cff
include requirements.txt
include matcha/VERSION
recursive-include matcha *.json
recursive-include matcha *.html
recursive-include matcha *.png
recursive-include matcha *.md
recursive-include matcha *.py
recursive-include matcha *.pyx
recursive-exclude tests *
prune tests*

View File

@@ -1,42 +0,0 @@
help: ## Show help
@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
clean: ## Clean autogenerated files
rm -rf dist
find . -type f -name "*.DS_Store" -ls -delete
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
find . | grep -E ".pytest_cache" | xargs rm -rf
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
rm -f .coverage
clean-logs: ## Clean logs
rm -rf logs/**
create-package: ## Create wheel and tar gz
rm -rf dist/
python setup.py bdist_wheel --plat-name=manylinux1_x86_64
python setup.py sdist
python -m twine upload dist/* --verbose --skip-existing
format: ## Run pre-commit hooks
pre-commit run -a
sync: ## Merge changes from main branch to your current branch
git pull
git pull origin main
test: ## Run not slow tests
pytest -k "not slow"
test-full: ## Run all tests
pytest
train-ljspeech: ## Train the model
python matcha/train.py experiment=ljspeech
train-ljspeech-min: ## Train the model with minimum memory
python matcha/train.py experiment=ljspeech_min_memory
start_app: ## Start the app
python matcha/app.py

751
README.md
View File

@@ -1,24 +1,28 @@
<div align="center">
# Matcha-TTS: A fast TTS architecture with conditional flow matching
# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
<head>
<link rel="icon" type="image/x-icon" href="favicon.ico">
<meta name="msapplication-TileColor" content="#da532c">
<meta charset="UTF-8">
<meta name="theme-color" content="#ffffff">
<meta property="og:title" content="Matcha-TTS: A fast TTS architecture with conditional flow matching" />
<meta name="og:description" content="We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching to speed up ODE-based speech synthesis. Our method is probabilistic, has compact memory footprint, sounds highly natural, is very fast to synthesise from">
<meta property="og:image" content="images/architecture.png" />
<meta property="twitter:image" content="images/architecture.png" />
<meta property="og:type" content="website" />
<meta property="og:site_name" content="Matcha-TTS" />
<meta name="twitter:card" content="images/architecture.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="keywords" content="tts, text to speech, probabilistic machine learning, diffusion models, conditional flow matching, generative modelling, machine learning, deep learning, speech synthesis, research, phd">
<meta name="description" content="We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching to speed up ODE-based speech synthesis. Our method is probabilistic, has compact memory footprint, sounds highly natural, is very fast to synthesise from." />
</head>
### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
[![python](https://img.shields.io/badge/-Python_3.10-blue?logo=python&logoColor=white)](https://www.python.org/downloads/release/python-3100/)
[![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
[![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/)
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
##### [Shivam Mehta][shivam_profile], [Ruibo Tu][ruibo_profile], [Jonas Beskow][jonas_profile], [Éva Székely][eva_profile], and [Gustav Eje Henter][gustav_profile]
<p style="text-align: center;">
<img src="https://shivammehta25.github.io/Matcha-TTS/images/logo.png" height="128"/>
<img src="images/logo.png" height="128"/>
</p>
</div>
> This is the official code implementation of 🍵 Matcha-TTS [ICASSP 2024].
We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses [conditional flow matching](https://arxiv.org/abs/2210.02747) (similar to [rectified flows](https://arxiv.org/abs/2209.03003)) to speed up ODE-based speech synthesis. Our method:
- Is probabilistic
@@ -26,237 +30,486 @@ We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, tha
- Sounds highly natural
- Is very fast to synthesise from
Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS) and read [our ICASSP 2024 paper](https://arxiv.org/abs/2309.03199) for more details.
[Pre-trained models](https://drive.google.com/drive/folders/17C_gYgEHOxI5ZypcfE_k1piKCtyR0isJ?usp=sharing) will be automatically downloaded with the CLI or gradio interface.
You can also [try 🍵 Matcha-TTS in your browser on HuggingFace 🤗 spaces](https://huggingface.co/spaces/shivammehta25/Matcha-TTS).
## Teaser video
[![Watch the video](https://img.youtube.com/vi/xmvJkz3bqw0/hqdefault.jpg)](https://youtu.be/xmvJkz3bqw0)
## Installation
1. Create an environment (suggested but optional)
```
conda create -n matcha-tts python=3.10 -y
conda activate matcha-tts
```
2. Install Matcha TTS using pip or from source
```bash
pip install matcha-tts
```
from source
```bash
pip install git+https://github.com/shivammehta25/Matcha-TTS.git
cd Matcha-TTS
pip install -e .
```
3. Run CLI / gradio app / jupyter notebook
```bash
# This will download the required models
matcha-tts --text "<INPUT TEXT>"
```
or
```bash
matcha-tts-app
```
or open `synthesis.ipynb` on jupyter notebook
### CLI Arguments
- To synthesise from given text, run:
```bash
matcha-tts --text "<INPUT TEXT>"
```
- To synthesise from a file, run:
```bash
matcha-tts --file <PATH TO FILE>
```
- To batch synthesise from a file, run:
```bash
matcha-tts --file <PATH TO FILE> --batched
```
Additional arguments
- Speaking rate
```bash
matcha-tts --text "<INPUT TEXT>" --speaking_rate 1.0
```
- Sampling temperature
```bash
matcha-tts --text "<INPUT TEXT>" --temperature 0.667
```
- Euler ODE solver steps
```bash
matcha-tts --text "<INPUT TEXT>" --steps 10
```
## Train with your own dataset
Let's assume we are training with LJ Speech
1. Download the dataset from [here](https://keithito.com/LJ-Speech-Dataset/), extract it to `data/LJSpeech-1.1`, and prepare the file lists to point to the extracted data like for [item 5 in the setup of the NVIDIA Tacotron 2 repo](https://github.com/NVIDIA/tacotron2#setup).
2. Clone and enter the Matcha-TTS repository
```bash
git clone https://github.com/shivammehta25/Matcha-TTS.git
cd Matcha-TTS
```
3. Install the package from source
```bash
pip install -e .
```
4. Go to `configs/data/ljspeech.yaml` and change
```yaml
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt
```
5. Generate normalisation statistics with the yaml file of dataset configuration
```bash
matcha-data-stats -i ljspeech.yaml
# Output:
#{'mel_mean': -5.53662231756592, 'mel_std': 2.1161014277038574}
```
Update these values in `configs/data/ljspeech.yaml` under `data_statistics` key.
```bash
data_statistics: # Computed for ljspeech dataset
mel_mean: -5.536622
mel_std: 2.116101
```
to the paths of your train and validation filelists.
6. Run the training script
```bash
make train-ljspeech
```
or
```bash
python matcha/train.py experiment=ljspeech
```
- for a minimum memory run
```bash
python matcha/train.py experiment=ljspeech_min_memory
```
- for multi-gpu training, run
```bash
python matcha/train.py experiment=ljspeech trainer.devices=[0,1]
```
7. Synthesise from the custom trained model
```bash
matcha-tts --text "<INPUT TEXT>" --checkpoint_path <PATH TO CHECKPOINT>
```
## ONNX support
> Special thanks to [@mush42](https://github.com/mush42) for implementing ONNX export and inference support.
It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph.
### ONNX export
To export a checkpoint to ONNX, first install ONNX with
```bash
pip install onnx
```
then run the following:
```bash
python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5
```
Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint** arguments. This enables you to embed the vocoder in the exported graph and generate waveforms in a single run (similar to end-to-end TTS systems).
**Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**.
**Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release.
### ONNX Inference
To run inference on the exported model, first install `onnxruntime` using
```bash
pip install onnxruntime
pip install onnxruntime-gpu # for GPU inference
```
then use the following:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs
```
You can also control synthesis parameters:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0
```
To run inference on **GPU**, make sure to install **onnxruntime-gpu** package, and then pass `--gpu` to the inference command:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --gpu
```
If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory.
If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory.
If you exported only Matcha to ONNX, and you want to run a full TTS pipeline, you can pass a path to a vocoder model in `ONNX` format:
```bash
python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vocoder hifigan.small.onnx
```
This will write `.wav` audio files to the output directory.
See below for audio examples, or read [our ICASSP 2024 paper][arxiv_link] for more details.
Code is available in our [GitHub repository][github_link], along with pre-trained models.
You can also [try 🍵 Matcha-TTS in your browser on HuggingFace 🤗 spaces][hf_space].
[shivam_profile]: https://www.kth.se/profile/smehta
[ruibo_profile]: https://www.kth.se/profile/ruibo
[jonas_profile]: https://www.kth.se/profile/beskow
[eva_profile]: https://www.kth.se/profile/szekely
[gustav_profile]: https://people.kth.se/~ghe/
[this_page]: https://shivammehta25.github.io/Matcha-TTS
[arxiv_link]: https://arxiv.org/abs/2309.03199
[grad_tts_paper]: https://arxiv.org/abs/2105.06337
[vits_paper]: https://arxiv.org/abs/2106.06103
[fastspeech2_paper]: https://arxiv.org/abs/2006.04558
[github_link]: https://github.com/shivammehta25/Matcha-TTS
[hf_space]: https://huggingface.co/spaces/shivammehta25/Matcha-TTS
<style type="text/css">
.tg {
border-collapse: collapse;
border-color: #9ABAD9;
border-spacing: 0;
}
.tg td {
background-color: #EBF5FF;
border-color: #9ABAD9;
border-style: solid;
border-width: 1px;
color: #444;
font-family: Arial, sans-serif;
font-size: 14px;
overflow: hidden;
padding: 0px 20px;
word-break: normal;
font-weight: bold;
vertical-align: middle;
text-align: center;
white-space: nowrap;
}
.tg th {
background-color: #409cff;
border-color: #9ABAD9;
border-style: solid;
border-width: 1px;
color: #fff;
font-family: Arial, sans-serif;
font-size: 14px;
font-weight: bold;
overflow: hidden;
padding: 0px 20px;
word-break: normal;
font-weight: bold;
vertical-align: middle;
text-align: center;
white-space: nowrap;
margin: auto;
}
.tg th a {
background-color: #409cff;
color: #fff;
text-decoration: none;
font-family: Arial, sans-serif;
font-size: 14px;
font-weight: bold;
overflow: hidden;
padding: 0px 20px;
word-break: normal;
font-weight: bold;
vertical-align: middle;
text-align: center;
white-space: nowrap;
margin: auto;
}
.tg .tg-0pky {
border-color: inherit;
text-align: center;
vertical-align: top,
}
td img {
position: relative;
margin: 0 auto;
max-width: 650px;
padding: 5px;
border: 0px;
}
.tg .tg-fymr {
border-color: inherit;
font-weight: bold;
text-align: center;
vertical-align: top
}
.slider {
-webkit-appearance: none;
width: 75%;
height: 15px;
border-radius: 5px;
background: #d3d3d3;
outline: none;
opacity: 0.7;
-webkit-transition: .2s;
transition: opacity .2s;
}
.slider::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 25px;
height: 25px;
border-radius: 50%;
background: #409cff;
cursor: pointer;
}
.slider::-moz-range-thumb {
width: 25px;
height: 25px;
border-radius: 50%;
background: #409cff;
cursor: pointer;
}
/* audio {
width: 240px;
} */
/* CSS */
.button-12 {
display: flex;
flex-direction: column;
align-items: center;
padding: 10px 54px;
font-family: -apple-system, BlinkMacSystemFont, 'Roboto', sans-serif;
font-weight: bold;
border-radius: 6px;
border: none;
background: #6E6D70;
box-shadow: 0px 0.5px 1px rgba(0, 0, 0, 0.1), inset 0px 0.5px 0.5px rgba(255, 255, 255, 0.5), 0px 0px 0px 0.5px rgba(0, 0, 0, 0.12);
color: #DFDEDF;
user-select: none;
-webkit-user-select: none;
touch-action: manipulation;
}
.button-12:focus {
box-shadow: inset 0px 0.8px 0px -0.25px rgba(255, 255, 255, 0.2), 0px 0.5px 1px rgba(0, 0, 0, 0.1), 0px 0px 0px 3.5px rgba(58, 108, 217, 0.5);
outline: 0;
}
audio {
margin: 0.5em;
}
.slider {
-webkit-appearance: none;
width: 75%;
height: 15px;
border-radius: 5px;
background: #d3d3d3;
outline: none;
opacity: 0.7;
-webkit-transition: .2s;
transition: opacity .2s;
}
.slider::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 25px;
height: 25px;
border-radius: 50%;
background: #409cff;
cursor: pointer;
}
.slider::-moz-range-thumb {
width: 25px;
height: 25px;
border-radius: 50%;
background: #409cff;
cursor: pointer;
}
</style>
<script src="transcripts.js"></script>
<!-- ## Architecture
<img src="images/architecture.png" alt="Architecture of Matcha-TTS" width="750"/> -->
<script>
transcript_listening_test = {
1: "It had established periodic regular review of the status of four hundred individuals;", //4
2: "The narrative of these events is based largely on the recollections of the participants,", // 3
3: "The jury did not believe him, and the verdict was for the defendants.", // 7
4: "One by one the huge uprights of black timber were fitted together,", // 19
5: "The position of this palmprint on the carton was parallel with the long axis of the box, and at right angles with the short axis;", // 23
6: "The boy declared he saw no one, and accordingly passed through without paying the toll of a penny." // 38
}
function play_audio(filename, audio_id, condition_name, transcription){
audio = document.getElementById(audio_id);
audio_source = document.getElementById(audio_id + "-src");
block_quote = document.getElementById(audio_id + "-transcript");
stimulus_span = document.getElementById(audio_id + "-span");
audio.pause();
audio_source.src = filename;
block_quote.innerHTML = transcription;
stimulus_span.innerHTML = condition_name;
audio.load();
audio.play();
}
</script>
## Stimuli from the listening test
> Click the buttons in the table to load and play the different stimuli.
Currently loaded stimulus: <span id="stimuli-from-listening-test-span" style="font-weight: bold;"> MAT-10 : Sentence 1</span>
<p>Audio player: </p>
<audio id="stimuli-from-listening-test" controls>
<source id="stimuli-from-listening-test-src" src="stimuli/sample_from_test/MAT-10_1.wav" type="audio/wav">
</audio>
<p> Transcription: </p>
<blockquote style="height: 60px">
<p id="stimuli-from-listening-test-transcript">
It had established periodic regular review of the status of four hundred individuals;
</p>
</blockquote>
<table class="tg">
<thead>
<tr>
<th class="tg-0pky">System</th>
<th class="tg-0pky">Condition</th>
<th class="tg-0pky">Sentence 1</th>
<th class="tg-0pky">Sentence 2</th>
<th class="tg-0pky">Sentence 3</th>
<th class="tg-0pky">Sentence 4</th>
<th class="tg-0pky">Sentence 5</th>
<th class="tg-0pky">Sentence 6</th>
</tr>
</thead>
<tbody>
<tr>
<th class="tg-0pky">Vocoded <br> speech</th>
<th class="tg-0pky">VOC</th>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VOC_1.wav', 'stimuli-from-listening-test', 'VOC , Sentence 1', transcript_listening_test[1])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VOC_2.wav', 'stimuli-from-listening-test', 'VOC , Sentence 2', transcript_listening_test[2])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VOC_3.wav', 'stimuli-from-listening-test', 'VOC , Sentence 3', transcript_listening_test[3])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VOC_4.wav', 'stimuli-from-listening-test', 'VOC , Sentence 4', transcript_listening_test[4])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VOC_5.wav', 'stimuli-from-listening-test', 'VOC , Sentence 5', transcript_listening_test[5])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VOC_6.wav', 'stimuli-from-listening-test', 'VOC , Sentence 6', transcript_listening_test[6])"/> </td>
</tr>
<tr>
<th class="tg-0pky" rowspan="3"><a href="https://shivammehta25.github.io/Matcha-TTS"> Matcha-TTS</a></th>
<th class="tg-0pky">MAT-10</th>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-10_1.wav', 'stimuli-from-listening-test', 'MAT-10 , Sentence 1', transcript_listening_test[1])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-10_2.wav', 'stimuli-from-listening-test', 'MAT-10 , Sentence 2', transcript_listening_test[2])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-10_3.wav', 'stimuli-from-listening-test', 'MAT-10 , Sentence 3', transcript_listening_test[3])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-10_4.wav', 'stimuli-from-listening-test', 'MAT-10 , Sentence 4', transcript_listening_test[4])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-10_5.wav', 'stimuli-from-listening-test', 'MAT-10 , Sentence 5', transcript_listening_test[5])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-10_6.wav', 'stimuli-from-listening-test', 'MAT-10 , Sentence 6', transcript_listening_test[6])"/> </td>
</tr>
<tr>
<th class="tg-0pky">MAT-4</th>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-4_1.wav', 'stimuli-from-listening-test', 'MAT-4 , Sentence 1', transcript_listening_test[1])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-4_2.wav', 'stimuli-from-listening-test', 'MAT-4 , Sentence 2', transcript_listening_test[2])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-4_3.wav', 'stimuli-from-listening-test', 'MAT-4 : Sentence 3', transcript_listening_test[3])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-4_4.wav', 'stimuli-from-listening-test', 'MAT-4 : Sentence 4', transcript_listening_test[4])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-4_5.wav', 'stimuli-from-listening-test', 'MAT-4 , Sentence 5', transcript_listening_test[5])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-4_6.wav', 'stimuli-from-listening-test', 'MAT-4 , Sentence 6', transcript_listening_test[6])"/> </td>
</tr>
<tr>
<th class="tg-0pky">MAT-2</th>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-2_1.wav', 'stimuli-from-listening-test', 'MAT-2 , Sentence 1', transcript_listening_test[1])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-2_2.wav', 'stimuli-from-listening-test', 'MAT-2 , Sentence 2', transcript_listening_test[2])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-2_3.wav', 'stimuli-from-listening-test', 'MAT-2 , Sentence 3', transcript_listening_test[3])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-2_4.wav', 'stimuli-from-listening-test', 'MAT-2 , Sentence 4', transcript_listening_test[4])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-2_5.wav', 'stimuli-from-listening-test', 'MAT-2 , Sentence 5', transcript_listening_test[5])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/MAT-2_6.wav', 'stimuli-from-listening-test', 'MAT-2 , Sentence 6', transcript_listening_test[6])"/> </td>
</tr>
<tr>
<th class="tg-0pky" rowspan="2"><a href="https://arxiv.org/abs/2105.06337">Grad-TTS</a></th>
<th class="tg-0pky">GRAD-10</th>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-10_1.wav', 'stimuli-from-listening-test', 'GRAD-10 , Sentence 1', transcript_listening_test[1])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-10_2.wav', 'stimuli-from-listening-test', 'GRAD-10 , Sentence 2', transcript_listening_test[2])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-10_3.wav', 'stimuli-from-listening-test', 'GRAD-10 , Sentence 3', transcript_listening_test[3])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-10_4.wav', 'stimuli-from-listening-test', 'GRAD-10 , Sentence 4', transcript_listening_test[4])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-10_5.wav', 'stimuli-from-listening-test', 'GRAD-10 , Sentence 5', transcript_listening_test[5])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-10_6.wav', 'stimuli-from-listening-test', 'GRAD-10 , Sentence 6', transcript_listening_test[6])"/> </td>
</tr>
<tr>
<th class="tg-0pky">GRAD-4</th>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-4_1.wav', 'stimuli-from-listening-test', 'GRAD-4 , Sentence 1', transcript_listening_test[1])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-4_2.wav', 'stimuli-from-listening-test', 'GRAD-4 , Sentence 2', transcript_listening_test[2])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-4_3.wav', 'stimuli-from-listening-test', 'GRAD-4 , Sentence 3', transcript_listening_test[3])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-4_4.wav', 'stimuli-from-listening-test', 'GRAD-4 , Sentence 4', transcript_listening_test[4])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-4_5.wav', 'stimuli-from-listening-test', 'GRAD-4 , Sentence 5', transcript_listening_test[5])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GRAD-4_6.wav', 'stimuli-from-listening-test', 'GRAD-4 , Sentence 6', transcript_listening_test[6])"/> </td>
</tr>
<tr>
<th class="tg-0pky">Grad-TTS+CFM</th>
<th class="tg-0pky">GCFM-4</th>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GCFM-4_1.wav', 'stimuli-from-listening-test', 'GCFM-4 , Sentence 1', transcript_listening_test[1])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GCFM-4_2.wav', 'stimuli-from-listening-test', 'GCFM-4 , Sentence 2', transcript_listening_test[2])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GCFM-4_3.wav', 'stimuli-from-listening-test', 'GCFM-4 , Sentence 3', transcript_listening_test[3])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GCFM-4_4.wav', 'stimuli-from-listening-test', 'GCFM-4 , Sentence 4', transcript_listening_test[4])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GCFM-4_5.wav', 'stimuli-from-listening-test', 'GCFM-4 , Sentence 5', transcript_listening_test[5])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/GCFM-4_6.wav', 'stimuli-from-listening-test', 'GCFM-4 , Sentence 6', transcript_listening_test[6])"/> </td>
</tr>
<tr>
<th class="tg-0pky"><a href="https://arxiv.org/abs/2006.04558">FastSpeech 2</a></th>
<th class="tg-0pky">FS2</th>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/FS2_1.wav', 'stimuli-from-listening-test', 'FS2 , Sentence 1', transcript_listening_test[1])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/FS2_2.wav', 'stimuli-from-listening-test', 'FS2 , Sentence 2', transcript_listening_test[2])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/FS2_3.wav', 'stimuli-from-listening-test', 'FS2 , Sentence 3', transcript_listening_test[3])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/FS2_4.wav', 'stimuli-from-listening-test', 'FS2 , Sentence 4', transcript_listening_test[4])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/FS2_5.wav', 'stimuli-from-listening-test', 'FS2 , Sentence 5', transcript_listening_test[5])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/FS2_6.wav', 'stimuli-from-listening-test', 'FS2 , Sentence 6', transcript_listening_test[6])"/> </td>
</tr>
<tr>
<th class="tg-0pky"><a href="https://arxiv.org/abs/2106.06103">VITS</a></th>
<th class="tg-0pky">VITS</th>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VITS_1.wav', 'stimuli-from-listening-test', 'VITS , Sentence 1', transcript_listening_test[1])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VITS_2.wav', 'stimuli-from-listening-test', 'VITS , Sentence 2', transcript_listening_test[2])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VITS_3.wav', 'stimuli-from-listening-test', 'VITS , Sentence 3', transcript_listening_test[3])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VITS_4.wav', 'stimuli-from-listening-test', 'VITS , Sentence 4', transcript_listening_test[4])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VITS_5.wav', 'stimuli-from-listening-test', 'VITS , Sentence 5', transcript_listening_test[5])"/> </td>
<td> <img src="images/play_button.png" height=40 style="cursor: pointer;" onclick="play_audio('stimuli/sample_from_test/VITS_6.wav', 'stimuli-from-listening-test', 'VITS , Sentence 6', transcript_listening_test[6])"/> </td>
</tr>
</tbody>
</table>
## Effect of the number of ODE solver steps
<div class="slidecontainer">
<label for="itr_slider"><span style="font-weight:bold"> 1 </span></label>
<input type="range" min="1" max="12" value="6" class="slider" id="itr_slider">
<label for="itr_slider"><span style="font-weight:bold"> 500 </span> </label>
<p><span style="font-weight:bold">Steps:</span> <span class="itr_val"></span>
</p>
</div>
<script>
var itr_slider = document.getElementById("itr_slider");
var itr_vals = document.getElementsByClassName("itr_val");
// Functions to update values
var iterations = {
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 10,
7: 15,
8: 20,
9: 25,
10: 50,
11: 100,
12: 500,
};
function updateVals(classes, value){
for(var i=0; i < classes.length; i++) {
classes[i].innerHTML= iterations[parseInt(value)];
}
}
let systems = [
"MAT",
"GRAD",
"GCFM"
]
updateVals(itr_vals, 6);
itr_slider.oninput = function() {
updateVals(itr_vals, this.value);
let iteration = iterations[parseInt(this.value)];
// Update sources
for (let sent=1; sent <= 3; sent++){
for (let system_idx = 0; system_idx < systems.length; system_idx++){
let audio = document.getElementById(systems[system_idx] + "_sent_" + sent);
let audio_src = document.getElementById( systems[system_idx] + "_sent_src_" + sent);
audio_src.src = "stimuli/number_of_ode_solver/" + systems[system_idx] + "-" + iteration + "_" + sent + ".wav";
audio.load();
}
}
}
</script>
<table class="tg">
<thead>
<tr>
<th class="tg-0pky">System</th>
<th class="tg-0pky">Sentence 1</th>
<th class="tg-0pky">Sentence 2</th>
<th class="tg-0pky">Sentence 3</th>
</tr>
</thead>
<tbody>
<tr>
<th class="tg-0pky"><a href="https://shivammehta25.github.io/Matcha-TTS">Matcha-TTS</a></th>
<td>
<audio id="MAT_sent_1" controls>
<source id="MAT_sent_src_1" src="stimuli/number_of_ode_solver/MAT-10_1.wav" type="audio/wav">
</audio>
</td>
<td>
<audio id="MAT_sent_2" controls>
<source id="MAT_sent_src_2" src="stimuli/number_of_ode_solver/MAT-10_2.wav" type="audio/wav">
</audio>
</td>
<td>
<audio id="MAT_sent_3" controls>
<source id="MAT_sent_src_3" src="stimuli/sample_from_test/MAT-10_3.wav" type="audio/wav">
</audio>
</td>
</tr>
<tr>
<th class="tg-0pky"><a href="https://arxiv.org/abs/2105.06337">Grad-TTS</a></th>
<td>
<audio id="GRAD_sent_1" controls>
<source id="GRAD_sent_src_1" src="stimuli/number_of_ode_solver/GRAD-10_1.wav" type="audio/wav">
</audio>
</td>
<td>
<audio id="GRAD_sent_2" controls>
<source id="GRAD_sent_src_2" src="stimuli/number_of_ode_solver/GRAD-10_2.wav" type="audio/wav">
</audio>
</td>
<td>
<audio id="GRAD_sent_3" controls>
<source id="GRAD_sent_src_3" src="stimuli/number_of_ode_solver/GRAD-10_3.wav" type="audio/wav">
</audio>
</td>
</tr>
<tr>
<th class="tg-0pky">Grad-TTS + CFM</th>
<td>
<audio id="GCFM_sent_1" controls>
<source id="GCFM_sent_src_1" src="stimuli/number_of_ode_solver/GCFM-10_1.wav" type="audio/wav">
</audio>
</td>
<td>
<audio id="GCFM_sent_2" controls>
<source id="GCFM_sent_src_2" src="stimuli/number_of_ode_solver/GCFM-10_2.wav" type="audio/wav">
</audio>
</td>
<td>
<audio id="GCFM_sent_3" controls>
<source id="GCFM_sent_src_3" src="stimuli/number_of_ode_solver/GCFM-10_3.wav" type="audio/wav">
</audio>
</td>
</tr>
</tbody>
</table>
## Citation information
If you use our code or otherwise find this work useful, please cite our paper:
```text
```
@inproceedings{mehta2024matcha,
title={Matcha-{TTS}: A fast {TTS} architecture with conditional flow matching},
author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje},
@@ -265,14 +518,4 @@ If you use our code or otherwise find this work useful, please cite our paper:
}
```
## Acknowledgements
Since this code uses [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that come with it.
Other source code we would like to acknowledge:
- [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev): For helping me figure out how to make cython binaries pip installable and encouragement
- [Hugging Face Diffusers](https://huggingface.co/): For their awesome diffusers library and its components
- [Grad-TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): For the monotonic alignment search source code
- [torchdyn](https://github.com/DiffEqML/torchdyn): Useful for trying other ODE solvers during research and development
- [labml.ai](https://nn.labml.ai/transformers/rope/index.html): For the RoPE implementation
[![MatchaTTS](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https://shivammehta25.github.io/Matcha-TTS&count_bg=%23409CFF&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Matcha-TTS&edge_flat=false)][this_page]

4
_config.yaml Normal file
View File

@@ -0,0 +1,4 @@
title: Matcha-TTS
theme: jekyll-theme-dinky
description: A fast TTS architecture with conditional flow matching
show_downloads: False

View File

@@ -1 +0,0 @@
# this file is needed here to include configs when building project as a package

View File

@@ -1,5 +0,0 @@
defaults:
- model_checkpoint.yaml
- model_summary.yaml
- rich_progress_bar.yaml
- _self_

View File

@@ -1,17 +0,0 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${paths.output_dir}/checkpoints # directory to save the model file
filename: checkpoint_{epoch:03d} # checkpoint filename
monitor: epoch # name of the logged metric which determines when model is improving
verbose: False # verbosity mode
save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 10 # save k best models (determined by above metric)
mode: "max" # "max" means higher metric value is better, can be also "min"
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
save_weights_only: False # if True, then only the models weights will be saved
every_n_train_steps: null # number of training steps between checkpoints
train_time_interval: null # checkpoints are monitored at the specified time interval
every_n_epochs: 100 # number of epochs between checkpoints
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation

View File

@@ -1,5 +0,0 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
model_summary:
_target_: lightning.pytorch.callbacks.RichModelSummary
max_depth: 3 # the maximum depth of layer nesting that the summary will include

View File

@@ -1,4 +0,0 @@
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
rich_progress_bar:
_target_: lightning.pytorch.callbacks.RichProgressBar

View File

@@ -1,14 +0,0 @@
defaults:
- ljspeech
- _self_
# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: hi-fi_en-US_female
train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt
valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt
batch_size: 32
cleaners: [english_cleaners_piper]
data_statistics: # Computed for this dataset
mel_mean: -6.38385
mel_std: 2.541796

View File

@@ -1,10 +0,0 @@
defaults:
- ljspeech
- _self_
name: joe_spont_only
train_filelist_path: data/filelists/joe_spontonly_train.txt
valid_filelist_path: data/filelists/joe_spontonly_val.txt
data_statistics:
mel_mean: -5.882903
mel_std: 2.458284

View File

@@ -1,21 +0,0 @@
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: ljspeech
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt
batch_size: 32
num_workers: 20
pin_memory: True
cleaners: [english_cleaners2]
add_blank: True
n_spks: 1
n_fft: 1024
n_feats: 80
sample_rate: 22050
hop_length: 256
win_length: 1024
f_min: 0
f_max: 8000
data_statistics: # Computed for ljspeech dataset
mel_mean: -5.536622
mel_std: 2.116101
seed: ${seed}

View File

@@ -1,10 +0,0 @@
defaults:
- ljspeech
- _self_
name: ryan
train_filelist_path: data/filelists/ryan_train.csv
valid_filelist_path: data/filelists/ryan_val.csv
data_statistics:
mel_mean: -4.715779
mel_std: 2.124502

View File

@@ -1,10 +0,0 @@
defaults:
- ljspeech
- _self_
name: tsg2
train_filelist_path: data/filelists/cormac_train.txt
valid_filelist_path: data/filelists/cormac_val.txt
data_statistics:
mel_mean: -5.536622
mel_std: 2.116101

View File

@@ -1,14 +0,0 @@
defaults:
- ljspeech
- _self_
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: vctk
train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt
valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt
batch_size: 32
add_blank: True
n_spks: 109
data_statistics: # Computed for vctk dataset
mel_mean: -6.630575
mel_std: 2.482914

View File

@@ -1,35 +0,0 @@
# @package _global_
# default debugging setup, runs 1 full epoch
# other debugging configs can inherit from this one
# overwrite task name so debugging logs are stored in separate folder
task_name: "debug"
# disable callbacks and loggers during debugging
# callbacks: null
# logger: null
extras:
ignore_warnings: False
enforce_tags: False
# sets level of all command line loggers to 'DEBUG'
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
hydra:
job_logging:
root:
level: DEBUG
# use this to also set hydra loggers to 'DEBUG'
# verbose: True
trainer:
max_epochs: 1
accelerator: cpu # debuggers don't like gpus
devices: 1 # debuggers don't like multiprocessing
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
data:
num_workers: 0 # debuggers don't like multiprocessing
pin_memory: False # disable gpu memory pin

View File

@@ -1,9 +0,0 @@
# @package _global_
# runs 1 train, 1 validation and 1 test step
defaults:
- default
trainer:
fast_dev_run: true

View File

@@ -1,12 +0,0 @@
# @package _global_
# uses only 1% of the training data and 5% of validation/test data
defaults:
- default
trainer:
max_epochs: 3
limit_train_batches: 0.01
limit_val_batches: 0.05
limit_test_batches: 0.05

View File

@@ -1,13 +0,0 @@
# @package _global_
# overfits to 3 batches
defaults:
- default
trainer:
max_epochs: 20
overfit_batches: 3
# model ckpt and early stopping need to be disabled during overfitting
callbacks: null

View File

@@ -1,15 +0,0 @@
# @package _global_
# runs with execution time profiling
defaults:
- default
trainer:
max_epochs: 1
# profiler: "simple"
profiler: "advanced"
# profiler: "pytorch"
accelerator: gpu
limit_train_batches: 0.02

View File

@@ -1,18 +0,0 @@
# @package _global_
defaults:
- _self_
- data: mnist # choose datamodule with `test_dataloader()` for evaluation
- model: mnist
- logger: null
- trainer: default
- paths: default
- extras: default
- hydra: default
task_name: "eval"
tags: ["dev"]
# passing checkpoint path is necessary for evaluation
ckpt_path: ???

View File

@@ -1,14 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: hi-fi_en-US_female.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"]
run_name: hi-fi_en-US_female_piper_phonemizer

View File

@@ -1,14 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_det_dur

View File

@@ -1,20 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_stoc_dur
model:
duration_predictor:
p_dropout: 0.2

View File

@@ -1,14 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ljspeech.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech

View File

@@ -1,18 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ljspeech.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech_min
model:
out_size: 172

View File

@@ -1,16 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ljspeech.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech

View File

@@ -1,14 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: vctk.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["multispeaker"]
run_name: multispeaker

View File

@@ -1,18 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_det_dur
trainer:
max_epochs: 3000

View File

@@ -1,24 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_stoc_dur
model:
duration_predictor:
p_dropout: 0.2
trainer:
max_epochs: 3000

View File

@@ -1,14 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_det_dur

View File

@@ -1,20 +0,0 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_stoc_dur
model:
duration_predictor:
p_dropout: 0.5

View File

@@ -1,8 +0,0 @@
# disable python warnings if they annoy you
ignore_warnings: False
# ask user for tags if none are provided in the config
enforce_tags: True
# pretty print config tree at the start of the run using Rich library
print_config: True

View File

@@ -1,52 +0,0 @@
# @package _global_
# example hyperparameter optimization of some experiment with Optuna:
# python train.py -m hparams_search=mnist_optuna experiment=example
defaults:
- override /hydra/sweeper: optuna
# choose metric which will be optimized by Optuna
# make sure this is the correct name of some metric logged in lightning module!
optimized_metric: "val/acc_best"
# here we define Optuna hyperparameter search
# it optimizes for value returned from function with @hydra.main decorator
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
hydra:
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
sweeper:
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
# storage URL to persist optimization results
# for example, you can use SQLite if you set 'sqlite:///example.db'
storage: null
# name of the study to persist optimization results
study_name: null
# number of parallel workers
n_jobs: 1
# 'minimize' or 'maximize' the objective
direction: maximize
# total number of runs that will be executed
n_trials: 20
# choose Optuna hyperparameter sampler
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
sampler:
_target_: optuna.samplers.TPESampler
seed: 1234
n_startup_trials: 10 # number of random sampling runs before optimization starts
# define hyperparameter search space
params:
model.optimizer.lr: interval(0.0001, 0.1)
data.batch_size: choice(32, 64, 128, 256)
model.net.lin1_size: choice(64, 128, 256)
model.net.lin2_size: choice(64, 128, 256)
model.net.lin3_size: choice(32, 64, 128, 256)

View File

@@ -1,19 +0,0 @@
# https://hydra.cc/docs/configure_hydra/intro/
# enable color logging
defaults:
- override hydra_logging: colorlog
- override job_logging: colorlog
# output directory, generated dynamically on each run
run:
dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
sweep:
dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
subdir: ${hydra.job.num}
job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log

View File

View File

@@ -1,28 +0,0 @@
# https://aimstack.io/
# example usage in lightning module:
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
# `aim up`
aim:
_target_: aim.pytorch_lightning.AimLogger
repo: ${paths.root_dir} # .aim folder will be created here
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
# aim allows to group runs under experiment name
experiment: null # any string, set to "default" if not specified
train_metric_prefix: "train/"
val_metric_prefix: "val/"
test_metric_prefix: "test/"
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
system_tracking_interval: 10 # set to null to disable system metrics tracking
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
log_system_params: true
# enable/disable tracking console logs (default value is true)
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550

View File

@@ -1,12 +0,0 @@
# https://www.comet.ml
comet:
_target_: lightning.pytorch.loggers.comet.CometLogger
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
save_dir: "${paths.output_dir}"
project_name: "lightning-hydra-template"
rest_api_key: null
# experiment_name: ""
experiment_key: null # set to resume experiment
offline: False
prefix: ""

View File

@@ -1,7 +0,0 @@
# csv logger built in lightning
csv:
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
save_dir: "${paths.output_dir}"
name: "csv/"
prefix: ""

View File

@@ -1,9 +0,0 @@
# train with many loggers at once
defaults:
# - comet
- csv
# - mlflow
# - neptune
- tensorboard
- wandb

View File

@@ -1,12 +0,0 @@
# https://mlflow.org
mlflow:
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
# experiment_name: ""
# run_name: ""
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
tags: null
# save_dir: "./mlruns"
prefix: ""
artifact_location: null
# run_id: ""

View File

@@ -1,9 +0,0 @@
# https://neptune.ai
neptune:
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
project: username/lightning-hydra-template
# name: ""
log_model_checkpoints: True
prefix: ""

View File

@@ -1,10 +0,0 @@
# https://www.tensorflow.org/tensorboard/
tensorboard:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${paths.output_dir}/tensorboard/"
name: null
log_graph: False
default_hp_metric: True
prefix: ""
# version: ""

View File

@@ -1,16 +0,0 @@
# https://wandb.ai
wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: False
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "lightning-hydra-template"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
# entity: "" # set to name of your wandb team
group: ""
tags: []
job_type: ""

View File

@@ -1,3 +0,0 @@
name: CFM
solver: euler
sigma_min: 1e-4

View File

@@ -1,7 +0,0 @@
channels: [256, 256]
dropout: 0.05
attention_head_dim: 64
n_blocks: 1
num_mid_blocks: 2
num_heads: 2
act_fn: snakebeta

View File

@@ -1,7 +0,0 @@
name: deterministic
n_spks: ${model.n_spks}
spk_emb_dim: ${model.spk_emb_dim}
filter_channels: 256
kernel_size: 3
n_channels: ${model.encoder.encoder_params.n_channels}
p_dropout: ${model.encoder.encoder_params.p_dropout}

View File

@@ -1,7 +0,0 @@
defaults:
- deterministic.yaml
- _self_
sigma_min: 1e-4
n_steps: 10
name: flow_matching

View File

@@ -1,10 +0,0 @@
encoder_type: RoPE Encoder
encoder_params:
n_feats: ${model.n_feats}
n_channels: 192
filter_channels: 768
n_heads: 2
n_layers: 6
kernel_size: 3
p_dropout: 0.1
prenet: true

View File

@@ -1,16 +0,0 @@
defaults:
- _self_
- encoder: default.yaml
- duration_predictor: deterministic.yaml
- decoder: default.yaml
- cfm: default.yaml
- optimizer: adam.yaml
_target_: matcha.models.matcha_tts.MatchaTTS
n_vocab: 178
n_spks: ${data.n_spks}
spk_emb_dim: 64
n_feats: 80
data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4
prior_loss: true

View File

@@ -1,4 +0,0 @@
_target_: torch.optim.Adam
_partial_: true
lr: 1e-4
weight_decay: 0.0

View File

@@ -1,18 +0,0 @@
# path to root directory
# this requires PROJECT_ROOT environment variable to exist
# you can replace it with "." if you want the root to be the current working directory
root_dir: ${oc.env:PROJECT_ROOT}
# path to data directory
data_dir: ${paths.root_dir}/data/
# path to logging directory
log_dir: ${paths.root_dir}/logs/
# path to output directory, created dynamically by hydra
# path generation pattern is specified in `configs/hydra/default.yaml`
# use it to store all files generated during the run, like ckpts and metrics
output_dir: ${hydra:runtime.output_dir}
# path to working directory
work_dir: ${hydra:runtime.cwd}

View File

@@ -1,51 +0,0 @@
# @package _global_
# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: ljspeech
- model: matcha
- callbacks: default
- logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
- extras: default
- hydra: default
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null
# config for hyperparameter optimization
- hparams_search: null
# optional local config for machine/user specific settings
# it's optional since it doesn't need to exist and is excluded from version control
- optional local: default
# debugging config (enable through command line, e.g. `python train.py debug=default)
- debug: null
# task name, determines output directory path
task_name: "train"
run_name: ???
# tags to help you identify your experiments
# you can overwrite this in experiment configs
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
tags: ["dev"]
# set False to skip model training
train: True
# evaluate on test set, using best model weights achieved during training
# lightning chooses best weights based on the metric specified in checkpoint callback
test: True
# simply provide checkpoint path to resume training
ckpt_path: null
# seed for random number generators in pytorch, numpy and python.random
seed: 1234

View File

@@ -1,5 +0,0 @@
defaults:
- default
accelerator: cpu
devices: 1

View File

@@ -1,9 +0,0 @@
defaults:
- default
strategy: ddp
accelerator: gpu
devices: [0,1]
num_nodes: 1
sync_batchnorm: True

View File

@@ -1,7 +0,0 @@
defaults:
- default
# simulate DDP on CPU, useful for debugging
accelerator: cpu
devices: 2
strategy: ddp_spawn

View File

@@ -1,20 +0,0 @@
_target_: lightning.pytorch.trainer.Trainer
default_root_dir: ${paths.output_dir}
max_epochs: -1
accelerator: gpu
devices: [0]
# mixed precision for extra speed-up
precision: 16-mixed
# perform a validation loop every N training epochs
check_val_every_n_epoch: 1
# set True to to ensure deterministic results
# makes training slower but gives more reproducibility than just setting seeds
deterministic: False
gradient_clip_val: 5.0

View File

@@ -1,5 +0,0 @@
defaults:
- default
accelerator: gpu
devices: 1

View File

@@ -1,5 +0,0 @@
defaults:
- default
accelerator: mps
devices: 1

1
data
View File

@@ -1 +0,0 @@
/home/smehta/Projects/Speech-Backbones/Grad-TTS/data

BIN
favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

BIN
images/architecture.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 928 KiB

BIN
images/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 352 KiB

BIN
images/play_button.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.8 KiB

View File

@@ -1 +0,0 @@
0.0.5.1

View File

View File

@@ -1,357 +0,0 @@
import tempfile
from argparse import Namespace
from pathlib import Path
import gradio as gr
import soundfile as sf
import torch
from matcha.cli import (
MATCHA_URLS,
VOCODER_URLS,
assert_model_downloaded,
get_device,
load_matcha,
load_vocoder,
process_text,
to_waveform,
)
from matcha.utils.utils import get_user_data_dir, plot_tensor
LOCATION = Path(get_user_data_dir())
args = Namespace(
cpu=False,
model="matcha_vctk",
vocoder="hifigan_univ_v1",
spk=0,
)
CURRENTLY_LOADED_MODEL = args.model
def MATCHA_TTS_LOC(x):
return LOCATION / f"{x}.ckpt"
def VOCODER_LOC(x):
return LOCATION / f"{x}"
LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
RADIO_OPTIONS = {
"Multi Speaker (VCTK)": {
"model": "matcha_vctk",
"vocoder": "hifigan_univ_v1",
},
"Single Speaker (LJ Speech)": {
"model": "matcha_ljspeech",
"vocoder": "hifigan_T2_v1",
},
}
# Ensure all the required models are downloaded
assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"])
assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"])
assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
device = get_device(args)
# Load default model
model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device)
vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device)
def load_model(model_name, vocoder_name):
model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device)
vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device)
return model, vocoder, denoiser
def load_model_ui(model_type, textbox):
model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"]
global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
if CURRENTLY_LOADED_MODEL != model_name:
model, vocoder, denoiser = load_model(model_name, vocoder_name)
CURRENTLY_LOADED_MODEL = model_name
if model_name == "matcha_ljspeech":
spk_slider = gr.update(visible=False, value=-1)
single_speaker_examples = gr.update(visible=True)
multi_speaker_examples = gr.update(visible=False)
length_scale = gr.update(value=0.95)
else:
spk_slider = gr.update(visible=True, value=0)
single_speaker_examples = gr.update(visible=False)
multi_speaker_examples = gr.update(visible=True)
length_scale = gr.update(value=0.85)
return (
textbox,
gr.update(interactive=True),
spk_slider,
single_speaker_examples,
multi_speaker_examples,
length_scale,
)
@torch.inference_mode()
def process_text_gradio(text):
output = process_text(1, text, device)
return output["x_phones"][1::2], output["x"], output["x_lengths"]
@torch.inference_mode()
def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk):
spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
output = model.synthesise(
text,
text_length,
n_timesteps=n_timesteps,
temperature=temperature,
spks=spk,
length_scale=length_scale,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
sf.write(fp.name, output["waveform"], 22050, "PCM_24")
return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk):
global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
if CURRENTLY_LOADED_MODEL != "matcha_vctk":
global model, vocoder, denoiser # pylint: disable=global-statement
model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1")
CURRENTLY_LOADED_MODEL = "matcha_vctk"
phones, text, text_lengths = process_text_gradio(text)
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
return phones, audio, mel_spectrogram
def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1):
global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
if CURRENTLY_LOADED_MODEL != "matcha_ljspeech":
global model, vocoder, denoiser # pylint: disable=global-statement
model, vocoder, denoiser = load_model("matcha_ljspeech", "hifigan_T2_v1")
CURRENTLY_LOADED_MODEL = "matcha_ljspeech"
phones, text, text_lengths = process_text_gradio(text)
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
return phones, audio, mel_spectrogram
def main():
description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method:
* Is probabilistic
* Has compact memory footprint
* Sounds highly natural
* Is very fast to synthesise from
Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199).
Code is available in our [GitHub repository](https://github.com/shivammehta25/Matcha-TTS), along with pre-trained models.
Cached examples are available at the bottom of the page.
"""
with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
processed_text = gr.State(value=None)
processed_text_len = gr.State(value=None)
with gr.Box():
with gr.Row():
gr.Markdown(description, scale=3)
with gr.Column():
gr.Image(LOGO_URL, label="Matcha-TTS logo", height=50, width=50, scale=1, show_label=False)
html = '<br><iframe width="560" height="315" src="https://www.youtube.com/embed/xmvJkz3bqw0?si=jN7ILyDsbPwJCGoa" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>'
gr.HTML(html)
with gr.Box():
radio_options = list(RADIO_OPTIONS.keys())
model_type = gr.Radio(
radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False
)
with gr.Row():
gr.Markdown("# Text Input")
with gr.Row():
text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3)
spk_slider = gr.Slider(
minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1
)
with gr.Row():
gr.Markdown("### Hyper parameters")
with gr.Row():
n_timesteps = gr.Slider(
label="Number of ODE steps",
minimum=1,
maximum=100,
step=1,
value=10,
interactive=True,
)
length_scale = gr.Slider(
label="Length scale (Speaking rate)",
minimum=0.5,
maximum=1.5,
step=0.05,
value=1.0,
interactive=True,
)
mel_temp = gr.Slider(
label="Sampling temperature",
minimum=0.00,
maximum=2.001,
step=0.16675,
value=0.667,
interactive=True,
)
synth_btn = gr.Button("Synthesise")
with gr.Box():
with gr.Row():
gr.Markdown("### Phonetised text")
phonetised_text = gr.Textbox(interactive=False, scale=10, label="Phonetised text")
with gr.Box():
with gr.Row():
mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram")
# with gr.Row():
audio = gr.Audio(interactive=False, label="Audio")
with gr.Row(visible=False) as example_row_lj_speech:
examples = gr.Examples( # pylint: disable=unused-variable
examples=[
[
"We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
50,
0.677,
0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
2,
0.677,
0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
4,
0.677,
0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
10,
0.677,
0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
50,
0.677,
0.95,
],
[
"The narrative of these events is based largely on the recollections of the participants.",
10,
0.677,
0.95,
],
[
"The jury did not believe him, and the verdict was for the defendants.",
10,
0.677,
0.95,
],
],
fn=ljspeech_example_cacher,
inputs=[text, n_timesteps, mel_temp, length_scale],
outputs=[phonetised_text, audio, mel_spectrogram],
cache_examples=True,
)
with gr.Row() as example_row_multispeaker:
multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable
examples=[
[
"Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!",
10,
0.677,
0.85,
0,
],
[
"Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!",
10,
0.677,
0.85,
16,
],
[
"Hello everyone! I am speaker 44 and I am here to tell you that Matcha-TTS is amazing!",
50,
0.677,
0.85,
44,
],
[
"Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!",
50,
0.677,
0.85,
45,
],
[
"Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!",
4,
0.677,
0.85,
58,
],
],
fn=multispeaker_example_cacher,
inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider],
outputs=[phonetised_text, audio, mel_spectrogram],
cache_examples=True,
label="Multi Speaker Examples",
)
model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then(
load_model_ui,
inputs=[model_type, text],
outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale],
)
synth_btn.click(
fn=process_text_gradio,
inputs=[
text,
],
outputs=[phonetised_text, processed_text, processed_text_len],
api_name="matcha_tts",
queue=True,
).then(
fn=synthesise_mel,
inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider],
outputs=[audio, mel_spectrogram],
)
demo.queue().launch(share=True)
if __name__ == "__main__":
main()

View File

@@ -1,418 +0,0 @@
import argparse
import datetime as dt
import os
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import torch
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.env import AttrDict
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.models.matcha_tts import MatchaTTS
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
MATCHA_URLS = {
"matcha_ljspeech": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt",
"matcha_vctk": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_vctk.ckpt",
}
VOCODER_URLS = {
"hifigan_T2_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1", # Old url: https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link
"hifigan_univ_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/g_02500000", # Old url: https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link
}
MULTISPEAKER_MODEL = {
"matcha_vctk": {"vocoder": "hifigan_univ_v1", "speaking_rate": 0.85, "spk": 0, "spk_range": (0, 107)}
}
SINGLESPEAKER_MODEL = {"matcha_ljspeech": {"vocoder": "hifigan_T2_v1", "speaking_rate": 0.95, "spk": None}}
def plot_spectrogram_to_numpy(spectrogram, filename):
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.title("Synthesised Mel-Spectrogram")
fig.canvas.draw()
plt.savefig(filename)
def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}")
x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"]), 0),
dtype=torch.long,
device=device,
)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
x_phones = sequence_to_text(x.squeeze(0).tolist())
print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
def get_texts(args):
if args.text:
texts = [args.text]
else:
with open(args.file, encoding="utf-8") as f:
texts = f.readlines()
return texts
def assert_required_models_available(args):
save_dir = get_user_data_dir()
if not hasattr(args, "checkpoint_path") and args.checkpoint_path is None:
model_path = args.checkpoint_path
else:
model_path = save_dir / f"{args.model}.ckpt"
assert_model_downloaded(model_path, MATCHA_URLS[args.model])
vocoder_path = save_dir / f"{args.vocoder}"
assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder])
return {"matcha": model_path, "vocoder": vocoder_path}
def load_hifigan(checkpoint_path, device):
h = AttrDict(v1)
hifigan = HiFiGAN(h).to(device)
hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def load_vocoder(vocoder_name, checkpoint_path, device):
print(f"[!] Loading {vocoder_name}!")
vocoder = None
if vocoder_name in ("hifigan_T2_v1", "hifigan_univ_v1"):
vocoder = load_hifigan(checkpoint_path, device)
else:
raise NotImplementedError(
f"Vocoder {vocoder_name} not implemented! define a load_<<vocoder_name>> method for it"
)
denoiser = Denoiser(vocoder, mode="zeros")
print(f"[+] {vocoder_name} loaded!")
return vocoder, denoiser
def load_matcha(model_name, checkpoint_path, device):
print(f"[!] Loading {model_name}!")
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
_ = model.eval()
print(f"[+] {model_name} loaded!")
return model
def to_waveform(mel, vocoder, denoiser=None):
audio = vocoder(mel).clamp(-1, 1)
if denoiser is not None:
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
def save_to_folder(filename: str, output: dict, folder: str):
folder = Path(folder)
folder.mkdir(exist_ok=True, parents=True)
plot_spectrogram_to_numpy(np.array(output["mel"].squeeze().float().cpu()), f"{filename}.png")
np.save(folder / f"{filename}", output["mel"].cpu().numpy())
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
return folder.resolve() / f"{filename}.wav"
def validate_args(args):
assert (
args.text or args.file
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
assert args.temperature >= 0, "Sampling temperature cannot be negative"
assert args.steps > 0, "Number of ODE steps must be greater than 0"
if args.checkpoint_path is None:
# When using pretrained models
if args.model in SINGLESPEAKER_MODEL:
args = validate_args_for_single_speaker_model(args)
if args.model in MULTISPEAKER_MODEL:
args = validate_args_for_multispeaker_model(args)
else:
# When using a custom model
if args.vocoder != "hifigan_univ_v1":
warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech."
warnings.warn(warn_, UserWarning)
if args.speaking_rate is None:
args.speaking_rate = 1.0
if args.batched:
assert args.batch_size > 0, "Batch size must be greater than 0"
assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
return args
def validate_args_for_multispeaker_model(args):
if args.vocoder is not None:
if args.vocoder != MULTISPEAKER_MODEL[args.model]["vocoder"]:
warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {MULTISPEAKER_MODEL[args.model]['vocoder']}"
warnings.warn(warn_, UserWarning)
else:
args.vocoder = MULTISPEAKER_MODEL[args.model]["vocoder"]
if args.speaking_rate is None:
args.speaking_rate = MULTISPEAKER_MODEL[args.model]["speaking_rate"]
spk_range = MULTISPEAKER_MODEL[args.model]["spk_range"]
if args.spk is not None:
assert (
args.spk >= spk_range[0] and args.spk <= spk_range[-1]
), f"Speaker ID must be between {spk_range} for this model."
else:
available_spk_id = MULTISPEAKER_MODEL[args.model]["spk"]
warn_ = f"[!] Speaker ID not provided! Using speaker ID {available_spk_id}"
warnings.warn(warn_, UserWarning)
args.spk = available_spk_id
return args
def validate_args_for_single_speaker_model(args):
if args.vocoder is not None:
if args.vocoder != SINGLESPEAKER_MODEL[args.model]["vocoder"]:
warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {SINGLESPEAKER_MODEL[args.model]['vocoder']}"
warnings.warn(warn_, UserWarning)
else:
args.vocoder = SINGLESPEAKER_MODEL[args.model]["vocoder"]
if args.speaking_rate is None:
args.speaking_rate = SINGLESPEAKER_MODEL[args.model]["speaking_rate"]
if args.spk != SINGLESPEAKER_MODEL[args.model]["spk"]:
warn_ = f"[-] Ignoring speaker id {args.spk} for {args.model}"
warnings.warn(warn_, UserWarning)
args.spk = SINGLESPEAKER_MODEL[args.model]["spk"]
return args
@torch.inference_mode()
def cli():
parser = argparse.ArgumentParser(
description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
)
parser.add_argument(
"--model",
type=str,
default="matcha_ljspeech",
help="Model to use",
choices=MATCHA_URLS.keys(),
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Path to the custom model checkpoint",
)
parser.add_argument(
"--vocoder",
type=str,
default="hifigan_univ_v1",
help="Vocoder to use (default: will use the one suggested with the pretrained model))",
choices=VOCODER_URLS.keys(),
)
parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
parser.add_argument(
"--temperature",
type=float,
default=0.667,
help="Variance of the x0 noise (default: 0.667)",
)
parser.add_argument(
"--speaking_rate",
type=float,
default=None,
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
)
parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)")
parser.add_argument("--cpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
parser.add_argument(
"--denoiser_strength",
type=float,
default=0.00025,
help="Strength of the vocoder bias denoiser (default: 0.00025)",
)
parser.add_argument(
"--output_folder",
type=str,
default=os.getcwd(),
help="Output folder to save results (default: current dir)",
)
parser.add_argument("--batched", action="store_true", help="Batched inference (default: False)")
parser.add_argument(
"--batch_size", type=int, default=32, help="Batch size only useful when --batched (default: 32)"
)
args = parser.parse_args()
args = validate_args(args)
device = get_device(args)
print_config(args)
paths = assert_required_models_available(args)
if args.checkpoint_path is not None:
print(f"[🍵] Loading custom model from {args.checkpoint_path}")
paths["matcha"] = args.checkpoint_path
args.model = "custom_model"
model = load_matcha(args.model, paths["matcha"], device)
vocoder, denoiser = load_vocoder(args.vocoder, paths["vocoder"], device)
texts = get_texts(args)
spk = torch.tensor([args.spk], device=device, dtype=torch.long) if args.spk is not None else None
if len(texts) == 1 or not args.batched:
unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
else:
batched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
class BatchedSynthesisDataset(torch.utils.data.Dataset):
def __init__(self, processed_texts):
self.processed_texts = processed_texts
def __len__(self):
return len(self.processed_texts)
def __getitem__(self, idx):
return self.processed_texts[idx]
def batched_collate_fn(batch):
x = []
x_lengths = []
for b in batch:
x.append(b["x"].squeeze(0))
x_lengths.append(b["x_lengths"])
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
x_lengths = torch.concat(x_lengths, dim=0)
return {"x": x, "x_lengths": x_lengths}
def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
total_rtf = []
total_rtf_w = []
processed_text = [process_text(i, text, "cpu") for i, text in enumerate(texts)]
dataloader = torch.utils.data.DataLoader(
BatchedSynthesisDataset(processed_text),
batch_size=args.batch_size,
collate_fn=batched_collate_fn,
num_workers=8,
)
for i, batch in enumerate(dataloader):
i = i + 1
start_t = dt.datetime.now()
output = model.synthesise(
batch["x"].to(device),
batch["x_lengths"].to(device),
n_timesteps=args.steps,
temperature=args.temperature,
spks=spk,
length_scale=args.speaking_rate,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])
print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}")
print(f"[🍵-Batch: {i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
total_rtf.append(output["rtf"])
total_rtf_w.append(rtf_w)
for j in range(output["mel"].shape[0]):
base_name = f"utterance_{j:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{j:03d}"
length = output["mel_lengths"][j]
new_dict = {"mel": output["mel"][j][:, :length], "waveform": output["waveform"][j][: length * 256]}
location = save_to_folder(base_name, new_dict, args.output_folder)
print(f"[🍵-{j}] Waveform saved: {location}")
print("".join(["="] * 100))
print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
total_rtf = []
total_rtf_w = []
for i, text in enumerate(texts):
i = i + 1
base_name = f"utterance_{i:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{i:03d}"
print("".join(["="] * 100))
text = text.strip()
text_processed = process_text(i, text, device)
print(f"[🍵] Whisking Matcha-T(ea)TS for: {i}")
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=args.steps,
temperature=args.temperature,
spks=spk,
length_scale=args.speaking_rate,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
# RTF with HiFiGAN
t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])
print(f"[🍵-{i}] Matcha-TTS RTF: {output['rtf']:.4f}")
print(f"[🍵-{i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
total_rtf.append(output["rtf"])
total_rtf_w.append(rtf_w)
location = save_to_folder(base_name, output, args.output_folder)
print(f"[+] Waveform saved: {location}")
print("".join(["="] * 100))
print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
def print_config(args):
print("[!] Configurations: ")
print(f"\t- Model: {args.model}")
print(f"\t- Vocoder: {args.vocoder}")
print(f"\t- Temperature: {args.temperature}")
print(f"\t- Speaking rate: {args.speaking_rate}")
print(f"\t- Number of ODE steps: {args.steps}")
print(f"\t- Speaker: {args.spk}")
def get_device(args):
if torch.cuda.is_available() and not args.cpu:
print("[+] GPU Available! Using GPU")
device = torch.device("cuda")
else:
print("[-] GPU not available or forced CPU run! Using CPU")
device = torch.device("cpu")
return device
if __name__ == "__main__":
cli()

View File

@@ -1,242 +0,0 @@
import random
from typing import Any, Dict, Optional
import torch
import torchaudio as ta
from lightning import LightningDataModule
from torch.utils.data.dataloader import DataLoader
from matcha.text import text_to_sequence
from matcha.utils.audio import mel_spectrogram
from matcha.utils.model import fix_len_compatibility, normalize
from matcha.utils.utils import intersperse
def parse_filelist(filelist_path, split_char="|"):
with open(filelist_path, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split_char) for line in f]
return filepaths_and_text
class TextMelDataModule(LightningDataModule):
def __init__( # pylint: disable=unused-argument
self,
name,
train_filelist_path,
valid_filelist_path,
batch_size,
num_workers,
pin_memory,
cleaners,
add_blank,
n_spks,
n_fft,
n_feats,
sample_rate,
hop_length,
win_length,
f_min,
f_max,
data_statistics,
seed,
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)
def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute things like random split twice!
"""
# load and split datasets only if not loaded already
self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.train_filelist_path,
self.hparams.n_spks,
self.hparams.cleaners,
self.hparams.add_blank,
self.hparams.n_fft,
self.hparams.n_feats,
self.hparams.sample_rate,
self.hparams.hop_length,
self.hparams.win_length,
self.hparams.f_min,
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
)
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.valid_filelist_path,
self.hparams.n_spks,
self.hparams.cleaners,
self.hparams.add_blank,
self.hparams.n_fft,
self.hparams.n_feats,
self.hparams.sample_rate,
self.hparams.hop_length,
self.hparams.win_length,
self.hparams.f_min,
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
)
def train_dataloader(self):
return DataLoader(
dataset=self.trainset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
collate_fn=TextMelBatchCollate(self.hparams.n_spks),
)
def val_dataloader(self):
return DataLoader(
dataset=self.validset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
collate_fn=TextMelBatchCollate(self.hparams.n_spks),
)
def teardown(self, stage: Optional[str] = None):
"""Clean up after fit or test."""
pass # pylint: disable=unnecessary-pass
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""Things to do when loading checkpoint."""
pass # pylint: disable=unnecessary-pass
class TextMelDataset(torch.utils.data.Dataset):
def __init__(
self,
filelist_path,
n_spks,
cleaners,
add_blank=True,
n_fft=1024,
n_mels=80,
sample_rate=22050,
hop_length=256,
win_length=1024,
f_min=0.0,
f_max=8000,
data_parameters=None,
seed=None,
):
self.filepaths_and_text = parse_filelist(filelist_path)
self.n_spks = n_spks
self.cleaners = cleaners
self.add_blank = add_blank
self.n_fft = n_fft
self.n_mels = n_mels
self.sample_rate = sample_rate
self.hop_length = hop_length
self.win_length = win_length
self.f_min = f_min
self.f_max = f_max
if data_parameters is not None:
self.data_parameters = data_parameters
else:
self.data_parameters = {"mel_mean": 0, "mel_std": 1}
random.seed(seed)
random.shuffle(self.filepaths_and_text)
def get_datapoint(self, filepath_and_text):
if self.n_spks > 1:
filepath, spk, text = (
filepath_and_text[0],
int(filepath_and_text[1]),
filepath_and_text[2],
)
else:
filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
def get_mel(self, filepath):
audio, sr = ta.load(filepath)
assert sr == self.sample_rate
mel = mel_spectrogram(
audio,
self.n_fft,
self.n_mels,
self.sample_rate,
self.hop_length,
self.win_length,
self.f_min,
self.f_max,
center=False,
).squeeze()
mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"])
return mel
def get_text(self, text, add_blank=True):
text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
if self.add_blank:
text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm)
return text_norm, cleaned_text
def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index])
return datapoint
def __len__(self):
return len(self.filepaths_and_text)
class TextMelBatchCollate:
def __init__(self, n_spks):
self.n_spks = n_spks
def __call__(self, batch):
B = len(batch)
y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
y_max_length = fix_len_compatibility(y_max_length)
x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
n_feats = batch[0]["y"].shape[-2]
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], []
spks = []
filepaths, x_texts = [], []
for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1])
x_lengths.append(x_.shape[-1])
y[i, :, : y_.shape[-1]] = y_
x[i, : x_.shape[-1]] = x_
spks.append(item["spk"])
filepaths.append(item["filepath"])
x_texts.append(item["x_text"])
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {
"x": x,
"x_lengths": x_lengths,
"y": y,
"y_lengths": y_lengths,
"spks": spks,
"filepaths": filepaths,
"x_texts": x_texts,
}

View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2020 Jungil Kong
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,101 +0,0 @@
# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
In our [paper](https://arxiv.org/abs/2010.05646),
we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
We provide our implementation and pretrained models as open source in this repository.
**Abstract :**
Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
Although such methods improve the sampling efficiency and memory usage,
their sample quality has not yet reached that of autoregressive and flow-based generative models.
In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
As speech audio consists of sinusoidal signals with various periods,
we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
faster than real-time on CPU with comparable quality to an autoregressive counterpart.
Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
## Pre-requisites
1. Python >= 3.6
2. Clone this repository.
3. Install python requirements. Please refer [requirements.txt](requirements.txt)
4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
And move all wav files to `LJSpeech-1.1/wavs`
## Training
```
python train.py --config config_v1.json
```
To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
You can change the path by adding `--checkpoint_path` option.
Validation loss during training with V1 generator.<br>
![validation loss](./validation_loss.png)
## Pretrained Model
You can also use pretrained models we provide.<br/>
[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
Details of each folder are as in follows:
| Folder Name | Generator | Dataset | Fine-Tuned |
| ------------ | --------- | --------- | ------------------------------------------------------ |
| LJ_V1 | V1 | LJSpeech | No |
| LJ_V2 | V2 | LJSpeech | No |
| LJ_V3 | V3 | LJSpeech | No |
| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
| VCTK_V1 | V1 | VCTK | No |
| VCTK_V2 | V2 | VCTK | No |
| VCTK_V3 | V3 | VCTK | No |
| UNIVERSAL_V1 | V1 | Universal | No |
We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
## Fine-Tuning
1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
Example:
` Audio File : LJ001-0001.wav
Mel-Spectrogram File : LJ001-0001.npy`
2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
3. Run the following command.
```
python train.py --fine_tuning True --config config_v1.json
```
For other command line options, please refer to the training section.
## Inference from wav file
1. Make `test_files` directory and copy wav files into the directory.
2. Run the following command.
` python inference.py --checkpoint_file [generator checkpoint file path]`
Generated wav files are saved in `generated_files` by default.<br>
You can change the path by adding `--output_dir` option.
## Inference for end-to-end speech synthesis
1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
[Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
2. Run the following command.
` python inference_e2e.py --checkpoint_file [generator checkpoint file path]`
Generated wav files are saved in `generated_files_from_mel` by default.<br>
You can change the path by adding `--output_dir` option.
## Acknowledgements
We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.

View File

@@ -1,28 +0,0 @@
v1 = {
"resblock": "1",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.0004,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [8, 8, 2, 2],
"upsample_kernel_sizes": [16, 16, 4, 4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"resblock_initial_channel": 256,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_loss": None,
"num_workers": 4,
"dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
}

View File

@@ -1,64 +0,0 @@
# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
import torch
class Denoiser(torch.nn.Module):
"""Removes model bias from audio produced with waveglow"""
def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
super().__init__()
self.filter_length = filter_length
self.hop_length = int(filter_length / n_overlap)
self.win_length = win_length
dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
self.device = device
if mode == "zeros":
mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
elif mode == "normal":
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
else:
raise Exception(f"Mode {mode} if not supported")
def stft_fn(audio, n_fft, hop_length, win_length, window):
spec = torch.stft(
audio,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
return_complex=True,
)
spec = torch.view_as_real(spec)
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
self.stft = lambda x: stft_fn(
audio=x,
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
window=torch.hann_window(self.win_length, device=device),
)
self.istft = lambda x, y: torch.istft(
torch.complex(x * torch.cos(y), x * torch.sin(y)),
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
window=torch.hann_window(self.win_length, device=device),
)
with torch.no_grad():
bias_audio = vocoder(mel_input).float().squeeze(0)
bias_spec, _ = self.stft(bias_audio)
self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
@torch.inference_mode()
def forward(self, audio, strength=0.0005):
audio_spec, audio_angles = self.stft(audio)
audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
audio_denoised = self.istft(audio_spec_denoised, audio_angles)
return audio_denoised

View File

@@ -1,17 +0,0 @@
""" from https://github.com/jik876/hifi-gan """
import os
import shutil
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
def build_env(config, config_name, path):
t_path = os.path.join(path, config_name)
if config != t_path:
os.makedirs(path, exist_ok=True)
shutil.copyfile(config, os.path.join(path, config_name))

View File

@@ -1,217 +0,0 @@
""" from https://github.com/jik876/hifi-gan """
import math
import os
import random
import numpy as np
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from librosa.util import normalize
from scipy.io.wavfile import read
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[str(y.device)],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
def get_dataset_filelist(a):
with open(a.input_training_file, encoding="utf-8") as fi:
training_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
with open(a.input_validation_file, encoding="utf-8") as fi:
validation_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
return training_files, validation_files
class MelDataset(torch.utils.data.Dataset):
def __init__(
self,
training_files,
segment_size,
n_fft,
num_mels,
hop_size,
win_size,
sampling_rate,
fmin,
fmax,
split=True,
shuffle=True,
n_cache_reuse=1,
device=None,
fmax_loss=None,
fine_tuning=False,
base_mels_path=None,
):
self.audio_files = training_files
random.seed(1234)
if shuffle:
random.shuffle(self.audio_files)
self.segment_size = segment_size
self.sampling_rate = sampling_rate
self.split = split
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.fmax_loss = fmax_loss
self.cached_wav = None
self.n_cache_reuse = n_cache_reuse
self._cache_ref_count = 0
self.device = device
self.fine_tuning = fine_tuning
self.base_mels_path = base_mels_path
def __getitem__(self, index):
filename = self.audio_files[index]
if self._cache_ref_count == 0:
audio, sampling_rate = load_wav(filename)
audio = audio / MAX_WAV_VALUE
if not self.fine_tuning:
audio = normalize(audio) * 0.95
self.cached_wav = audio
if sampling_rate != self.sampling_rate:
raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
self._cache_ref_count = self.n_cache_reuse
else:
audio = self.cached_wav
self._cache_ref_count -= 1
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)
if not self.fine_tuning:
if self.split:
if audio.size(1) >= self.segment_size:
max_audio_start = audio.size(1) - self.segment_size
audio_start = random.randint(0, max_audio_start)
audio = audio[:, audio_start : audio_start + self.segment_size]
else:
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
mel = mel_spectrogram(
audio,
self.n_fft,
self.num_mels,
self.sampling_rate,
self.hop_size,
self.win_size,
self.fmin,
self.fmax,
center=False,
)
else:
mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
mel = torch.from_numpy(mel)
if len(mel.shape) < 3:
mel = mel.unsqueeze(0)
if self.split:
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
if audio.size(1) >= self.segment_size:
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
else:
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
mel_loss = mel_spectrogram(
audio,
self.n_fft,
self.num_mels,
self.sampling_rate,
self.hop_size,
self.win_size,
self.fmin,
self.fmax_loss,
center=False,
)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
def __len__(self):
return len(self.audio_files)

View File

@@ -1,368 +0,0 @@
""" from https://github.com/jik876/hifi-gan """
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from .xutils import get_padding, init_weights
LRELU_SLOPE = 0.1
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.h = h
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.h = h
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Generator(torch.nn.Module):
def __init__(self, h):
super().__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super().__init__()
self.period = period
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2),
DiscriminatorP(3),
DiscriminatorP(5),
DiscriminatorP(7),
DiscriminatorP(11),
]
)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for _, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(torch.nn.Module):
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
]
)
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i - 1](y)
y_hat = self.meanpools[i - 1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses

View File

@@ -1,60 +0,0 @@
""" from https://github.com/jik876/hifi-gan """
import glob
import os
import matplotlib
import torch
from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def apply_weight_norm(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
weight_norm(m)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def save_checkpoint(filepath, obj):
print(f"Saving checkpoint to {filepath}")
torch.save(obj, filepath)
print("Complete.")
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + "????????")
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return None
return sorted(cp_list)[-1]

View File

@@ -1,209 +0,0 @@
"""
This is a base lightning module that can be used to train a model.
The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
"""
import inspect
from abc import ABC
from typing import Any, Dict
import torch
from lightning import LightningModule
from lightning.pytorch.utilities import grad_norm
from matcha import utils
from matcha.utils.utils import plot_tensor
log = utils.get_pylogger(__name__)
class BaseLightningClass(LightningModule, ABC):
def update_data_statistics(self, data_statistics):
if data_statistics is None:
data_statistics = {
"mel_mean": 0.0,
"mel_std": 1.0,
}
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
def configure_optimizers(self) -> Any:
optimizer = self.hparams.optimizer(params=self.parameters())
if self.hparams.scheduler not in (None, {}):
scheduler_args = {}
# Manage last epoch for exponential schedulers
if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
if hasattr(self, "ckpt_loaded_epoch"):
current_epoch = self.ckpt_loaded_epoch - 1
else:
current_epoch = -1
scheduler_args.update({"optimizer": optimizer})
scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
scheduler.last_epoch = current_epoch
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": self.hparams.scheduler.lightning_args.interval,
"frequency": self.hparams.scheduler.lightning_args.frequency,
"name": "learning_rate",
},
}
return {"optimizer": optimizer}
def get_losses(self, batch):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
dur_loss, prior_loss, diff_loss, *_ = self(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
out_size=self.out_size,
)
return {
"dur_loss": dur_loss,
"prior_loss": prior_loss,
"diff_loss": diff_loss,
}
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
def training_step(self, batch: Any, batch_idx: int):
loss_dict = self.get_losses(batch)
self.log(
"step",
float(self.global_step),
on_step=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/train_dur_loss",
loss_dict["dur_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/train_prior_loss",
loss_dict["prior_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/train_diff_loss",
loss_dict["diff_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
total_loss = sum(loss_dict.values())
self.log(
"loss/train",
total_loss,
on_step=True,
on_epoch=True,
logger=True,
prog_bar=True,
sync_dist=True,
)
return {"loss": total_loss, "log": loss_dict}
def validation_step(self, batch: Any, batch_idx: int):
loss_dict = self.get_losses(batch)
self.log(
"sub_loss/val_dur_loss",
loss_dict["dur_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/val_prior_loss",
loss_dict["prior_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"sub_loss/val_diff_loss",
loss_dict["diff_loss"],
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
total_loss = sum(loss_dict.values())
self.log(
"loss/val",
total_loss,
on_step=True,
on_epoch=True,
logger=True,
prog_bar=True,
sync_dist=True,
)
return total_loss
def on_validation_end(self) -> None:
if self.trainer.is_global_zero:
one_batch = next(iter(self.trainer.val_dataloaders))
if self.current_epoch == 0:
log.debug("Plotting original samples")
for i in range(2):
y = one_batch["y"][i].unsqueeze(0).to(self.device)
self.logger.experiment.add_image(
f"original/{i}",
plot_tensor(y.squeeze().cpu()),
self.current_epoch,
dataformats="HWC",
)
log.debug("Synthesising...")
for i in range(2):
x = one_batch["x"][i].unsqueeze(0).to(self.device)
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
attn = output["attn"]
self.logger.experiment.add_image(
f"generated_enc/{i}",
plot_tensor(y_enc.squeeze().cpu()),
self.current_epoch,
dataformats="HWC",
)
self.logger.experiment.add_image(
f"generated_dec/{i}",
plot_tensor(y_dec.squeeze().cpu()),
self.current_epoch,
dataformats="HWC",
)
self.logger.experiment.add_image(
f"alignment/{i}",
plot_tensor(attn.squeeze().cpu()),
self.current_epoch,
dataformats="HWC",
)
def on_before_optimizer_step(self, optimizer):
self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})

View File

@@ -1,443 +0,0 @@
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from conformer import ConformerBlock
from diffusers.models.activations import get_activation
from einops import pack, rearrange, repeat
from matcha.models.components.transformer import BasicTransformerBlock
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Block1D(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
torch.nn.GroupNorm(groups, dim_out),
nn.Mish(),
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class ResnetBlock1D(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super().__init__()
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.block1 = Block1D(dim, dim_out, groups=groups)
self.block2 = Block1D(dim_out, dim_out, groups=groups)
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
class Downsample1D(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs):
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
return outputs
class ConformerWrapper(ConformerBlock):
def __init__( # pylint: disable=useless-super-delegation
self,
*,
dim,
dim_head=64,
heads=8,
ff_mult=4,
conv_expansion_factor=2,
conv_kernel_size=31,
attn_dropout=0,
ff_dropout=0,
conv_dropout=0,
conv_causal=False,
):
super().__init__(
dim=dim,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
conv_dropout=conv_dropout,
conv_causal=conv_causal,
)
def forward(
self,
hidden_states,
attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
):
return super().forward(x=hidden_states, mask=attention_mask.bool())
class Decoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
down_block_type="transformer",
mid_block_type="transformer",
up_block_type="transformer",
):
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
down_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for i in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
self.get_block(
mid_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i]
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
dim=2 * input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
self.get_block(
up_block_type,
output_channel,
attention_head_dim,
num_heads,
dropout,
act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
# nn.init.normal_(self.final_proj.weight)
@staticmethod
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
if block_type == "conformer":
block = ConformerWrapper(
dim=dim,
dim_head=attention_head_dim,
heads=num_heads,
ff_mult=1,
conv_expansion_factor=2,
ff_dropout=dropout,
attn_dropout=dropout,
conv_dropout=dropout,
conv_kernel_size=31,
)
elif block_type == "transformer":
block = BasicTransformerBlock(
dim=dim,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
else:
raise ValueError(f"Unknown block type {block_type}")
return block
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c")
mask_down = rearrange(mask_down, "b 1 t -> b t")
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=mask_down,
timestep=t,
)
x = rearrange(x, "b t c -> b c t")
mask_down = rearrange(mask_down, "b t -> b 1 t")
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c")
mask_mid = rearrange(mask_mid, "b 1 t -> b t")
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=mask_mid,
timestep=t,
)
x = rearrange(x, "b t c -> b c t")
mask_mid = rearrange(mask_mid, "b t -> b 1 t")
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
x = rearrange(x, "b c t -> b t c")
mask_up = rearrange(mask_up, "b 1 t -> b t")
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=mask_up,
timestep=t,
)
x = rearrange(x, "b t c -> b c t")
mask_up = rearrange(mask_up, "b t -> b 1 t")
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask

View File

@@ -1,448 +0,0 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack
from matcha.models.components.decoder import SinusoidalPosEmb, TimestepEmbedding
from matcha.models.components.text_encoder import LayerNorm
# Define available networks
class DurationPredictorNetwork(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class DurationPredictorNetworkWithTimeStep(nn.Module):
"""Similar architecture but with a time embedding support"""
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.time_embeddings = SinusoidalPosEmb(filter_channels)
self.time_mlp = TimestepEmbedding(
in_channels=filter_channels,
time_embed_dim=filter_channels,
act_fn="silu",
)
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask, enc_outputs, t):
t = self.time_embeddings(t)
t = self.time_mlp(t).unsqueeze(-1)
x = pack([x, enc_outputs], "b * t")[0]
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
# Define available methods to compute loss
# Simple MSE deterministic
class DeterministicDurationPredictor(nn.Module):
def __init__(self, params):
super().__init__()
self.estimator = DurationPredictorNetwork(
params.n_channels + (params.spk_emb_dim if params.n_spks > 1 else 0),
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
@torch.inference_mode()
def forward(self, x, x_mask):
return self.estimator(x, x_mask)
def compute_loss(self, durations, enc_outputs, x_mask):
return F.mse_loss(self.estimator(enc_outputs, x_mask), durations, reduction="sum") / torch.sum(x_mask)
# Flow Matching duration predictor
class FlowMatchingDurationPrediction(nn.Module):
def __init__(self, params) -> None:
super().__init__()
self.estimator = DurationPredictorNetworkWithTimeStep(
1
+ params.n_channels
+ (
params.spk_emb_dim if params.n_spks > 1 else 0
), # 1 for the durations and n_channels for encoder outputs
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
self.sigma_min = params.sigma_min
self.n_steps = params.n_steps
@torch.inference_mode()
def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
if n_timesteps is None:
n_timesteps = self.n_steps
b, _, t = enc_outputs.shape
z = torch.randn((b, 1, t), device=enc_outputs.device, dtype=enc_outputs.dtype) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=enc_outputs.device)
return self.solve_euler(z, t_span=t_span, enc_outputs=enc_outputs, mask=mask)
def solve_euler(self, x, t_span, enc_outputs, mask):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, enc_outputs, t)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, enc_outputs, mask):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
enc_outputs = enc_outputs.detach() # don't update encoder from the duration predictor
b, _, t = enc_outputs.shape
# random timestep
t = torch.rand([b, 1, 1], device=enc_outputs.device, dtype=enc_outputs.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, enc_outputs, t.squeeze()), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss
# VITS discrete normalising flow based duration predictor
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = Log()
self.flows = nn.ModuleList()
self.flows.append(ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
# Meta class to wrap all duration predictors
class DP(nn.Module):
def __init__(self, params):
super().__init__()
self.name = params.name
if params.name == "deterministic":
self.dp = DeterministicDurationPredictor(
params,
)
elif params.name == "flow_matching":
self.dp = FlowMatchingDurationPrediction(
params,
)
else:
raise ValueError(f"Invalid duration predictor configuration: {params.name}")
@torch.inference_mode()
def forward(self, enc_outputs, mask):
return self.dp(enc_outputs, mask)
def compute_loss(self, durations, enc_outputs, mask):
return self.dp.compute_loss(durations, enc_outputs, mask)

View File

@@ -1,132 +0,0 @@
from abc import ABC
import torch
import torch.nn.functional as F
from matcha.models.components.decoder import Decoder
from matcha.utils.pylogger import get_pylogger
log = get_pylogger(__name__)
class BASECFM(torch.nn.Module, ABC):
def __init__(
self,
n_feats,
cfm_params,
n_spks=1,
spk_emb_dim=128,
):
super().__init__()
self.n_feats = n_feats
self.n_spks = n_spks
self.spk_emb_dim = spk_emb_dim
self.solver = cfm_params.solver
if hasattr(cfm_params, "sigma_min"):
self.sigma_min = cfm_params.sigma_min
else:
self.sigma_min = 1e-4
self.estimator = None
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss, y
class CFM(BASECFM):
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
# Just change the architecture of the estimator here
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)

View File

@@ -1,376 +0,0 @@
""" from https://github.com/jaywalnut310/glow-tts """
import math
import torch
import torch.nn as nn
from einops import rearrange
import matcha.utils as utils
from matcha.utils.model import sequence_mask
log = utils.get_pylogger(__name__)
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-4):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = torch.nn.Parameter(torch.ones(channels))
self.beta = torch.nn.Parameter(torch.zeros(channels))
def forward(self, x):
n_dims = len(x.shape)
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
shape = [1, -1] + [1] * (n_dims - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.conv_layers = torch.nn.ModuleList()
self.norm_layers = torch.nn.ModuleList()
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class RotaryPositionalEmbeddings(nn.Module):
"""
## RoPE module
Rotary encoding transforms pairs of features by rotating in the 2D plane.
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
by an angle depending on the position of the token.
"""
def __init__(self, d: int, base: int = 10_000):
r"""
* `d` is the number of features $d$
* `base` is the constant used for calculating $\Theta$
"""
super().__init__()
self.base = base
self.d = int(d)
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
r"""
Cache $\cos$ and $\sin$ values
"""
# Return if cache is already built
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
# Get sequence length
seq_len = x.shape[0]
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
# Concatenate so that for row $m$ we have
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
# Cache them
self.cos_cached = idx_theta2.cos()[:, None, None, :]
self.sin_cached = idx_theta2.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor):
# $\frac{d}{2}$
d_2 = self.d // 2
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
# Cache $\cos$ and $\sin$ values
x = rearrange(x, "b h t d -> t b h d")
self._build_cache(x)
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
# Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope)
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
heads_share=True,
p_dropout=0.0,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.heads_share = heads_share
self.proximal_bias = proximal_bias
self.p_dropout = p_dropout
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
# from https://nn.labml.ai/transformers/rope/index.html
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout)
torch.nn.init.xavier_uniform_(self.conv_q.weight)
torch.nn.init.xavier_uniform_(self.conv_k.weight)
if proximal_init:
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
torch.nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
b, d, t_s, t_t = (*key.size(), query.size(2))
query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
query = self.query_rotary_pe(query)
key = self.key_rotary_pe(key)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
p_attn = torch.nn.functional.softmax(scores, dim=-1)
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
return output, p_attn
@staticmethod
def _attention_bias_proximal(length):
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.drop = torch.nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.attn_layers = torch.nn.ModuleList()
self.norm_layers_1 = torch.nn.ModuleList()
self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList()
for _ in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.n_layers):
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class TextEncoder(nn.Module):
def __init__(
self,
encoder_type,
encoder_params,
n_vocab,
n_spks=1,
spk_emb_dim=128,
):
super().__init__()
self.encoder_type = encoder_type
self.n_vocab = n_vocab
self.n_feats = encoder_params.n_feats
self.n_channels = encoder_params.n_channels
self.spk_emb_dim = spk_emb_dim
self.n_spks = n_spks
self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
if encoder_params.prenet:
self.prenet = ConvReluNorm(
self.n_channels,
self.n_channels,
self.n_channels,
kernel_size=5,
n_layers=3,
p_dropout=0.5,
)
else:
self.prenet = lambda x, x_mask: x
self.encoder = Encoder(
encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
encoder_params.filter_channels,
encoder_params.n_heads,
encoder_params.n_layers,
encoder_params.kernel_size,
encoder_params.p_dropout,
)
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
def forward(self, x, x_lengths, spks=None):
"""Run forward pass to the transformer based encoder and duration predictor
Args:
x (torch.Tensor): text input
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): text input lengths
shape: (batch_size,)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size,)
Returns:
mu (torch.Tensor): average output of the encoder
shape: (batch_size, n_feats, max_text_length)
logw (torch.Tensor): log duration predicted by the duration predictor
shape: (batch_size, 1, max_text_length)
x_mask (torch.Tensor): mask for the text input
shape: (batch_size, 1, max_text_length)
"""
x = self.emb(x) * math.sqrt(self.n_channels)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.prenet(x, x_mask)
if self.n_spks > 1:
x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
x = self.encoder(x, x_mask)
mu = self.proj_m(x) * x_mask
# x_dp = torch.detach(x)
# logw = self.proj_w(x_dp, x_mask)
return mu, x, x_mask

View File

@@ -1,316 +0,0 @@
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from diffusers.models.attention import (
GEGLU,
GELU,
AdaLayerNorm,
AdaLayerNormZero,
ApproximateGELU,
)
from diffusers.models.attention_processor import Attention
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.utils.torch_utils import maybe_allow_in_graph
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super().__init__()
self.in_features = out_features if isinstance(out_features, list) else [out_features]
self.proj = LoRACompatibleLinear(in_features, out_features)
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
"""
x = self.proj(x)
if self.alpha_logscale:
alpha = torch.exp(self.alpha)
beta = torch.exp(self.beta)
else:
alpha = self.alpha
beta = self.beta
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
return x
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
elif activation_fn == "snakebeta":
act_fn = SnakeBeta(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
# scale_qk=False, # uncomment this to not to use flash attention
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states

View File

@@ -1,244 +0,0 @@
import datetime as dt
import math
import random
import torch
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.duration_predictors import DP
from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder
from matcha.utils.model import (
denormalize,
fix_len_compatibility,
generate_path,
sequence_mask,
)
log = utils.get_pylogger(__name__)
class MatchaTTS(BaseLightningClass): # 🍵
def __init__(
self,
n_vocab,
n_spks,
spk_emb_dim,
n_feats,
encoder,
duration_predictor,
decoder,
cfm,
data_statistics,
out_size,
optimizer=None,
scheduler=None,
prior_loss=True,
):
super().__init__()
self.save_hyperparameters(logger=False)
self.n_vocab = n_vocab
self.n_spks = n_spks
self.spk_emb_dim = spk_emb_dim
self.n_feats = n_feats
self.out_size = out_size
self.prior_loss = prior_loss
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
self.encoder = TextEncoder(
encoder.encoder_type,
encoder.encoder_params,
n_vocab,
n_spks,
spk_emb_dim,
)
self.dp = DP(duration_predictor)
self.decoder = CFM(
in_channels=2 * encoder.encoder_params.n_feats,
out_channel=encoder.encoder_params.n_feats,
cfm_params=cfm,
decoder_params=decoder,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
self.update_data_statistics(data_statistics)
@torch.inference_mode()
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
"""
Generates mel-spectrogram from text. Returns:
1. encoder outputs
2. decoder outputs
3. generated alignment
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
temperature (float, optional): controls variance of terminal distribution.
spks (bool, optional): speaker ids.
shape: (batch_size,)
length_scale (float, optional): controls speech pace.
Increase value to slow down generated speech and vice versa.
Returns:
dict: {
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Average mel spectrogram generated by the encoder
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Refined mel spectrogram improved by the CFM
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
# Alignment map between text and mel spectrogram
"mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Denormalized mel spectrogram
"mel_lengths": torch.Tensor, shape: (batch_size,),
# Lengths of mel spectrograms
"rtf": float,
# Real-time factor
"""
# For RTF computation
t = dt.datetime.now()
if self.n_spks > 1:
# Get speaker embedding
spks = self.spk_emb(spks.long())
# Get encoder_outputs `mu_x` and encoded text `enc_output`
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
# Get log-scaled token durations `logw`
logw = self.dp(enc_output, x_mask)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
# print(w_ceil)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max()
y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length]
# Generate sample tracing the probability flow
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
decoder_outputs = decoder_outputs[:, :, :y_max_length]
t = (dt.datetime.now() - t).total_seconds()
rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
return {
"encoder_outputs": encoder_outputs,
"decoder_outputs": decoder_outputs,
"attn": attn[:, :, :y_max_length],
"mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
"mel_lengths": y_lengths,
"rtf": rtf,
}
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs.
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
y (torch.Tensor): batch of corresponding mel-spectrograms.
shape: (batch_size, n_feats, max_mel_length)
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
shape: (batch_size,)
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
spks (torch.Tensor, optional): speaker ids.
shape: (batch_size,)
"""
if self.n_spks > 1:
# Get speaker embedding
spks = self.spk_emb(spks)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
y_max_length = y.shape[-1]
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach()
# Compute loss between predicted log-scaled durations and those obtained from MAS
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = self.dp.compute_loss(logw_, enc_output, x_mask)
# Cut a small segment of mel-spectrogram in order to increase batch size
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
# - Do not need this hack for Matcha-TTS, but it works with it as well
if not isinstance(out_size, type(None)):
max_offset = (y_lengths - out_size).clamp(0)
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
out_offset = torch.LongTensor(
[torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
).to(y_lengths)
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
y_cut_lengths = []
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
y_cut_lengths.append(y_cut_length)
cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
y_cut_lengths = torch.LongTensor(y_cut_lengths)
y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
attn = attn_cut
y = y_cut
y_mask = y_cut_mask
# Align encoded text with mel-spectrogram and get mu_y segment
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
# Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
if self.prior_loss:
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss, attn

Some files were not shown because too many files have changed in this diff Show More