mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 17:59:19 +08:00
Adding docstrings
This commit is contained in:
2
Makefile
2
Makefile
@@ -38,4 +38,4 @@ train-ljspeech-min: ## Train the model with minimum memory
|
|||||||
python matcha/train.py experiment=ljspeech_min_memory
|
python matcha/train.py experiment=ljspeech_min_memory
|
||||||
|
|
||||||
start_app: ## Start the app
|
start_app: ## Start the app
|
||||||
python matcha/app.py
|
python matcha/app.py
|
||||||
|
|||||||
@@ -6,9 +6,16 @@ import gradio as gr
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from matcha.cli import (MATCHA_URLS, VOCODER_URL, assert_model_downloaded,
|
from matcha.cli import (
|
||||||
get_device, load_matcha, load_vocoder, process_text,
|
MATCHA_URLS,
|
||||||
to_waveform)
|
VOCODER_URL,
|
||||||
|
assert_model_downloaded,
|
||||||
|
get_device,
|
||||||
|
load_matcha,
|
||||||
|
load_vocoder,
|
||||||
|
process_text,
|
||||||
|
to_waveform,
|
||||||
|
)
|
||||||
from matcha.utils.utils import get_user_data_dir, plot_tensor
|
from matcha.utils.utils import get_user_data_dir, plot_tensor
|
||||||
|
|
||||||
LOCATION = Path(get_user_data_dir())
|
LOCATION = Path(get_user_data_dir())
|
||||||
@@ -59,6 +66,7 @@ def run_full_synthesis(text, n_timesteps, mel_temp, length_scale):
|
|||||||
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale)
|
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale)
|
||||||
return phones, audio, mel_spectrogram
|
return phones, audio, mel_spectrogram
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
|
description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
|
||||||
### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
|
### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
|
||||||
@@ -80,7 +88,6 @@ def main():
|
|||||||
with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
|
with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
|
||||||
processed_text = gr.State(value=None)
|
processed_text = gr.State(value=None)
|
||||||
processed_text_len = gr.State(value=None)
|
processed_text_len = gr.State(value=None)
|
||||||
mel_variable = gr.State(value=None)
|
|
||||||
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@@ -136,7 +143,7 @@ def main():
|
|||||||
audio = gr.Audio(interactive=False, label="Audio")
|
audio = gr.Audio(interactive=False, label="Audio")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
examples = gr.Examples(
|
examples = gr.Examples( # pylint: disable=unused-variable
|
||||||
examples=[
|
examples=[
|
||||||
[
|
[
|
||||||
"We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
|
"We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
|
||||||
@@ -202,6 +209,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
demo.queue(concurrency_count=5).launch(share=True)
|
demo.queue(concurrency_count=5).launch(share=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ from matcha.hifigan.env import AttrDict
|
|||||||
from matcha.hifigan.models import Generator as HiFiGAN
|
from matcha.hifigan.models import Generator as HiFiGAN
|
||||||
from matcha.models.matcha_tts import MatchaTTS
|
from matcha.models.matcha_tts import MatchaTTS
|
||||||
from matcha.text import sequence_to_text, text_to_sequence
|
from matcha.text import sequence_to_text, text_to_sequence
|
||||||
from matcha.utils.utils import (assert_model_downloaded, get_user_data_dir,
|
from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
|
||||||
intersperse)
|
|
||||||
|
|
||||||
MATCHA_URLS = {"matcha_ljspeech": ""} # , "matcha_vctk": ""} # Coming soon
|
MATCHA_URLS = {"matcha_ljspeech": ""} # , "matcha_vctk": ""} # Coming soon
|
||||||
|
|
||||||
@@ -146,7 +145,9 @@ def validate_args(args):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def cli():
|
def cli():
|
||||||
parser = argparse.ArgumentParser(description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching")
|
parser = argparse.ArgumentParser(
|
||||||
|
description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class Downsample1D(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedding(nn.Module):
|
class TimestepEmbedding(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -115,6 +116,7 @@ class TimestepEmbedding(nn.Module):
|
|||||||
sample = self.post_act(sample)
|
sample = self.post_act(sample)
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
class Upsample1D(nn.Module):
|
class Upsample1D(nn.Module):
|
||||||
"""A 1D upsampling layer with an optional convolution.
|
"""A 1D upsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class BASECFM(torch.nn.Module, ABC):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
sample: generated mel-spectrogram
|
sample: generated mel-spectrogram
|
||||||
shape: (batch_size, n_feats, mel_timesteps)
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
"""
|
"""
|
||||||
z = torch.randn_like(mu) * temperature
|
z = torch.randn_like(mu) * temperature
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||||
|
|||||||
@@ -2,8 +2,13 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm,
|
from diffusers.models.attention import (
|
||||||
AdaLayerNormZero, ApproximateGELU)
|
GEGLU,
|
||||||
|
GELU,
|
||||||
|
AdaLayerNorm,
|
||||||
|
AdaLayerNormZero,
|
||||||
|
ApproximateGELU,
|
||||||
|
)
|
||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
from diffusers.models.lora import LoRACompatibleLinear
|
from diffusers.models.lora import LoRACompatibleLinear
|
||||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||||
@@ -38,7 +43,7 @@ class SnakeBeta(nn.Module):
|
|||||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||||
alpha will be trained along with the rest of your model.
|
alpha will be trained along with the rest of your model.
|
||||||
"""
|
"""
|
||||||
super(SnakeBeta, self).__init__()
|
super().__init__()
|
||||||
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
||||||
self.proj = LoRACompatibleLinear(in_features, out_features)
|
self.proj = LoRACompatibleLinear(in_features, out_features)
|
||||||
|
|
||||||
@@ -73,8 +78,8 @@ class SnakeBeta(nn.Module):
|
|||||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
A feed-forward layer.
|
A feed-forward layer.
|
||||||
@@ -127,8 +132,7 @@ class FeedForward(nn.Module):
|
|||||||
for module in self.net:
|
for module in self.net:
|
||||||
hidden_states = module(hidden_states)
|
hidden_states = module(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
@@ -217,7 +221,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
bias=attention_bias,
|
bias=attention_bias,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
# scale_qk=False, # uncomment this to not to use flash attention
|
# scale_qk=False, # uncomment this to not to use flash attention
|
||||||
) # is self-attn if encoder_hidden_states is none
|
) # is self-attn if encoder_hidden_states is none
|
||||||
else:
|
else:
|
||||||
self.norm2 = None
|
self.norm2 = None
|
||||||
@@ -309,4 +313,4 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
hidden_states = ff_output + hidden_states
|
hidden_states = ff_output + hidden_states
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -9,9 +9,13 @@ from matcha import utils
|
|||||||
from matcha.models.baselightningmodule import BaseLightningClass
|
from matcha.models.baselightningmodule import BaseLightningClass
|
||||||
from matcha.models.components.flow_matching import CFM
|
from matcha.models.components.flow_matching import CFM
|
||||||
from matcha.models.components.text_encoder import TextEncoder
|
from matcha.models.components.text_encoder import TextEncoder
|
||||||
from matcha.utils.model import (denormalize, duration_loss,
|
from matcha.utils.model import (
|
||||||
fix_len_compatibility, generate_path,
|
denormalize,
|
||||||
sequence_mask)
|
duration_loss,
|
||||||
|
fix_len_compatibility,
|
||||||
|
generate_path,
|
||||||
|
sequence_mask,
|
||||||
|
)
|
||||||
|
|
||||||
log = utils.get_pylogger(__name__)
|
log = utils.get_pylogger(__name__)
|
||||||
|
|
||||||
@@ -83,7 +87,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
shape: (batch_size,)
|
shape: (batch_size,)
|
||||||
length_scale (float, optional): controls speech pace.
|
length_scale (float, optional): controls speech pace.
|
||||||
Increase value to slow down generated speech and vice versa.
|
Increase value to slow down generated speech and vice versa.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: {
|
dict: {
|
||||||
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ def get_user_data_dir(appname="matcha_tts"):
|
|||||||
ans = Path("~/Library/Application Support/").expanduser()
|
ans = Path("~/Library/Application Support/").expanduser()
|
||||||
else:
|
else:
|
||||||
ans = Path.home().joinpath(".local/share")
|
ans = Path.home().joinpath(".local/share")
|
||||||
|
|
||||||
final_path = ans.joinpath(appname)
|
final_path = ans.joinpath(appname)
|
||||||
final_path.mkdir(parents=True, exist_ok=True)
|
final_path.mkdir(parents=True, exist_ok=True)
|
||||||
return final_path
|
return final_path
|
||||||
|
|||||||
@@ -41,4 +41,4 @@ ipywidgets
|
|||||||
gradio
|
gradio
|
||||||
gdown
|
gdown
|
||||||
wget
|
wget
|
||||||
seaborn
|
seaborn
|
||||||
|
|||||||
8
setup.py
8
setup.py
@@ -12,7 +12,7 @@ exts = [
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
with open("README.md", "r", encoding="utf-8") as readme_file:
|
with open("README.md", encoding="utf-8") as readme_file:
|
||||||
README = readme_file.read()
|
README = readme_file.read()
|
||||||
|
|
||||||
|
|
||||||
@@ -25,10 +25,7 @@ setup(
|
|||||||
author="Shivam Mehta",
|
author="Shivam Mehta",
|
||||||
author_email="shivam.mehta25@gmail.com",
|
author_email="shivam.mehta25@gmail.com",
|
||||||
url="https://shivammehta25.github.io/Matcha-TTS",
|
url="https://shivammehta25.github.io/Matcha-TTS",
|
||||||
install_requires=[
|
install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))],
|
||||||
str(r)
|
|
||||||
for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
|
||||||
],
|
|
||||||
include_dirs=[numpy.get_include()],
|
include_dirs=[numpy.get_include()],
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
|
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
|
||||||
@@ -42,5 +39,4 @@ setup(
|
|||||||
},
|
},
|
||||||
ext_modules=cythonize(exts, language_level=3),
|
ext_modules=cythonize(exts, language_level=3),
|
||||||
python_requires=">=3.9.0",
|
python_requires=">=3.9.0",
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user