mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 02:09:21 +08:00
@@ -234,9 +234,9 @@ class TextMelBatchCollate:
|
||||
|
||||
def __call__(self, batch):
|
||||
B = len(batch)
|
||||
y_max_length = max([item["y"].shape[-1] for item in batch])
|
||||
y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
|
||||
y_max_length = fix_len_compatibility(y_max_length)
|
||||
x_max_length = max([item["x"].shape[-1] for item in batch])
|
||||
x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
|
||||
n_feats = batch[0]["y"].shape[-2]
|
||||
|
||||
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ModeException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Denoiser(torch.nn.Module):
|
||||
"""Removes model bias from audio produced with waveglow"""
|
||||
|
||||
@@ -20,7 +24,7 @@ class Denoiser(torch.nn.Module):
|
||||
elif mode == "normal":
|
||||
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
|
||||
else:
|
||||
raise Exception(f"Mode {mode} if not supported")
|
||||
raise ModeException(f"Mode {mode} if not supported")
|
||||
|
||||
def stft_fn(audio, n_fft, hop_length, win_length, window):
|
||||
spec = torch.stft(
|
||||
|
||||
@@ -55,7 +55,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
""" from https://github.com/jik876/hifi-gan """
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
@@ -2,7 +2,7 @@ import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
from conformer import ConformerBlock
|
||||
from diffusers.models.activations import get_activation
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
from einops import rearrange
|
||||
|
||||
import matcha.utils as utils
|
||||
import matcha.utils as utils # pylint: disable=consider-using-from-import
|
||||
from matcha.utils.model import sequence_mask
|
||||
|
||||
log = utils.get_pylogger(__name__)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
from diffusers.models.attention import (
|
||||
GEGLU,
|
||||
GELU,
|
||||
|
||||
@@ -4,7 +4,7 @@ import random
|
||||
|
||||
import torch
|
||||
|
||||
import matcha.utils.monotonic_align as monotonic_align
|
||||
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
|
||||
from matcha import utils
|
||||
from matcha.models.baselightningmodule import BaseLightningClass
|
||||
from matcha.models.components.flow_matching import CFM
|
||||
|
||||
@@ -7,6 +7,10 @@ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
|
||||
|
||||
|
||||
class UnknownCleanerException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names):
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
Args:
|
||||
@@ -48,6 +52,6 @@ def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception("Unknown cleaner: %s" % name)
|
||||
raise UnknownCleanerException(f"Unknown cleaner: {name}")
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
@@ -41,7 +41,7 @@ _brackets_re = re.compile(r"[\[\]\(\)\{\}]")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
(re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
|
||||
@@ -48,7 +48,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
|
||||
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
|
||||
@@ -102,10 +102,8 @@ def main():
|
||||
log.info("Dataloader loaded! Now computing stats...")
|
||||
params = compute_data_statistics(data_loader, cfg["n_feats"])
|
||||
print(params)
|
||||
json.dump(
|
||||
params,
|
||||
open(output_file, "w"),
|
||||
)
|
||||
with open(output_file, "w", encoding="utf-8") as dumpfile:
|
||||
json.dump(params, dumpfile)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -72,7 +72,7 @@ def print_config_tree(
|
||||
|
||||
# save config tree to file
|
||||
if save_to_file:
|
||||
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
||||
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file:
|
||||
rich.print(tree, file=file)
|
||||
|
||||
|
||||
@@ -97,5 +97,5 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
||||
log.info(f"Tags: {cfg.tags}")
|
||||
|
||||
if save_to_file:
|
||||
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
||||
with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file:
|
||||
rich.print(cfg.tags, file=file)
|
||||
|
||||
11
setup.py
11
setup.py
@@ -16,9 +16,16 @@ with open("README.md", encoding="utf-8") as readme_file:
|
||||
README = readme_file.read()
|
||||
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(os.path.join(cwd, "matcha", "VERSION")) as fin:
|
||||
with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin:
|
||||
version = fin.read().strip()
|
||||
|
||||
|
||||
def get_requires():
|
||||
requirements = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
||||
with open(requirements, encoding="utf-8") as reqfile:
|
||||
return [str(r).strip() for r in reqfile]
|
||||
|
||||
|
||||
setup(
|
||||
name="matcha-tts",
|
||||
version=version,
|
||||
@@ -28,7 +35,7 @@ setup(
|
||||
author="Shivam Mehta",
|
||||
author_email="shivam.mehta25@gmail.com",
|
||||
url="https://shivammehta25.github.io/Matcha-TTS",
|
||||
install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))],
|
||||
install_requires=get_requires(),
|
||||
include_dirs=[numpy.get_include()],
|
||||
include_package_data=True,
|
||||
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
|
||||
|
||||
Reference in New Issue
Block a user