11 Commits

Author SHA1 Message Date
Shivam Mehta
4c35836fa5 minor fix 2024-03-02 12:48:54 +00:00
Shivam Mehta
294c6b1327 Adding saving phones while getting durations from matcha 2024-03-02 12:47:08 +00:00
Shivam Mehta
ad76016916 Fixing configs 2024-02-26 09:11:22 +00:00
Shivam Mehta
05c8f9b4a8 updating configs and experiments 2024-02-25 22:02:36 +00:00
Shivam Mehta
4d5b62cea9 Adding a bit of comments 2024-02-24 15:20:13 +00:00
Shivam Mehta
8e87111a98 Adding possibility of getting durations out 2024-02-24 15:10:19 +00:00
Shivam Mehta
def0855608 Adding other experiment configs 2024-01-22 11:46:08 +00:00
Shivam Mehta
6976a91348 Merge branch 'main' into stoc_dur 2024-01-12 11:58:41 +00:00
Shivam Mehta
458e9df236 Adding synthesis 2024-01-10 11:05:22 +00:00
Shivam Mehta
d03bba82bb In the middle of adding discrete nf based duration predictor 2024-01-10 11:04:46 +00:00
Shivam Mehta
a58bab5403 Adding option to do flow matching based duration prediction 2024-01-05 11:13:07 +00:00
48 changed files with 1972 additions and 566 deletions

1
.gitignore vendored
View File

@@ -161,4 +161,3 @@ generator_v1
g_02500000
gradio_cached_examples/
synth_output/
/data

View File

@@ -1,5 +1,5 @@
default_language_version:
python: python3.11
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks

View File

