Adding saving phones while getting durations from matcha

This commit is contained in:
Shivam Mehta
2024-03-02 12:47:08 +00:00
parent ad76016916
commit 294c6b1327
9 changed files with 126 additions and 12 deletions

View File

@@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset):
filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None
text = self.get_text(text, add_blank=self.add_blank)
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk, "filepath": filepath}
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
def get_mel(self, filepath):
audio, sr = ta.load(filepath)
@@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset):
return mel
def get_text(self, text, add_blank=True):
text_norm = text_to_sequence(text, self.cleaners)
text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
if self.add_blank:
text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm)
return text_norm
return text_norm, cleaned_text
def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index])
@@ -216,7 +216,7 @@ class TextMelBatchCollate:
x = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], []
spks = []
filepaths = []
filepaths, x_texts = [], []
for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1])
@@ -225,9 +225,18 @@ class TextMelBatchCollate:
x[i, : x_.shape[-1]] = x_
spks.append(item["spk"])
filepaths.append(item["filepath"])
x_texts.append(item["x_text"])
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks, "filepaths": filepaths}
return {
"x": x,
"x_lengths": x_lengths,
"y": y,
"y_lengths": y_lengths,
"spks": spks,
"filepaths": filepaths,
"x_texts": x_texts,
}

View File

@@ -126,7 +126,7 @@ class FlowMatchingDurationPrediction(nn.Module):
self.n_steps = params.n_steps
@torch.inference_mode()
def forward(self, enc_outputs, mask, n_timesteps=None, temperature=1):
def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1):
"""Forward diffusion
Args:

View File

@@ -121,7 +121,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
logw = self.dp(enc_output, x_mask)
w = torch.exp(logw) * x_mask
w_ceil = torch.round(w) * length_scale
w_ceil = torch.ceil(w) * length_scale
# print(w_ceil)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max()

View File

@@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names):
for symbol in clean_text:
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]
return sequence
return sequence, clean_text
def cleaned_text_to_sequence(cleaned_text):

View File

@@ -5,6 +5,7 @@ when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
import sys
from pathlib import Path
@@ -22,13 +23,20 @@ from matcha.cli import get_device
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS
from matcha.utils.logging_utils import pylogger
from matcha.utils.utils import get_phoneme_durations
log = pylogger.get_pylogger(__name__)
def save_durations_to_folder(attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path):
def save_durations_to_folder(
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
):
durations = attn.squeeze().sum(1)[:x_length].numpy()
durations_json = get_phoneme_durations(durations, text)
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
json.dump(durations_json, f, indent=4, ensure_ascii=False)
np.save(output, durations)
@@ -62,7 +70,12 @@ def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module
attn = attn.cpu()
for i in range(attn.shape[0]):
save_durations_to_folder(
attn[i], x_lengths[i].item(), y_lengths[i].item(), batch["filepaths"][i], output_folder
attn[i],
x_lengths[i].item(),
y_lengths[i].item(),
batch["filepaths"][i],
output_folder,
batch["x_texts"][i],
)
@@ -131,7 +144,7 @@ def main():
if args.output_folder is not None:
output_folder = Path(args.output_folder)
else:
output_folder = Path("data") / "processed_data" / cfg["name"] / "durations"
output_folder = Path("data") / "temp" / cfg["name"] / "durations"
if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite")

View File

@@ -2,6 +2,7 @@ import os
import sys
import warnings
from importlib.util import find_spec
from math import ceil
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else:
wget.download(url=url, out=checkpoint_path)
def get_phoneme_durations(durations, phones):
prev = durations[0]
merged_durations = []
# Convolve with stride 2
for i in range(1, len(durations), 2):
if i == len(durations) - 2:
# if it is last take full value
next_half = durations[i + 1]
else:
next_half = ceil(durations[i + 1] / 2)
curr = prev + durations[i] + next_half
prev = durations[i + 1] - next_half
merged_durations.append(curr)
assert len(phones) == len(merged_durations)
assert len(merged_durations) == (len(durations) - 1) // 2
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
start = torch.tensor(0)
duration_json = []
for i, duration in enumerate(merged_durations):
duration_json.append(
{
phones[i]: {
"starttime": start.item(),
"endtime": duration.item(),
"duration": duration.item() - start.item(),
}
}
)
start = duration
assert list(duration_json[-1].values())[0]["endtime"] == sum(
durations
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
return duration_json

15
scripts/get_durations.sh Normal file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
echo "Starting script"
echo "Getting LJ Speech durations"
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c logs/train/lj_det/runs/2024-01-12_12-05-00/checkpoints/last.ckpt -f
echo "Getting TSG2 durations"
python matcha/utils/get_durations_from_trained_model.py -i tsg2.yaml -c logs/train/tsg2_det_dur/runs/2024-01-05_12-33-35/checkpoints/last.ckpt -f
echo "Getting Joe Spont durations"
python matcha/utils/get_durations_from_trained_model.py -i joe_spont_only.yaml -c logs/train/joe_det_dur/runs/2024-02-20_14-01-01/checkpoints/last.ckpt -f
echo "Getting Ryan durations"
python matcha/utils/get_durations_from_trained_model.py -i ryan.yaml -c logs/train/matcha_ryan_det/runs/2024-02-26_09-28-09/checkpoints/last.ckpt -f

7
scripts/transcribe.sh Normal file
View File

@@ -0,0 +1,7 @@
echo "Transcribing"
whispertranscriber -i lj_det_output -o lj_det_output_transcriptions -f
whispertranscriber -i lj_fm_output -o lj_fm_output_transcriptions -f
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_det_output_transcriptions
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_fm_output_transcriptions

30
scripts/wer_computer.sh Normal file
View File

@@ -0,0 +1,30 @@
#!/bin/bash
# Run from root folder with: bash scripts/wer_computer.sh
root_folder=${1:-"dur_wer_computation"}
echo "Running WER computation for Duration predictors"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_fm_output_transcriptions/"
# echo $cmd
echo "LJ"
echo "==================================="
echo "Flow Matching"
$cmd
echo "-----------------------------------"
echo "LJ Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_det_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac"
echo "==================================="
echo "Cormac Flow Matching"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/fm_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/det_output_transcriptions/"
$cmd
echo "-----------------------------------"