mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 17:59:19 +08:00
Adding the possibility of get durations out of pretrained model
This commit is contained in:
194
matcha/utils/get_durations_from_trained_model.py
Normal file
194
matcha/utils/get_durations_from_trained_model.py
Normal file
@@ -0,0 +1,194 @@
|
||||
r"""
|
||||
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
|
||||
when needed.
|
||||
|
||||
Parameters from hparam.py will be used
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import lightning
|
||||
import numpy as np
|
||||
import rootutils
|
||||
import torch
|
||||
from hydra import compose, initialize
|
||||
from omegaconf import open_dict
|
||||
from torch import nn
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from matcha.cli import get_device
|
||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.utils.logging_utils import pylogger
|
||||
from matcha.utils.utils import get_phoneme_durations
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def save_durations_to_folder(
|
||||
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
|
||||
):
|
||||
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
||||
durations_json = get_phoneme_durations(durations, text)
|
||||
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
|
||||
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
|
||||
json.dump(durations_json, f, indent=4, ensure_ascii=False)
|
||||
|
||||
np.save(output, durations)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
|
||||
"""Generate durations from the model for each datapoint and save it in a folder
|
||||
|
||||
Args:
|
||||
data_loader (torch.utils.data.DataLoader): Dataloader
|
||||
model (nn.Module): MatchaTTS model
|
||||
device (torch.device): GPU or CPU
|
||||
"""
|
||||
|
||||
for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
|
||||
x, x_lengths = batch["x"], batch["x_lengths"]
|
||||
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||
spks = batch["spks"]
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
x_lengths = x_lengths.to(device)
|
||||
y_lengths = y_lengths.to(device)
|
||||
spks = spks.to(device) if spks is not None else None
|
||||
|
||||
_, _, _, attn = model(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
y=y,
|
||||
y_lengths=y_lengths,
|
||||
spks=spks,
|
||||
)
|
||||
attn = attn.cpu()
|
||||
for i in range(attn.shape[0]):
|
||||
save_durations_to_folder(
|
||||
attn[i],
|
||||
x_lengths[i].item(),
|
||||
y_lengths[i].item(),
|
||||
batch["filepaths"][i],
|
||||
output_folder,
|
||||
batch["x_texts"][i],
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-config",
|
||||
type=str,
|
||||
default="ljspeech.yaml",
|
||||
help="The name of the yaml config file under configs/data",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default="32",
|
||||
help="Can have increased batch size for faster computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
action="store_true",
|
||||
default=False,
|
||||
required=False,
|
||||
help="force overwrite the file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint file to load the model from",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output folder to save the data statistics",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
||||
|
||||
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||
|
||||
with open_dict(cfg):
|
||||
del cfg["hydra"]
|
||||
del cfg["_target_"]
|
||||
cfg["seed"] = 1234
|
||||
cfg["batch_size"] = args.batch_size
|
||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||
|
||||
if args.output_folder is not None:
|
||||
output_folder = Path(args.output_folder)
|
||||
else:
|
||||
output_folder = Path(cfg["train_filelist_path"]).parent / "durations"
|
||||
|
||||
print(f"Output folder set to: {output_folder}")
|
||||
|
||||
if os.path.exists(output_folder) and not args.force:
|
||||
print("Folder already exists. Use -f to force overwrite")
|
||||
sys.exit(1)
|
||||
|
||||
output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
|
||||
print("Loading model...")
|
||||
device = get_device(args)
|
||||
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
|
||||
|
||||
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||
text_mel_datamodule.setup()
|
||||
try:
|
||||
print("Computing stats for training set if exists...")
|
||||
train_dataloader = text_mel_datamodule.train_dataloader()
|
||||
compute_durations(train_dataloader, model, device, output_folder)
|
||||
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||
print("No training set found")
|
||||
|
||||
try:
|
||||
print("Computing stats for validation set if exists...")
|
||||
val_dataloader = text_mel_datamodule.val_dataloader()
|
||||
compute_durations(val_dataloader, model, device, output_folder)
|
||||
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||
print("No validation set found")
|
||||
|
||||
try:
|
||||
print("Computing stats for test set if exists...")
|
||||
test_dataloader = text_mel_datamodule.test_dataloader()
|
||||
compute_durations(test_dataloader, model, device, output_folder)
|
||||
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
||||
print("No test set found")
|
||||
|
||||
print(f"[+] Done! Data statistics saved to: {output_folder}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Helps with generating durations for the dataset to train other architectures
|
||||
# that cannot learn to align due to limited size of dataset
|
||||
# Example usage:
|
||||
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
|
||||
# This will create a folder in data/processed_data/durations/ljspeech with the durations
|
||||
main()
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import sys
|
||||
import warnings
|
||||
from importlib.util import find_spec
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
||||
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
|
||||
else:
|
||||
wget.download(url=url, out=checkpoint_path)
|
||||
|
||||
|
||||
def get_phoneme_durations(durations, phones):
|
||||
prev = durations[0]
|
||||
merged_durations = []
|
||||
# Convolve with stride 2
|
||||
for i in range(1, len(durations), 2):
|
||||
if i == len(durations) - 2:
|
||||
# if it is last take full value
|
||||
next_half = durations[i + 1]
|
||||
else:
|
||||
next_half = ceil(durations[i + 1] / 2)
|
||||
|
||||
curr = prev + durations[i] + next_half
|
||||
prev = durations[i + 1] - next_half
|
||||
merged_durations.append(curr)
|
||||
|
||||
assert len(phones) == len(merged_durations)
|
||||
assert len(merged_durations) == (len(durations) - 1) // 2
|
||||
|
||||
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
|
||||
start = torch.tensor(0)
|
||||
duration_json = []
|
||||
for i, duration in enumerate(merged_durations):
|
||||
duration_json.append(
|
||||
{
|
||||
phones[i]: {
|
||||
"starttime": start.item(),
|
||||
"endtime": duration.item(),
|
||||
"duration": duration.item() - start.item(),
|
||||
}
|
||||
}
|
||||
)
|
||||
start = duration
|
||||
|
||||
assert list(duration_json[-1].values())[0]["endtime"] == sum(
|
||||
durations
|
||||
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
|
||||
return duration_json
|
||||
|
||||
Reference in New Issue
Block a user