mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 10:19:19 +08:00
Adding saving phones while getting durations from matcha
This commit is contained in:
@@ -5,6 +5,7 @@ when needed.
|
||||
Parameters from hparam.py will be used
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -22,13 +23,20 @@ from matcha.cli import get_device
|
||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.utils.logging_utils import pylogger
|
||||
from matcha.utils.utils import get_phoneme_durations
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def save_durations_to_folder(attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path):
|
||||
def save_durations_to_folder(
|
||||
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
|
||||
):
|
||||
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
||||
durations_json = get_phoneme_durations(durations, text)
|
||||
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
|
||||
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
|
||||
json.dump(durations_json, f, indent=4, ensure_ascii=False)
|
||||
|
||||
np.save(output, durations)
|
||||
|
||||
|
||||
@@ -62,7 +70,12 @@ def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module
|
||||
attn = attn.cpu()
|
||||
for i in range(attn.shape[0]):
|
||||
save_durations_to_folder(
|
||||
attn[i], x_lengths[i].item(), y_lengths[i].item(), batch["filepaths"][i], output_folder
|
||||
attn[i],
|
||||
x_lengths[i].item(),
|
||||
y_lengths[i].item(),
|
||||
batch["filepaths"][i],
|
||||
output_folder,
|
||||
batch["x_texts"][i],
|
||||
)
|
||||
|
||||
|
||||
@@ -131,7 +144,7 @@ def main():
|
||||
if args.output_folder is not None:
|
||||
output_folder = Path(args.output_folder)
|
||||
else:
|
||||
output_folder = Path("data") / "processed_data" / cfg["name"] / "durations"
|
||||
output_folder = Path("data") / "temp" / cfg["name"] / "durations"
|
||||
|
||||
if os.path.exists(output_folder) and not args.force:
|
||||
print("Folder already exists. Use -f to force overwrite")
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import sys
|
||||
import warnings
|
||||
from importlib.util import find_spec
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
||||
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
|
||||
else:
|
||||
wget.download(url=url, out=checkpoint_path)
|
||||
|
||||
|
||||
def get_phoneme_durations(durations, phones):
|
||||
prev = durations[0]
|
||||
merged_durations = []
|
||||
# Convolve with stride 2
|
||||
for i in range(1, len(durations), 2):
|
||||
if i == len(durations) - 2:
|
||||
# if it is last take full value
|
||||
next_half = durations[i + 1]
|
||||
else:
|
||||
next_half = ceil(durations[i + 1] / 2)
|
||||
|
||||
curr = prev + durations[i] + next_half
|
||||
prev = durations[i + 1] - next_half
|
||||
merged_durations.append(curr)
|
||||
|
||||
assert len(phones) == len(merged_durations)
|
||||
assert len(merged_durations) == (len(durations) - 1) // 2
|
||||
|
||||
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
|
||||
start = torch.tensor(0)
|
||||
duration_json = []
|
||||
for i, duration in enumerate(merged_durations):
|
||||
duration_json.append(
|
||||
{
|
||||
phones[i]: {
|
||||
"starttime": start.item(),
|
||||
"endtime": duration.item(),
|
||||
"duration": duration.item() - start.item(),
|
||||
}
|
||||
}
|
||||
)
|
||||
start = duration
|
||||
|
||||
assert list(duration_json[-1].values())[0]["endtime"] == sum(
|
||||
durations
|
||||
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
|
||||
return duration_json
|
||||
|
||||
Reference in New Issue
Block a user