mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 09:49:21 +08:00
Adding the possibility of get durations out of pretrained model
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.11
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
|
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
|
||||||
name: ljspeech
|
name: ljspeech
|
||||||
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt
|
train_filelist_path: data/LJSpeech-1.1/train.txt
|
||||||
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt
|
valid_filelist_path: data/LJSpeech-1.1/val.txt
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
num_workers: 20
|
num_workers: 20
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
0.0.5.1
|
0.0.6.0
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ def plot_spectrogram_to_numpy(spectrogram, filename):
|
|||||||
def process_text(i: int, text: str, device: torch.device):
|
def process_text(i: int, text: str, device: torch.device):
|
||||||
print(f"[{i}] - Input text: {text}")
|
print(f"[{i}] - Input text: {text}")
|
||||||
x = torch.tensor(
|
x = torch.tensor(
|
||||||
intersperse(text_to_sequence(text, ["english_cleaners2"]), 0),
|
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device,
|
device=device,
|
||||||
)[None]
|
)[None]
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ class TextMelDataModule(LightningDataModule):
|
|||||||
"""Clean up after fit or test."""
|
"""Clean up after fit or test."""
|
||||||
pass # pylint: disable=unnecessary-pass
|
pass # pylint: disable=unnecessary-pass
|
||||||
|
|
||||||
def state_dict(self): # pylint: disable=no-self-use
|
def state_dict(self):
|
||||||
"""Extra things to save to checkpoint."""
|
"""Extra things to save to checkpoint."""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset):
|
|||||||
filepath, text = filepath_and_text[0], filepath_and_text[1]
|
filepath, text = filepath_and_text[0], filepath_and_text[1]
|
||||||
spk = None
|
spk = None
|
||||||
|
|
||||||
text = self.get_text(text, add_blank=self.add_blank)
|
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
|
||||||
mel = self.get_mel(filepath)
|
mel = self.get_mel(filepath)
|
||||||
|
|
||||||
return {"x": text, "y": mel, "spk": spk}
|
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
|
||||||
|
|
||||||
def get_mel(self, filepath):
|
def get_mel(self, filepath):
|
||||||
audio, sr = ta.load(filepath)
|
audio, sr = ta.load(filepath)
|
||||||
@@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset):
|
|||||||
return mel
|
return mel
|
||||||
|
|
||||||
def get_text(self, text, add_blank=True):
|
def get_text(self, text, add_blank=True):
|
||||||
text_norm = text_to_sequence(text, self.cleaners)
|
text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
|
||||||
if self.add_blank:
|
if self.add_blank:
|
||||||
text_norm = intersperse(text_norm, 0)
|
text_norm = intersperse(text_norm, 0)
|
||||||
text_norm = torch.IntTensor(text_norm)
|
text_norm = torch.IntTensor(text_norm)
|
||||||
return text_norm
|
return text_norm, cleaned_text
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
datapoint = self.get_datapoint(self.filepaths_and_text[index])
|
datapoint = self.get_datapoint(self.filepaths_and_text[index])
|
||||||
@@ -216,6 +216,7 @@ class TextMelBatchCollate:
|
|||||||
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
||||||
y_lengths, x_lengths = [], []
|
y_lengths, x_lengths = [], []
|
||||||
spks = []
|
spks = []
|
||||||
|
filepaths, x_texts = [], []
|
||||||
for i, item in enumerate(batch):
|
for i, item in enumerate(batch):
|
||||||
y_, x_ = item["y"], item["x"]
|
y_, x_ = item["y"], item["x"]
|
||||||
y_lengths.append(y_.shape[-1])
|
y_lengths.append(y_.shape[-1])
|
||||||
@@ -223,9 +224,19 @@ class TextMelBatchCollate:
|
|||||||
y[i, :, : y_.shape[-1]] = y_
|
y[i, :, : y_.shape[-1]] = y_
|
||||||
x[i, : x_.shape[-1]] = x_
|
x[i, : x_.shape[-1]] = x_
|
||||||
spks.append(item["spk"])
|
spks.append(item["spk"])
|
||||||
|
filepaths.append(item["filepath"])
|
||||||
|
x_texts.append(item["x_text"])
|
||||||
|
|
||||||
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
||||||
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
||||||
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
|
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
|
||||||
|
|
||||||
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks}
|
return {
|
||||||
|
"x": x,
|
||||||
|
"x_lengths": x_lengths,
|
||||||
|
"y": y,
|
||||||
|
"y_lengths": y_lengths,
|
||||||
|
"spks": spks,
|
||||||
|
"filepaths": filepaths,
|
||||||
|
"x_texts": x_texts,
|
||||||
|
}
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class BaseLightningClass(LightningModule, ABC):
|
|||||||
y, y_lengths = batch["y"], batch["y_lengths"]
|
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||||
spks = batch["spks"]
|
spks = batch["spks"]
|
||||||
|
|
||||||
dur_loss, prior_loss, diff_loss = self(
|
dur_loss, prior_loss, diff_loss, *_ = self(
|
||||||
x=x,
|
x=x,
|
||||||
x_lengths=x_lengths,
|
x_lengths=x_lengths,
|
||||||
y=y,
|
y=y,
|
||||||
|
|||||||
@@ -236,4 +236,4 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
else:
|
else:
|
||||||
prior_loss = 0
|
prior_loss = 0
|
||||||
|
|
||||||
return dur_loss, prior_loss, diff_loss
|
return dur_loss, prior_loss, diff_loss, attn
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names):
|
|||||||
for symbol in clean_text:
|
for symbol in clean_text:
|
||||||
symbol_id = _symbol_to_id[symbol]
|
symbol_id = _symbol_to_id[symbol]
|
||||||
sequence += [symbol_id]
|
sequence += [symbol_id]
|
||||||
return sequence
|
return sequence, clean_text
|
||||||
|
|
||||||
|
|
||||||
def cleaned_text_to_sequence(cleaned_text):
|
def cleaned_text_to_sequence(cleaned_text):
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import logging
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import phonemizer
|
import phonemizer
|
||||||
import piper_phonemize
|
|
||||||
from unidecode import unidecode
|
from unidecode import unidecode
|
||||||
|
|
||||||
# To avoid excessive logging we set the log level of the phonemizer package to Critical
|
# To avoid excessive logging we set the log level of the phonemizer package to Critical
|
||||||
@@ -106,11 +105,17 @@ def english_cleaners2(text):
|
|||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
|
|
||||||
def english_cleaners_piper(text):
|
# I am removing this due to incompatibility with several version of python
|
||||||
"""Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
|
# However, if you want to use it, you can uncomment it
|
||||||
text = convert_to_ascii(text)
|
# and install piper-phonemize with the following command:
|
||||||
text = lowercase(text)
|
# pip install piper-phonemize
|
||||||
text = expand_abbreviations(text)
|
|
||||||
phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
|
# import piper_phonemize
|
||||||
phonemes = collapse_whitespace(phonemes)
|
# def english_cleaners_piper(text):
|
||||||
return phonemes
|
# """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
|
||||||
|
# text = convert_to_ascii(text)
|
||||||
|
# text = lowercase(text)
|
||||||
|
# text = expand_abbreviations(text)
|
||||||
|
# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
|
||||||
|
# phonemes = collapse_whitespace(phonemes)
|
||||||
|
# return phonemes
|
||||||
|
|||||||
194
matcha/utils/get_durations_from_trained_model.py
Normal file
194
matcha/utils/get_durations_from_trained_model.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
r"""
|
||||||
|
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
|
||||||
|
when needed.
|
||||||
|
|
||||||
|
Parameters from hparam.py will be used
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import lightning
|
||||||
|
import numpy as np
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from hydra import compose, initialize
|
||||||
|
from omegaconf import open_dict
|
||||||
|
from torch import nn
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from matcha.cli import get_device
|
||||||
|
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||||
|
from matcha.models.matcha_tts import MatchaTTS
|
||||||
|
from matcha.utils.logging_utils import pylogger
|
||||||
|
from matcha.utils.utils import get_phoneme_durations
|
||||||
|
|
||||||
|
log = pylogger.get_pylogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def save_durations_to_folder(
|
||||||
|
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
|
||||||
|
):
|
||||||
|
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
||||||
|
durations_json = get_phoneme_durations(durations, text)
|
||||||
|
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
|
||||||
|
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
|
||||||
|
json.dump(durations_json, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
np.save(output, durations)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
|
||||||
|
"""Generate durations from the model for each datapoint and save it in a folder
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_loader (torch.utils.data.DataLoader): Dataloader
|
||||||
|
model (nn.Module): MatchaTTS model
|
||||||
|
device (torch.device): GPU or CPU
|
||||||
|
"""
|
||||||
|
|
||||||
|
for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
|
||||||
|
x, x_lengths = batch["x"], batch["x_lengths"]
|
||||||
|
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||||
|
spks = batch["spks"]
|
||||||
|
x = x.to(device)
|
||||||
|
y = y.to(device)
|
||||||
|
x_lengths = x_lengths.to(device)
|
||||||
|
y_lengths = y_lengths.to(device)
|
||||||
|
spks = spks.to(device) if spks is not None else None
|
||||||
|
|
||||||
|
_, _, _, attn = model(
|
||||||
|
x=x,
|
||||||
|
x_lengths=x_lengths,
|
||||||
|
y=y,
|
||||||
|
y_lengths=y_lengths,
|
||||||
|
spks=spks,
|
||||||
|
)
|
||||||
|
attn = attn.cpu()
|
||||||
|
for i in range(attn.shape[0]):
|
||||||
|
save_durations_to_folder(
|
||||||
|
attn[i],
|
||||||
|
x_lengths[i].item(),
|
||||||
|
y_lengths[i].item(),
|
||||||
|
batch["filepaths"][i],
|
||||||
|
output_folder,
|
||||||
|
batch["x_texts"][i],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-i",
|
||||||
|
"--input-config",
|
||||||
|
type=str,
|
||||||
|
default="ljspeech.yaml",
|
||||||
|
help="The name of the yaml config file under configs/data",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-b",
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default="32",
|
||||||
|
help="Can have increased batch size for faster computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
required=False,
|
||||||
|
help="force overwrite the file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--checkpoint_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint file to load the model from",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-o",
|
||||||
|
"--output-folder",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Output folder to save the data statistics",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||||
|
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
||||||
|
|
||||||
|
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||||
|
|
||||||
|
with open_dict(cfg):
|
||||||
|
del cfg["hydra"]
|
||||||
|
del cfg["_target_"]
|
||||||
|
cfg["seed"] = 1234
|
||||||
|
cfg["batch_size"] = args.batch_size
|
||||||
|
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||||
|
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||||
|
|
||||||
|
if args.output_folder is not None:
|
||||||
|
output_folder = Path(args.output_folder)
|
||||||
|
else:
|
||||||
|
output_folder = Path(cfg["train_filelist_path"]).parent / "durations"
|
||||||
|
|
||||||
|
print(f"Output folder set to: {output_folder}")
|
||||||
|
|
||||||
|
if os.path.exists(output_folder) and not args.force:
|
||||||
|
print("Folder already exists. Use -f to force overwrite")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
|
||||||
|
print("Loading model...")
|
||||||
|
device = get_device(args)
|
||||||
|
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
|
||||||
|
|
||||||
|
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||||
|
text_mel_datamodule.setup()
|
||||||
|
try:
|
||||||
|
print("Computing stats for training set if exists...")
|
||||||
|
train_dataloader = text_mel_datamodule.train_dataloader()
|
||||||
|
compute_durations(train_dataloader, model, device, output_folder)
|
||||||
|
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||||
|
print("No training set found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Computing stats for validation set if exists...")
|
||||||
|
val_dataloader = text_mel_datamodule.val_dataloader()
|
||||||
|
compute_durations(val_dataloader, model, device, output_folder)
|
||||||
|
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||||
|
print("No validation set found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Computing stats for test set if exists...")
|
||||||
|
test_dataloader = text_mel_datamodule.test_dataloader()
|
||||||
|
compute_durations(test_dataloader, model, device, output_folder)
|
||||||
|
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||||
|
print("No test set found")
|
||||||
|
|
||||||
|
print(f"[+] Done! Data statistics saved to: {output_folder}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Helps with generating durations for the dataset to train other architectures
|
||||||
|
# that cannot learn to align due to limited size of dataset
|
||||||
|
# Example usage:
|
||||||
|
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
|
||||||
|
# This will create a folder in data/processed_data/durations/ljspeech with the durations
|
||||||
|
main()
|
||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Tuple
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
|||||||
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
|
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
|
||||||
else:
|
else:
|
||||||
wget.download(url=url, out=checkpoint_path)
|
wget.download(url=url, out=checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_phoneme_durations(durations, phones):
|
||||||
|
prev = durations[0]
|
||||||
|
merged_durations = []
|
||||||
|
# Convolve with stride 2
|
||||||
|
for i in range(1, len(durations), 2):
|
||||||
|
if i == len(durations) - 2:
|
||||||
|
# if it is last take full value
|
||||||
|
next_half = durations[i + 1]
|
||||||
|
else:
|
||||||
|
next_half = ceil(durations[i + 1] / 2)
|
||||||
|
|
||||||
|
curr = prev + durations[i] + next_half
|
||||||
|
prev = durations[i + 1] - next_half
|
||||||
|
merged_durations.append(curr)
|
||||||
|
|
||||||
|
assert len(phones) == len(merged_durations)
|
||||||
|
assert len(merged_durations) == (len(durations) - 1) // 2
|
||||||
|
|
||||||
|
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
|
||||||
|
start = torch.tensor(0)
|
||||||
|
duration_json = []
|
||||||
|
for i, duration in enumerate(merged_durations):
|
||||||
|
duration_json.append(
|
||||||
|
{
|
||||||
|
phones[i]: {
|
||||||
|
"starttime": start.item(),
|
||||||
|
"endtime": duration.item(),
|
||||||
|
"duration": duration.item() - start.item(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
start = duration
|
||||||
|
|
||||||
|
assert list(duration_json[-1].values())[0]["endtime"] == sum(
|
||||||
|
durations
|
||||||
|
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
|
||||||
|
return duration_json
|
||||||
|
|||||||
@@ -42,4 +42,3 @@ gradio
|
|||||||
gdown
|
gdown
|
||||||
wget
|
wget
|
||||||
seaborn
|
seaborn
|
||||||
piper_phonemize
|
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -38,6 +38,7 @@ setup(
|
|||||||
"matcha-data-stats=matcha.utils.generate_data_statistics:main",
|
"matcha-data-stats=matcha.utils.generate_data_statistics:main",
|
||||||
"matcha-tts=matcha.cli:cli",
|
"matcha-tts=matcha.cli:cli",
|
||||||
"matcha-tts-app=matcha.app:main",
|
"matcha-tts-app=matcha.app:main",
|
||||||
|
"matcha-get-durations=matcha.utils.get_durations_from_trained_model:main",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
ext_modules=cythonize(exts, language_level=3),
|
ext_modules=cythonize(exts, language_level=3),
|
||||||
|
|||||||
Reference in New Issue
Block a user