mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 09:49:21 +08:00
Adding the possibility to train with durations
This commit is contained in:
@@ -19,3 +19,4 @@ data_statistics: # Computed for ljspeech dataset
|
|||||||
mel_mean: -5.536622
|
mel_mean: -5.536622
|
||||||
mel_std: 2.116101
|
mel_std: 2.116101
|
||||||
seed: ${seed}
|
seed: ${seed}
|
||||||
|
load_durations: false
|
||||||
|
|||||||
@@ -13,3 +13,4 @@ n_feats: 80
|
|||||||
data_statistics: ${data.data_statistics}
|
data_statistics: ${data.data_statistics}
|
||||||
out_size: null # Must be divisible by 4
|
out_size: null # Must be divisible by 4
|
||||||
prior_loss: true
|
prior_loss: true
|
||||||
|
use_precomputed_durations: ${data.load_durations}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio as ta
|
import torchaudio as ta
|
||||||
from lightning import LightningDataModule
|
from lightning import LightningDataModule
|
||||||
@@ -39,6 +41,7 @@ class TextMelDataModule(LightningDataModule):
|
|||||||
f_max,
|
f_max,
|
||||||
data_statistics,
|
data_statistics,
|
||||||
seed,
|
seed,
|
||||||
|
load_durations,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -68,6 +71,7 @@ class TextMelDataModule(LightningDataModule):
|
|||||||
self.hparams.f_max,
|
self.hparams.f_max,
|
||||||
self.hparams.data_statistics,
|
self.hparams.data_statistics,
|
||||||
self.hparams.seed,
|
self.hparams.seed,
|
||||||
|
self.hparams.load_durations,
|
||||||
)
|
)
|
||||||
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
|
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
|
||||||
self.hparams.valid_filelist_path,
|
self.hparams.valid_filelist_path,
|
||||||
@@ -83,6 +87,7 @@ class TextMelDataModule(LightningDataModule):
|
|||||||
self.hparams.f_max,
|
self.hparams.f_max,
|
||||||
self.hparams.data_statistics,
|
self.hparams.data_statistics,
|
||||||
self.hparams.seed,
|
self.hparams.seed,
|
||||||
|
self.hparams.load_durations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
@@ -134,6 +139,7 @@ class TextMelDataset(torch.utils.data.Dataset):
|
|||||||
f_max=8000,
|
f_max=8000,
|
||||||
data_parameters=None,
|
data_parameters=None,
|
||||||
seed=None,
|
seed=None,
|
||||||
|
load_durations=False,
|
||||||
):
|
):
|
||||||
self.filepaths_and_text = parse_filelist(filelist_path)
|
self.filepaths_and_text = parse_filelist(filelist_path)
|
||||||
self.n_spks = n_spks
|
self.n_spks = n_spks
|
||||||
@@ -146,6 +152,8 @@ class TextMelDataset(torch.utils.data.Dataset):
|
|||||||
self.win_length = win_length
|
self.win_length = win_length
|
||||||
self.f_min = f_min
|
self.f_min = f_min
|
||||||
self.f_max = f_max
|
self.f_max = f_max
|
||||||
|
self.load_durations = load_durations
|
||||||
|
|
||||||
if data_parameters is not None:
|
if data_parameters is not None:
|
||||||
self.data_parameters = data_parameters
|
self.data_parameters = data_parameters
|
||||||
else:
|
else:
|
||||||
@@ -167,7 +175,26 @@ class TextMelDataset(torch.utils.data.Dataset):
|
|||||||
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
|
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
|
||||||
mel = self.get_mel(filepath)
|
mel = self.get_mel(filepath)
|
||||||
|
|
||||||
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
|
durations = self.get_durations(filepath, text) if self.load_durations else None
|
||||||
|
|
||||||
|
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
|
||||||
|
|
||||||
|
def get_durations(self, filepath, text):
|
||||||
|
filepath = Path(filepath)
|
||||||
|
data_dir, name = filepath.parent.parent, filepath.stem
|
||||||
|
|
||||||
|
try:
|
||||||
|
dur_loc = data_dir / "durations" / f"{name}.npy"
|
||||||
|
durs = torch.from_numpy(np.load(dur_loc).astype(int))
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
|
||||||
|
|
||||||
|
return durs
|
||||||
|
|
||||||
def get_mel(self, filepath):
|
def get_mel(self, filepath):
|
||||||
audio, sr = ta.load(filepath)
|
audio, sr = ta.load(filepath)
|
||||||
@@ -214,6 +241,8 @@ class TextMelBatchCollate:
|
|||||||
|
|
||||||
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
||||||
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
||||||
|
durations = torch.zeros((B, x_max_length), dtype=torch.long)
|
||||||
|
|
||||||
y_lengths, x_lengths = [], []
|
y_lengths, x_lengths = [], []
|
||||||
spks = []
|
spks = []
|
||||||
filepaths, x_texts = [], []
|
filepaths, x_texts = [], []
|
||||||
@@ -226,6 +255,8 @@ class TextMelBatchCollate:
|
|||||||
spks.append(item["spk"])
|
spks.append(item["spk"])
|
||||||
filepaths.append(item["filepath"])
|
filepaths.append(item["filepath"])
|
||||||
x_texts.append(item["x_text"])
|
x_texts.append(item["x_text"])
|
||||||
|
if item["durations"] is not None:
|
||||||
|
durations[i, : item["durations"].shape[-1]] = item["durations"]
|
||||||
|
|
||||||
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
||||||
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
||||||
@@ -239,4 +270,5 @@ class TextMelBatchCollate:
|
|||||||
"spks": spks,
|
"spks": spks,
|
||||||
"filepaths": filepaths,
|
"filepaths": filepaths,
|
||||||
"x_texts": x_texts,
|
"x_texts": x_texts,
|
||||||
|
"durations": durations if not torch.eq(durations, 0).all() else None,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class BaseLightningClass(LightningModule, ABC):
|
|||||||
y_lengths=y_lengths,
|
y_lengths=y_lengths,
|
||||||
spks=spks,
|
spks=spks,
|
||||||
out_size=self.out_size,
|
out_size=self.out_size,
|
||||||
|
durations=batch["durations"],
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"dur_loss": dur_loss,
|
"dur_loss": dur_loss,
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
optimizer=None,
|
optimizer=None,
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
prior_loss=True,
|
prior_loss=True,
|
||||||
|
use_precomputed_durations=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -46,6 +47,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
self.n_feats = n_feats
|
self.n_feats = n_feats
|
||||||
self.out_size = out_size
|
self.out_size = out_size
|
||||||
self.prior_loss = prior_loss
|
self.prior_loss = prior_loss
|
||||||
|
self.use_precomputed_durations = use_precomputed_durations
|
||||||
|
|
||||||
if n_spks > 1:
|
if n_spks > 1:
|
||||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||||
@@ -147,7 +149,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
"rtf": rtf,
|
"rtf": rtf,
|
||||||
}
|
}
|
||||||
|
|
||||||
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
|
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
|
||||||
"""
|
"""
|
||||||
Computes 3 losses:
|
Computes 3 losses:
|
||||||
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
||||||
@@ -179,6 +181,9 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
|
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
|
||||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||||
|
|
||||||
|
if self.use_precomputed_durations:
|
||||||
|
attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
|
||||||
|
else:
|
||||||
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
||||||
@@ -189,7 +194,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
log_prior = y_square - y_mu_double + mu_square + const
|
log_prior = y_square - y_mu_double + mu_square + const
|
||||||
|
|
||||||
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
|
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
|
||||||
attn = attn.detach()
|
attn = attn.detach() # b, t_text, T_mel
|
||||||
|
|
||||||
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
||||||
# refered to as prior loss in the paper
|
# refered to as prior loss in the paper
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ def main():
|
|||||||
cfg["batch_size"] = args.batch_size
|
cfg["batch_size"] = args.batch_size
|
||||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||||
|
cfg["load_durations"] = False
|
||||||
|
|
||||||
text_mel_datamodule = TextMelDataModule(**cfg)
|
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||||
text_mel_datamodule.setup()
|
text_mel_datamodule.setup()
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ def main():
|
|||||||
cfg["batch_size"] = args.batch_size
|
cfg["batch_size"] = args.batch_size
|
||||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||||
|
cfg["load_durations"] = False
|
||||||
|
|
||||||
if args.output_folder is not None:
|
if args.output_folder is not None:
|
||||||
output_folder = Path(args.output_folder)
|
output_folder = Path(args.output_folder)
|
||||||
|
|||||||
Reference in New Issue
Block a user