Merge pull request #101 from jimregan/pylint

Make pylint happy
This commit is contained in:
Shivam Mehta
2024-11-13 22:13:36 -08:00
committed by GitHub
14 changed files with 34 additions and 21 deletions

View File

@@ -234,9 +234,9 @@ class TextMelBatchCollate:
def __call__(self, batch):
B = len(batch)
y_max_length = max([item["y"].shape[-1] for item in batch])
y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
y_max_length = fix_len_compatibility(y_max_length)
x_max_length = max([item["x"].shape[-1] for item in batch])
x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
n_feats = batch[0]["y"].shape[-2]
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)

View File

@@ -4,6 +4,10 @@
import torch
class ModeException(Exception):
pass
class Denoiser(torch.nn.Module):
"""Removes model bias from audio produced with waveglow"""
@@ -20,7 +24,7 @@ class Denoiser(torch.nn.Module):
elif mode == "normal":
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
else:
raise Exception(f"Mode {mode} if not supported")
raise ModeException(f"Mode {mode} if not supported")
def stft_fn(audio, n_fft, hop_length, win_length, window):
spec = torch.stft(

View File

@@ -55,7 +55,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

@@ -1,7 +1,7 @@
""" from https://github.com/jik876/hifi-gan """
import torch
import torch.nn as nn
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm

View File

@@ -2,7 +2,7 @@ import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
from conformer import ConformerBlock
from diffusers.models.activations import get_activation

View File

@@ -3,10 +3,10 @@
import math
import torch
import torch.nn as nn
import torch.nn as nn # pylint: disable=consider-using-from-import
from einops import rearrange
import matcha.utils as utils
import matcha.utils as utils # pylint: disable=consider-using-from-import
from matcha.utils.model import sequence_mask
log = utils.get_pylogger(__name__)

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
import torch.nn as nn # pylint: disable=consider-using-from-import
from diffusers.models.attention import (
GEGLU,
GELU,

View File

@@ -4,7 +4,7 @@ import random
import torch
import matcha.utils.monotonic_align as monotonic_align
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.flow_matching import CFM

View File

@@ -7,6 +7,10 @@ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
class UnknownCleanerException(Exception):
pass
def text_to_sequence(text, cleaner_names):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
@@ -48,6 +52,6 @@ def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
raise UnknownCleanerException(f"Unknown cleaner: {name}")
text = cleaner(text)
return text

View File

@@ -41,7 +41,7 @@ _brackets_re = re.compile(r"[\[\]\(\)\{\}]")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
(re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),

View File

@@ -48,7 +48,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

@@ -102,10 +102,8 @@ def main():
log.info("Dataloader loaded! Now computing stats...")
params = compute_data_statistics(data_loader, cfg["n_feats"])
print(params)
json.dump(
params,
open(output_file, "w"),
)
with open(output_file, "w", encoding="utf-8") as dumpfile:
json.dump(params, dumpfile)
if __name__ == "__main__":

View File

@@ -72,7 +72,7 @@ def print_config_tree(
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file:
rich.print(tree, file=file)
@@ -97,5 +97,5 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file:
rich.print(cfg.tags, file=file)

View File

@@ -16,9 +16,16 @@ with open("README.md", encoding="utf-8") as readme_file:
README = readme_file.read()
cwd = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(cwd, "matcha", "VERSION")) as fin:
with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin:
version = fin.read().strip()
def get_requires():
requirements = os.path.join(os.path.dirname(__file__), "requirements.txt")
with open(requirements, encoding="utf-8") as reqfile:
return [str(r).strip() for r in reqfile]
setup(
name="matcha-tts",
version=version,
@@ -28,7 +35,7 @@ setup(
author="Shivam Mehta",
author_email="shivam.mehta25@gmail.com",
url="https://shivammehta25.github.io/Matcha-TTS",
install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))],
install_requires=get_requires(),
include_dirs=[numpy.get_include()],
include_package_data=True,
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),