mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 02:09:21 +08:00
@@ -234,9 +234,9 @@ class TextMelBatchCollate:
|
|||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
B = len(batch)
|
B = len(batch)
|
||||||
y_max_length = max([item["y"].shape[-1] for item in batch])
|
y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
|
||||||
y_max_length = fix_len_compatibility(y_max_length)
|
y_max_length = fix_len_compatibility(y_max_length)
|
||||||
x_max_length = max([item["x"].shape[-1] for item in batch])
|
x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
|
||||||
n_feats = batch[0]["y"].shape[-2]
|
n_feats = batch[0]["y"].shape[-2]
|
||||||
|
|
||||||
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
||||||
|
|||||||
@@ -4,6 +4,10 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ModeException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Denoiser(torch.nn.Module):
|
class Denoiser(torch.nn.Module):
|
||||||
"""Removes model bias from audio produced with waveglow"""
|
"""Removes model bias from audio produced with waveglow"""
|
||||||
|
|
||||||
@@ -20,7 +24,7 @@ class Denoiser(torch.nn.Module):
|
|||||||
elif mode == "normal":
|
elif mode == "normal":
|
||||||
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
|
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Mode {mode} if not supported")
|
raise ModeException(f"Mode {mode} if not supported")
|
||||||
|
|
||||||
def stft_fn(audio, n_fft, hop_length, win_length, window):
|
def stft_fn(audio, n_fft, hop_length, win_length, window):
|
||||||
spec = torch.stft(
|
spec = torch.stft(
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
|||||||
if torch.max(y) > 1.0:
|
if torch.max(y) > 1.0:
|
||||||
print("max value is ", torch.max(y))
|
print("max value is ", torch.max(y))
|
||||||
|
|
||||||
global mel_basis, hann_window # pylint: disable=global-statement
|
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
|
||||||
if fmax not in mel_basis:
|
if fmax not in mel_basis:
|
||||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
""" from https://github.com/jik876/hifi-gan """
|
""" from https://github.com/jik876/hifi-gan """
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import math
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from conformer import ConformerBlock
|
from conformer import ConformerBlock
|
||||||
from diffusers.models.activations import get_activation
|
from diffusers.models.activations import get_activation
|
||||||
|
|||||||
@@ -3,10 +3,10 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
import matcha.utils as utils
|
import matcha.utils as utils # pylint: disable=consider-using-from-import
|
||||||
from matcha.utils.model import sequence_mask
|
from matcha.utils.model import sequence_mask
|
||||||
|
|
||||||
log = utils.get_pylogger(__name__)
|
log = utils.get_pylogger(__name__)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||||
from diffusers.models.attention import (
|
from diffusers.models.attention import (
|
||||||
GEGLU,
|
GEGLU,
|
||||||
GELU,
|
GELU,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import random
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import matcha.utils.monotonic_align as monotonic_align
|
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
|
||||||
from matcha import utils
|
from matcha import utils
|
||||||
from matcha.models.baselightningmodule import BaseLightningClass
|
from matcha.models.baselightningmodule import BaseLightningClass
|
||||||
from matcha.models.components.flow_matching import CFM
|
from matcha.models.components.flow_matching import CFM
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
|||||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
|
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownCleanerException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def text_to_sequence(text, cleaner_names):
|
def text_to_sequence(text, cleaner_names):
|
||||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||||
Args:
|
Args:
|
||||||
@@ -48,6 +52,6 @@ def _clean_text(text, cleaner_names):
|
|||||||
for name in cleaner_names:
|
for name in cleaner_names:
|
||||||
cleaner = getattr(cleaners, name)
|
cleaner = getattr(cleaners, name)
|
||||||
if not cleaner:
|
if not cleaner:
|
||||||
raise Exception("Unknown cleaner: %s" % name)
|
raise UnknownCleanerException(f"Unknown cleaner: {name}")
|
||||||
text = cleaner(text)
|
text = cleaner(text)
|
||||||
return text
|
return text
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ _brackets_re = re.compile(r"[\[\]\(\)\{\}]")
|
|||||||
|
|
||||||
# List of (regular expression, replacement) pairs for abbreviations:
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
_abbreviations = [
|
_abbreviations = [
|
||||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
(re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("mrs", "misess"),
|
("mrs", "misess"),
|
||||||
("mr", "mister"),
|
("mr", "mister"),
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
|||||||
if torch.max(y) > 1.0:
|
if torch.max(y) > 1.0:
|
||||||
print("max value is ", torch.max(y))
|
print("max value is ", torch.max(y))
|
||||||
|
|
||||||
global mel_basis, hann_window # pylint: disable=global-statement
|
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
|
||||||
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
||||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||||
|
|||||||
@@ -102,10 +102,8 @@ def main():
|
|||||||
log.info("Dataloader loaded! Now computing stats...")
|
log.info("Dataloader loaded! Now computing stats...")
|
||||||
params = compute_data_statistics(data_loader, cfg["n_feats"])
|
params = compute_data_statistics(data_loader, cfg["n_feats"])
|
||||||
print(params)
|
print(params)
|
||||||
json.dump(
|
with open(output_file, "w", encoding="utf-8") as dumpfile:
|
||||||
params,
|
json.dump(params, dumpfile)
|
||||||
open(output_file, "w"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ def print_config_tree(
|
|||||||
|
|
||||||
# save config tree to file
|
# save config tree to file
|
||||||
if save_to_file:
|
if save_to_file:
|
||||||
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file:
|
||||||
rich.print(tree, file=file)
|
rich.print(tree, file=file)
|
||||||
|
|
||||||
|
|
||||||
@@ -97,5 +97,5 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
|||||||
log.info(f"Tags: {cfg.tags}")
|
log.info(f"Tags: {cfg.tags}")
|
||||||
|
|
||||||
if save_to_file:
|
if save_to_file:
|
||||||
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file:
|
||||||
rich.print(cfg.tags, file=file)
|
rich.print(cfg.tags, file=file)
|
||||||
|
|||||||
11
setup.py
11
setup.py
@@ -16,9 +16,16 @@ with open("README.md", encoding="utf-8") as readme_file:
|
|||||||
README = readme_file.read()
|
README = readme_file.read()
|
||||||
|
|
||||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
with open(os.path.join(cwd, "matcha", "VERSION")) as fin:
|
with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin:
|
||||||
version = fin.read().strip()
|
version = fin.read().strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_requires():
|
||||||
|
requirements = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
||||||
|
with open(requirements, encoding="utf-8") as reqfile:
|
||||||
|
return [str(r).strip() for r in reqfile]
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="matcha-tts",
|
name="matcha-tts",
|
||||||
version=version,
|
version=version,
|
||||||
@@ -28,7 +35,7 @@ setup(
|
|||||||
author="Shivam Mehta",
|
author="Shivam Mehta",
|
||||||
author_email="shivam.mehta25@gmail.com",
|
author_email="shivam.mehta25@gmail.com",
|
||||||
url="https://shivammehta25.github.io/Matcha-TTS",
|
url="https://shivammehta25.github.io/Matcha-TTS",
|
||||||
install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))],
|
install_requires=get_requires(),
|
||||||
include_dirs=[numpy.get_include()],
|
include_dirs=[numpy.get_include()],
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
|
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
|
||||||
|
|||||||
Reference in New Issue
Block a user