From 4b39f6cad0353a0a2a38375d5eabf0ac9142bd9c Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Fri, 24 May 2024 11:34:51 +0200 Subject: [PATCH 1/6] Adding the possibility of get durations out of pretrained model --- .pre-commit-config.yaml | 2 +- configs/data/ljspeech.yaml | 4 +- matcha/VERSION | 2 +- matcha/cli.py | 2 +- matcha/data/text_mel_datamodule.py | 23 ++- matcha/models/baselightningmodule.py | 2 +- matcha/models/matcha_tts.py | 2 +- matcha/text/__init__.py | 2 +- matcha/text/cleaners.py | 23 ++- .../utils/get_durations_from_trained_model.py | 194 ++++++++++++++++++ matcha/utils/utils.py | 40 ++++ requirements.txt | 1 - setup.py | 1 + 13 files changed, 274 insertions(+), 24 deletions(-) create mode 100644 matcha/utils/get_durations_from_trained_model.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7cda633..e6f84b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.10 + python: python3.11 repos: - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/configs/data/ljspeech.yaml b/configs/data/ljspeech.yaml index f251420..6fba3be 100644 --- a/configs/data/ljspeech.yaml +++ b/configs/data/ljspeech.yaml @@ -1,7 +1,7 @@ _target_: matcha.data.text_mel_datamodule.TextMelDataModule name: ljspeech -train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt -valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt +train_filelist_path: data/LJSpeech-1.1/train.txt +valid_filelist_path: data/LJSpeech-1.1/val.txt batch_size: 32 num_workers: 20 pin_memory: True diff --git a/matcha/VERSION b/matcha/VERSION index 442b113..5bcb0a7 100644 --- a/matcha/VERSION +++ b/matcha/VERSION @@ -1 +1 @@ -0.0.5.1 +0.0.6.0 diff --git a/matcha/cli.py b/matcha/cli.py index 579d7d6..635c586 100644 --- a/matcha/cli.py +++ b/matcha/cli.py @@ -48,7 +48,7 @@ def plot_spectrogram_to_numpy(spectrogram, filename): def process_text(i: int, text: str, device: torch.device): print(f"[{i}] - Input text: {text}") x = torch.tensor( - intersperse(text_to_sequence(text, ["english_cleaners2"]), 0), + intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), dtype=torch.long, device=device, )[None] diff --git a/matcha/data/text_mel_datamodule.py b/matcha/data/text_mel_datamodule.py index 704f936..f281bfd 100644 --- a/matcha/data/text_mel_datamodule.py +++ b/matcha/data/text_mel_datamodule.py @@ -109,7 +109,7 @@ class TextMelDataModule(LightningDataModule): """Clean up after fit or test.""" pass # pylint: disable=unnecessary-pass - def state_dict(self): # pylint: disable=no-self-use + def state_dict(self): """Extra things to save to checkpoint.""" return {} @@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset): filepath, text = filepath_and_text[0], filepath_and_text[1] spk = None - text = self.get_text(text, add_blank=self.add_blank) + text, cleaned_text = self.get_text(text, add_blank=self.add_blank) mel = self.get_mel(filepath) - return {"x": text, "y": mel, "spk": spk} + return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text} def get_mel(self, filepath): audio, sr = ta.load(filepath) @@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset): return mel def get_text(self, text, add_blank=True): - text_norm = text_to_sequence(text, self.cleaners) + text_norm, cleaned_text = text_to_sequence(text, self.cleaners) if self.add_blank: text_norm = intersperse(text_norm, 0) text_norm = torch.IntTensor(text_norm) - return text_norm + return text_norm, cleaned_text def __getitem__(self, index): datapoint = self.get_datapoint(self.filepaths_and_text[index]) @@ -216,6 +216,7 @@ class TextMelBatchCollate: x = torch.zeros((B, x_max_length), dtype=torch.long) y_lengths, x_lengths = [], [] spks = [] + filepaths, x_texts = [], [] for i, item in enumerate(batch): y_, x_ = item["y"], item["x"] y_lengths.append(y_.shape[-1]) @@ -223,9 +224,19 @@ class TextMelBatchCollate: y[i, :, : y_.shape[-1]] = y_ x[i, : x_.shape[-1]] = x_ spks.append(item["spk"]) + filepaths.append(item["filepath"]) + x_texts.append(item["x_text"]) y_lengths = torch.tensor(y_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long) spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None - return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks} + return { + "x": x, + "x_lengths": x_lengths, + "y": y, + "y_lengths": y_lengths, + "spks": spks, + "filepaths": filepaths, + "x_texts": x_texts, + } diff --git a/matcha/models/baselightningmodule.py b/matcha/models/baselightningmodule.py index 3724888..5fd09a4 100644 --- a/matcha/models/baselightningmodule.py +++ b/matcha/models/baselightningmodule.py @@ -58,7 +58,7 @@ class BaseLightningClass(LightningModule, ABC): y, y_lengths = batch["y"], batch["y_lengths"] spks = batch["spks"] - dur_loss, prior_loss, diff_loss = self( + dur_loss, prior_loss, diff_loss, *_ = self( x=x, x_lengths=x_lengths, y=y, diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 64b2c07..464efcd 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -236,4 +236,4 @@ class MatchaTTS(BaseLightningClass): # 🍵 else: prior_loss = 0 - return dur_loss, prior_loss, diff_loss + return dur_loss, prior_loss, diff_loss, attn diff --git a/matcha/text/__init__.py b/matcha/text/__init__.py index 71a4b57..8c75d6b 100644 --- a/matcha/text/__init__.py +++ b/matcha/text/__init__.py @@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names): for symbol in clean_text: symbol_id = _symbol_to_id[symbol] sequence += [symbol_id] - return sequence + return sequence, clean_text def cleaned_text_to_sequence(cleaned_text): diff --git a/matcha/text/cleaners.py b/matcha/text/cleaners.py index 5e8d96b..36776e3 100644 --- a/matcha/text/cleaners.py +++ b/matcha/text/cleaners.py @@ -15,7 +15,6 @@ import logging import re import phonemizer -import piper_phonemize from unidecode import unidecode # To avoid excessive logging we set the log level of the phonemizer package to Critical @@ -106,11 +105,17 @@ def english_cleaners2(text): return phonemes -def english_cleaners_piper(text): - """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_abbreviations(text) - phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) - phonemes = collapse_whitespace(phonemes) - return phonemes +# I am removing this due to incompatibility with several version of python +# However, if you want to use it, you can uncomment it +# and install piper-phonemize with the following command: +# pip install piper-phonemize + +# import piper_phonemize +# def english_cleaners_piper(text): +# """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" +# text = convert_to_ascii(text) +# text = lowercase(text) +# text = expand_abbreviations(text) +# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) +# phonemes = collapse_whitespace(phonemes) +# return phonemes diff --git a/matcha/utils/get_durations_from_trained_model.py b/matcha/utils/get_durations_from_trained_model.py new file mode 100644 index 0000000..9bee56e --- /dev/null +++ b/matcha/utils/get_durations_from_trained_model.py @@ -0,0 +1,194 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import argparse +import json +import os +import sys +from pathlib import Path + +import lightning +import numpy as np +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from torch import nn +from tqdm.auto import tqdm + +from matcha.cli import get_device +from matcha.data.text_mel_datamodule import TextMelDataModule +from matcha.models.matcha_tts import MatchaTTS +from matcha.utils.logging_utils import pylogger +from matcha.utils.utils import get_phoneme_durations + +log = pylogger.get_pylogger(__name__) + + +def save_durations_to_folder( + attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str +): + durations = attn.squeeze().sum(1)[:x_length].numpy() + durations_json = get_phoneme_durations(durations, text) + output = output_folder / Path(filepath).name.replace(".wav", ".npy") + with open(output.with_suffix(".json"), "w", encoding="utf-8") as f: + json.dump(durations_json, f, indent=4, ensure_ascii=False) + + np.save(output, durations) + + +@torch.inference_mode() +def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder): + """Generate durations from the model for each datapoint and save it in a folder + + Args: + data_loader (torch.utils.data.DataLoader): Dataloader + model (nn.Module): MatchaTTS model + device (torch.device): GPU or CPU + """ + + for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + x = x.to(device) + y = y.to(device) + x_lengths = x_lengths.to(device) + y_lengths = y_lengths.to(device) + spks = spks.to(device) if spks is not None else None + + _, _, _, attn = model( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + ) + attn = attn.cpu() + for i in range(attn.shape[0]): + save_durations_to_folder( + attn[i], + x_lengths[i].item(), + y_lengths[i].item(), + batch["filepaths"][i], + output_folder, + batch["x_texts"][i], + ) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="ljspeech.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="32", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + parser.add_argument( + "-c", + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint file to load the model from", + ) + + parser.add_argument( + "-o", + "--output-folder", + type=str, + default=None, + help="Output folder to save the data statistics", + ) + + parser.add_argument( + "--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)" + ) + + args = parser.parse_args() + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + del cfg["hydra"] + del cfg["_target_"] + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + + if args.output_folder is not None: + output_folder = Path(args.output_folder) + else: + output_folder = Path(cfg["train_filelist_path"]).parent / "durations" + + print(f"Output folder set to: {output_folder}") + + if os.path.exists(output_folder) and not args.force: + print("Folder already exists. Use -f to force overwrite") + sys.exit(1) + + output_folder.mkdir(parents=True, exist_ok=True) + + print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") + print("Loading model...") + device = get_device(args) + model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + try: + print("Computing stats for training set if exists...") + train_dataloader = text_mel_datamodule.train_dataloader() + compute_durations(train_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No training set found") + + try: + print("Computing stats for validation set if exists...") + val_dataloader = text_mel_datamodule.val_dataloader() + compute_durations(val_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No validation set found") + + try: + print("Computing stats for test set if exists...") + test_dataloader = text_mel_datamodule.test_dataloader() + compute_durations(test_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No test set found") + + print(f"[+] Done! Data statistics saved to: {output_folder}") + + +if __name__ == "__main__": + # Helps with generating durations for the dataset to train other architectures + # that cannot learn to align due to limited size of dataset + # Example usage: + # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model + # This will create a folder in data/processed_data/durations/ljspeech with the durations + main() diff --git a/matcha/utils/utils.py b/matcha/utils/utils.py index af65e09..fc3a48e 100644 --- a/matcha/utils/utils.py +++ b/matcha/utils/utils.py @@ -2,6 +2,7 @@ import os import sys import warnings from importlib.util import find_spec +from math import ceil from pathlib import Path from typing import Any, Callable, Dict, Tuple @@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True): gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) else: wget.download(url=url, out=checkpoint_path) + + +def get_phoneme_durations(durations, phones): + prev = durations[0] + merged_durations = [] + # Convolve with stride 2 + for i in range(1, len(durations), 2): + if i == len(durations) - 2: + # if it is last take full value + next_half = durations[i + 1] + else: + next_half = ceil(durations[i + 1] / 2) + + curr = prev + durations[i] + next_half + prev = durations[i + 1] - next_half + merged_durations.append(curr) + + assert len(phones) == len(merged_durations) + assert len(merged_durations) == (len(durations) - 1) // 2 + + merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) + start = torch.tensor(0) + duration_json = [] + for i, duration in enumerate(merged_durations): + duration_json.append( + { + phones[i]: { + "starttime": start.item(), + "endtime": duration.item(), + "duration": duration.item() - start.item(), + } + } + ) + start = duration + + assert list(duration_json[-1].values())[0]["endtime"] == sum( + durations + ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" + return duration_json diff --git a/requirements.txt b/requirements.txt index 0a7e14c..d25358d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,4 +42,3 @@ gradio gdown wget seaborn -piper_phonemize diff --git a/setup.py b/setup.py index 80d4aac..608de9f 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ setup( "matcha-data-stats=matcha.utils.generate_data_statistics:main", "matcha-tts=matcha.cli:cli", "matcha-tts-app=matcha.app:main", + "matcha-get-durations=matcha.utils.get_durations_from_trained_model:main", ] }, ext_modules=cythonize(exts, language_level=3), From d816c40e3d578f1a318b9ce6298dc000955ea17a Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Fri, 24 May 2024 11:46:03 +0200 Subject: [PATCH 2/6] Updating the notebook to adjust to the change --- synthesis.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synthesis.ipynb b/synthesis.ipynb index dfbde30..1e47c53 100644 --- a/synthesis.ipynb +++ b/synthesis.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "148f4bc0-c28e-4670-9a5e-4c7928ab8992", "metadata": {}, "outputs": [ @@ -192,7 +192,7 @@ "source": [ "@torch.inference_mode()\n", "def process_text(text: str):\n", - " x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2']), 0),dtype=torch.long, device=device)[None]\n", + " x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2'])[0], 0),dtype=torch.long, device=device)[None]\n", " x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)\n", " x_phones = sequence_to_text(x.squeeze(0).tolist())\n", " return {\n", From e658aee6a56ecbeeda6b1b97dc2be17fbe3cc710 Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Sat, 25 May 2024 20:15:17 +0200 Subject: [PATCH 3/6] Pinning gradio --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d25358d..b3b9109 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,7 +38,7 @@ conformer==0.3.2 diffusers==0.25.0 notebook ipywidgets -gradio +gradio==3.43.2 gdown wget seaborn From aa496aa13f33c8da3f2c90557e2c16136fc3b287 Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Mon, 27 May 2024 13:24:21 +0200 Subject: [PATCH 4/6] Adding the possibility to train with durations --- configs/data/ljspeech.yaml | 1 + configs/model/matcha.yaml | 1 + matcha/data/text_mel_datamodule.py | 34 ++++++++++++++++++- matcha/models/baselightningmodule.py | 1 + matcha/models/matcha_tts.py | 27 +++++++++------ matcha/utils/generate_data_statistics.py | 1 + .../utils/get_durations_from_trained_model.py | 1 + 7 files changed, 54 insertions(+), 12 deletions(-) diff --git a/configs/data/ljspeech.yaml b/configs/data/ljspeech.yaml index 6fba3be..ee87a6a 100644 --- a/configs/data/ljspeech.yaml +++ b/configs/data/ljspeech.yaml @@ -19,3 +19,4 @@ data_statistics: # Computed for ljspeech dataset mel_mean: -5.536622 mel_std: 2.116101 seed: ${seed} +load_durations: false diff --git a/configs/model/matcha.yaml b/configs/model/matcha.yaml index 36f6eaf..e2b5c78 100644 --- a/configs/model/matcha.yaml +++ b/configs/model/matcha.yaml @@ -13,3 +13,4 @@ n_feats: 80 data_statistics: ${data.data_statistics} out_size: null # Must be divisible by 4 prior_loss: true +use_precomputed_durations: ${data.load_durations} diff --git a/matcha/data/text_mel_datamodule.py b/matcha/data/text_mel_datamodule.py index f281bfd..e10dfcb 100644 --- a/matcha/data/text_mel_datamodule.py +++ b/matcha/data/text_mel_datamodule.py @@ -1,6 +1,8 @@ import random +from pathlib import Path from typing import Any, Dict, Optional +import numpy as np import torch import torchaudio as ta from lightning import LightningDataModule @@ -39,6 +41,7 @@ class TextMelDataModule(LightningDataModule): f_max, data_statistics, seed, + load_durations, ): super().__init__() @@ -68,6 +71,7 @@ class TextMelDataModule(LightningDataModule): self.hparams.f_max, self.hparams.data_statistics, self.hparams.seed, + self.hparams.load_durations, ) self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init self.hparams.valid_filelist_path, @@ -83,6 +87,7 @@ class TextMelDataModule(LightningDataModule): self.hparams.f_max, self.hparams.data_statistics, self.hparams.seed, + self.hparams.load_durations, ) def train_dataloader(self): @@ -134,6 +139,7 @@ class TextMelDataset(torch.utils.data.Dataset): f_max=8000, data_parameters=None, seed=None, + load_durations=False, ): self.filepaths_and_text = parse_filelist(filelist_path) self.n_spks = n_spks @@ -146,6 +152,8 @@ class TextMelDataset(torch.utils.data.Dataset): self.win_length = win_length self.f_min = f_min self.f_max = f_max + self.load_durations = load_durations + if data_parameters is not None: self.data_parameters = data_parameters else: @@ -167,7 +175,26 @@ class TextMelDataset(torch.utils.data.Dataset): text, cleaned_text = self.get_text(text, add_blank=self.add_blank) mel = self.get_mel(filepath) - return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text} + durations = self.get_durations(filepath, text) if self.load_durations else None + + return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations} + + def get_durations(self, filepath, text): + filepath = Path(filepath) + data_dir, name = filepath.parent.parent, filepath.stem + + try: + dur_loc = data_dir / "durations" / f"{name}.npy" + durs = torch.from_numpy(np.load(dur_loc).astype(int)) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n" + ) from e + + assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match" + + return durs def get_mel(self, filepath): audio, sr = ta.load(filepath) @@ -214,6 +241,8 @@ class TextMelBatchCollate: y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) x = torch.zeros((B, x_max_length), dtype=torch.long) + durations = torch.zeros((B, x_max_length), dtype=torch.long) + y_lengths, x_lengths = [], [] spks = [] filepaths, x_texts = [], [] @@ -226,6 +255,8 @@ class TextMelBatchCollate: spks.append(item["spk"]) filepaths.append(item["filepath"]) x_texts.append(item["x_text"]) + if item["durations"] is not None: + durations[i, : item["durations"].shape[-1]] = item["durations"] y_lengths = torch.tensor(y_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long) @@ -239,4 +270,5 @@ class TextMelBatchCollate: "spks": spks, "filepaths": filepaths, "x_texts": x_texts, + "durations": durations if not torch.eq(durations, 0).all() else None, } diff --git a/matcha/models/baselightningmodule.py b/matcha/models/baselightningmodule.py index 5fd09a4..f8abe7b 100644 --- a/matcha/models/baselightningmodule.py +++ b/matcha/models/baselightningmodule.py @@ -65,6 +65,7 @@ class BaseLightningClass(LightningModule, ABC): y_lengths=y_lengths, spks=spks, out_size=self.out_size, + durations=batch["durations"], ) return { "dur_loss": dur_loss, diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 464efcd..07f95ad 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -35,6 +35,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 optimizer=None, scheduler=None, prior_loss=True, + use_precomputed_durations=False, ): super().__init__() @@ -46,6 +47,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 self.n_feats = n_feats self.out_size = out_size self.prior_loss = prior_loss + self.use_precomputed_durations = use_precomputed_durations if n_spks > 1: self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) @@ -147,7 +149,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 "rtf": rtf, } - def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None): + def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None): """ Computes 3 losses: 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). @@ -179,17 +181,20 @@ class MatchaTTS(BaseLightningClass): # 🍵 y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) - # Use MAS to find most likely alignment `attn` between text and mel-spectrogram - with torch.no_grad(): - const = -0.5 * math.log(2 * math.pi) * self.n_feats - factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) - y_square = torch.matmul(factor.transpose(1, 2), y**2) - y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) - mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) - log_prior = y_square - y_mu_double + mu_square + const + if self.use_precomputed_durations: + attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1)) + else: + # Use MAS to find most likely alignment `attn` between text and mel-spectrogram + with torch.no_grad(): + const = -0.5 * math.log(2 * math.pi) * self.n_feats + factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) + y_square = torch.matmul(factor.transpose(1, 2), y**2) + y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) + mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) + log_prior = y_square - y_mu_double + mu_square + const - attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) - attn = attn.detach() + attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) + attn = attn.detach() # b, t_text, T_mel # Compute loss between predicted log-scaled durations and those obtained from MAS # refered to as prior loss in the paper diff --git a/matcha/utils/generate_data_statistics.py b/matcha/utils/generate_data_statistics.py index 96a5382..49ed3c1 100644 --- a/matcha/utils/generate_data_statistics.py +++ b/matcha/utils/generate_data_statistics.py @@ -94,6 +94,7 @@ def main(): cfg["batch_size"] = args.batch_size cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False text_mel_datamodule = TextMelDataModule(**cfg) text_mel_datamodule.setup() diff --git a/matcha/utils/get_durations_from_trained_model.py b/matcha/utils/get_durations_from_trained_model.py index 9bee56e..0fe2f35 100644 --- a/matcha/utils/get_durations_from_trained_model.py +++ b/matcha/utils/get_durations_from_trained_model.py @@ -140,6 +140,7 @@ def main(): cfg["batch_size"] = args.batch_size cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False if args.output_folder is not None: output_folder = Path(args.output_folder) From de910380bcc41249d7e4857b46f46fa83f2335b4 Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Mon, 27 May 2024 13:40:02 +0200 Subject: [PATCH 5/6] Fixing batched synthesis for multispeaker model --- matcha/cli.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/matcha/cli.py b/matcha/cli.py index 635c586..7daf130 100644 --- a/matcha/cli.py +++ b/matcha/cli.py @@ -326,12 +326,13 @@ def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk): for i, batch in enumerate(dataloader): i = i + 1 start_t = dt.datetime.now() + b = batch["x"].shape[0] output = model.synthesise( batch["x"].to(device), batch["x_lengths"].to(device), n_timesteps=args.steps, temperature=args.temperature, - spks=spk, + spks=spk.expand(b) if spk is not None else spk, length_scale=args.speaking_rate, ) From ac0b258f805d1f30d6ba758a8f631f3bdf7f9382 Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Mon, 27 May 2024 13:50:21 +0200 Subject: [PATCH 6/6] Adding configuration for training from durations --- .../experiment/ljspeech_from_durations.yaml | 19 +++++++++++++++++++ setup.py | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 configs/experiment/ljspeech_from_durations.yaml diff --git a/configs/experiment/ljspeech_from_durations.yaml b/configs/experiment/ljspeech_from_durations.yaml new file mode 100644 index 0000000..63f7d29 --- /dev/null +++ b/configs/experiment/ljspeech_from_durations.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech + + +data: + load_durations: True + batch_size: 64 diff --git a/setup.py b/setup.py index 608de9f..a49c2cc 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ setup( "matcha-data-stats=matcha.utils.generate_data_statistics:main", "matcha-tts=matcha.cli:cli", "matcha-tts-app=matcha.app:main", - "matcha-get-durations=matcha.utils.get_durations_from_trained_model:main", + "matcha-tts-get-durations=matcha.utils.get_durations_from_trained_model:main", ] }, ext_modules=cythonize(exts, language_level=3),