Adding saving phones while getting durations from matcha

This commit is contained in:
Shivam Mehta
2024-03-02 12:47:08 +00:00
parent ad76016916
commit 294c6b1327
9 changed files with 126 additions and 12 deletions

View File

@@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset):
filepath, text = filepath_and_text[0], filepath_and_text[1] filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None spk = None
text = self.get_text(text, add_blank=self.add_blank) text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath) mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk, "filepath": filepath} return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
def get_mel(self, filepath): def get_mel(self, filepath):
audio, sr = ta.load(filepath) audio, sr = ta.load(filepath)
@@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset):
return mel return mel
def get_text(self, text, add_blank=True): def get_text(self, text, add_blank=True):
text_norm = text_to_sequence(text, self.cleaners) text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
if self.add_blank: if self.add_blank:
text_norm = intersperse(text_norm, 0) text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm) text_norm = torch.IntTensor(text_norm)
return text_norm return text_norm, cleaned_text
def __getitem__(self, index): def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index]) datapoint = self.get_datapoint(self.filepaths_and_text[index])
@@ -216,7 +216,7 @@ class TextMelBatchCollate:
x = torch.zeros((B, x_max_length), dtype=torch.long) x = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], [] y_lengths, x_lengths = [], []
spks = [] spks = []
filepaths = [] filepaths, x_texts = [], []
for i, item in enumerate(batch): for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"] y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1]) y_lengths.append(y_.shape[-1])
@@ -225,9 +225,18 @@ class TextMelBatchCollate:
x[i, : x_.shape[-1]] = x_ x[i, : x_.shape[-1]] = x_
spks.append(item["spk"]) spks.append(item["spk"])
filepaths.append(item["filepath"]) filepaths.append(item["filepath"])
x_texts.append(item["x_text"])
y_lengths = torch.tensor(y_lengths, dtype=torch.long) y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks, "filepaths": filepaths} return {
"x": x,
"x_lengths": x_lengths,
"y": y,
"y_lengths": y_lengths,
"spks": spks,
"filepaths": filepaths,
"x_texts": x_texts,
}

View File

@@ -126,7 +126,7 @@ class FlowMatchingDurationPrediction(nn.Module):
self.n_steps = params.n_steps self.n_steps = params.n_steps
@torch.inference_mode() @torch.inference_mode()
def forward(self, enc_outputs, mask, n_timesteps=None, temperature=1): def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1):
"""Forward diffusion """Forward diffusion
Args: Args:

View File

@@ -121,7 +121,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
logw = self.dp(enc_output, x_mask) logw = self.dp(enc_output, x_mask)
w = torch.exp(logw) * x_mask w = torch.exp(logw) * x_mask
w_ceil = torch.round(w) * length_scale w_ceil = torch.ceil(w) * length_scale
# print(w_ceil) # print(w_ceil)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max() y_max_length = y_lengths.max()

View File

@@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names):
for symbol in clean_text: for symbol in clean_text:
symbol_id = _symbol_to_id[symbol] symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id] sequence += [symbol_id]
return sequence return sequence, clean_text
def cleaned_text_to_sequence(cleaned_text): def cleaned_text_to_sequence(cleaned_text):

View File

