mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 02:09:21 +08:00
Adding the possibility to train with durations
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio as ta
|
||||
from lightning import LightningDataModule
|
||||
@@ -39,6 +41,7 @@ class TextMelDataModule(LightningDataModule):
|
||||
f_max,
|
||||
data_statistics,
|
||||
seed,
|
||||
load_durations,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -68,6 +71,7 @@ class TextMelDataModule(LightningDataModule):
|
||||
self.hparams.f_max,
|
||||
self.hparams.data_statistics,
|
||||
self.hparams.seed,
|
||||
self.hparams.load_durations,
|
||||
)
|
||||
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
|
||||
self.hparams.valid_filelist_path,
|
||||
@@ -83,6 +87,7 @@ class TextMelDataModule(LightningDataModule):
|
||||
self.hparams.f_max,
|
||||
self.hparams.data_statistics,
|
||||
self.hparams.seed,
|
||||
self.hparams.load_durations,
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
@@ -134,6 +139,7 @@ class TextMelDataset(torch.utils.data.Dataset):
|
||||
f_max=8000,
|
||||
data_parameters=None,
|
||||
seed=None,
|
||||
load_durations=False,
|
||||
):
|
||||
self.filepaths_and_text = parse_filelist(filelist_path)
|
||||
self.n_spks = n_spks
|
||||
@@ -146,6 +152,8 @@ class TextMelDataset(torch.utils.data.Dataset):
|
||||
self.win_length = win_length
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.load_durations = load_durations
|
||||
|
||||
if data_parameters is not None:
|
||||
self.data_parameters = data_parameters
|
||||
else:
|
||||
@@ -167,7 +175,26 @@ class TextMelDataset(torch.utils.data.Dataset):
|
||||
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
|
||||
mel = self.get_mel(filepath)
|
||||
|
||||
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
|
||||
durations = self.get_durations(filepath, text) if self.load_durations else None
|
||||
|
||||
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
|
||||
|
||||
def get_durations(self, filepath, text):
|
||||
filepath = Path(filepath)
|
||||
data_dir, name = filepath.parent.parent, filepath.stem
|
||||
|
||||
try:
|
||||
dur_loc = data_dir / "durations" / f"{name}.npy"
|
||||
durs = torch.from_numpy(np.load(dur_loc).astype(int))
|
||||
|
||||
except FileNotFoundError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
|
||||
) from e
|
||||
|
||||
assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
|
||||
|
||||
return durs
|
||||
|
||||
def get_mel(self, filepath):
|
||||
audio, sr = ta.load(filepath)
|
||||
@@ -214,6 +241,8 @@ class TextMelBatchCollate:
|
||||
|
||||
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
||||
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
||||
durations = torch.zeros((B, x_max_length), dtype=torch.long)
|
||||
|
||||
y_lengths, x_lengths = [], []
|
||||
spks = []
|
||||
filepaths, x_texts = [], []
|
||||
@@ -226,6 +255,8 @@ class TextMelBatchCollate:
|
||||
spks.append(item["spk"])
|
||||
filepaths.append(item["filepath"])
|
||||
x_texts.append(item["x_text"])
|
||||
if item["durations"] is not None:
|
||||
durations[i, : item["durations"].shape[-1]] = item["durations"]
|
||||
|
||||
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
||||
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
||||
@@ -239,4 +270,5 @@ class TextMelBatchCollate:
|
||||
"spks": spks,
|
||||
"filepaths": filepaths,
|
||||
"x_texts": x_texts,
|
||||
"durations": durations if not torch.eq(durations, 0).all() else None,
|
||||
}
|
||||
|
||||
@@ -65,6 +65,7 @@ class BaseLightningClass(LightningModule, ABC):
|
||||
y_lengths=y_lengths,
|
||||
spks=spks,
|
||||
out_size=self.out_size,
|
||||
durations=batch["durations"],
|
||||
)
|
||||
return {
|
||||
"dur_loss": dur_loss,
|
||||
|
||||
@@ -35,6 +35,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
prior_loss=True,
|
||||
use_precomputed_durations=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -46,6 +47,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
self.n_feats = n_feats
|
||||
self.out_size = out_size
|
||||
self.prior_loss = prior_loss
|
||||
self.use_precomputed_durations = use_precomputed_durations
|
||||
|
||||
if n_spks > 1:
|
||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||
@@ -147,7 +149,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
"rtf": rtf,
|
||||
}
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
|
||||
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
|
||||
"""
|
||||
Computes 3 losses:
|
||||
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
||||
@@ -179,17 +181,20 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
|
||||
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
||||
with torch.no_grad():
|
||||
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
||||
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
|
||||
y_square = torch.matmul(factor.transpose(1, 2), y**2)
|
||||
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
||||
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
|
||||
log_prior = y_square - y_mu_double + mu_square + const
|
||||
if self.use_precomputed_durations:
|
||||
attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
|
||||
else:
|
||||
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
||||
with torch.no_grad():
|
||||
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
||||
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
|
||||
y_square = torch.matmul(factor.transpose(1, 2), y**2)
|
||||
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
||||
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
|
||||
log_prior = y_square - y_mu_double + mu_square + const
|
||||
|
||||
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
|
||||
attn = attn.detach()
|
||||
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
|
||||
attn = attn.detach() # b, t_text, T_mel
|
||||
|
||||
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
||||
# refered to as prior loss in the paper
|
||||
|
||||
@@ -94,6 +94,7 @@ def main():
|
||||
cfg["batch_size"] = args.batch_size
|
||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||
cfg["load_durations"] = False
|
||||
|
||||
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||
text_mel_datamodule.setup()
|
||||
|
||||
@@ -140,6 +140,7 @@ def main():
|
||||
cfg["batch_size"] = args.batch_size
|
||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||
cfg["load_durations"] = False
|
||||
|
||||
if args.output_folder is not None:
|
||||
output_folder = Path(args.output_folder)
|
||||
|
||||
Reference in New Issue
Block a user