diff --git a/matcha/data/text_mel_datamodule.py b/matcha/data/text_mel_datamodule.py index 3141293..96c6e40 100644 --- a/matcha/data/text_mel_datamodule.py +++ b/matcha/data/text_mel_datamodule.py @@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset): filepath, text = filepath_and_text[0], filepath_and_text[1] spk = None - text = self.get_text(text, add_blank=self.add_blank) + text, cleaned_text = self.get_text(text, add_blank=self.add_blank) mel = self.get_mel(filepath) - return {"x": text, "y": mel, "spk": spk, "filepath": filepath} + return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text} def get_mel(self, filepath): audio, sr = ta.load(filepath) @@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset): return mel def get_text(self, text, add_blank=True): - text_norm = text_to_sequence(text, self.cleaners) + text_norm, cleaned_text = text_to_sequence(text, self.cleaners) if self.add_blank: text_norm = intersperse(text_norm, 0) text_norm = torch.IntTensor(text_norm) - return text_norm + return text_norm, cleaned_text def __getitem__(self, index): datapoint = self.get_datapoint(self.filepaths_and_text[index]) @@ -216,7 +216,7 @@ class TextMelBatchCollate: x = torch.zeros((B, x_max_length), dtype=torch.long) y_lengths, x_lengths = [], [] spks = [] - filepaths = [] + filepaths, x_texts = [], [] for i, item in enumerate(batch): y_, x_ = item["y"], item["x"] y_lengths.append(y_.shape[-1]) @@ -225,9 +225,18 @@ class TextMelBatchCollate: x[i, : x_.shape[-1]] = x_ spks.append(item["spk"]) filepaths.append(item["filepath"]) + x_texts.append(item["x_text"]) y_lengths = torch.tensor(y_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long) spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None - return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks, "filepaths": filepaths} + return { + "x": x, + "x_lengths": x_lengths, + "y": y, + "y_lengths": y_lengths, + "spks": spks, + "filepaths": filepaths, + "x_texts": x_texts, + } diff --git a/matcha/models/components/duration_predictors.py b/matcha/models/components/duration_predictors.py index d34f5e3..a660fae 100644 --- a/matcha/models/components/duration_predictors.py +++ b/matcha/models/components/duration_predictors.py @@ -126,7 +126,7 @@ class FlowMatchingDurationPrediction(nn.Module): self.n_steps = params.n_steps @torch.inference_mode() - def forward(self, enc_outputs, mask, n_timesteps=None, temperature=1): + def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1): """Forward diffusion Args: diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index ae951a2..3000f04 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -121,7 +121,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 logw = self.dp(enc_output, x_mask) w = torch.exp(logw) * x_mask - w_ceil = torch.round(w) * length_scale + w_ceil = torch.ceil(w) * length_scale # print(w_ceil) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = y_lengths.max() diff --git a/matcha/text/__init__.py b/matcha/text/__init__.py index 71a4b57..8c75d6b 100644 --- a/matcha/text/__init__.py +++ b/matcha/text/__init__.py @@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names): for symbol in clean_text: symbol_id = _symbol_to_id[symbol] sequence += [symbol_id] - return sequence + return sequence, clean_text def cleaned_text_to_sequence(cleaned_text): diff --git a/matcha/utils/get_durations_from_trained_model.py b/matcha/utils/get_durations_from_trained_model.py index d3c967b..d339699 100644 --- a/matcha/utils/get_durations_from_trained_model.py +++ b/matcha/utils/get_durations_from_trained_model.py @@ -5,6 +5,7 @@ when needed. Parameters from hparam.py will be used """ import argparse +import json import os import sys from pathlib import Path @@ -22,13 +23,20 @@ from matcha.cli import get_device from matcha.data.text_mel_datamodule import TextMelDataModule from matcha.models.matcha_tts import MatchaTTS from matcha.utils.logging_utils import pylogger +from matcha.utils.utils import get_phoneme_durations log = pylogger.get_pylogger(__name__) -def save_durations_to_folder(attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path): +def save_durations_to_folder( + attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str +): durations = attn.squeeze().sum(1)[:x_length].numpy() + durations_json = get_phoneme_durations(durations, text) output = output_folder / Path(filepath).name.replace(".wav", ".npy") + with open(output.with_suffix(".json"), "w", encoding="utf-8") as f: + json.dump(durations_json, f, indent=4, ensure_ascii=False) + np.save(output, durations) @@ -62,7 +70,12 @@ def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module attn = attn.cpu() for i in range(attn.shape[0]): save_durations_to_folder( - attn[i], x_lengths[i].item(), y_lengths[i].item(), batch["filepaths"][i], output_folder + attn[i], + x_lengths[i].item(), + y_lengths[i].item(), + batch["filepaths"][i], + output_folder, + batch["x_texts"][i], ) @@ -131,7 +144,7 @@ def main(): if args.output_folder is not None: output_folder = Path(args.output_folder) else: - output_folder = Path("data") / "processed_data" / cfg["name"] / "durations" + output_folder = Path("data") / "temp" / cfg["name"] / "durations" if os.path.exists(output_folder) and not args.force: print("Folder already exists. Use -f to force overwrite") diff --git a/matcha/utils/utils.py b/matcha/utils/utils.py index af65e09..fc3a48e 100644 --- a/matcha/utils/utils.py +++ b/matcha/utils/utils.py @@ -2,6 +2,7 @@ import os import sys import warnings from importlib.util import find_spec +from math import ceil from pathlib import Path from typing import Any, Callable, Dict, Tuple @@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True): gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) else: wget.download(url=url, out=checkpoint_path) + + +def get_phoneme_durations(durations, phones): + prev = durations[0] + merged_durations = [] + # Convolve with stride 2 + for i in range(1, len(durations), 2): + if i == len(durations) - 2: + # if it is last take full value + next_half = durations[i + 1] + else: + next_half = ceil(durations[i + 1] / 2) + + curr = prev + durations[i] + next_half + prev = durations[i + 1] - next_half + merged_durations.append(curr) + + assert len(phones) == len(merged_durations) + assert len(merged_durations) == (len(durations) - 1) // 2 + + merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) + start = torch.tensor(0) + duration_json = [] + for i, duration in enumerate(merged_durations): + duration_json.append( + { + phones[i]: { + "starttime": start.item(), + "endtime": duration.item(), + "duration": duration.item() - start.item(), + } + } + ) + start = duration + + assert list(duration_json[-1].values())[0]["endtime"] == sum( + durations + ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" + return duration_json diff --git a/scripts/get_durations.sh b/scripts/get_durations.sh new file mode 100644 index 0000000..ff24d62 --- /dev/null +++ b/scripts/get_durations.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +echo "Starting script" + +echo "Getting LJ Speech durations" +python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c logs/train/lj_det/runs/2024-01-12_12-05-00/checkpoints/last.ckpt -f + +echo "Getting TSG2 durations" +python matcha/utils/get_durations_from_trained_model.py -i tsg2.yaml -c logs/train/tsg2_det_dur/runs/2024-01-05_12-33-35/checkpoints/last.ckpt -f + +echo "Getting Joe Spont durations" +python matcha/utils/get_durations_from_trained_model.py -i joe_spont_only.yaml -c logs/train/joe_det_dur/runs/2024-02-20_14-01-01/checkpoints/last.ckpt -f + +echo "Getting Ryan durations" +python matcha/utils/get_durations_from_trained_model.py -i ryan.yaml -c logs/train/matcha_ryan_det/runs/2024-02-26_09-28-09/checkpoints/last.ckpt -f \ No newline at end of file diff --git a/scripts/transcribe.sh b/scripts/transcribe.sh new file mode 100644 index 0000000..d72a056 --- /dev/null +++ b/scripts/transcribe.sh @@ -0,0 +1,7 @@ +echo "Transcribing" + +whispertranscriber -i lj_det_output -o lj_det_output_transcriptions -f + +whispertranscriber -i lj_fm_output -o lj_fm_output_transcriptions -f +wercompute -r dur_wer_computation/reference_transcripts/ -i lj_det_output_transcriptions +wercompute -r dur_wer_computation/reference_transcripts/ -i lj_fm_output_transcriptions \ No newline at end of file diff --git a/scripts/wer_computer.sh b/scripts/wer_computer.sh new file mode 100644 index 0000000..6e93bc3 --- /dev/null +++ b/scripts/wer_computer.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Run from root folder with: bash scripts/wer_computer.sh + + +root_folder=${1:-"dur_wer_computation"} +echo "Running WER computation for Duration predictors" +cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_fm_output_transcriptions/" +# echo $cmd +echo "LJ" +echo "===================================" +echo "Flow Matching" +$cmd +echo "-----------------------------------" + +echo "LJ Determinstic" +cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_det_output_transcriptions/" +$cmd +echo "-----------------------------------" + +echo "Cormac" +echo "===================================" +echo "Cormac Flow Matching" +cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/fm_output_transcriptions/" +$cmd +echo "-----------------------------------" + +echo "Cormac Determinstic" +cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/det_output_transcriptions/" +$cmd +echo "-----------------------------------" \ No newline at end of file