@@ -5,6 +5,7 @@ when needed.
Parameters from hparam.py will be used Parameters from hparam.py will be used
""" """
import argparse import argparse
import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
@@ -22,13 +23,20 @@ from matcha.cli import get_device
from matcha.data.text_mel_datamodule import TextMelDataModule from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS from matcha.models.matcha_tts import MatchaTTS
from matcha.utils.logging_utils import pylogger from matcha.utils.logging_utils import pylogger
from matcha.utils.utils import get_phoneme_durations
log = pylogger.get_pylogger(__name__) log = pylogger.get_pylogger(__name__)
def save_durations_to_folder(attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path): def save_durations_to_folder(
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
):
durations = attn.squeeze().sum(1)[:x_length].numpy() durations = attn.squeeze().sum(1)[:x_length].numpy()
durations_json = get_phoneme_durations(durations, text)
output = output_folder / Path(filepath).name.replace(".wav", ".npy") output = output_folder / Path(filepath).name.replace(".wav", ".npy")
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
json.dump(durations_json, f, indent=4, ensure_ascii=False)
np.save(output, durations) np.save(output, durations)
@@ -62,7 +70,12 @@ def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module
attn = attn.cpu() attn = attn.cpu()
for i in range(attn.shape[0]): for i in range(attn.shape[0]):
save_durations_to_folder( save_durations_to_folder(
attn[i], x_lengths[i].item(), y_lengths[i].item(), batch["filepaths"][i], output_folder attn[i],
x_lengths[i].item(),
y_lengths[i].item(),
batch["filepaths"][i],
output_folder,
batch["x_texts"][i],
) )
@@ -131,7 +144,7 @@ def main():
if args.output_folder is not None: if args.output_folder is not None:
output_folder = Path(args.output_folder) output_folder = Path(args.output_folder)
else: else:
output_folder = Path("data") / "processed_data" / cfg["name"] / "durations" output_folder = Path("data") / "temp" / cfg["name"] / "durations"
if os.path.exists(output_folder) and not args.force: if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite") print("Folder already exists. Use -f to force overwrite")

View File

@@ -2,6 +2,7 @@ import os
import sys import sys
import warnings import warnings
from importlib.util import find_spec from importlib.util import find_spec
from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable, Dict, Tuple
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else: else:
wget.download(url=url, out=checkpoint_path) wget.download(url=url, out=checkpoint_path)
def get_phoneme_durations(durations, phones):
prev = durations[0]
merged_durations = []
# Convolve with stride 2
for i in range(1, len(durations), 2):
if i == len(durations) - 2:
# if it is last take full value
next_half = durations[i + 1]
else:
next_half = ceil(durations[i + 1] / 2)
curr = prev + durations[i] + next_half
prev = durations[i + 1] - next_half
merged_durations.append(curr)
assert len(phones) == len(merged_durations)
assert len(merged_durations) == (len(durations) - 1) // 2
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
start = torch.tensor(0)
duration_json = []
for i, duration in enumerate(merged_durations):
duration_json.append(
{
phones[i]: {
"starttime": start.item(),
"endtime": duration.item(),
"duration": duration.item() - start.item(),
}
}
)
start = duration
assert list(duration_json[-1].values())[0]["endtime"] == sum(
durations
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
return duration_json

15
scripts/get_durations.sh Normal file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
echo "Starting script"
echo "Getting LJ Speech durations"
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c logs/train/lj_det/runs/2024-01-12_12-05-00/checkpoints/last.ckpt -f
echo "Getting TSG2 durations"
python matcha/utils/get_durations_from_trained_model.py -i tsg2.yaml -c logs/train/tsg2_det_dur/runs/2024-01-05_12-33-35/checkpoints/last.ckpt -f
echo "Getting Joe Spont durations"
python matcha/utils/get_durations_from_trained_model.py -i joe_spont_only.yaml -c logs/train/joe_det_dur/runs/2024-02-20_14-01-01/checkpoints/last.ckpt -f
echo "Getting Ryan durations"
python matcha/utils/get_durations_from_trained_model.py -i ryan.yaml -c logs/train/matcha_ryan_det/runs/2024-02-26_09-28-09/checkpoints/last.ckpt -f

7
scripts/transcribe.sh Normal file
View File

@@ -0,0 +1,7 @@
echo "Transcribing"
whispertranscriber -i lj_det_output -o lj_det_output_transcriptions -f
whispertranscriber -i lj_fm_output -o lj_fm_output_transcriptions -f
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_det_output_transcriptions
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_fm_output_transcriptions

30
scripts/wer_computer.sh Normal file
View File

@@ -0,0 +1,30 @@
#!/bin/bash
# Run from root folder with: bash scripts/wer_computer.sh
root_folder=${1:-"dur_wer_computation"}
echo "Running WER computation for Duration predictors"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_fm_output_transcriptions/"
# echo $cmd
echo "LJ"
echo "==================================="
echo "Flow Matching"
$cmd
echo "-----------------------------------"
echo "LJ Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_det_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac"
echo "==================================="
echo "Cormac Flow Matching"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/fm_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/det_output_transcriptions/"
$cmd
echo "-----------------------------------"