mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 17:59:19 +08:00
Adding the possibility to train with durations
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio as ta
|
||||
from lightning import LightningDataModule
|
||||
@@ -39,6 +41,7 @@ class TextMelDataModule(LightningDataModule):
|
||||
f_max,
|
||||
data_statistics,
|
||||
seed,
|
||||
load_durations,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -68,6 +71,7 @@ class TextMelDataModule(LightningDataModule):
|
||||
self.hparams.f_max,
|
||||
self.hparams.data_statistics,
|
||||
self.hparams.seed,
|
||||
self.hparams.load_durations,
|
||||
)
|
||||
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
|
||||
self.hparams.valid_filelist_path,
|
||||
@@ -83,6 +87,7 @@ class TextMelDataModule(LightningDataModule):
|
||||
self.hparams.f_max,
|
||||
self.hparams.data_statistics,
|
||||
self.hparams.seed,
|
||||
self.hparams.load_durations,
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
@@ -134,6 +139,7 @@ class TextMelDataset(torch.utils.data.Dataset):
|
||||
f_max=8000,
|
||||
data_parameters=None,
|
||||
seed=None,
|
||||
load_durations=False,
|
||||
):
|
||||
self.filepaths_and_text = parse_filelist(filelist_path)
|
||||
self.n_spks = n_spks
|
||||
@@ -146,6 +152,8 @@ class TextMelDataset(torch.utils.data.Dataset):
|
||||
self.win_length = win_length
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.load_durations = load_durations
|
||||
|
||||
if data_parameters is not None:
|
||||
self.data_parameters = data_parameters
|
||||
else:
|
||||
@@ -167,7 +175,26 @@ class TextMelDataset(torch.utils.data.Dataset):
|
||||
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
|
||||
mel = self.get_mel(filepath)
|
||||
|
||||
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
|
||||
durations = self.get_durations(filepath, text) if self.load_durations else None
|
||||
|
||||
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
|
||||
|
||||
def get_durations(self, filepath, text):
|
||||
filepath = Path(filepath)
|
||||
data_dir, name = filepath.parent.parent, filepath.stem
|
||||
|
||||
try:
|
||||
dur_loc = data_dir / "durations" / f"{name}.npy"
|
||||
durs = torch.from_numpy(np.load(dur_loc).astype(int))
|
||||
|
||||
except FileNotFoundError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
|
||||
) from e
|
||||
|
||||
assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
|
||||
|
||||
return durs
|
||||
|
||||
def get_mel(self, filepath):
|
||||
audio, sr = ta.load(filepath)
|
||||
@@ -214,6 +241,8 @@ class TextMelBatchCollate:
|
||||
|
||||
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
||||
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
||||
durations = torch.zeros((B, x_max_length), dtype=torch.long)
|
||||
|
||||
y_lengths, x_lengths = [], []
|
||||
spks = []
|
||||
filepaths, x_texts = [], []
|
||||
@@ -226,6 +255,8 @@ class TextMelBatchCollate:
|
||||
spks.append(item["spk"])
|
||||
filepaths.append(item["filepath"])
|
||||
x_texts.append(item["x_text"])
|
||||
if item["durations"] is not None:
|
||||
durations[i, : item["durations"].shape[-1]] = item["durations"]
|
||||
|
||||
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
||||
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
||||
@@ -239,4 +270,5 @@ class TextMelBatchCollate:
|
||||
"spks": spks,
|
||||
"filepaths": filepaths,
|
||||
"x_texts": x_texts,
|
||||
"durations": durations if not torch.eq(durations, 0).all() else None,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user