Adding the possibility of get durations out of pretrained model

This commit is contained in:
Shivam Mehta
2024-05-24 11:34:51 +02:00
parent fb7b954de5
commit 4b39f6cad0
13 changed files with 274 additions and 24 deletions

View File

@@ -1,5 +1,5 @@
default_language_version: default_language_version:
python: python3.10 python: python3.11
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks

View File

@@ -1,7 +1,7 @@
_target_: matcha.data.text_mel_datamodule.TextMelDataModule _target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: ljspeech name: ljspeech
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt train_filelist_path: data/LJSpeech-1.1/train.txt
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt valid_filelist_path: data/LJSpeech-1.1/val.txt
batch_size: 32 batch_size: 32
num_workers: 20 num_workers: 20
pin_memory: True pin_memory: True

View File

@@ -1 +1 @@
0.0.5.1 0.0.6.0

View File

@@ -48,7 +48,7 @@ def plot_spectrogram_to_numpy(spectrogram, filename):
def process_text(i: int, text: str, device: torch.device): def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}") print(f"[{i}] - Input text: {text}")
x = torch.tensor( x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"]), 0), intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
dtype=torch.long, dtype=torch.long,
device=device, device=device,
)[None] )[None]

View File

@@ -109,7 +109,7 @@ class TextMelDataModule(LightningDataModule):
"""Clean up after fit or test.""" """Clean up after fit or test."""
pass # pylint: disable=unnecessary-pass pass # pylint: disable=unnecessary-pass
def state_dict(self): # pylint: disable=no-self-use def state_dict(self):
"""Extra things to save to checkpoint.""" """Extra things to save to checkpoint."""
return {} return {}
@@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset):
filepath, text = filepath_and_text[0], filepath_and_text[1] filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None spk = None
text = self.get_text(text, add_blank=self.add_blank) text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath) mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk} return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
def get_mel(self, filepath): def get_mel(self, filepath):
audio, sr = ta.load(filepath) audio, sr = ta.load(filepath)
@@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset):
return mel return mel
def get_text(self, text, add_blank=True): def get_text(self, text, add_blank=True):
text_norm = text_to_sequence(text, self.cleaners) text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
if self.add_blank: if self.add_blank:
text_norm = intersperse(text_norm, 0) text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm) text_norm = torch.IntTensor(text_norm)
return text_norm return text_norm, cleaned_text
def __getitem__(self, index): def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index]) datapoint = self.get_datapoint(self.filepaths_and_text[index])
@@ -216,6 +216,7 @@ class TextMelBatchCollate:
x = torch.zeros((B, x_max_length), dtype=torch.long) x = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], [] y_lengths, x_lengths = [], []
spks = [] spks = []
filepaths, x_texts = [], []
for i, item in enumerate(batch): for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"] y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1]) y_lengths.append(y_.shape[-1])
@@ -223,9 +224,19 @@ class TextMelBatchCollate:
y[i, :, : y_.shape[-1]] = y_ y[i, :, : y_.shape[-1]] = y_
x[i, : x_.shape[-1]] = x_ x[i, : x_.shape[-1]] = x_
spks.append(item["spk"]) spks.append(item["spk"])
filepaths.append(item["filepath"])
x_texts.append(item["x_text"])
y_lengths = torch.tensor(y_lengths, dtype=torch.long) y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks} return {
"x": x,
"x_lengths": x_lengths,
"y": y,
"y_lengths": y_lengths,
"spks": spks,
"filepaths": filepaths,
"x_texts": x_texts,
}

View File

@@ -58,7 +58,7 @@ class BaseLightningClass(LightningModule, ABC):
y, y_lengths = batch["y"], batch["y_lengths"] y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"] spks = batch["spks"]
dur_loss, prior_loss, diff_loss = self( dur_loss, prior_loss, diff_loss, *_ = self(
x=x, x=x,
x_lengths=x_lengths, x_lengths=x_lengths,
y=y, y=y,

View File

@@ -236,4 +236,4 @@ class MatchaTTS(BaseLightningClass): # 🍵
else: else:
prior_loss = 0 prior_loss = 0
return dur_loss, prior_loss, diff_loss return dur_loss, prior_loss, diff_loss, attn

View File

@@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names):
for symbol in clean_text: for symbol in clean_text:
symbol_id = _symbol_to_id[symbol] symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id] sequence += [symbol_id]
return sequence return sequence, clean_text
def cleaned_text_to_sequence(cleaned_text): def cleaned_text_to_sequence(cleaned_text):

View File

@@ -15,7 +15,6 @@ import logging
import re import re
import phonemizer import phonemizer
import piper_phonemize
from unidecode import unidecode from unidecode import unidecode
# To avoid excessive logging we set the log level of the phonemizer package to Critical # To avoid excessive logging we set the log level of the phonemizer package to Critical
@@ -106,11 +105,17 @@ def english_cleaners2(text):
return phonemes return phonemes
def english_cleaners_piper(text): # I am removing this due to incompatibility with several version of python
"""Pipeline for English text, including abbreviation expansion. + punctuation + stress""" # However, if you want to use it, you can uncomment it
text = convert_to_ascii(text) # and install piper-phonemize with the following command:
text = lowercase(text) # pip install piper-phonemize
text = expand_abbreviations(text)
phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) # import piper_phonemize
phonemes = collapse_whitespace(phonemes) # def english_cleaners_piper(text):
return phonemes # """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
# text = convert_to_ascii(text)
# text = lowercase(text)
# text = expand_abbreviations(text)
# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
# phonemes = collapse_whitespace(phonemes)
# return phonemes

