mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 02:09:21 +08:00
Adding possibility of getting durations out
This commit is contained in:
@@ -227,7 +227,7 @@ def cli():
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=str,
|
||||
default=None,
|
||||
default="hifigan_univ_v1",
|
||||
help="Vocoder to use (default: will use the one suggested with the pretrained model))",
|
||||
choices=VOCODER_URLS.keys(),
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TextMelDataModule(LightningDataModule):
|
||||
"""Clean up after fit or test."""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def state_dict(self): # pylint: disable=no-self-use
|
||||
def state_dict(self):
|
||||
"""Extra things to save to checkpoint."""
|
||||
return {}
|
||||
|
||||
@@ -167,7 +167,7 @@ class TextMelDataset(torch.utils.data.Dataset):
|
||||
text = self.get_text(text, add_blank=self.add_blank)
|
||||
mel = self.get_mel(filepath)
|
||||
|
||||
return {"x": text, "y": mel, "spk": spk}
|
||||
return {"x": text, "y": mel, "spk": spk, "filepath": filepath}
|
||||
|
||||
def get_mel(self, filepath):
|
||||
audio, sr = ta.load(filepath)
|
||||
@@ -207,15 +207,16 @@ class TextMelBatchCollate:
|
||||
|
||||
def __call__(self, batch):
|
||||
B = len(batch)
|
||||
y_max_length = max([item["y"].shape[-1] for item in batch])
|
||||
y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
|
||||
y_max_length = fix_len_compatibility(y_max_length)
|
||||
x_max_length = max([item["x"].shape[-1] for item in batch])
|
||||
x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
|
||||
n_feats = batch[0]["y"].shape[-2]
|
||||
|
||||
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
||||
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
||||
y_lengths, x_lengths = [], []
|
||||
spks = []
|
||||
filepaths = []
|
||||
for i, item in enumerate(batch):
|
||||
y_, x_ = item["y"], item["x"]
|
||||
y_lengths.append(y_.shape[-1])
|
||||
@@ -223,9 +224,10 @@ class TextMelBatchCollate:
|
||||
y[i, :, : y_.shape[-1]] = y_
|
||||
x[i, : x_.shape[-1]] = x_
|
||||
spks.append(item["spk"])
|
||||
filepaths.append(item["filepath"])
|
||||
|
||||
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
||||
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
||||
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
|
||||
|
||||
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks}
|
||||
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks, "filepaths": filepaths}
|
||||
|
||||
@@ -58,7 +58,7 @@ class BaseLightningClass(LightningModule, ABC):
|
||||
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||
spks = batch["spks"]
|
||||
|
||||
dur_loss, prior_loss, diff_loss = self(
|
||||
dur_loss, prior_loss, diff_loss, *_ = self(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
y=y,
|
||||
|
||||
@@ -4,7 +4,7 @@ import random
|
||||
|
||||
import torch
|
||||
|
||||
import matcha.utils.monotonic_align as monotonic_align
|
||||
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
|
||||
from matcha import utils
|
||||
from matcha.models.baselightningmodule import BaseLightningClass
|
||||
from matcha.models.components.duration_predictors import DP
|
||||
@@ -241,4 +241,4 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
else:
|
||||
prior_loss = 0
|
||||
|
||||
return dur_loss, prior_loss, diff_loss
|
||||
return dur_loss, prior_loss, diff_loss, attn
|
||||
|
||||
174
matcha/utils/get_durations_from_trained_model.py
Normal file
174
matcha/utils/get_durations_from_trained_model.py
Normal file
@@ -0,0 +1,174 @@
|
||||
r"""
|
||||
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
|
||||
when needed.
|
||||
|
||||
Parameters from hparam.py will be used
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import lightning
|
||||
import numpy as np
|
||||
import rootutils
|
||||
import torch
|
||||
from hydra import compose, initialize
|
||||
from omegaconf import open_dict
|
||||
from torch import nn
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from matcha.cli import get_device
|
||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.utils.logging_utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def save_durations_to_folder(attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path):
|
||||
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
||||
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
|
||||
np.save(output, durations)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
|
||||
"""Generate durations from the model for each datapoint and save it in a folder
|
||||
|
||||
Args:
|
||||
data_loader (torch.utils.data.DataLoader): Dataloader
|
||||
model (nn.Module): MatchaTTS model
|
||||
device (torch.device): GPU or CPU
|
||||
"""
|
||||
|
||||
for batch in tqdm(data_loader, leave=False):
|
||||
x, x_lengths = batch["x"], batch["x_lengths"]
|
||||
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||
spks = batch["spks"]
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
x_lengths = x_lengths.to(device)
|
||||
y_lengths = y_lengths.to(device)
|
||||
spks = spks.to(device) if spks is not None else None
|
||||
|
||||
_, _, _, attn = model(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
y=y,
|
||||
y_lengths=y_lengths,
|
||||
spks=spks,
|
||||
)
|
||||
attn = attn.cpu()
|
||||
for i in range(attn.shape[0]):
|
||||
save_durations_to_folder(
|
||||
attn[i], x_lengths[i].item(), y_lengths[i].item(), batch["filepaths"][i], output_folder
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-config",
|
||||
type=str,
|
||||
default="vctk.yaml",
|
||||
help="The name of the yaml config file under configs/data",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default="32",
|
||||
help="Can have increased batch size for faster computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
action="store_true",
|
||||
default=False,
|
||||
required=False,
|
||||
help="force overwrite the file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint file to load the model from",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output folder to save the data statistics",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
||||
|
||||
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||
|
||||
with open_dict(cfg):
|
||||
del cfg["hydra"]
|
||||
del cfg["_target_"]
|
||||
cfg["seed"] = 1234
|
||||
cfg["batch_size"] = args.batch_size
|
||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||
|
||||
if args.output_folder is not None:
|
||||
output_folder = Path(args.output_folder)
|
||||
else:
|
||||
output_folder = Path("data") / "processed_data" / "durations" / cfg["name"]
|
||||
|
||||
if os.path.exists(output_folder) and not args.force:
|
||||
print("Folder already exists. Use -f to force overwrite")
|
||||
sys.exit(1)
|
||||
|
||||
output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
|
||||
print("Loading model...")
|
||||
device = get_device(args)
|
||||
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
|
||||
|
||||
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||
text_mel_datamodule.setup()
|
||||
try:
|
||||
print("Computing stats for training set if exists...")
|
||||
train_dataloader = text_mel_datamodule.train_dataloader()
|
||||
compute_durations(train_dataloader, model, device, output_folder)
|
||||
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||
print("No training set found")
|
||||
|
||||
try:
|
||||
print("Computing stats for validation set if exists...")
|
||||
val_dataloader = text_mel_datamodule.val_dataloader()
|
||||
compute_durations(val_dataloader, model, device, output_folder)
|
||||
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||
print("No validation set found")
|
||||
|
||||
try:
|
||||
print("Computing stats for test set if exists...")
|
||||
test_dataloader = text_mel_datamodule.test_dataloader()
|
||||
compute_durations(test_dataloader, model, device, output_folder)
|
||||
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||
print("No test set found")
|
||||
|
||||
print(f"[+] Done! Data statistics saved to: {output_folder}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user