@@ -10,7 +10,7 @@
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
[![PyPI Downloads](https://static.pepy.tech/personalized-badge/matcha-tts?period=total&units=INTERNATIONAL_SYSTEM&left_color=BLACK&right_color=GREEN&left_text=downloads)](https://pepy.tech/projects/matcha-tts)
<p style="text-align: center;">
<img src="https://shivammehta25.github.io/Matcha-TTS/images/logo.png" height="128"/>
</p>
@@ -252,43 +252,6 @@ python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vo
This will write `.wav` audio files to the output directory.
## Extract phoneme alignments from Matcha-TTS
If the dataset is structured as
```bash
data/
└── LJSpeech-1.1
├── metadata.csv
├── README
├── test.txt
├── train.txt
├── val.txt
└── wavs
```
Then you can extract the phoneme level alignments from a Trained Matcha-TTS model using:
```bash
python matcha/utils/get_durations_from_trained_model.py -i dataset_yaml -c <checkpoint>
```
Example:
```bash
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c matcha_ljspeech.ckpt
```
or simply:
```bash
matcha-tts-get-durations -i ljspeech.yaml -c matcha_ljspeech.ckpt
```
---
## Train using extracted alignments
In the datasetconfig turn on load duration.
Example: `ljspeech.yaml`
```
load_durations: True
```
or see an examples in configs/experiment/ljspeech_from_durations.yaml
## Citation information
If you use our code or otherwise find this work useful, please cite our paper:

View File

@@ -5,8 +5,8 @@ defaults:
# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: hi-fi_en-US_female
train_filelist_path: data/hi-fi_en-US_female/train.txt
valid_filelist_path: data/hi-fi_en-US_female/val.txt
train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt
valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt
batch_size: 32
cleaners: [english_cleaners_piper]
data_statistics: # Computed for this dataset

View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: joe_spont_only
train_filelist_path: data/filelists/joe_spontonly_train.txt
valid_filelist_path: data/filelists/joe_spontonly_val.txt
data_statistics:
mel_mean: -5.882903
mel_std: 2.458284

View File

@@ -1,7 +1,7 @@
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: ljspeech
train_filelist_path: data/LJSpeech-1.1/train.txt
valid_filelist_path: data/LJSpeech-1.1/val.txt
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt
batch_size: 32
num_workers: 20
pin_memory: True
@@ -19,4 +19,3 @@ data_statistics: # Computed for ljspeech dataset
mel_mean: -5.536622
mel_std: 2.116101
seed: ${seed}
load_durations: false

10
configs/data/ryan.yaml Normal file
View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: ryan
train_filelist_path: data/filelists/ryan_train.csv
valid_filelist_path: data/filelists/ryan_val.csv
data_statistics:
mel_mean: -4.715779
mel_std: 2.124502

10
configs/data/tsg2.yaml Normal file
View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: tsg2
train_filelist_path: data/filelists/cormac_train.txt
valid_filelist_path: data/filelists/cormac_val.txt
data_statistics:
mel_mean: -5.536622
mel_std: 2.116101

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_det_dur

View File

@@ -0,0 +1,20 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_stoc_dur
model:
duration_predictor:
p_dropout: 0.2

View File

@@ -5,15 +5,12 @@
defaults:
- override /data: ljspeech.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech
data:
load_durations: True
batch_size: 64

View File

@@ -0,0 +1,18 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_det_dur
trainer:
max_epochs: 3000

View File

@@ -0,0 +1,24 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_stoc_dur
model:
duration_predictor:
p_dropout: 0.2
trainer:
max_epochs: 3000

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_det_dur

View File

@@ -0,0 +1,20 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_stoc_dur
model:
duration_predictor:
p_dropout: 0.5

View File

@@ -0,0 +1,7 @@
name: deterministic
n_spks: ${model.n_spks}
spk_emb_dim: ${model.spk_emb_dim}
filter_channels: 256
kernel_size: 3
n_channels: ${model.encoder.encoder_params.n_channels}
p_dropout: ${model.encoder.encoder_params.p_dropout}

View File

@@ -0,0 +1,7 @@
defaults:
- deterministic.yaml
- _self_
sigma_min: 1e-4
n_steps: 10
name: flow_matching

View File

@@ -3,16 +3,8 @@ encoder_params:
n_feats: ${model.n_feats}
n_channels: 192
filter_channels: 768
filter_channels_dp: 256
n_heads: 2
n_layers: 6
kernel_size: 3
p_dropout: 0.1
spk_emb_dim: 64
n_spks: 1
prenet: true
duration_predictor_params:
filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp}
kernel_size: 3
p_dropout: ${model.encoder.encoder_params.p_dropout}

View File

@@ -1,6 +1,7 @@
defaults:
- _self_
- encoder: default.yaml
- duration_predictor: deterministic.yaml
- decoder: default.yaml
- cfm: default.yaml
- optimizer: adam.yaml
@@ -13,4 +14,3 @@ n_feats: 80
data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4
prior_loss: true
use_precomputed_durations: ${data.load_durations}

1
data Symbolic link
View File

@@ -0,0 +1 @@
/home/smehta/Projects/Speech-Backbones/Grad-TTS/data

View File

@@ -1 +1 @@
0.0.7.2
0.0.5.1

View File

@@ -48,7 +48,7 @@ def plot_spectrogram_to_numpy(spectrogram, filename):
def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}")
x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
intersperse(text_to_sequence(text, ["english_cleaners2"]), 0),
dtype=torch.long,
device=device,
)[None]
@@ -114,10 +114,10 @@ def load_matcha(model_name, checkpoint_path, device):
return model
def to_waveform(mel, vocoder, denoiser=None, denoiser_strength=0.00025):
def to_waveform(mel, vocoder, denoiser=None):
audio = vocoder(mel).clamp(-1, 1)
if denoiser is not None:
audio = denoiser(audio.squeeze(), strength=denoiser_strength).cpu().squeeze()
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
@@ -227,7 +227,7 @@ def cli():
parser.add_argument(
"--vocoder",
type=str,
default=None,
default="hifigan_univ_v1",
help="Vocoder to use (default: will use the one suggested with the pretrained model))",
choices=VOCODER_URLS.keys(),
)
@@ -326,17 +326,16 @@ def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
for i, batch in enumerate(dataloader):
i = i + 1
start_t = dt.datetime.now()
b = batch["x"].shape[0]
output = model.synthesise(
batch["x"].to(device),
batch["x_lengths"].to(device),
n_timesteps=args.steps,
temperature=args.temperature,
spks=spk.expand(b) if spk is not None else spk,
spks=spk,
length_scale=args.speaking_rate,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])
print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}")
@@ -377,7 +376,7 @@ def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
spks=spk,
length_scale=args.speaking_rate,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
# RTF with HiFiGAN
t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])

