mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 10:19:19 +08:00
Adding saving phones while getting durations from matcha
This commit is contained in:
@@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset):
|
|||||||
filepath, text = filepath_and_text[0], filepath_and_text[1]
|
filepath, text = filepath_and_text[0], filepath_and_text[1]
|
||||||
spk = None
|
spk = None
|
||||||
|
|
||||||
text = self.get_text(text, add_blank=self.add_blank)
|
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
|
||||||
mel = self.get_mel(filepath)
|
mel = self.get_mel(filepath)
|
||||||
|
|
||||||
return {"x": text, "y": mel, "spk": spk, "filepath": filepath}
|
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
|
||||||
|
|
||||||
def get_mel(self, filepath):
|
def get_mel(self, filepath):
|
||||||
audio, sr = ta.load(filepath)
|
audio, sr = ta.load(filepath)
|
||||||
@@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset):
|
|||||||
return mel
|
return mel
|
||||||
|
|
||||||
def get_text(self, text, add_blank=True):
|
def get_text(self, text, add_blank=True):
|
||||||
text_norm = text_to_sequence(text, self.cleaners)
|
text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
|
||||||
if self.add_blank:
|
if self.add_blank:
|
||||||
text_norm = intersperse(text_norm, 0)
|
text_norm = intersperse(text_norm, 0)
|
||||||
text_norm = torch.IntTensor(text_norm)
|
text_norm = torch.IntTensor(text_norm)
|
||||||
return text_norm
|
return text_norm, cleaned_text
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
datapoint = self.get_datapoint(self.filepaths_and_text[index])
|
datapoint = self.get_datapoint(self.filepaths_and_text[index])
|
||||||
@@ -216,7 +216,7 @@ class TextMelBatchCollate:
|
|||||||
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
||||||
y_lengths, x_lengths = [], []
|
y_lengths, x_lengths = [], []
|
||||||
spks = []
|
spks = []
|
||||||
filepaths = []
|
filepaths, x_texts = [], []
|
||||||
for i, item in enumerate(batch):
|
for i, item in enumerate(batch):
|
||||||
y_, x_ = item["y"], item["x"]
|
y_, x_ = item["y"], item["x"]
|
||||||
y_lengths.append(y_.shape[-1])
|
y_lengths.append(y_.shape[-1])
|
||||||
@@ -225,9 +225,18 @@ class TextMelBatchCollate:
|
|||||||
x[i, : x_.shape[-1]] = x_
|
x[i, : x_.shape[-1]] = x_
|
||||||
spks.append(item["spk"])
|
spks.append(item["spk"])
|
||||||
filepaths.append(item["filepath"])
|
filepaths.append(item["filepath"])
|
||||||
|
x_texts.append(item["x_text"])
|
||||||
|
|
||||||
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
||||||
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
||||||
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
|
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
|
||||||
|
|
||||||
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks, "filepaths": filepaths}
|
return {
|
||||||
|
"x": x,
|
||||||
|
"x_lengths": x_lengths,
|
||||||
|
"y": y,
|
||||||
|
"y_lengths": y_lengths,
|
||||||
|
"spks": spks,
|
||||||
|
"filepaths": filepaths,
|
||||||
|
"x_texts": x_texts,
|
||||||
|
}
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class FlowMatchingDurationPrediction(nn.Module):
|
|||||||
self.n_steps = params.n_steps
|
self.n_steps = params.n_steps
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward(self, enc_outputs, mask, n_timesteps=None, temperature=1):
|
def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1):
|
||||||
"""Forward diffusion
|
"""Forward diffusion
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
logw = self.dp(enc_output, x_mask)
|
logw = self.dp(enc_output, x_mask)
|
||||||
|
|
||||||
w = torch.exp(logw) * x_mask
|
w = torch.exp(logw) * x_mask
|
||||||
w_ceil = torch.round(w) * length_scale
|
w_ceil = torch.ceil(w) * length_scale
|
||||||
# print(w_ceil)
|
# print(w_ceil)
|
||||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||||
y_max_length = y_lengths.max()
|
y_max_length = y_lengths.max()
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names):
|
|||||||
for symbol in clean_text:
|
for symbol in clean_text:
|
||||||
symbol_id = _symbol_to_id[symbol]
|
symbol_id = _symbol_to_id[symbol]
|
||||||
sequence += [symbol_id]
|
sequence += [symbol_id]
|
||||||
return sequence
|
return sequence, clean_text
|
||||||
|
|
||||||
|
|
||||||
def cleaned_text_to_sequence(cleaned_text):
|
def cleaned_text_to_sequence(cleaned_text):
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ when needed.
|
|||||||
Parameters from hparam.py will be used
|
Parameters from hparam.py will be used
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -22,13 +23,20 @@ from matcha.cli import get_device
|
|||||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||||
from matcha.models.matcha_tts import MatchaTTS
|
from matcha.models.matcha_tts import MatchaTTS
|
||||||
from matcha.utils.logging_utils import pylogger
|
from matcha.utils.logging_utils import pylogger
|
||||||
|
from matcha.utils.utils import get_phoneme_durations
|
||||||
|
|
||||||
log = pylogger.get_pylogger(__name__)
|
log = pylogger.get_pylogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def save_durations_to_folder(attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path):
|
def save_durations_to_folder(
|
||||||
|
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
|
||||||
|
):
|
||||||
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
||||||
|
durations_json = get_phoneme_durations(durations, text)
|
||||||
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
|
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
|
||||||
|
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
|
||||||
|
json.dump(durations_json, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
np.save(output, durations)
|
np.save(output, durations)
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +70,12 @@ def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module
|
|||||||
attn = attn.cpu()
|
attn = attn.cpu()
|
||||||
for i in range(attn.shape[0]):
|
for i in range(attn.shape[0]):
|
||||||
save_durations_to_folder(
|
save_durations_to_folder(
|
||||||
attn[i], x_lengths[i].item(), y_lengths[i].item(), batch["filepaths"][i], output_folder
|
attn[i],
|
||||||
|
x_lengths[i].item(),
|
||||||
|
y_lengths[i].item(),
|
||||||
|
batch["filepaths"][i],
|
||||||
|
output_folder,
|
||||||
|
batch["x_texts"][i],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -131,7 +144,7 @@ def main():
|
|||||||
if args.output_folder is not None:
|
if args.output_folder is not None:
|
||||||
output_folder = Path(args.output_folder)
|
output_folder = Path(args.output_folder)
|
||||||
else:
|
else:
|
||||||
output_folder = Path("data") / "processed_data" / cfg["name"] / "durations"
|
output_folder = Path("data") / "temp" / cfg["name"] / "durations"
|
||||||
|
|
||||||
if os.path.exists(output_folder) and not args.force:
|
if os.path.exists(output_folder) and not args.force:
|
||||||
print("Folder already exists. Use -f to force overwrite")
|
print("Folder already exists. Use -f to force overwrite")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Tuple
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
|||||||
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
|
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
|
||||||
else:
|
else:
|
||||||
wget.download(url=url, out=checkpoint_path)
|
wget.download(url=url, out=checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_phoneme_durations(durations, phones):
|
||||||
|
prev = durations[0]
|
||||||
|
merged_durations = []
|
||||||
|
# Convolve with stride 2
|
||||||
|
for i in range(1, len(durations), 2):
|
||||||
|
if i == len(durations) - 2:
|
||||||
|
# if it is last take full value
|
||||||
|
next_half = durations[i + 1]
|
||||||
|
else:
|
||||||
|
next_half = ceil(durations[i + 1] / 2)
|
||||||
|
|
||||||
|
curr = prev + durations[i] + next_half
|
||||||
|
prev = durations[i + 1] - next_half
|
||||||
|
merged_durations.append(curr)
|
||||||
|
|
||||||
|
assert len(phones) == len(merged_durations)
|
||||||
|
assert len(merged_durations) == (len(durations) - 1) // 2
|
||||||
|
|
||||||
|
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
|
||||||
|
start = torch.tensor(0)
|
||||||
|
duration_json = []
|
||||||
|
for i, duration in enumerate(merged_durations):
|
||||||
|
duration_json.append(
|
||||||
|
{
|
||||||
|
phones[i]: {
|
||||||
|
"starttime": start.item(),
|
||||||
|
"endtime": duration.item(),
|
||||||
|
"duration": duration.item() - start.item(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
start = duration
|
||||||
|
|
||||||
|
assert list(duration_json[-1].values())[0]["endtime"] == sum(
|
||||||
|
durations
|
||||||
|
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
|
||||||
|
return duration_json
|
||||||
|
|||||||
15
scripts/get_durations.sh
Normal file
15
scripts/get_durations.sh
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "Starting script"
|
||||||
|
|
||||||
|
echo "Getting LJ Speech durations"
|
||||||
|
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c logs/train/lj_det/runs/2024-01-12_12-05-00/checkpoints/last.ckpt -f
|
||||||
|
|
||||||
|
echo "Getting TSG2 durations"
|
||||||
|
python matcha/utils/get_durations_from_trained_model.py -i tsg2.yaml -c logs/train/tsg2_det_dur/runs/2024-01-05_12-33-35/checkpoints/last.ckpt -f
|
||||||
|
|
||||||
|
echo "Getting Joe Spont durations"
|
||||||
|
python matcha/utils/get_durations_from_trained_model.py -i joe_spont_only.yaml -c logs/train/joe_det_dur/runs/2024-02-20_14-01-01/checkpoints/last.ckpt -f
|
||||||
|
|
||||||
|
echo "Getting Ryan durations"
|
||||||
|
python matcha/utils/get_durations_from_trained_model.py -i ryan.yaml -c logs/train/matcha_ryan_det/runs/2024-02-26_09-28-09/checkpoints/last.ckpt -f
|
||||||
7
scripts/transcribe.sh
Normal file
7
scripts/transcribe.sh
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
echo "Transcribing"
|
||||||
|
|
||||||
|
whispertranscriber -i lj_det_output -o lj_det_output_transcriptions -f
|
||||||
|
|
||||||
|
whispertranscriber -i lj_fm_output -o lj_fm_output_transcriptions -f
|
||||||
|
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_det_output_transcriptions
|
||||||
|
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_fm_output_transcriptions
|
||||||
30
scripts/wer_computer.sh
Normal file
30
scripts/wer_computer.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Run from root folder with: bash scripts/wer_computer.sh
|
||||||
|
|
||||||
|
|
||||||
|
root_folder=${1:-"dur_wer_computation"}
|
||||||
|
echo "Running WER computation for Duration predictors"
|
||||||
|
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_fm_output_transcriptions/"
|
||||||
|
# echo $cmd
|
||||||
|
echo "LJ"
|
||||||
|
echo "==================================="
|
||||||
|
echo "Flow Matching"
|
||||||
|
$cmd
|
||||||
|
echo "-----------------------------------"
|
||||||
|
|
||||||
|
echo "LJ Determinstic"
|
||||||
|
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_det_output_transcriptions/"
|
||||||
|
$cmd
|
||||||
|
echo "-----------------------------------"
|
||||||
|
|
||||||
|
echo "Cormac"
|
||||||
|
echo "==================================="
|
||||||
|
echo "Cormac Flow Matching"
|
||||||
|
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/fm_output_transcriptions/"
|
||||||
|
$cmd
|
||||||
|
echo "-----------------------------------"
|
||||||
|
|
||||||
|
echo "Cormac Determinstic"
|
||||||
|
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/det_output_transcriptions/"
|
||||||
|
$cmd
|
||||||
|
echo "-----------------------------------"
|
||||||
Reference in New Issue
Block a user