28 Commits

Author SHA1 Message Date
pre-commit-ci[bot]
66178aea04 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2026-01-19 22:11:39 +00:00
pre-commit-ci[bot]
7ebef67773 [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/pre-commit/pre-commit-hooks: v4.5.0 → v6.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.5.0...v6.0.0)
- https://github.com/psf/blackhttps://github.com/psf/black-pre-commit-mirror
- [github.com/psf/black-pre-commit-mirror: 23.12.1 → 26.1.0](https://github.com/psf/black-pre-commit-mirror/compare/23.12.1...26.1.0)
- [github.com/PyCQA/isort: 5.13.2 → 7.0.0](https://github.com/PyCQA/isort/compare/5.13.2...7.0.0)
- [github.com/asottile/pyupgrade: v3.15.0 → v3.21.2](https://github.com/asottile/pyupgrade/compare/v3.15.0...v3.21.2)
- [github.com/PyCQA/flake8: 7.0.0 → 7.3.0](https://github.com/PyCQA/flake8/compare/7.0.0...7.3.0)
- [github.com/pycqa/pylint: v3.0.3 → v4.0.4](https://github.com/pycqa/pylint/compare/v3.0.3...v4.0.4)
2026-01-19 22:10:05 +00:00
Shivam Mehta
bd4d90d932 Update README.md 2025-09-17 08:49:57 -07:00
Shivam Mehta
108906c603 Merge pull request #121 from jimregan/english-data
ljspeech/hificaptain from #99
2024-12-02 09:02:41 -06:00
Shivam Mehta
354f5dc69f Merge pull request #123 from jimregan/patch-1
Fix a typo
2024-12-02 08:26:00 -06:00
Jim O’Regan
8e5f98476e Fix a typo 2024-12-02 15:21:31 +01:00
Jim O'Regan
7e499df0b2 ljspeech/hificaptain from #99 2024-12-02 11:01:04 +00:00
Shivam Mehta
0735e653fc Merge pull request #103 from jimregan/mmconv-cleaner
add a cleaner for IPA data (pre-phonetised)
2024-11-13 22:15:47 -08:00
Shivam Mehta
f9843cfca4 Merge pull request #101 from jimregan/pylint
Make pylint happy
2024-11-13 22:13:36 -08:00
Shivam Mehta
289ef51578 Fixing thhe usage of denoiser_strength from the command line. 2024-11-14 06:55:51 +01:00
Shivam Mehta
7a65f83b17 Updating the version 2024-11-14 06:42:06 +01:00
Shivam Mehta
7275764a48 Fixing espeak not removing brackets in some cases 2024-11-14 06:39:58 +01:00
Jim O'Regan
863bfbdd8b rename method, it's more generic than the previous name suggested 2024-10-03 18:51:47 +00:00
Jim O'Regan
4bc541705a add a cleaner for the mmconv data
Different versions of espeak represent things differently, it seems
(also, there are some distinctions none of our speakers make, so
normalising those away reduces perplexity a tiny amount).
2024-10-03 17:18:58 +00:00
pre-commit-ci[bot]
a3fea22988 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-10-02 14:31:11 +00:00
Jim O'Regan
d56f40765c disable consider-using-from-import instead (missed one) 2024-10-02 14:30:18 +00:00
Jim O'Regan
b0ba920dc1 disable consider-using-from-import instead 2024-10-02 14:29:06 +00:00
Jim O'Regan
a220f283e3 disable consider-using-generator 2024-10-02 13:57:12 +00:00
Jim O'Regan
1df73ef43e disable global-variable-not-assigned 2024-10-02 13:55:44 +00:00
Jim O'Regan
404b045b65 add dummy exception (W0719) 2024-10-02 13:51:17 +00:00
Jim O'Regan
7cfae6bed4 add dummy exception (W0719) 2024-10-02 13:49:47 +00:00
Jim O'Regan
a83fd29829 C0209 2024-10-02 13:45:27 +00:00
pre-commit-ci[bot]
c8178bf2cd [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-10-02 13:32:45 +00:00
Jim O'Regan
8b1284993a W1514 + R1732 2024-10-02 13:31:57 +00:00
Jim O'Regan
0000f93021 R1732 + W1514 2024-10-02 13:25:02 +00:00
pre-commit-ci[bot]
c2569a1018 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-10-02 13:21:37 +00:00
Jim O'Regan
bd058a68f7 R0402 2024-10-02 13:21:00 +00:00
Jim O'Regan
362ba2dce7 C0209 2024-10-02 08:38:28 +00:00
32 changed files with 389 additions and 48 deletions

1
.gitignore vendored
View File

@@ -161,3 +161,4 @@ generator_v1
g_02500000
gradio_cached_examples/
synth_output/
/data

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v6.0.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
@@ -17,29 +17,29 @@ repos:
- id: check-added-large-files
# python code formatting
- repo: https://github.com/psf/black
rev: 23.12.1
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.1.0
hooks:
- id: black
args: [--line-length, "120"]
# python import sorting
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
rev: 7.0.0
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.21.2
hooks:
- id: pyupgrade
args: [--py38-plus]
# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
rev: 7.3.0
hooks:
- id: flake8
args:
@@ -54,6 +54,6 @@ repos:
# pylint
- repo: https://github.com/pycqa/pylint
rev: v3.0.3
rev: v4.0.4
hooks:
- id: pylint

View File

@@ -10,7 +10,7 @@
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
[![PyPI Downloads](https://static.pepy.tech/personalized-badge/matcha-tts?period=total&units=INTERNATIONAL_SYSTEM&left_color=BLACK&right_color=GREEN&left_text=downloads)](https://pepy.tech/projects/matcha-tts)
<p style="text-align: center;">
<img src="https://shivammehta25.github.io/Matcha-TTS/images/logo.png" height="128"/>
</p>

View File

@@ -5,8 +5,8 @@ defaults:
# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: hi-fi_en-US_female
train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt
valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt
train_filelist_path: data/hi-fi_en-US_female/train.txt
valid_filelist_path: data/hi-fi_en-US_female/val.txt
batch_size: 32
cleaners: [english_cleaners_piper]
data_statistics: # Computed for this dataset

1
data
View File

@@ -1 +0,0 @@
/home/smehta/Projects/Speech-Backbones/Grad-TTS/data

View File

@@ -1 +1 @@
0.0.7.0
0.0.7.2

View File

@@ -114,10 +114,10 @@ def load_matcha(model_name, checkpoint_path, device):
return model
def to_waveform(mel, vocoder, denoiser=None):
def to_waveform(mel, vocoder, denoiser=None, denoiser_strength=0.00025):
audio = vocoder(mel).clamp(-1, 1)
if denoiser is not None:
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
audio = denoiser(audio.squeeze(), strength=denoiser_strength).cpu().squeeze()
return audio.cpu().squeeze()
@@ -336,7 +336,7 @@ def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
length_scale=args.speaking_rate,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength)
t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])
print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}")
@@ -377,7 +377,7 @@ def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
spks=spk,
length_scale=args.speaking_rate,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength)
# RTF with HiFiGAN
t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])

View File

@@ -234,9 +234,9 @@ class TextMelBatchCollate:
def __call__(self, batch):
B = len(batch)
y_max_length = max([item["y"].shape[-1] for item in batch])
y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
y_max_length = fix_len_compatibility(y_max_length)
x_max_length = max([item["x"].shape[-1] for item in batch])
x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
n_feats = batch[0]["y"].shape[-2]
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)

