Adding saving phones while getting durations from matcha

This commit is contained in:
Shivam Mehta
2024-03-02 12:47:08 +00:00
parent ad76016916
commit 294c6b1327
9 changed files with 126 additions and 12 deletions

View File

@@ -5,6 +5,7 @@ when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
import sys
from pathlib import Path
@@ -22,13 +23,20 @@ from matcha.cli import get_device
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS
from matcha.utils.logging_utils import pylogger
from matcha.utils.utils import get_phoneme_durations
log = pylogger.get_pylogger(__name__)
def save_durations_to_folder(attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path):
def save_durations_to_folder(
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
):
durations = attn.squeeze().sum(1)[:x_length].numpy()
durations_json = get_phoneme_durations(durations, text)
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
json.dump(durations_json, f, indent=4, ensure_ascii=False)
np.save(output, durations)
@@ -62,7 +70,12 @@ def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module
attn = attn.cpu()
for i in range(attn.shape[0]):
save_durations_to_folder(
attn[i], x_lengths[i].item(), y_lengths[i].item(), batch["filepaths"][i], output_folder
attn[i],
x_lengths[i].item(),
y_lengths[i].item(),
batch["filepaths"][i],
output_folder,
batch["x_texts"][i],
)
@@ -131,7 +144,7 @@ def main():
if args.output_folder is not None:
output_folder = Path(args.output_folder)
else:
output_folder = Path("data") / "processed_data" / cfg["name"] / "durations"
output_folder = Path("data") / "temp" / cfg["name"] / "durations"
if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite")

View File

@@ -2,6 +2,7 @@ import os
import sys
import warnings
from importlib.util import find_spec
from math import ceil
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else:
wget.download(url=url, out=checkpoint_path)
def get_phoneme_durations(durations, phones):
prev = durations[0]
merged_durations = []
# Convolve with stride 2
for i in range(1, len(durations), 2):
if i == len(durations) - 2:
# if it is last take full value
next_half = durations[i + 1]
else:
next_half = ceil(durations[i + 1] / 2)
curr = prev + durations[i] + next_half
prev = durations[i + 1] - next_half
merged_durations.append(curr)
assert len(phones) == len(merged_durations)
assert len(merged_durations) == (len(durations) - 1) // 2
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
start = torch.tensor(0)
duration_json = []
for i, duration in enumerate(merged_durations):
duration_json.append(
{
phones[i]: {
"starttime": start.item(),
"endtime": duration.item(),
"duration": duration.item() - start.item(),
}
}
)
start = duration
assert list(duration_json[-1].values())[0]["endtime"] == sum(
durations
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
return duration_json