Adding docstrings

This commit is contained in:
Shivam Mehta
2023-09-17 06:50:46 +00:00
parent c079e5254a
commit a9251ed984
10 changed files with 47 additions and 32 deletions

View File

@@ -38,4 +38,4 @@ train-ljspeech-min: ## Train the model with minimum memory
python matcha/train.py experiment=ljspeech_min_memory python matcha/train.py experiment=ljspeech_min_memory
start_app: ## Start the app start_app: ## Start the app
python matcha/app.py python matcha/app.py

View File

@@ -6,9 +6,16 @@ import gradio as gr
import soundfile as sf import soundfile as sf
import torch import torch
from matcha.cli import (MATCHA_URLS, VOCODER_URL, assert_model_downloaded, from matcha.cli import (
get_device, load_matcha, load_vocoder, process_text, MATCHA_URLS,
to_waveform) VOCODER_URL,
assert_model_downloaded,
get_device,
load_matcha,
load_vocoder,
process_text,
to_waveform,
)
from matcha.utils.utils import get_user_data_dir, plot_tensor from matcha.utils.utils import get_user_data_dir, plot_tensor
LOCATION = Path(get_user_data_dir()) LOCATION = Path(get_user_data_dir())
@@ -59,6 +66,7 @@ def run_full_synthesis(text, n_timesteps, mel_temp, length_scale):
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale) audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale)
return phones, audio, mel_spectrogram return phones, audio, mel_spectrogram
def main(): def main():
description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/) ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
@@ -80,7 +88,6 @@ def main():
with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo: with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
processed_text = gr.State(value=None) processed_text = gr.State(value=None)
processed_text_len = gr.State(value=None) processed_text_len = gr.State(value=None)
mel_variable = gr.State(value=None)
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
@@ -136,7 +143,7 @@ def main():
audio = gr.Audio(interactive=False, label="Audio") audio = gr.Audio(interactive=False, label="Audio")
with gr.Row(): with gr.Row():
examples = gr.Examples( examples = gr.Examples( # pylint: disable=unused-variable
examples=[ examples=[
[ [
"We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.", "We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
@@ -202,6 +209,7 @@ def main():
) )
demo.queue(concurrency_count=5).launch(share=True) demo.queue(concurrency_count=5).launch(share=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -14,8 +14,7 @@ from matcha.hifigan.env import AttrDict
from matcha.hifigan.models import Generator as HiFiGAN from matcha.hifigan.models import Generator as HiFiGAN
from matcha.models.matcha_tts import MatchaTTS from matcha.models.matcha_tts import MatchaTTS
from matcha.text import sequence_to_text, text_to_sequence from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import (assert_model_downloaded, get_user_data_dir, from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
intersperse)
MATCHA_URLS = {"matcha_ljspeech": ""} # , "matcha_vctk": ""} # Coming soon MATCHA_URLS = {"matcha_ljspeech": ""} # , "matcha_vctk": ""} # Coming soon
@@ -146,7 +145,9 @@ def validate_args(args):
@torch.inference_mode() @torch.inference_mode()
def cli(): def cli():
parser = argparse.ArgumentParser(description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") parser = argparse.ArgumentParser(
description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
)
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,

View File

@@ -69,6 +69,7 @@ class Downsample1D(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class TimestepEmbedding(nn.Module): class TimestepEmbedding(nn.Module):
def __init__( def __init__(
self, self,
@@ -115,6 +116,7 @@ class TimestepEmbedding(nn.Module):
sample = self.post_act(sample) sample = self.post_act(sample)
return sample return sample
class Upsample1D(nn.Module): class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution. """A 1D upsampling layer with an optional convolution.

View File

@@ -46,7 +46,7 @@ class BASECFM(torch.nn.Module, ABC):
Returns: Returns:
sample: generated mel-spectrogram sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps) shape: (batch_size, n_feats, mel_timesteps)
""" """
z = torch.randn_like(mu) * temperature z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)

View File

@@ -2,8 +2,13 @@ from typing import Any, Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm, from diffusers.models.attention import (
AdaLayerNormZero, ApproximateGELU) GEGLU,
GELU,
AdaLayerNorm,
AdaLayerNormZero,
ApproximateGELU,
)
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
from diffusers.models.lora import LoRACompatibleLinear from diffusers.models.lora import LoRACompatibleLinear
from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.utils.torch_utils import maybe_allow_in_graph
@@ -38,7 +43,7 @@ class SnakeBeta(nn.Module):
beta is initialized to 1 by default, higher values = higher-magnitude. beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model. alpha will be trained along with the rest of your model.
""" """
super(SnakeBeta, self).__init__() super().__init__()
self.in_features = out_features if isinstance(out_features, list) else [out_features] self.in_features = out_features if isinstance(out_features, list) else [out_features]
self.proj = LoRACompatibleLinear(in_features, out_features) self.proj = LoRACompatibleLinear(in_features, out_features)
@@ -73,8 +78,8 @@ class SnakeBeta(nn.Module):
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
return x return x
class FeedForward(nn.Module): class FeedForward(nn.Module):
r""" r"""
A feed-forward layer. A feed-forward layer.
@@ -127,8 +132,7 @@ class FeedForward(nn.Module):
for module in self.net: for module in self.net:
hidden_states = module(hidden_states) hidden_states = module(hidden_states)
return hidden_states return hidden_states
@maybe_allow_in_graph @maybe_allow_in_graph
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
@@ -217,7 +221,7 @@ class BasicTransformerBlock(nn.Module):
dropout=dropout, dropout=dropout,
bias=attention_bias, bias=attention_bias,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
# scale_qk=False, # uncomment this to not to use flash attention # scale_qk=False, # uncomment this to not to use flash attention
) # is self-attn if encoder_hidden_states is none ) # is self-attn if encoder_hidden_states is none
else: else:
self.norm2 = None self.norm2 = None
@@ -309,4 +313,4 @@ class BasicTransformerBlock(nn.Module):
hidden_states = ff_output + hidden_states hidden_states = ff_output + hidden_states
return hidden_states return hidden_states

View File

@@ -9,9 +9,13 @@ from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.flow_matching import CFM from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder from matcha.models.components.text_encoder import TextEncoder
from matcha.utils.model import (denormalize, duration_loss, from matcha.utils.model import (
fix_len_compatibility, generate_path, denormalize,
sequence_mask) duration_loss,
fix_len_compatibility,
generate_path,
sequence_mask,
)
log = utils.get_pylogger(__name__) log = utils.get_pylogger(__name__)
@@ -83,7 +87,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
shape: (batch_size,) shape: (batch_size,)
length_scale (float, optional): controls speech pace. length_scale (float, optional): controls speech pace.
Increase value to slow down generated speech and vice versa. Increase value to slow down generated speech and vice versa.
Returns: Returns:
dict: { dict: {
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),

View File

@@ -199,7 +199,7 @@ def get_user_data_dir(appname="matcha_tts"):
ans = Path("~/Library/Application Support/").expanduser() ans = Path("~/Library/Application Support/").expanduser()
else: else:
ans = Path.home().joinpath(".local/share") ans = Path.home().joinpath(".local/share")
final_path = ans.joinpath(appname) final_path = ans.joinpath(appname)
final_path.mkdir(parents=True, exist_ok=True) final_path.mkdir(parents=True, exist_ok=True)
return final_path return final_path

View File

@@ -41,4 +41,4 @@ ipywidgets
gradio gradio
gdown gdown
wget wget
seaborn seaborn

View File

@@ -12,7 +12,7 @@ exts = [
) )
] ]
with open("README.md", "r", encoding="utf-8") as readme_file: with open("README.md", encoding="utf-8") as readme_file:
README = readme_file.read() README = readme_file.read()
@@ -25,10 +25,7 @@ setup(
author="Shivam Mehta", author="Shivam Mehta",
author_email="shivam.mehta25@gmail.com", author_email="shivam.mehta25@gmail.com",
url="https://shivammehta25.github.io/Matcha-TTS", url="https://shivammehta25.github.io/Matcha-TTS",
install_requires=[ install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))],
str(r)
for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
],
include_dirs=[numpy.get_include()], include_dirs=[numpy.get_include()],
include_package_data=True, include_package_data=True,
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
@@ -42,5 +39,4 @@ setup(
}, },
ext_modules=cythonize(exts, language_level=3), ext_modules=cythonize(exts, language_level=3),
python_requires=">=3.9.0", python_requires=">=3.9.0",
) )