From 4b39f6cad0353a0a2a38375d5eabf0ac9142bd9c Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Fri, 24 May 2024 11:34:51 +0200 Subject: [PATCH] Adding the possibility of get durations out of pretrained model --- .pre-commit-config.yaml | 2 +- configs/data/ljspeech.yaml | 4 +- matcha/VERSION | 2 +- matcha/cli.py | 2 +- matcha/data/text_mel_datamodule.py | 23 ++- matcha/models/baselightningmodule.py | 2 +- matcha/models/matcha_tts.py | 2 +- matcha/text/__init__.py | 2 +- matcha/text/cleaners.py | 23 ++- .../utils/get_durations_from_trained_model.py | 194 ++++++++++++++++++ matcha/utils/utils.py | 40 ++++ requirements.txt | 1 - setup.py | 1 + 13 files changed, 274 insertions(+), 24 deletions(-) create mode 100644 matcha/utils/get_durations_from_trained_model.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7cda633..e6f84b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.10 + python: python3.11 repos: - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/configs/data/ljspeech.yaml b/configs/data/ljspeech.yaml index f251420..6fba3be 100644 --- a/configs/data/ljspeech.yaml +++ b/configs/data/ljspeech.yaml @@ -1,7 +1,7 @@ _target_: matcha.data.text_mel_datamodule.TextMelDataModule name: ljspeech -train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt -valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt +train_filelist_path: data/LJSpeech-1.1/train.txt +valid_filelist_path: data/LJSpeech-1.1/val.txt batch_size: 32 num_workers: 20 pin_memory: True diff --git a/matcha/VERSION b/matcha/VERSION index 442b113..5bcb0a7 100644 --- a/matcha/VERSION +++ b/matcha/VERSION @@ -1 +1 @@ -0.0.5.1 +0.0.6.0 diff --git a/matcha/cli.py b/matcha/cli.py index 579d7d6..635c586 100644 --- a/matcha/cli.py +++ b/matcha/cli.py @@ -48,7 +48,7 @@ def plot_spectrogram_to_numpy(spectrogram, filename): def process_text(i: int, text: str, device: torch.device): print(f"[{i}] - Input text: {text}") x = torch.tensor( - intersperse(text_to_sequence(text, ["english_cleaners2"]), 0), + intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), dtype=torch.long, device=device, )[None] diff --git a/matcha/data/text_mel_datamodule.py b/matcha/data/text_mel_datamodule.py index 704f936..f281bfd 100644 --- a/matcha/data/text_mel_datamodule.py +++ b/matcha/data/text_mel_datamodule.py @@ -109,7 +109,7 @@ class TextMelDataModule(LightningDataModule): """Clean up after fit or test.""" pass # pylint: disable=unnecessary-pass - def state_dict(self): # pylint: disable=no-self-use + def state_dict(self): """Extra things to save to checkpoint.""" return {} @@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset): filepath, text = filepath_and_text[0], filepath_and_text[1] spk = None - text = self.get_text(text, add_blank=self.add_blank) + text, cleaned_text = self.get_text(text, add_blank=self.add_blank) mel = self.get_mel(filepath) - return {"x": text, "y": mel, "spk": spk} + return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text} def get_mel(self, filepath): audio, sr = ta.load(filepath) @@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset): return mel def get_text(self, text, add_blank=True): - text_norm = text_to_sequence(text, self.cleaners) + text_norm, cleaned_text = text_to_sequence(text, self.cleaners) if self.add_blank: text_norm = intersperse(text_norm, 0) text_norm = torch.IntTensor(text_norm) - return text_norm + return text_norm, cleaned_text def __getitem__(self, index): datapoint = self.get_datapoint(self.filepaths_and_text[index]) @@ -216,6 +216,7 @@ class TextMelBatchCollate: x = torch.zeros((B, x_max_length), dtype=torch.long) y_lengths, x_lengths = [], [] spks = [] + filepaths, x_texts = [], [] for i, item in enumerate(batch): y_, x_ = item["y"], item["x"] y_lengths.append(y_.shape[-1]) @@ -223,9 +224,19 @@ class TextMelBatchCollate: y[i, :, : y_.shape[-1]] = y_ x[i, : x_.shape[-1]] = x_ spks.append(item["spk"]) + filepaths.append(item["filepath"]) + x_texts.append(item["x_text"]) y_lengths = torch.tensor(y_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long) spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None - return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks} + return { + "x": x, + "x_lengths": x_lengths, + "y": y, + "y_lengths": y_lengths, + "spks": spks, + "filepaths": filepaths, + "x_texts": x_texts, + } diff --git a/matcha/models/baselightningmodule.py b/matcha/models/baselightningmodule.py index 3724888..5fd09a4 100644 --- a/matcha/models/baselightningmodule.py +++ b/matcha/models/baselightningmodule.py @@ -58,7 +58,7 @@ class BaseLightningClass(LightningModule, ABC): y, y_lengths = batch["y"], batch["y_lengths"] spks = batch["spks"] - dur_loss, prior_loss, diff_loss = self( + dur_loss, prior_loss, diff_loss, *_ = self( x=x, x_lengths=x_lengths, y=y, diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 64b2c07..464efcd 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -236,4 +236,4 @@ class MatchaTTS(BaseLightningClass): # 🍵 else: prior_loss = 0 - return dur_loss, prior_loss, diff_loss + return dur_loss, prior_loss, diff_loss, attn diff --git a/matcha/text/__init__.py b/matcha/text/__init__.py index 71a4b57..8c75d6b 100644 --- a/matcha/text/__init__.py +++ b/matcha/text/__init__.py @@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names): for symbol in clean_text: symbol_id = _symbol_to_id[symbol] sequence += [symbol_id] - return sequence + return sequence, clean_text def cleaned_text_to_sequence(cleaned_text): diff --git a/matcha/text/cleaners.py b/matcha/text/cleaners.py index 5e8d96b..36776e3 100644 --- a/matcha/text/cleaners.py +++ b/matcha/text/cleaners.py @@ -15,7 +15,6 @@ import logging import re import phonemizer -import piper_phonemize from unidecode import unidecode # To avoid excessive logging we set the log level of the phonemizer package to Critical @@ -106,11 +105,17 @@ def english_cleaners2(text): return phonemes -def english_cleaners_piper(text): - """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_abbreviations(text) - phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) - phonemes = collapse_whitespace(phonemes) - return phonemes +# I am removing this due to incompatibility with several version of python +# However, if you want to use it, you can uncomment it +# and install piper-phonemize with the following command: +# pip install piper-phonemize + +# import piper_phonemize +# def english_cleaners_piper(text): +# """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" +# text = convert_to_ascii(text) +# text = lowercase(text) +# text = expand_abbreviations(text) +# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) +# phonemes = collapse_whitespace(phonemes) +# return phonemes diff --git a/matcha/utils/get_durations_from_trained_model.py b/matcha/utils/get_durations_from_trained_model.py new file mode 100644 index 0000000..9bee56e --- /dev/null +++ b/matcha/utils/get_durations_from_trained_model.py @@ -0,0 +1,194 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import argparse +import json +import os +import sys +from pathlib import Path + +import lightning +import numpy as np +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from torch import nn +from tqdm.auto import tqdm + +from matcha.cli import get_device +from matcha.data.text_mel_datamodule import TextMelDataModule +from matcha.models.matcha_tts import MatchaTTS +from matcha.utils.logging_utils import pylogger +from matcha.utils.utils import get_phoneme_durations + +log = pylogger.get_pylogger(__name__) + + +def save_durations_to_folder( + attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str +): + durations = attn.squeeze().sum(1)[:x_length].numpy() + durations_json = get_phoneme_durations(durations, text) + output = output_folder / Path(filepath).name.replace(".wav", ".npy") + with open(output.with_suffix(".json"), "w", encoding="utf-8") as f: + json.dump(durations_json, f, indent=4, ensure_ascii=False) + + np.save(output, durations) + + +@torch.inference_mode() +def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder): + """Generate durations from the model for each datapoint and save it in a folder + + Args: + data_loader (torch.utils.data.DataLoader): Dataloader + model (nn.Module): MatchaTTS model + device (torch.device): GPU or CPU + """ + + for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + x = x.to(device) + y = y.to(device) + x_lengths = x_lengths.to(device) + y_lengths = y_lengths.to(device) + spks = spks.to(device) if spks is not None else None + + _, _, _, attn = model( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + ) + attn = attn.cpu() + for i in range(attn.shape[0]): + save_durations_to_folder( + attn[i], + x_lengths[i].item(), + y_lengths[i].item(), + batch["filepaths"][i], + output_folder, + batch["x_texts"][i], + ) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="ljspeech.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="32", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + parser.add_argument( + "-c", + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint file to load the model from", + ) + + parser.add_argument( + "-o", + "--output-folder", + type=str, + default=None, + help="Output folder to save the data statistics", + ) + + parser.add_argument( + "--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)" + ) + + args = parser.parse_args() + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + del cfg["hydra"] + del cfg["_target_"] + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + + if args.output_folder is not None: + output_folder = Path(args.output_folder) + else: + output_folder = Path(cfg["train_filelist_path"]).parent / "durations" + + print(f"Output folder set to: {output_folder}") + + if os.path.exists(output_folder) and not args.force: + print("Folder already exists. Use -f to force overwrite") + sys.exit(1) + + output_folder.mkdir(parents=True, exist_ok=True) + + print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") + print("Loading model...") + device = get_device(args) + model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + try: + print("Computing stats for training set if exists...") + train_dataloader = text_mel_datamodule.train_dataloader() + compute_durations(train_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No training set found") + + try: + print("Computing stats for validation set if exists...") + val_dataloader = text_mel_datamodule.val_dataloader() + compute_durations(val_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No validation set found") + + try: + print("Computing stats for test set if exists...") + test_dataloader = text_mel_datamodule.test_dataloader() + compute_durations(test_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No test set found") + + print(f"[+] Done! Data statistics saved to: {output_folder}") + + +if __name__ == "__main__": + # Helps with generating durations for the dataset to train other architectures + # that cannot learn to align due to limited size of dataset + # Example usage: + # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model + # This will create a folder in data/processed_data/durations/ljspeech with the durations + main() diff --git a/matcha/utils/utils.py b/matcha/utils/utils.py index af65e09..fc3a48e 100644 --- a/matcha/utils/utils.py +++ b/matcha/utils/utils.py @@ -2,6 +2,7 @@ import os import sys import warnings from importlib.util import find_spec +from math import ceil from pathlib import Path from typing import Any, Callable, Dict, Tuple @@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True): gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) else: wget.download(url=url, out=checkpoint_path) + + +def get_phoneme_durations(durations, phones): + prev = durations[0] + merged_durations = [] + # Convolve with stride 2 + for i in range(1, len(durations), 2): + if i == len(durations) - 2: + # if it is last take full value + next_half = durations[i + 1] + else: + next_half = ceil(durations[i + 1] / 2) + + curr = prev + durations[i] + next_half + prev = durations[i + 1] - next_half + merged_durations.append(curr) + + assert len(phones) == len(merged_durations) + assert len(merged_durations) == (len(durations) - 1) // 2 + + merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) + start = torch.tensor(0) + duration_json = [] + for i, duration in enumerate(merged_durations): + duration_json.append( + { + phones[i]: { + "starttime": start.item(), + "endtime": duration.item(), + "duration": duration.item() - start.item(), + } + } + ) + start = duration + + assert list(duration_json[-1].values())[0]["endtime"] == sum( + durations + ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" + return duration_json diff --git a/requirements.txt b/requirements.txt index 0a7e14c..d25358d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,4 +42,3 @@ gradio gdown wget seaborn -piper_phonemize diff --git a/setup.py b/setup.py index 80d4aac..608de9f 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ setup( "matcha-data-stats=matcha.utils.generate_data_statistics:main", "matcha-tts=matcha.cli:cli", "matcha-tts-app=matcha.app:main", + "matcha-get-durations=matcha.utils.get_durations_from_trained_model:main", ] }, ext_modules=cythonize(exts, language_level=3),