View File

@@ -1,8 +1,6 @@
import random
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
import torch
import torchaudio as ta
from lightning import LightningDataModule
@@ -41,7 +39,6 @@ class TextMelDataModule(LightningDataModule):
f_max,
data_statistics,
seed,
load_durations,
):
super().__init__()
@@ -71,7 +68,6 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
self.hparams.load_durations,
)
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.valid_filelist_path,
@@ -87,7 +83,6 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
self.hparams.load_durations,
)
def train_dataloader(self):
@@ -139,7 +134,6 @@ class TextMelDataset(torch.utils.data.Dataset):
f_max=8000,
data_parameters=None,
seed=None,
load_durations=False,
):
self.filepaths_and_text = parse_filelist(filelist_path)
self.n_spks = n_spks
@@ -152,8 +146,6 @@ class TextMelDataset(torch.utils.data.Dataset):
self.win_length = win_length
self.f_min = f_min
self.f_max = f_max
self.load_durations = load_durations
if data_parameters is not None:
self.data_parameters = data_parameters
else:
@@ -175,26 +167,7 @@ class TextMelDataset(torch.utils.data.Dataset):
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath)
durations = self.get_durations(filepath, text) if self.load_durations else None
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
def get_durations(self, filepath, text):
filepath = Path(filepath)
data_dir, name = filepath.parent.parent, filepath.stem
try:
dur_loc = data_dir / "durations" / f"{name}.npy"
durs = torch.from_numpy(np.load(dur_loc).astype(int))
except FileNotFoundError as e:
raise FileNotFoundError(
f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
) from e
assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
return durs
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
def get_mel(self, filepath):
audio, sr = ta.load(filepath)
@@ -241,8 +214,6 @@ class TextMelBatchCollate:
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long)
durations = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], []
spks = []
filepaths, x_texts = [], []
@@ -255,8 +226,6 @@ class TextMelBatchCollate:
spks.append(item["spk"])
filepaths.append(item["filepath"])
x_texts.append(item["x_text"])
if item["durations"] is not None:
durations[i, : item["durations"].shape[-1]] = item["durations"]
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
@@ -270,5 +239,4 @@ class TextMelBatchCollate:
"spks": spks,
"filepaths": filepaths,
"x_texts": x_texts,
"durations": durations if not torch.eq(durations, 0).all() else None,
}

View File