View File

@@ -1,9 +1,14 @@
# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
import torch
class ModeException(Exception):
pass
class Denoiser(torch.nn.Module):
"""Removes model bias from audio produced with waveglow"""
@@ -20,7 +25,7 @@ class Denoiser(torch.nn.Module):
elif mode == "normal":
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
else:
raise Exception(f"Mode {mode} if not supported")
raise ModeException(f"Mode {mode} if not supported")
def stft_fn(audio, n_fft, hop_length, win_length, window):
spec = torch.stft(

View File

@@ -1,4 +1,4 @@
""" from https://github.com/jik876/hifi-gan """
"""from https://github.com/jik876/hifi-gan"""
import os
import shutil

View File

@@ -1,4 +1,4 @@
""" from https://github.com/jik876/hifi-gan """
"""from https://github.com/jik876/hifi-gan"""
import math
import os
@@ -55,7 +55,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

@@ -1,7 +1,7 @@
""" from https://github.com/jik876/hifi-gan """
"""from https://github.com/jik876/hifi-gan"""
import torch
import torch.nn as nn
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm

View File

@@ -1,4 +1,4 @@
""" from https://github.com/jik876/hifi-gan """
"""from https://github.com/jik876/hifi-gan"""
import glob
import os

View File

@@ -2,6 +2,7 @@
This is a base lightning module that can be used to train a model.
The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
"""
import inspect
from abc import ABC
from typing import Any, Dict

View File

@@ -2,7 +2,7 @@ import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
from conformer import ConformerBlock
from diffusers.models.activations import get_activation

View File

@@ -1,12 +1,12 @@
""" from https://github.com/jaywalnut310/glow-tts """
"""from https://github.com/jaywalnut310/glow-tts"""
import math
import torch
import torch.nn as nn
import torch.nn as nn # pylint: disable=consider-using-from-import
from einops import rearrange
import matcha.utils as utils
import matcha.utils as utils # pylint: disable=consider-using-from-import
from matcha.utils.model import sequence_mask
log = utils.get_pylogger(__name__)

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
import torch.nn as nn # pylint: disable=consider-using-from-import
from diffusers.models.attention import (
GEGLU,
GELU,

View File

@@ -4,7 +4,7 @@ import random
import torch
import matcha.utils.monotonic_align as monotonic_align
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.flow_matching import CFM
@@ -106,6 +106,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Lengths of mel spectrograms
"rtf": float,
# Real-time factor
}
"""
# For RTF computation
t = dt.datetime.now()
@@ -152,7 +153,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
1. duration loss: loss between predicted token durations and those extracted by Monotonic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs.
3. flow matching loss: loss between mel-spectrogram and decoder outputs.

View File

@@ -1,4 +1,5 @@
""" from https://github.com/keithito/tacotron """
"""from https://github.com/keithito/tacotron"""
from matcha.text import cleaners
from matcha.text.symbols import symbols
@@ -7,6 +8,10 @@ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
class UnknownCleanerException(Exception):
pass
def text_to_sequence(text, cleaner_names):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
@@ -48,6 +53,6 @@ def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
raise UnknownCleanerException(f"Unknown cleaner: {name}")
text = cleaner(text)
return text

View File

@@ -1,4 +1,4 @@
""" from https://github.com/keithito/tacotron
"""from https://github.com/keithito/tacotron
Cleaners are transformations that run over the input text at both training and eval time.
@@ -36,9 +36,12 @@ global_phonemizer = phonemizer.backend.EspeakBackend(
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")
# Remove brackets
_brackets_re = re.compile(r"[\[\]\(\)\{\}]")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
(re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
@@ -72,6 +75,10 @@ def lowercase(text):
return text.lower()
def remove_brackets(text):
return re.sub(_brackets_re, "", text)
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
@@ -101,10 +108,26 @@ def english_cleaners2(text):
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
# Added in some cases espeak is not removing brackets
phonemes = remove_brackets(phonemes)
phonemes = collapse_whitespace(phonemes)
return phonemes
def ipa_simplifier(text):
replacements = [
("ɐ", "ə"),
("ˈə", "ə"),
("ʤ", ""),
("ʧ", ""),
("", "ɪ"),
]
for replacement in replacements:
text = text.replace(replacement[0], replacement[1])
phonemes = collapse_whitespace(text)
return phonemes
# I am removing this due to incompatibility with several version of python
# However, if you want to use it, you can uncomment it
# and install piper-phonemize with the following command:

View File

@@ -1,4 +1,4 @@
""" from https://github.com/keithito/tacotron """
"""from https://github.com/keithito/tacotron"""
import re

View File

@@ -1,7 +1,8 @@
""" from https://github.com/keithito/tacotron
"""from https://github.com/keithito/tacotron
Defines the set of symbols used in text input to the model.
"""
_pad = "_"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

View File

@@ -48,7 +48,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python
import argparse
import os
import sys
import tempfile
from pathlib import Path
import torchaudio
from torch.hub import download_url_to_file
from tqdm import tqdm
from matcha.utils.data.utils import _extract_zip
URLS = {
"en-US": {
"female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_F.zip",
"male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_M.zip",
},
"ja-JP": {
"female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_F.zip",
"male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_M.zip",
},
}
INFO_PAGE = "https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/"
# On their website they say "We NICT open-sourced Hi-Fi-CAPTAIN",
# but they use this very-much-not-open-source licence.
# Dunno if this is open washing or stupidity.
LICENCE = "CC BY-NC-SA 4.0"
# I'd normally put the citation here. It's on their website.
# Boo to non-open-source stuff.
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files")
parser.add_argument(
"-r",
"--skip-resampling",
action="store_true",
default=False,
help="Skip resampling the data (from 48 to 22.05)",
)
parser.add_argument(
"-l", "--language", type=str, choices=["en-US", "ja-JP"], default="en-US", help="The language to download"
)
parser.add_argument(
"-g",
"--gender",
type=str,
choices=["male", "female"],
default="female",
help="The gender of the speaker to download",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
default="data",
help="Place to store the converted data. Top-level only, the subdirectory will be created",
)
return parser.parse_args()
def process_text(infile, outpath: Path):
outmode = "w"
if infile.endswith("dev.txt"):
outfile = outpath / "valid.txt"
elif infile.endswith("eval.txt"):
outfile = outpath / "test.txt"
else:
outfile = outpath / "train.txt"
if outfile.exists():
outmode = "a"
with (
open(infile, encoding="utf-8") as inf,
open(outfile, outmode, encoding="utf-8") as of,
):
for line in inf.readlines():
line = line.strip()
fileid, rest = line.split(" ", maxsplit=1)
outfile = str(outpath / f"{fileid}.wav")
of.write(f"{outfile}|{rest}\n")
def process_files(zipfile, outpath, resample=True):
with tempfile.TemporaryDirectory() as tmpdirname:
for filename in tqdm(_extract_zip(zipfile, tmpdirname)):
if not filename.startswith(tmpdirname):
filename = os.path.join(tmpdirname, filename)
if filename.endswith(".txt"):
process_text(filename, outpath)
elif filename.endswith(".wav"):
filepart = filename.rsplit("/", maxsplit=1)[-1]
outfile = str(outpath / filepart)
arr, sr = torchaudio.load(filename)
if resample:
arr = torchaudio.functional.resample(arr, orig_freq=sr, new_freq=22050)
torchaudio.save(outfile, arr, 22050)
else:
continue
def main():
args = get_args()
save_dir = None
if args.save_dir:
save_dir = Path(args.save_dir)
if not save_dir.is_dir():
save_dir.mkdir()
if not args.output_dir:
print("output directory not specified, exiting")
sys.exit(1)
URL = URLS[args.language][args.gender]
dirname = f"hi-fi_{args.language}_{args.gender}"
outbasepath = Path(args.output_dir)
if not outbasepath.is_dir():
outbasepath.mkdir()
outpath = outbasepath / dirname
if not outpath.is_dir():
outpath.mkdir()
resample = True
if args.skip_resampling:
resample = False
if save_dir:
zipname = URL.rsplit("/", maxsplit=1)[-1]
zipfile = save_dir / zipname
if not zipfile.exists():
download_url_to_file(URL, zipfile, progress=True)
process_files(zipfile, outpath, resample)
else:
with tempfile.NamedTemporaryFile(suffix=".zip", delete=True) as zf:
download_url_to_file(URL, zf.name, progress=True)
process_files(zf.name, outpath, resample)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,97 @@
#!/usr/bin/env python
import argparse
import random
import tempfile
from pathlib import Path
from torch.hub import download_url_to_file
from matcha.utils.data.utils import _extract_tar
URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
INFO_PAGE = "https://keithito.com/LJ-Speech-Dataset/"
LICENCE = "Public domain (LibriVox copyright disclaimer)"
CITATION = """
@misc{ljspeech17,
author = {Keith Ito and Linda Johnson},
title = {The LJ Speech Dataset},
howpublished = {\\url{https://keithito.com/LJ-Speech-Dataset/}},
year = 2017
}
"""
def decision():
return random.random() < 0.98
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files")
parser.add_argument(
"output_dir",
type=str,
nargs="?",
default="data",
help="Place to store the converted data (subdirectory LJSpeech-1.1 will be created)",
)
return parser.parse_args()
def process_csv(ljpath: Path):
if (ljpath / "metadata.csv").exists():
basepath = ljpath
elif (ljpath / "LJSpeech-1.1" / "metadata.csv").exists():
basepath = ljpath / "LJSpeech-1.1"
csvpath = basepath / "metadata.csv"
wavpath = basepath / "wavs"
with (
open(csvpath, encoding="utf-8") as csvf,
open(basepath / "train.txt", "w", encoding="utf-8") as tf,
open(basepath / "val.txt", "w", encoding="utf-8") as vf,
):
for line in csvf.readlines():
line = line.strip()
parts = line.split("|")
wavfile = str(wavpath / f"{parts[0]}.wav")
if decision():
tf.write(f"{wavfile}|{parts[1]}\n")
else:
vf.write(f"{wavfile}|{parts[1]}\n")
def main():
args = get_args()
save_dir = None
if args.save_dir:
save_dir = Path(args.save_dir)
if not save_dir.is_dir():
save_dir.mkdir()
outpath = Path(args.output_dir)
if not outpath.is_dir():
outpath.mkdir()
if save_dir:
tarname = URL.rsplit("/", maxsplit=1)[-1]
tarfile = save_dir / tarname
if not tarfile.exists():
download_url_to_file(URL, str(tarfile), progress=True)
_extract_tar(tarfile, outpath)
process_csv(outpath)
else:
with tempfile.NamedTemporaryFile(suffix=".tar.bz2", delete=True) as zf:
download_url_to_file(URL, zf.name, progress=True)
_extract_tar(zf.name, outpath)
process_csv(outpath)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
# taken from https://github.com/pytorch/audio/blob/main/src/torchaudio/datasets/utils.py
# Copyright (c) 2017 Facebook Inc. (Soumith Chintala)
# Licence: BSD 2-Clause
# pylint: disable=C0123
import logging
import os
import tarfile
import zipfile
from pathlib import Path
from typing import Any, List, Optional, Union
_LG = logging.getLogger(__name__)
def _extract_tar(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if type(from_path) is Path:
from_path = str(Path)
if to_path is None:
to_path = os.path.dirname(from_path)
with tarfile.open(from_path, "r") as tar:
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
return files
def _extract_zip(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if type(from_path) is Path:
from_path = str(Path)
if to_path is None:
to_path = os.path.dirname(from_path)
with zipfile.ZipFile(from_path, "r") as zfile:
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
return files

View File

@@ -4,6 +4,7 @@ when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
@@ -102,10 +103,8 @@ def main():
log.info("Dataloader loaded! Now computing stats...")
params = compute_data_statistics(data_loader, cfg["n_feats"])
print(params)
json.dump(
params,
open(output_file, "w"),
)
with open(output_file, "w", encoding="utf-8") as dumpfile:
json.dump(params, dumpfile)
if __name__ == "__main__":

View File

@@ -4,6 +4,7 @@ when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os

View File

@@ -1,4 +1,4 @@
""" from https://github.com/jaywalnut310/glow-tts """
"""from https://github.com/jaywalnut310/glow-tts"""
import numpy as np
import torch

View File

@@ -72,7 +72,7 @@ def print_config_tree(
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file:
rich.print(tree, file=file)
@@ -97,5 +97,5 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file:
rich.print(cfg.tags, file=file)

View File

@@ -16,9 +16,16 @@ with open("README.md", encoding="utf-8") as readme_file:
README = readme_file.read()
cwd = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(cwd, "matcha", "VERSION")) as fin:
with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin:
version = fin.read().strip()
def get_requires():
requirements = os.path.join(os.path.dirname(__file__), "requirements.txt")
with open(requirements, encoding="utf-8") as reqfile:
return [str(r).strip() for r in reqfile]
setup(
name="matcha-tts",
version=version,
@@ -28,7 +35,7 @@ setup(
author="Shivam Mehta",
author_email="shivam.mehta25@gmail.com",
url="https://shivammehta25.github.io/Matcha-TTS",
install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))],
install_requires=get_requires(),
include_dirs=[numpy.get_include()],
include_package_data=True,
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),