7 Commits

Author SHA1 Message Date
Shivam Mehta
068d135e20 Adding alginment information to readme 2024-05-27 13:57:10 +02:00
Shivam Mehta
ac0b258f80 Adding configuration for training from durations 2024-05-27 13:50:21 +02:00
Shivam Mehta
de910380bc Fixing batched synthesis for multispeaker model 2024-05-27 13:40:02 +02:00
Shivam Mehta
aa496aa13f Adding the possibility to train with durations 2024-05-27 13:24:21 +02:00
Shivam Mehta
e658aee6a5 Pinning gradio 2024-05-25 20:15:17 +02:00
Shivam Mehta
d816c40e3d Updating the notebook to adjust to the change 2024-05-24 11:46:03 +02:00
Shivam Mehta
4b39f6cad0 Adding the possibility of get durations out of pretrained model 2024-05-24 11:34:51 +02:00
18 changed files with 389 additions and 40 deletions

View File

@@ -1,5 +1,5 @@
default_language_version:
python: python3.10
python: python3.11
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks

View File

@@ -252,6 +252,43 @@ python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vo
This will write `.wav` audio files to the output directory.
## Extract phoneme alignments from Matcha-TTS
If the dataset is structured as
```bash
data/
└── LJSpeech-1.1
├── metadata.csv
├── README
├── test.txt
├── train.txt
├── val.txt
└── wavs
```
Then you can extract the phoneme level alignments from a Trained Matcha-TTS model using:
```bash
python matcha/utils/get_durations_from_trained_model.py -i dataset_yaml -c <checkpoint>
```
Example:
```bash
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c matcha_ljspeech.ckpt
```
or simply:
```bash
matcha-tts-get-durations -i ljspeech.yaml -c matcha_ljspeech.ckpt
```
---
## Train using extracted alignments
In the datasetconfig turn on load duration.
Example: `ljspeech.yaml`
```
load_durations: True
```
or see an examples in configs/experiment/ljspeech_from_durations.yaml
## Citation information
If you use our code or otherwise find this work useful, please cite our paper:

View File

@@ -1,7 +1,7 @@
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: ljspeech
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt
train_filelist_path: data/LJSpeech-1.1/train.txt
valid_filelist_path: data/LJSpeech-1.1/val.txt
batch_size: 32
num_workers: 20
pin_memory: True
@@ -19,3 +19,4 @@ data_statistics: # Computed for ljspeech dataset
mel_mean: -5.536622
mel_std: 2.116101
seed: ${seed}
load_durations: false

View File

@@ -0,0 +1,19 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ljspeech.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech
data:
load_durations: True
batch_size: 64

View File

@@ -13,3 +13,4 @@ n_feats: 80
data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4
prior_loss: true
use_precomputed_durations: ${data.load_durations}

View File

@@ -1 +1 @@
0.0.5.1
0.0.6.0

View File

@@ -48,7 +48,7 @@ def plot_spectrogram_to_numpy(spectrogram, filename):
def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}")
x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"]), 0),
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
dtype=torch.long,
device=device,
)[None]
@@ -326,12 +326,13 @@ def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
for i, batch in enumerate(dataloader):
i = i + 1
start_t = dt.datetime.now()
b = batch["x"].shape[0]
output = model.synthesise(
batch["x"].to(device),
batch["x_lengths"].to(device),
n_timesteps=args.steps,
temperature=args.temperature,
spks=spk,
spks=spk.expand(b) if spk is not None else spk,
length_scale=args.speaking_rate,
)

View File

