diff --git a/configs/data/ljspeech.yaml b/configs/data/ljspeech.yaml index 6fba3be..ee87a6a 100644 --- a/configs/data/ljspeech.yaml +++ b/configs/data/ljspeech.yaml @@ -19,3 +19,4 @@ data_statistics: # Computed for ljspeech dataset mel_mean: -5.536622 mel_std: 2.116101 seed: ${seed} +load_durations: false diff --git a/configs/model/matcha.yaml b/configs/model/matcha.yaml index 36f6eaf..e2b5c78 100644 --- a/configs/model/matcha.yaml +++ b/configs/model/matcha.yaml @@ -13,3 +13,4 @@ n_feats: 80 data_statistics: ${data.data_statistics} out_size: null # Must be divisible by 4 prior_loss: true +use_precomputed_durations: ${data.load_durations} diff --git a/matcha/data/text_mel_datamodule.py b/matcha/data/text_mel_datamodule.py index f281bfd..e10dfcb 100644 --- a/matcha/data/text_mel_datamodule.py +++ b/matcha/data/text_mel_datamodule.py @@ -1,6 +1,8 @@ import random +from pathlib import Path from typing import Any, Dict, Optional +import numpy as np import torch import torchaudio as ta from lightning import LightningDataModule @@ -39,6 +41,7 @@ class TextMelDataModule(LightningDataModule): f_max, data_statistics, seed, + load_durations, ): super().__init__() @@ -68,6 +71,7 @@ class TextMelDataModule(LightningDataModule): self.hparams.f_max, self.hparams.data_statistics, self.hparams.seed, + self.hparams.load_durations, ) self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init self.hparams.valid_filelist_path, @@ -83,6 +87,7 @@ class TextMelDataModule(LightningDataModule): self.hparams.f_max, self.hparams.data_statistics, self.hparams.seed, + self.hparams.load_durations, ) def train_dataloader(self): @@ -134,6 +139,7 @@ class TextMelDataset(torch.utils.data.Dataset): f_max=8000, data_parameters=None, seed=None, + load_durations=False, ): self.filepaths_and_text = parse_filelist(filelist_path) self.n_spks = n_spks @@ -146,6 +152,8 @@ class TextMelDataset(torch.utils.data.Dataset): self.win_length = win_length self.f_min = f_min self.f_max = f_max + self.load_durations = load_durations + if data_parameters is not None: self.data_parameters = data_parameters else: @@ -167,7 +175,26 @@ class TextMelDataset(torch.utils.data.Dataset): text, cleaned_text = self.get_text(text, add_blank=self.add_blank) mel = self.get_mel(filepath) - return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text} + durations = self.get_durations(filepath, text) if self.load_durations else None + + return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations} + + def get_durations(self, filepath, text): + filepath = Path(filepath) + data_dir, name = filepath.parent.parent, filepath.stem + + try: + dur_loc = data_dir / "durations" / f"{name}.npy" + durs = torch.from_numpy(np.load(dur_loc).astype(int)) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n" + ) from e + + assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match" + + return durs def get_mel(self, filepath): audio, sr = ta.load(filepath) @@ -214,6 +241,8 @@ class TextMelBatchCollate: y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) x = torch.zeros((B, x_max_length), dtype=torch.long) + durations = torch.zeros((B, x_max_length), dtype=torch.long) + y_lengths, x_lengths = [], [] spks = [] filepaths, x_texts = [], [] @@ -226,6 +255,8 @@ class TextMelBatchCollate: spks.append(item["spk"]) filepaths.append(item["filepath"]) x_texts.append(item["x_text"]) + if item["durations"] is not None: + durations[i, : item["durations"].shape[-1]] = item["durations"] y_lengths = torch.tensor(y_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long) @@ -239,4 +270,5 @@ class TextMelBatchCollate: "spks": spks, "filepaths": filepaths, "x_texts": x_texts, + "durations": durations if not torch.eq(durations, 0).all() else None, } diff --git a/matcha/models/baselightningmodule.py b/matcha/models/baselightningmodule.py index 5fd09a4..f8abe7b 100644 --- a/matcha/models/baselightningmodule.py +++ b/matcha/models/baselightningmodule.py @@ -65,6 +65,7 @@ class BaseLightningClass(LightningModule, ABC): y_lengths=y_lengths, spks=spks, out_size=self.out_size, + durations=batch["durations"], ) return { "dur_loss": dur_loss, diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 464efcd..07f95ad 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -35,6 +35,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 optimizer=None, scheduler=None, prior_loss=True, + use_precomputed_durations=False, ): super().__init__() @@ -46,6 +47,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 self.n_feats = n_feats self.out_size = out_size self.prior_loss = prior_loss + self.use_precomputed_durations = use_precomputed_durations if n_spks > 1: self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) @@ -147,7 +149,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 "rtf": rtf, } - def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None): + def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None): """ Computes 3 losses: 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). @@ -179,17 +181,20 @@ class MatchaTTS(BaseLightningClass): # 🍵 y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) - # Use MAS to find most likely alignment `attn` between text and mel-spectrogram - with torch.no_grad(): - const = -0.5 * math.log(2 * math.pi) * self.n_feats - factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) - y_square = torch.matmul(factor.transpose(1, 2), y**2) - y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) - mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) - log_prior = y_square - y_mu_double + mu_square + const + if self.use_precomputed_durations: + attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1)) + else: + # Use MAS to find most likely alignment `attn` between text and mel-spectrogram + with torch.no_grad(): + const = -0.5 * math.log(2 * math.pi) * self.n_feats + factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) + y_square = torch.matmul(factor.transpose(1, 2), y**2) + y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) + mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) + log_prior = y_square - y_mu_double + mu_square + const - attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) - attn = attn.detach() + attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) + attn = attn.detach() # b, t_text, T_mel # Compute loss between predicted log-scaled durations and those obtained from MAS # refered to as prior loss in the paper diff --git a/matcha/utils/generate_data_statistics.py b/matcha/utils/generate_data_statistics.py index 96a5382..49ed3c1 100644 --- a/matcha/utils/generate_data_statistics.py +++ b/matcha/utils/generate_data_statistics.py @@ -94,6 +94,7 @@ def main(): cfg["batch_size"] = args.batch_size cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False text_mel_datamodule = TextMelDataModule(**cfg) text_mel_datamodule.setup() diff --git a/matcha/utils/get_durations_from_trained_model.py b/matcha/utils/get_durations_from_trained_model.py index 9bee56e..0fe2f35 100644 --- a/matcha/utils/get_durations_from_trained_model.py +++ b/matcha/utils/get_durations_from_trained_model.py @@ -140,6 +140,7 @@ def main(): cfg["batch_size"] = args.batch_size cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False if args.output_folder is not None: output_folder = Path(args.output_folder)