diff --git a/matcha/data/text_mel_datamodule.py b/matcha/data/text_mel_datamodule.py index e10dfcb..48f8266 100644 --- a/matcha/data/text_mel_datamodule.py +++ b/matcha/data/text_mel_datamodule.py @@ -234,9 +234,9 @@ class TextMelBatchCollate: def __call__(self, batch): B = len(batch) - y_max_length = max([item["y"].shape[-1] for item in batch]) + y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator y_max_length = fix_len_compatibility(y_max_length) - x_max_length = max([item["x"].shape[-1] for item in batch]) + x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator n_feats = batch[0]["y"].shape[-2] y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) diff --git a/matcha/hifigan/denoiser.py b/matcha/hifigan/denoiser.py index 9fd3331..452be6a 100644 --- a/matcha/hifigan/denoiser.py +++ b/matcha/hifigan/denoiser.py @@ -4,6 +4,10 @@ import torch +class ModeException(Exception): + pass + + class Denoiser(torch.nn.Module): """Removes model bias from audio produced with waveglow""" @@ -20,7 +24,7 @@ class Denoiser(torch.nn.Module): elif mode == "normal": mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) else: - raise Exception(f"Mode {mode} if not supported") + raise ModeException(f"Mode {mode} if not supported") def stft_fn(audio, n_fft, hop_length, win_length, window): spec = torch.stft( diff --git a/matcha/hifigan/meldataset.py b/matcha/hifigan/meldataset.py index 8b43ea7..d1b3a90 100644 --- a/matcha/hifigan/meldataset.py +++ b/matcha/hifigan/meldataset.py @@ -55,7 +55,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, if torch.max(y) > 1.0: print("max value is ", torch.max(y)) - global mel_basis, hann_window # pylint: disable=global-statement + global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned if fmax not in mel_basis: mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) diff --git a/matcha/hifigan/models.py b/matcha/hifigan/models.py index d209d9a..57305ef 100644 --- a/matcha/hifigan/models.py +++ b/matcha/hifigan/models.py @@ -1,7 +1,7 @@ """ from https://github.com/jik876/hifi-gan """ import torch -import torch.nn as nn +import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn.functional as F from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm diff --git a/matcha/models/components/decoder.py b/matcha/models/components/decoder.py index 1137cd7..504f88b 100644 --- a/matcha/models/components/decoder.py +++ b/matcha/models/components/decoder.py @@ -2,7 +2,7 @@ import math from typing import Optional import torch -import torch.nn as nn +import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn.functional as F from conformer import ConformerBlock from diffusers.models.activations import get_activation diff --git a/matcha/models/components/text_encoder.py b/matcha/models/components/text_encoder.py index a388d05..66e7164 100644 --- a/matcha/models/components/text_encoder.py +++ b/matcha/models/components/text_encoder.py @@ -3,10 +3,10 @@ import math import torch -import torch.nn as nn +import torch.nn as nn # pylint: disable=consider-using-from-import from einops import rearrange -import matcha.utils as utils +import matcha.utils as utils # pylint: disable=consider-using-from-import from matcha.utils.model import sequence_mask log = utils.get_pylogger(__name__) diff --git a/matcha/models/components/transformer.py b/matcha/models/components/transformer.py index dd1afa3..4d604f5 100644 --- a/matcha/models/components/transformer.py +++ b/matcha/models/components/transformer.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Optional import torch -import torch.nn as nn +import torch.nn as nn # pylint: disable=consider-using-from-import from diffusers.models.attention import ( GEGLU, GELU, diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 07f95ad..092fa27 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -4,7 +4,7 @@ import random import torch -import matcha.utils.monotonic_align as monotonic_align +import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import from matcha import utils from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.components.flow_matching import CFM diff --git a/matcha/text/__init__.py b/matcha/text/__init__.py index 8c75d6b..50d7b90 100644 --- a/matcha/text/__init__.py +++ b/matcha/text/__init__.py @@ -7,6 +7,10 @@ _symbol_to_id = {s: i for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension +class UnknownCleanerException(Exception): + pass + + def text_to_sequence(text, cleaner_names): """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. Args: @@ -48,6 +52,6 @@ def _clean_text(text, cleaner_names): for name in cleaner_names: cleaner = getattr(cleaners, name) if not cleaner: - raise Exception("Unknown cleaner: %s" % name) + raise UnknownCleanerException(f"Unknown cleaner: {name}") text = cleaner(text) return text diff --git a/matcha/text/cleaners.py b/matcha/text/cleaners.py index 788776b..74cd1a2 100644 --- a/matcha/text/cleaners.py +++ b/matcha/text/cleaners.py @@ -41,7 +41,7 @@ _brackets_re = re.compile(r"[\[\]\(\)\{\}]") # List of (regular expression, replacement) pairs for abbreviations: _abbreviations = [ - (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1]) for x in [ ("mrs", "misess"), ("mr", "mister"), diff --git a/matcha/utils/audio.py b/matcha/utils/audio.py index 0bcd74d..d257f0d 100644 --- a/matcha/utils/audio.py +++ b/matcha/utils/audio.py @@ -48,7 +48,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, if torch.max(y) > 1.0: print("max value is ", torch.max(y)) - global mel_basis, hann_window # pylint: disable=global-statement + global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned if f"{str(fmax)}_{str(y.device)}" not in mel_basis: mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) diff --git a/matcha/utils/generate_data_statistics.py b/matcha/utils/generate_data_statistics.py index 49ed3c1..305d806 100644 --- a/matcha/utils/generate_data_statistics.py +++ b/matcha/utils/generate_data_statistics.py @@ -102,10 +102,8 @@ def main(): log.info("Dataloader loaded! Now computing stats...") params = compute_data_statistics(data_loader, cfg["n_feats"]) print(params) - json.dump( - params, - open(output_file, "w"), - ) + with open(output_file, "w", encoding="utf-8") as dumpfile: + json.dump(params, dumpfile) if __name__ == "__main__": diff --git a/matcha/utils/rich_utils.py b/matcha/utils/rich_utils.py index f602f6e..1f1d6fe 100644 --- a/matcha/utils/rich_utils.py +++ b/matcha/utils/rich_utils.py @@ -72,7 +72,7 @@ def print_config_tree( # save config tree to file if save_to_file: - with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file: rich.print(tree, file=file) @@ -97,5 +97,5 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: log.info(f"Tags: {cfg.tags}") if save_to_file: - with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file: rich.print(cfg.tags, file=file) diff --git a/setup.py b/setup.py index a49c2cc..5375723 100644 --- a/setup.py +++ b/setup.py @@ -16,9 +16,16 @@ with open("README.md", encoding="utf-8") as readme_file: README = readme_file.read() cwd = os.path.dirname(os.path.abspath(__file__)) -with open(os.path.join(cwd, "matcha", "VERSION")) as fin: +with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin: version = fin.read().strip() + +def get_requires(): + requirements = os.path.join(os.path.dirname(__file__), "requirements.txt") + with open(requirements, encoding="utf-8") as reqfile: + return [str(r).strip() for r in reqfile] + + setup( name="matcha-tts", version=version, @@ -28,7 +35,7 @@ setup( author="Shivam Mehta", author_email="shivam.mehta25@gmail.com", url="https://shivammehta25.github.io/Matcha-TTS", - install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))], + install_requires=get_requires(), include_dirs=[numpy.get_include()], include_package_data=True, packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),