@@ -1,6 +1,8 @@
import random
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
import torch
import torchaudio as ta
from lightning import LightningDataModule
@@ -39,6 +41,7 @@ class TextMelDataModule(LightningDataModule):
f_max,
data_statistics,
seed,
load_durations,
):
super().__init__()
@@ -68,6 +71,7 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
self.hparams.load_durations,
)
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.valid_filelist_path,
@@ -83,6 +87,7 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
self.hparams.load_durations,
)
def train_dataloader(self):
@@ -109,7 +114,7 @@ class TextMelDataModule(LightningDataModule):
"""Clean up after fit or test."""
pass # pylint: disable=unnecessary-pass
def state_dict(self): # pylint: disable=no-self-use
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
@@ -134,6 +139,7 @@ class TextMelDataset(torch.utils.data.Dataset):
f_max=8000,
data_parameters=None,
seed=None,
load_durations=False,
):
self.filepaths_and_text = parse_filelist(filelist_path)
self.n_spks = n_spks
@@ -146,6 +152,8 @@ class TextMelDataset(torch.utils.data.Dataset):
self.win_length = win_length
self.f_min = f_min
self.f_max = f_max
self.load_durations = load_durations
if data_parameters is not None:
self.data_parameters = data_parameters
else:
@@ -164,10 +172,29 @@ class TextMelDataset(torch.utils.data.Dataset):
filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None
text = self.get_text(text, add_blank=self.add_blank)
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk}
durations = self.get_durations(filepath, text) if self.load_durations else None
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
def get_durations(self, filepath, text):
filepath = Path(filepath)
data_dir, name = filepath.parent.parent, filepath.stem
try:
dur_loc = data_dir / "durations" / f"{name}.npy"
durs = torch.from_numpy(np.load(dur_loc).astype(int))
except FileNotFoundError as e:
raise FileNotFoundError(
f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
) from e
assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
return durs
def get_mel(self, filepath):
audio, sr = ta.load(filepath)
@@ -187,11 +214,11 @@ class TextMelDataset(torch.utils.data.Dataset):
return mel
def get_text(self, text, add_blank=True):
text_norm = text_to_sequence(text, self.cleaners)
text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
if self.add_blank:
text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm)
return text_norm
return text_norm, cleaned_text
def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index])
@@ -214,8 +241,11 @@ class TextMelBatchCollate:
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long)
durations = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], []
spks = []
filepaths, x_texts = [], []
for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1])
@@ -223,9 +253,22 @@ class TextMelBatchCollate:
y[i, :, : y_.shape[-1]] = y_
x[i, : x_.shape[-1]] = x_
spks.append(item["spk"])
filepaths.append(item["filepath"])
x_texts.append(item["x_text"])
if item["durations"] is not None:
durations[i, : item["durations"].shape[-1]] = item["durations"]
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks}
return {
"x": x,
"x_lengths": x_lengths,
"y": y,
"y_lengths": y_lengths,
"spks": spks,
"filepaths": filepaths,
"x_texts": x_texts,
"durations": durations if not torch.eq(durations, 0).all() else None,
}

View File

@@ -58,13 +58,14 @@ class BaseLightningClass(LightningModule, ABC):
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
dur_loss, prior_loss, diff_loss = self(
dur_loss, prior_loss, diff_loss, *_ = self(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
out_size=self.out_size,
durations=batch["durations"],
)
return {
"dur_loss": dur_loss,

View File

@@ -35,6 +35,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
optimizer=None,
scheduler=None,
prior_loss=True,
use_precomputed_durations=False,
):
super().__init__()
@@ -46,6 +47,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.n_feats = n_feats
self.out_size = out_size
self.prior_loss = prior_loss
self.use_precomputed_durations = use_precomputed_durations
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
@@ -147,7 +149,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
"rtf": rtf,
}
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
@@ -179,17 +181,20 @@ class MatchaTTS(BaseLightningClass): # 🍵
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
if self.use_precomputed_durations:
attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
else:
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach()
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach() # b, t_text, T_mel
# Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper
@@ -236,4 +241,4 @@ class MatchaTTS(BaseLightningClass): # 🍵
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss
return dur_loss, prior_loss, diff_loss, attn

View File

@@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names):
for symbol in clean_text:
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]
return sequence
return sequence, clean_text
def cleaned_text_to_sequence(cleaned_text):

View File

@@ -15,7 +15,6 @@ import logging
import re
import phonemizer
import piper_phonemize
from unidecode import unidecode
# To avoid excessive logging we set the log level of the phonemizer package to Critical
@@ -106,11 +105,17 @@ def english_cleaners2(text):
return phonemes
def english_cleaners_piper(text):
"""Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
phonemes = collapse_whitespace(phonemes)
return phonemes
# I am removing this due to incompatibility with several version of python
# However, if you want to use it, you can uncomment it
# and install piper-phonemize with the following command:
# pip install piper-phonemize
# import piper_phonemize
# def english_cleaners_piper(text):
# """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
# text = convert_to_ascii(text)
# text = lowercase(text)
# text = expand_abbreviations(text)
# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
# phonemes = collapse_whitespace(phonemes)
# return phonemes

View File

@@ -94,6 +94,7 @@ def main():
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup()

View File

