From a9251ed9842d8a03f2959ede1b5416c36743cd8a Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Sun, 17 Sep 2023 06:50:46 +0000 Subject: [PATCH] Adding docstrings --- Makefile | 2 +- matcha/app.py | 20 ++++++++++++++------ matcha/cli.py | 7 ++++--- matcha/models/components/decoder.py | 2 ++ matcha/models/components/flow_matching.py | 2 +- matcha/models/components/transformer.py | 22 +++++++++++++--------- matcha/models/matcha_tts.py | 12 ++++++++---- matcha/utils/utils.py | 2 +- requirements.txt | 2 +- setup.py | 8 ++------ 10 files changed, 47 insertions(+), 32 deletions(-) diff --git a/Makefile b/Makefile index 7d8b975..3468118 100644 --- a/Makefile +++ b/Makefile @@ -38,4 +38,4 @@ train-ljspeech-min: ## Train the model with minimum memory python matcha/train.py experiment=ljspeech_min_memory start_app: ## Start the app - python matcha/app.py \ No newline at end of file + python matcha/app.py diff --git a/matcha/app.py b/matcha/app.py index 6c82354..493d2bd 100644 --- a/matcha/app.py +++ b/matcha/app.py @@ -6,9 +6,16 @@ import gradio as gr import soundfile as sf import torch -from matcha.cli import (MATCHA_URLS, VOCODER_URL, assert_model_downloaded, - get_device, load_matcha, load_vocoder, process_text, - to_waveform) +from matcha.cli import ( + MATCHA_URLS, + VOCODER_URL, + assert_model_downloaded, + get_device, + load_matcha, + load_vocoder, + process_text, + to_waveform, +) from matcha.utils.utils import get_user_data_dir, plot_tensor LOCATION = Path(get_user_data_dir()) @@ -59,6 +66,7 @@ def run_full_synthesis(text, n_timesteps, mel_temp, length_scale): audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale) return phones, audio, mel_spectrogram + def main(): description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/) @@ -80,7 +88,6 @@ def main(): with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo: processed_text = gr.State(value=None) processed_text_len = gr.State(value=None) - mel_variable = gr.State(value=None) with gr.Box(): with gr.Row(): @@ -136,7 +143,7 @@ def main(): audio = gr.Audio(interactive=False, label="Audio") with gr.Row(): - examples = gr.Examples( + examples = gr.Examples( # pylint: disable=unused-variable examples=[ [ "We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.", @@ -202,6 +209,7 @@ def main(): ) demo.queue(concurrency_count=5).launch(share=True) - + + if __name__ == "__main__": main() diff --git a/matcha/cli.py b/matcha/cli.py index 6217d4f..49004fb 100644 --- a/matcha/cli.py +++ b/matcha/cli.py @@ -14,8 +14,7 @@ from matcha.hifigan.env import AttrDict from matcha.hifigan.models import Generator as HiFiGAN from matcha.models.matcha_tts import MatchaTTS from matcha.text import sequence_to_text, text_to_sequence -from matcha.utils.utils import (assert_model_downloaded, get_user_data_dir, - intersperse) +from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse MATCHA_URLS = {"matcha_ljspeech": ""} # , "matcha_vctk": ""} # Coming soon @@ -146,7 +145,9 @@ def validate_args(args): @torch.inference_mode() def cli(): - parser = argparse.ArgumentParser(description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") + parser = argparse.ArgumentParser( + description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" + ) parser.add_argument( "--model", type=str, diff --git a/matcha/models/components/decoder.py b/matcha/models/components/decoder.py index 90a5139..1137cd7 100644 --- a/matcha/models/components/decoder.py +++ b/matcha/models/components/decoder.py @@ -69,6 +69,7 @@ class Downsample1D(nn.Module): def forward(self, x): return self.conv(x) + class TimestepEmbedding(nn.Module): def __init__( self, @@ -115,6 +116,7 @@ class TimestepEmbedding(nn.Module): sample = self.post_act(sample) return sample + class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. diff --git a/matcha/models/components/flow_matching.py b/matcha/models/components/flow_matching.py index 781deb0..4d77547 100644 --- a/matcha/models/components/flow_matching.py +++ b/matcha/models/components/flow_matching.py @@ -46,7 +46,7 @@ class BASECFM(torch.nn.Module, ABC): Returns: sample: generated mel-spectrogram - shape: (batch_size, n_feats, mel_timesteps) + shape: (batch_size, n_feats, mel_timesteps) """ z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) diff --git a/matcha/models/components/transformer.py b/matcha/models/components/transformer.py index 9e6a22d..dd1afa3 100644 --- a/matcha/models/components/transformer.py +++ b/matcha/models/components/transformer.py @@ -2,8 +2,13 @@ from typing import Any, Dict, Optional import torch import torch.nn as nn -from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm, - AdaLayerNormZero, ApproximateGELU) +from diffusers.models.attention import ( + GEGLU, + GELU, + AdaLayerNorm, + AdaLayerNormZero, + ApproximateGELU, +) from diffusers.models.attention_processor import Attention from diffusers.models.lora import LoRACompatibleLinear from diffusers.utils.torch_utils import maybe_allow_in_graph @@ -38,7 +43,7 @@ class SnakeBeta(nn.Module): beta is initialized to 1 by default, higher values = higher-magnitude. alpha will be trained along with the rest of your model. """ - super(SnakeBeta, self).__init__() + super().__init__() self.in_features = out_features if isinstance(out_features, list) else [out_features] self.proj = LoRACompatibleLinear(in_features, out_features) @@ -73,8 +78,8 @@ class SnakeBeta(nn.Module): x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) return x - - + + class FeedForward(nn.Module): r""" A feed-forward layer. @@ -127,8 +132,7 @@ class FeedForward(nn.Module): for module in self.net: hidden_states = module(hidden_states) return hidden_states - - + @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): @@ -217,7 +221,7 @@ class BasicTransformerBlock(nn.Module): dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, - # scale_qk=False, # uncomment this to not to use flash attention + # scale_qk=False, # uncomment this to not to use flash attention ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None @@ -309,4 +313,4 @@ class BasicTransformerBlock(nn.Module): hidden_states = ff_output + hidden_states - return hidden_states \ No newline at end of file + return hidden_states diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index c480db8..bc5ed06 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -9,9 +9,13 @@ from matcha import utils from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.components.flow_matching import CFM from matcha.models.components.text_encoder import TextEncoder -from matcha.utils.model import (denormalize, duration_loss, - fix_len_compatibility, generate_path, - sequence_mask) +from matcha.utils.model import ( + denormalize, + duration_loss, + fix_len_compatibility, + generate_path, + sequence_mask, +) log = utils.get_pylogger(__name__) @@ -83,7 +87,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 shape: (batch_size,) length_scale (float, optional): controls speech pace. Increase value to slow down generated speech and vice versa. - + Returns: dict: { "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), diff --git a/matcha/utils/utils.py b/matcha/utils/utils.py index 1b71103..5f8162d 100644 --- a/matcha/utils/utils.py +++ b/matcha/utils/utils.py @@ -199,7 +199,7 @@ def get_user_data_dir(appname="matcha_tts"): ans = Path("~/Library/Application Support/").expanduser() else: ans = Path.home().joinpath(".local/share") - + final_path = ans.joinpath(appname) final_path.mkdir(parents=True, exist_ok=True) return final_path diff --git a/requirements.txt b/requirements.txt index 2d1845c..ac1abf8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,4 +41,4 @@ ipywidgets gradio gdown wget -seaborn \ No newline at end of file +seaborn diff --git a/setup.py b/setup.py index 3b2d30e..2b6b7d8 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ exts = [ ) ] -with open("README.md", "r", encoding="utf-8") as readme_file: +with open("README.md", encoding="utf-8") as readme_file: README = readme_file.read() @@ -25,10 +25,7 @@ setup( author="Shivam Mehta", author_email="shivam.mehta25@gmail.com", url="https://shivammehta25.github.io/Matcha-TTS", - install_requires=[ - str(r) - for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt")) - ], + install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))], include_dirs=[numpy.get_include()], include_package_data=True, packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), @@ -42,5 +39,4 @@ setup( }, ext_modules=cythonize(exts, language_level=3), python_requires=">=3.9.0", - )