Adding the possibility to train with durations

This commit is contained in:
Shivam Mehta
2024-05-27 13:24:21 +02:00
parent e658aee6a5
commit aa496aa13f
7 changed files with 54 additions and 12 deletions

View File

@@ -19,3 +19,4 @@ data_statistics: # Computed for ljspeech dataset
mel_mean: -5.536622 mel_mean: -5.536622
mel_std: 2.116101 mel_std: 2.116101
seed: ${seed} seed: ${seed}
load_durations: false

View File

@@ -13,3 +13,4 @@ n_feats: 80
data_statistics: ${data.data_statistics} data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4 out_size: null # Must be divisible by 4
prior_loss: true prior_loss: true
use_precomputed_durations: ${data.load_durations}

View File

@@ -1,6 +1,8 @@
import random import random
from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import numpy as np
import torch import torch
import torchaudio as ta import torchaudio as ta
from lightning import LightningDataModule from lightning import LightningDataModule
@@ -39,6 +41,7 @@ class TextMelDataModule(LightningDataModule):
f_max, f_max,
data_statistics, data_statistics,
seed, seed,
load_durations,
): ):
super().__init__() super().__init__()
@@ -68,6 +71,7 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max, self.hparams.f_max,
self.hparams.data_statistics, self.hparams.data_statistics,
self.hparams.seed, self.hparams.seed,
self.hparams.load_durations,
) )
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.valid_filelist_path, self.hparams.valid_filelist_path,
@@ -83,6 +87,7 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max, self.hparams.f_max,
self.hparams.data_statistics, self.hparams.data_statistics,
self.hparams.seed, self.hparams.seed,
self.hparams.load_durations,
) )
def train_dataloader(self): def train_dataloader(self):
@@ -134,6 +139,7 @@ class TextMelDataset(torch.utils.data.Dataset):
f_max=8000, f_max=8000,
data_parameters=None, data_parameters=None,
seed=None, seed=None,
load_durations=False,
): ):
self.filepaths_and_text = parse_filelist(filelist_path) self.filepaths_and_text = parse_filelist(filelist_path)
self.n_spks = n_spks self.n_spks = n_spks
@@ -146,6 +152,8 @@ class TextMelDataset(torch.utils.data.Dataset):
self.win_length = win_length self.win_length = win_length
self.f_min = f_min self.f_min = f_min
self.f_max = f_max self.f_max = f_max
self.load_durations = load_durations
if data_parameters is not None: if data_parameters is not None:
self.data_parameters = data_parameters self.data_parameters = data_parameters
else: else:
@@ -167,7 +175,26 @@ class TextMelDataset(torch.utils.data.Dataset):
text, cleaned_text = self.get_text(text, add_blank=self.add_blank) text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath) mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text} durations = self.get_durations(filepath, text) if self.load_durations else None
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
def get_durations(self, filepath, text):
filepath = Path(filepath)
data_dir, name = filepath.parent.parent, filepath.stem
try:
dur_loc = data_dir / "durations" / f"{name}.npy"
durs = torch.from_numpy(np.load(dur_loc).astype(int))
except FileNotFoundError as e:
raise FileNotFoundError(
f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
) from e
assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
return durs
def get_mel(self, filepath): def get_mel(self, filepath):
audio, sr = ta.load(filepath) audio, sr = ta.load(filepath)
@@ -214,6 +241,8 @@ class TextMelBatchCollate:
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long) x = torch.zeros((B, x_max_length), dtype=torch.long)
durations = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], [] y_lengths, x_lengths = [], []
spks = [] spks = []
filepaths, x_texts = [], [] filepaths, x_texts = [], []
@@ -226,6 +255,8 @@ class TextMelBatchCollate:
spks.append(item["spk"]) spks.append(item["spk"])
filepaths.append(item["filepath"]) filepaths.append(item["filepath"])
x_texts.append(item["x_text"]) x_texts.append(item["x_text"])
if item["durations"] is not None:
durations[i, : item["durations"].shape[-1]] = item["durations"]
y_lengths = torch.tensor(y_lengths, dtype=torch.long) y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long)
@@ -239,4 +270,5 @@ class TextMelBatchCollate:
"spks": spks, "spks": spks,
"filepaths": filepaths, "filepaths": filepaths,
"x_texts": x_texts, "x_texts": x_texts,
"durations": durations if not torch.eq(durations, 0).all() else None,
} }

View File

@@ -65,6 +65,7 @@ class BaseLightningClass(LightningModule, ABC):
y_lengths=y_lengths, y_lengths=y_lengths,
spks=spks, spks=spks,
out_size=self.out_size, out_size=self.out_size,
durations=batch["durations"],
) )
return { return {
"dur_loss": dur_loss, "dur_loss": dur_loss,

View File

@@ -35,6 +35,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
optimizer=None, optimizer=None,
scheduler=None, scheduler=None,
prior_loss=True, prior_loss=True,
use_precomputed_durations=False,
): ):
super().__init__() super().__init__()
@@ -46,6 +47,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.n_feats = n_feats self.n_feats = n_feats
self.out_size = out_size self.out_size = out_size
self.prior_loss = prior_loss self.prior_loss = prior_loss
self.use_precomputed_durations = use_precomputed_durations
if n_spks > 1: if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
@@ -147,7 +149,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
"rtf": rtf, "rtf": rtf,
} }
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None): def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
""" """
Computes 3 losses: Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
@@ -179,17 +181,20 @@ class MatchaTTS(BaseLightningClass): # 🍵
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram if self.use_precomputed_durations:
with torch.no_grad(): attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
const = -0.5 * math.log(2 * math.pi) * self.n_feats else:
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
y_square = torch.matmul(factor.transpose(1, 2), y**2) with torch.no_grad():
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) const = -0.5 * math.log(2 * math.pi) * self.n_feats
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
log_prior = y_square - y_mu_double + mu_square + const y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach() attn = attn.detach() # b, t_text, T_mel
# Compute loss between predicted log-scaled durations and those obtained from MAS # Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper # refered to as prior loss in the paper

View File

@@ -94,6 +94,7 @@ def main():
cfg["batch_size"] = args.batch_size cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
text_mel_datamodule = TextMelDataModule(**cfg) text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup() text_mel_datamodule.setup()

View File

@@ -140,6 +140,7 @@ def main():
cfg["batch_size"] = args.batch_size cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
if args.output_folder is not None: if args.output_folder is not None:
output_folder = Path(args.output_folder) output_folder = Path(args.output_folder)