View File

@@ -0,0 +1,194 @@
r"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
import sys
from pathlib import Path
import lightning
import numpy as np
import rootutils
import torch
from hydra import compose, initialize
from omegaconf import open_dict
from torch import nn
from tqdm.auto import tqdm
from matcha.cli import get_device
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS
from matcha.utils.logging_utils import pylogger
from matcha.utils.utils import get_phoneme_durations
log = pylogger.get_pylogger(__name__)
def save_durations_to_folder(
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
):
durations = attn.squeeze().sum(1)[:x_length].numpy()
durations_json = get_phoneme_durations(durations, text)
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
json.dump(durations_json, f, indent=4, ensure_ascii=False)
np.save(output, durations)
@torch.inference_mode()
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
"""Generate durations from the model for each datapoint and save it in a folder
Args:
data_loader (torch.utils.data.DataLoader): Dataloader
model (nn.Module): MatchaTTS model
device (torch.device): GPU or CPU
"""
for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
x = x.to(device)
y = y.to(device)
x_lengths = x_lengths.to(device)
y_lengths = y_lengths.to(device)
spks = spks.to(device) if spks is not None else None
_, _, _, attn = model(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
)
attn = attn.cpu()
for i in range(attn.shape[0]):
save_durations_to_folder(
attn[i],
x_lengths[i].item(),
y_lengths[i].item(),
batch["filepaths"][i],
output_folder,
batch["x_texts"][i],
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-config",
type=str,
default="ljspeech.yaml",
help="The name of the yaml config file under configs/data",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default="32",
help="Can have increased batch size for faster computation",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
required=False,
help="force overwrite the file",
)
parser.add_argument(
"-c",
"--checkpoint_path",
type=str,
required=True,
help="Path to the checkpoint file to load the model from",
)
parser.add_argument(
"-o",
"--output-folder",
type=str,
default=None,
help="Output folder to save the data statistics",
)
parser.add_argument(
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
)
args = parser.parse_args()
with initialize(version_base="1.3", config_path="../../configs/data"):
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
with open_dict(cfg):
del cfg["hydra"]
del cfg["_target_"]
cfg["seed"] = 1234
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
if args.output_folder is not None:
output_folder = Path(args.output_folder)
else:
output_folder = Path(cfg["train_filelist_path"]).parent / "durations"
print(f"Output folder set to: {output_folder}")
if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite")
sys.exit(1)
output_folder.mkdir(parents=True, exist_ok=True)
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
print("Loading model...")
device = get_device(args)
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup()
try:
print("Computing stats for training set if exists...")
train_dataloader = text_mel_datamodule.train_dataloader()
compute_durations(train_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No training set found")
try:
print("Computing stats for validation set if exists...")
val_dataloader = text_mel_datamodule.val_dataloader()
compute_durations(val_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No validation set found")
try:
print("Computing stats for test set if exists...")
test_dataloader = text_mel_datamodule.test_dataloader()
compute_durations(test_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No test set found")
print(f"[+] Done! Data statistics saved to: {output_folder}")
if __name__ == "__main__":
# Helps with generating durations for the dataset to train other architectures
# that cannot learn to align due to limited size of dataset
# Example usage:
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
# This will create a folder in data/processed_data/durations/ljspeech with the durations
main()

View File

@@ -2,6 +2,7 @@ import os
import sys import sys
import warnings import warnings
from importlib.util import find_spec from importlib.util import find_spec
from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable, Dict, Tuple
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else: else:
wget.download(url=url, out=checkpoint_path) wget.download(url=url, out=checkpoint_path)
def get_phoneme_durations(durations, phones):
prev = durations[0]
merged_durations = []
# Convolve with stride 2
for i in range(1, len(durations), 2):
if i == len(durations) - 2:
# if it is last take full value
next_half = durations[i + 1]
else:
next_half = ceil(durations[i + 1] / 2)
curr = prev + durations[i] + next_half
prev = durations[i + 1] - next_half
merged_durations.append(curr)
assert len(phones) == len(merged_durations)
assert len(merged_durations) == (len(durations) - 1) // 2
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
start = torch.tensor(0)
duration_json = []
for i, duration in enumerate(merged_durations):
duration_json.append(
{
phones[i]: {
"starttime": start.item(),
"endtime": duration.item(),
"duration": duration.item() - start.item(),
}
}
)
start = duration
assert list(duration_json[-1].values())[0]["endtime"] == sum(
durations
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
return duration_json

View File

@@ -42,4 +42,3 @@ gradio
gdown gdown
wget wget
seaborn seaborn
piper_phonemize

View File

@@ -38,6 +38,7 @@ setup(
"matcha-data-stats=matcha.utils.generate_data_statistics:main", "matcha-data-stats=matcha.utils.generate_data_statistics:main",
"matcha-tts=matcha.cli:cli", "matcha-tts=matcha.cli:cli",
"matcha-tts-app=matcha.app:main", "matcha-tts-app=matcha.app:main",
"matcha-get-durations=matcha.utils.get_durations_from_trained_model:main",
] ]
}, },
ext_modules=cythonize(exts, language_level=3), ext_modules=cythonize(exts, language_level=3),