@@ -4,10 +4,6 @@
import torch
class ModeException(Exception):
pass
class Denoiser(torch.nn.Module):
"""Removes model bias from audio produced with waveglow"""
@@ -24,7 +20,7 @@ class Denoiser(torch.nn.Module):
elif mode == "normal":
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
else:
raise ModeException(f"Mode {mode} if not supported")
raise Exception(f"Mode {mode} if not supported")
def stft_fn(audio, n_fft, hop_length, win_length, window):
spec = torch.stft(

View File

@@ -55,7 +55,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
global mel_basis, hann_window # pylint: disable=global-statement
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

@@ -1,7 +1,7 @@
""" from https://github.com/jik876/hifi-gan """
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm

View File

@@ -65,7 +65,6 @@ class BaseLightningClass(LightningModule, ABC):
y_lengths=y_lengths,
spks=spks,
out_size=self.out_size,
durations=batch["durations"],
)
return {
"dur_loss": dur_loss,

View File

@@ -2,7 +2,7 @@ import math
from typing import Optional
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn as nn
import torch.nn.functional as F
from conformer import ConformerBlock
from diffusers.models.activations import get_activation

View File

@@ -0,0 +1,448 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack
from matcha.models.components.decoder import SinusoidalPosEmb, TimestepEmbedding
from matcha.models.components.text_encoder import LayerNorm
# Define available networks
class DurationPredictorNetwork(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class DurationPredictorNetworkWithTimeStep(nn.Module):
"""Similar architecture but with a time embedding support"""
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.time_embeddings = SinusoidalPosEmb(filter_channels)
self.time_mlp = TimestepEmbedding(
in_channels=filter_channels,
time_embed_dim=filter_channels,
act_fn="silu",
)
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask, enc_outputs, t):
t = self.time_embeddings(t)
t = self.time_mlp(t).unsqueeze(-1)
x = pack([x, enc_outputs], "b * t")[0]
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
# Define available methods to compute loss
# Simple MSE deterministic
class DeterministicDurationPredictor(nn.Module):
def __init__(self, params):
super().__init__()
self.estimator = DurationPredictorNetwork(
params.n_channels + (params.spk_emb_dim if params.n_spks > 1 else 0),
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
@torch.inference_mode()
def forward(self, x, x_mask):
return self.estimator(x, x_mask)
def compute_loss(self, durations, enc_outputs, x_mask):
return F.mse_loss(self.estimator(enc_outputs, x_mask), durations, reduction="sum") / torch.sum(x_mask)
# Flow Matching duration predictor
class FlowMatchingDurationPrediction(nn.Module):
def __init__(self, params) -> None:
super().__init__()
self.estimator = DurationPredictorNetworkWithTimeStep(
1
+ params.n_channels
+ (
params.spk_emb_dim if params.n_spks > 1 else 0
), # 1 for the durations and n_channels for encoder outputs
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
self.sigma_min = params.sigma_min
self.n_steps = params.n_steps
@torch.inference_mode()
def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
if n_timesteps is None:
n_timesteps = self.n_steps
b, _, t = enc_outputs.shape
z = torch.randn((b, 1, t), device=enc_outputs.device, dtype=enc_outputs.dtype) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=enc_outputs.device)
return self.solve_euler(z, t_span=t_span, enc_outputs=enc_outputs, mask=mask)
def solve_euler(self, x, t_span, enc_outputs, mask):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, enc_outputs, t)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, enc_outputs, mask):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
enc_outputs = enc_outputs.detach() # don't update encoder from the duration predictor
b, _, t = enc_outputs.shape
# random timestep
t = torch.rand([b, 1, 1], device=enc_outputs.device, dtype=enc_outputs.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, enc_outputs, t.squeeze()), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss
# VITS discrete normalising flow based duration predictor
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = Log()
self.flows = nn.ModuleList()
self.flows.append(ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
# Meta class to wrap all duration predictors
class DP(nn.Module):
def __init__(self, params):
super().__init__()
self.name = params.name
if params.name == "deterministic":
self.dp = DeterministicDurationPredictor(
params,
)
elif params.name == "flow_matching":
self.dp = FlowMatchingDurationPrediction(
params,
)
else:
raise ValueError(f"Invalid duration predictor configuration: {params.name}")
@torch.inference_mode()
def forward(self, enc_outputs, mask):
return self.dp(enc_outputs, mask)
def compute_loss(self, durations, enc_outputs, mask):
return self.dp.compute_loss(durations, enc_outputs, mask)

View File

@@ -3,10 +3,10 @@
import math
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn as nn
from einops import rearrange
import matcha.utils as utils # pylint: disable=consider-using-from-import
import matcha.utils as utils
from matcha.utils.model import sequence_mask
log = utils.get_pylogger(__name__)
@@ -67,33 +67,6 @@ class ConvReluNorm(nn.Module):
return x * x_mask
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class RotaryPositionalEmbeddings(nn.Module):
"""
## RoPE module
@@ -330,7 +303,6 @@ class TextEncoder(nn.Module):
self,
encoder_type,
encoder_params,
duration_predictor_params,
n_vocab,
n_spks=1,
spk_emb_dim=128,
@@ -368,12 +340,6 @@ class TextEncoder(nn.Module):
)
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
self.proj_w = DurationPredictor(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
duration_predictor_params.filter_channels_dp,
duration_predictor_params.kernel_size,
duration_predictor_params.p_dropout,
)
def forward(self, x, x_lengths, spks=None):
"""Run forward pass to the transformer based encoder and duration predictor
@@ -404,7 +370,7 @@ class TextEncoder(nn.Module):
x = self.encoder(x, x_mask)
mu = self.proj_m(x) * x_mask
x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask)
# x_dp = torch.detach(x)
# logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask
return mu, x, x_mask

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn as nn
from diffusers.models.attention import (
GEGLU,
GELU,

View File

@@ -7,11 +7,11 @@ import torch
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.duration_predictors import DP
from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder
from matcha.utils.model import (
denormalize,
duration_loss,
fix_len_compatibility,
generate_path,
sequence_mask,
@@ -28,6 +28,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
spk_emb_dim,
n_feats,
encoder,
duration_predictor,
decoder,
cfm,
data_statistics,
@@ -35,7 +36,6 @@ class MatchaTTS(BaseLightningClass): # 🍵
optimizer=None,
scheduler=None,
prior_loss=True,
use_precomputed_durations=False,
):
super().__init__()
@@ -47,7 +47,6 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.n_feats = n_feats
self.out_size = out_size
self.prior_loss = prior_loss
self.use_precomputed_durations = use_precomputed_durations
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
@@ -55,12 +54,13 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.encoder = TextEncoder(
encoder.encoder_type,
encoder.encoder_params,
encoder.duration_predictor_params,
n_vocab,
n_spks,
spk_emb_dim,
)
self.dp = DP(duration_predictor)
self.decoder = CFM(
in_channels=2 * encoder.encoder_params.n_feats,
out_channel=encoder.encoder_params.n_feats,
@@ -106,7 +106,6 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Lengths of mel spectrograms
"rtf": float,
# Real-time factor
}
"""
# For RTF computation
t = dt.datetime.now()
@@ -115,11 +114,15 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Get speaker embedding
spks = self.spk_emb(spks.long())
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
# Get encoder_outputs `mu_x` and encoded text `enc_output`
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
# Get log-scaled token durations `logw`
logw = self.dp(enc_output, x_mask)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
# print(w_ceil)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max()
y_max_length_ = fix_len_compatibility(y_max_length)
@@ -150,10 +153,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
"rtf": rtf,
}
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotonic Alignment Search (MAS).
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs.
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
@@ -176,31 +179,27 @@ class MatchaTTS(BaseLightningClass): # 🍵
spks = self.spk_emb(spks)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
y_max_length = y.shape[-1]
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
if self.use_precomputed_durations:
attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
else:
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach() # b, t_text, T_mel
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach()
# Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
dur_loss = self.dp.compute_loss(logw_, enc_output, x_mask)
# Cut a small segment of mel-spectrogram in order to increase batch size
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it

View File

@@ -7,10 +7,6 @@ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
class UnknownCleanerException(Exception):
pass
def text_to_sequence(text, cleaner_names):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
@@ -52,6 +48,6 @@ def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise UnknownCleanerException(f"Unknown cleaner: {name}")
raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text)
return text

View File

@@ -15,6 +15,7 @@ import logging
import re
import phonemizer
import piper_phonemize
from unidecode import unidecode
# To avoid excessive logging we set the log level of the phonemizer package to Critical
@@ -36,12 +37,9 @@ global_phonemizer = phonemizer.backend.EspeakBackend(
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")
# Remove brackets
_brackets_re = re.compile(r"[\[\]\(\)\{\}]")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
@@ -75,10 +73,6 @@ def lowercase(text):
return text.lower()
def remove_brackets(text):
return re.sub(_brackets_re, "", text)
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
@@ -108,37 +102,15 @@ def english_cleaners2(text):
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
# Added in some cases espeak is not removing brackets
phonemes = remove_brackets(phonemes)
phonemes = collapse_whitespace(phonemes)
return phonemes
def ipa_simplifier(text):
replacements = [
("ɐ", "ə"),
("ˈə", "ə"),
("ʤ", ""),
("ʧ", ""),
("", "ɪ"),
]
for replacement in replacements:
text = text.replace(replacement[0], replacement[1])
phonemes = collapse_whitespace(text)
def english_cleaners_piper(text):
"""Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
phonemes = collapse_whitespace(phonemes)
return phonemes
# I am removing this due to incompatibility with several version of python
# However, if you want to use it, you can uncomment it
# and install piper-phonemize with the following command:
# pip install piper-phonemize
# import piper_phonemize
# def english_cleaners_piper(text):
# """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
# text = convert_to_ascii(text)
# text = lowercase(text)
# text = expand_abbreviations(text)
# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
# phonemes = collapse_whitespace(phonemes)
# return phonemes

View File

@@ -48,7 +48,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
global mel_basis, hann_window # pylint: disable=global-statement
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

@@ -1,148 +0,0 @@
#!/usr/bin/env python
import argparse
import os
import sys
import tempfile
from pathlib import Path
import torchaudio
from torch.hub import download_url_to_file
from tqdm import tqdm
from matcha.utils.data.utils import _extract_zip
URLS = {
"en-US": {
"female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_F.zip",
"male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_M.zip",
},
"ja-JP": {
"female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_F.zip",
"male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_M.zip",
},
}
INFO_PAGE = "https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/"
# On their website they say "We NICT open-sourced Hi-Fi-CAPTAIN",
# but they use this very-much-not-open-source licence.
# Dunno if this is open washing or stupidity.
LICENCE = "CC BY-NC-SA 4.0"
# I'd normally put the citation here. It's on their website.
# Boo to non-open-source stuff.
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files")
parser.add_argument(
"-r",
"--skip-resampling",
action="store_true",
default=False,
help="Skip resampling the data (from 48 to 22.05)",
)
parser.add_argument(
"-l", "--language", type=str, choices=["en-US", "ja-JP"], default="en-US", help="The language to download"
)
parser.add_argument(
"-g",
"--gender",
type=str,
choices=["male", "female"],
default="female",
help="The gender of the speaker to download",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
default="data",
help="Place to store the converted data. Top-level only, the subdirectory will be created",
)
return parser.parse_args()
def process_text(infile, outpath: Path):
outmode = "w"
if infile.endswith("dev.txt"):
outfile = outpath / "valid.txt"
elif infile.endswith("eval.txt"):
outfile = outpath / "test.txt"
else:
outfile = outpath / "train.txt"
if outfile.exists():
outmode = "a"
with (
open(infile, encoding="utf-8") as inf,
open(outfile, outmode, encoding="utf-8") as of,
):
for line in inf.readlines():
line = line.strip()
fileid, rest = line.split(" ", maxsplit=1)
outfile = str(outpath / f"{fileid}.wav")
of.write(f"{outfile}|{rest}\n")
def process_files(zipfile, outpath, resample=True):
with tempfile.TemporaryDirectory() as tmpdirname:
for filename in tqdm(_extract_zip(zipfile, tmpdirname)):
if not filename.startswith(tmpdirname):
filename = os.path.join(tmpdirname, filename)
if filename.endswith(".txt"):
process_text(filename, outpath)
elif filename.endswith(".wav"):
filepart = filename.rsplit("/", maxsplit=1)[-1]
outfile = str(outpath / filepart)
arr, sr = torchaudio.load(filename)
if resample:
arr = torchaudio.functional.resample(arr, orig_freq=sr, new_freq=22050)
torchaudio.save(outfile, arr, 22050)
else:
continue
def main():
args = get_args()
save_dir = None
if args.save_dir:
save_dir = Path(args.save_dir)
if not save_dir.is_dir():
save_dir.mkdir()
if not args.output_dir:
print("output directory not specified, exiting")
sys.exit(1)
URL = URLS[args.language][args.gender]
dirname = f"hi-fi_{args.language}_{args.gender}"
outbasepath = Path(args.output_dir)
if not outbasepath.is_dir():
outbasepath.mkdir()
outpath = outbasepath / dirname
if not outpath.is_dir():
outpath.mkdir()
resample = True
if args.skip_resampling:
resample = False
if save_dir:
zipname = URL.rsplit("/", maxsplit=1)[-1]
zipfile = save_dir / zipname
if not zipfile.exists():
download_url_to_file(URL, zipfile, progress=True)
process_files(zipfile, outpath, resample)
else:
with tempfile.NamedTemporaryFile(suffix=".zip", delete=True) as zf:
download_url_to_file(URL, zf.name, progress=True)
process_files(zf.name, outpath, resample)
if __name__ == "__main__":
main()

View File

@@ -1,97 +0,0 @@
#!/usr/bin/env python
import argparse
import random
import tempfile
from pathlib import Path
from torch.hub import download_url_to_file
from matcha.utils.data.utils import _extract_tar
URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
INFO_PAGE = "https://keithito.com/LJ-Speech-Dataset/"
LICENCE = "Public domain (LibriVox copyright disclaimer)"
CITATION = """
@misc{ljspeech17,
author = {Keith Ito and Linda Johnson},
title = {The LJ Speech Dataset},
howpublished = {\\url{https://keithito.com/LJ-Speech-Dataset/}},
year = 2017
}
"""
def decision():
return random.random() < 0.98
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files")
parser.add_argument(
"output_dir",
type=str,
nargs="?",
default="data",
help="Place to store the converted data (subdirectory LJSpeech-1.1 will be created)",
)
return parser.parse_args()
def process_csv(ljpath: Path):
if (ljpath / "metadata.csv").exists():
basepath = ljpath
elif (ljpath / "LJSpeech-1.1" / "metadata.csv").exists():
basepath = ljpath / "LJSpeech-1.1"
csvpath = basepath / "metadata.csv"
wavpath = basepath / "wavs"
with (
open(csvpath, encoding="utf-8") as csvf,
open(basepath / "train.txt", "w", encoding="utf-8") as tf,
open(basepath / "val.txt", "w", encoding="utf-8") as vf,
):
for line in csvf.readlines():
line = line.strip()
parts = line.split("|")
wavfile = str(wavpath / f"{parts[0]}.wav")
if decision():
tf.write(f"{wavfile}|{parts[1]}\n")
else:
vf.write(f"{wavfile}|{parts[1]}\n")
def main():
args = get_args()
save_dir = None
if args.save_dir:
save_dir = Path(args.save_dir)
if not save_dir.is_dir():
save_dir.mkdir()
outpath = Path(args.output_dir)
if not outpath.is_dir():
outpath.mkdir()
if save_dir:
tarname = URL.rsplit("/", maxsplit=1)[-1]
tarfile = save_dir / tarname
if not tarfile.exists():
download_url_to_file(URL, str(tarfile), progress=True)
_extract_tar(tarfile, outpath)
process_csv(outpath)
else:
with tempfile.NamedTemporaryFile(suffix=".tar.bz2", delete=True) as zf:
download_url_to_file(URL, zf.name, progress=True)
_extract_tar(zf.name, outpath)
process_csv(outpath)
if __name__ == "__main__":
main()

View File

@@ -1,53 +0,0 @@
# taken from https://github.com/pytorch/audio/blob/main/src/torchaudio/datasets/utils.py
# Copyright (c) 2017 Facebook Inc. (Soumith Chintala)
# Licence: BSD 2-Clause
# pylint: disable=C0123
import logging
import os
import tarfile
import zipfile
from pathlib import Path
from typing import Any, List, Optional, Union
_LG = logging.getLogger(__name__)
def _extract_tar(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if type(from_path) is Path:
from_path = str(Path)
if to_path is None:
to_path = os.path.dirname(from_path)
with tarfile.open(from_path, "r") as tar:
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
return files
def _extract_zip(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if type(from_path) is Path:
from_path = str(Path)
if to_path is None:
to_path = os.path.dirname(from_path)
with zipfile.ZipFile(from_path, "r") as zfile:
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
return files

View File

@@ -94,7 +94,6 @@ def main():
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup()
@@ -102,8 +101,10 @@ def main():
log.info("Dataloader loaded! Now computing stats...")
params = compute_data_statistics(data_loader, cfg["n_feats"])
print(params)
with open(output_file, "w", encoding="utf-8") as dumpfile:
json.dump(params, dumpfile)
json.dump(
params,
open(output_file, "w"),
)
if __name__ == "__main__":

View File

@@ -86,7 +86,7 @@ def main():
"-i",
"--input-config",
type=str,
default="ljspeech.yaml",
default="vctk.yaml",
help="The name of the yaml config file under configs/data",
)
@@ -140,14 +140,11 @@ def main():
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
if args.output_folder is not None:
output_folder = Path(args.output_folder)
else:
output_folder = Path(cfg["train_filelist_path"]).parent / "durations"
print(f"Output folder set to: {output_folder}")
output_folder = Path("data") / "processed_data" / cfg["name"] / "durations"
if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite")

View File

@@ -72,7 +72,7 @@ def print_config_tree(
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
rich.print(tree, file=file)
@@ -97,5 +97,5 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
rich.print(cfg.tags, file=file)

View File

@@ -35,10 +35,11 @@ torchaudio
matplotlib
pandas
conformer==0.3.2
diffusers # developed using version ==0.25.0
diffusers==0.25.0
notebook
ipywidgets
gradio==3.43.2
gradio
gdown
wget
seaborn
piper_phonemize

15
scripts/get_durations.sh Normal file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
echo "Starting script"
echo "Getting LJ Speech durations"
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c logs/train/lj_det/runs/2024-01-12_12-05-00/checkpoints/last.ckpt -f
echo "Getting TSG2 durations"
python matcha/utils/get_durations_from_trained_model.py -i tsg2.yaml -c logs/train/tsg2_det_dur/runs/2024-01-05_12-33-35/checkpoints/last.ckpt -f
echo "Getting Joe Spont durations"
python matcha/utils/get_durations_from_trained_model.py -i joe_spont_only.yaml -c logs/train/joe_det_dur/runs/2024-02-20_14-01-01/checkpoints/last.ckpt -f
echo "Getting Ryan durations"
python matcha/utils/get_durations_from_trained_model.py -i ryan.yaml -c logs/train/matcha_ryan_det/runs/2024-02-26_09-28-09/checkpoints/last.ckpt -f

7
scripts/transcribe.sh Normal file
View File

@@ -0,0 +1,7 @@
echo "Transcribing"
whispertranscriber -i lj_det_output -o lj_det_output_transcriptions -f
whispertranscriber -i lj_fm_output -o lj_fm_output_transcriptions -f
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_det_output_transcriptions
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_fm_output_transcriptions

30
scripts/wer_computer.sh Normal file
View File

@@ -0,0 +1,30 @@
#!/bin/bash
# Run from root folder with: bash scripts/wer_computer.sh
root_folder=${1:-"dur_wer_computation"}
echo "Running WER computation for Duration predictors"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_fm_output_transcriptions/"
# echo $cmd
echo "LJ"
echo "==================================="
echo "Flow Matching"
$cmd
echo "-----------------------------------"
echo "LJ Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_det_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac"
echo "==================================="
echo "Cormac Flow Matching"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/fm_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/det_output_transcriptions/"
$cmd
echo "-----------------------------------"

View File

@@ -16,16 +16,9 @@ with open("README.md", encoding="utf-8") as readme_file:
README = readme_file.read()
cwd = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin:
with open(os.path.join(cwd, "matcha", "VERSION")) as fin:
version = fin.read().strip()
def get_requires():
requirements = os.path.join(os.path.dirname(__file__), "requirements.txt")
with open(requirements, encoding="utf-8") as reqfile:
return [str(r).strip() for r in reqfile]
setup(
name="matcha-tts",
version=version,
@@ -35,7 +28,7 @@ setup(
author="Shivam Mehta",
author_email="shivam.mehta25@gmail.com",
url="https://shivammehta25.github.io/Matcha-TTS",
install_requires=get_requires(),
install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))],
include_dirs=[numpy.get_include()],
include_package_data=True,
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
@@ -45,7 +38,6 @@ setup(
"matcha-data-stats=matcha.utils.generate_data_statistics:main",
"matcha-tts=matcha.cli:cli",
"matcha-tts-app=matcha.app:main",
"matcha-tts-get-durations=matcha.utils.get_durations_from_trained_model:main",
]
},
ext_modules=cythonize(exts, language_level=3),

File diff suppressed because one or more lines are too long