@@ -0,0 +1,195 @@
r"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
import sys
from pathlib import Path
import lightning
import numpy as np
import rootutils
import torch
from hydra import compose, initialize
from omegaconf import open_dict
from torch import nn
from tqdm.auto import tqdm
from matcha.cli import get_device
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS
from matcha.utils.logging_utils import pylogger
from matcha.utils.utils import get_phoneme_durations
log = pylogger.get_pylogger(__name__)
def save_durations_to_folder(
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
):
durations = attn.squeeze().sum(1)[:x_length].numpy()
durations_json = get_phoneme_durations(durations, text)
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
json.dump(durations_json, f, indent=4, ensure_ascii=False)
np.save(output, durations)
@torch.inference_mode()
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
"""Generate durations from the model for each datapoint and save it in a folder
Args:
data_loader (torch.utils.data.DataLoader): Dataloader
model (nn.Module): MatchaTTS model
device (torch.device): GPU or CPU
"""
for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
x = x.to(device)
y = y.to(device)
x_lengths = x_lengths.to(device)
y_lengths = y_lengths.to(device)
spks = spks.to(device) if spks is not None else None
_, _, _, attn = model(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
)
attn = attn.cpu()
for i in range(attn.shape[0]):
save_durations_to_folder(
attn[i],
x_lengths[i].item(),
y_lengths[i].item(),
batch["filepaths"][i],
output_folder,
batch["x_texts"][i],
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-config",
type=str,
default="ljspeech.yaml",
help="The name of the yaml config file under configs/data",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default="32",
help="Can have increased batch size for faster computation",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
required=False,
help="force overwrite the file",
)
parser.add_argument(
"-c",
"--checkpoint_path",
type=str,
required=True,
help="Path to the checkpoint file to load the model from",
)
parser.add_argument(
"-o",
"--output-folder",
type=str,
default=None,
help="Output folder to save the data statistics",
)
parser.add_argument(
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
)
args = parser.parse_args()
with initialize(version_base="1.3", config_path="../../configs/data"):
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
with open_dict(cfg):
del cfg["hydra"]
del cfg["_target_"]
cfg["seed"] = 1234
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
if args.output_folder is not None:
output_folder = Path(args.output_folder)
else:
output_folder = Path(cfg["train_filelist_path"]).parent / "durations"
print(f"Output folder set to: {output_folder}")
if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite")
sys.exit(1)
output_folder.mkdir(parents=True, exist_ok=True)
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
print("Loading model...")
device = get_device(args)
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup()
try:
print("Computing stats for training set if exists...")
train_dataloader = text_mel_datamodule.train_dataloader()
compute_durations(train_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No training set found")
try:
print("Computing stats for validation set if exists...")
val_dataloader = text_mel_datamodule.val_dataloader()
compute_durations(val_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No validation set found")
try:
print("Computing stats for test set if exists...")
test_dataloader = text_mel_datamodule.test_dataloader()
compute_durations(test_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No test set found")
print(f"[+] Done! Data statistics saved to: {output_folder}")
if __name__ == "__main__":
# Helps with generating durations for the dataset to train other architectures
# that cannot learn to align due to limited size of dataset
# Example usage:
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
# This will create a folder in data/processed_data/durations/ljspeech with the durations
main()

View File

@@ -2,6 +2,7 @@ import os
import sys
import warnings
from importlib.util import find_spec
from math import ceil
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else:
wget.download(url=url, out=checkpoint_path)
def get_phoneme_durations(durations, phones):
prev = durations[0]
merged_durations = []
# Convolve with stride 2
for i in range(1, len(durations), 2):
if i == len(durations) - 2:
# if it is last take full value
next_half = durations[i + 1]
else:
next_half = ceil(durations[i + 1] / 2)
curr = prev + durations[i] + next_half
prev = durations[i + 1] - next_half
merged_durations.append(curr)
assert len(phones) == len(merged_durations)
assert len(merged_durations) == (len(durations) - 1) // 2
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
start = torch.tensor(0)
duration_json = []
for i, duration in enumerate(merged_durations):
duration_json.append(
{
phones[i]: {
"starttime": start.item(),
"endtime": duration.item(),
"duration": duration.item() - start.item(),
}
}
)
start = duration
assert list(duration_json[-1].values())[0]["endtime"] == sum(
durations
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
return duration_json

View File

@@ -35,11 +35,10 @@ torchaudio
matplotlib
pandas
conformer==0.3.2
diffusers==0.27.2
diffusers==0.25.0
notebook
ipywidgets
gradio
gradio==3.43.2
gdown
wget
seaborn
piper_phonemize

View File

@@ -38,6 +38,7 @@ setup(
"matcha-data-stats=matcha.utils.generate_data_statistics:main",
"matcha-tts=matcha.cli:cli",
"matcha-tts-app=matcha.app:main",
"matcha-tts-get-durations=matcha.utils.get_durations_from_trained_model:main",
]
},
ext_modules=cythonize(exts, language_level=3),

View File

@@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "148f4bc0-c28e-4670-9a5e-4c7928ab8992",
"metadata": {},
"outputs": [
@@ -192,7 +192,7 @@
"source": [
"@torch.inference_mode()\n",
"def process_text(text: str):\n",
" x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2']), 0),dtype=torch.long, device=device)[None]\n",
" x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2'])[0], 0),dtype=torch.long, device=device)[None]\n",
" x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)\n",
" x_phones = sequence_to_text(x.squeeze(0).tolist())\n",
" return {\n",