16 Commits

Author SHA1 Message Date
Shivam Mehta
4c35836fa5 minor fix 2024-03-02 12:48:54 +00:00
Shivam Mehta
294c6b1327 Adding saving phones while getting durations from matcha 2024-03-02 12:47:08 +00:00
Shivam Mehta
ad76016916 Fixing configs 2024-02-26 09:11:22 +00:00
Shivam Mehta
05c8f9b4a8 updating configs and experiments 2024-02-25 22:02:36 +00:00
Shivam Mehta
4d5b62cea9 Adding a bit of comments 2024-02-24 15:20:13 +00:00
Shivam Mehta
8e87111a98 Adding possibility of getting durations out 2024-02-24 15:10:19 +00:00
Shivam Mehta
def0855608 Adding other experiment configs 2024-01-22 11:46:08 +00:00
Shivam Mehta
6976a91348 Merge branch 'main' into stoc_dur 2024-01-12 11:58:41 +00:00
Shivam Mehta
256adc55d3 Adding ICASSP 2024 2024-01-12 11:31:01 +00:00
Shivam Mehta
bfcbdbc82e Merge pull request #43 from shivammehta25/dev
Removing gdown for HifiGAN checkpoints too
2024-01-12 12:29:03 +01:00
Shivam Mehta
47a629f128 Merge pull request #42 from shivammehta25/dev
Merging dev adding another dataset, piper phonemizer and refractoring
2024-01-12 11:49:53 +01:00
Shivam Mehta
5a2a893750 Merge pull request #19 from shivammehta25/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2024-01-12 11:47:10 +01:00
Shivam Mehta
458e9df236 Adding synthesis 2024-01-10 11:05:22 +00:00
Shivam Mehta
d03bba82bb In the middle of adding discrete nf based duration predictor 2024-01-10 11:04:46 +00:00
pre-commit-ci[bot]
dc035a09f2 [pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/pre-commit/pre-commit-hooks: v4.4.0 → v4.5.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.4.0...v4.5.0)
- [github.com/psf/black: 23.9.1 → 23.12.1](https://github.com/psf/black/compare/23.9.1...23.12.1)
- [github.com/PyCQA/isort: 5.12.0 → 5.13.2](https://github.com/PyCQA/isort/compare/5.12.0...5.13.2)
- [github.com/asottile/pyupgrade: v3.14.0 → v3.15.0](https://github.com/asottile/pyupgrade/compare/v3.14.0...v3.15.0)
- [github.com/PyCQA/flake8: 6.1.0 → 7.0.0](https://github.com/PyCQA/flake8/compare/6.1.0...7.0.0)
- [github.com/pycqa/pylint: v3.0.0 → v3.0.3](https://github.com/pycqa/pylint/compare/v3.0.0...v3.0.3)
2024-01-08 21:15:26 +00:00
Shivam Mehta
a58bab5403 Adding option to do flow matching based duration prediction 2024-01-05 11:13:07 +00:00
30 changed files with 2189 additions and 99 deletions

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
@@ -18,28 +18,28 @@ repos:
# python code formatting
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.12.1
hooks:
- id: black
args: [--line-length, "120"]
# python import sorting
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
rev: v3.14.0
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py38-plus]
# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
rev: 7.0.0
hooks:
- id: flake8
args:
@@ -54,6 +54,6 @@ repos:
# pylint
- repo: https://github.com/pycqa/pylint
rev: v3.0.0
rev: v3.0.3
hooks:
- id: pylint

View File

@@ -17,7 +17,7 @@
</div>
> This is the official code implementation of 🍵 Matcha-TTS.
> This is the official code implementation of 🍵 Matcha-TTS [ICASSP 2024].
We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses [conditional flow matching](https://arxiv.org/abs/2210.02747) (similar to [rectified flows](https://arxiv.org/abs/2209.03003)) to speed up ODE-based speech synthesis. Our method:

View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: joe_spont_only
train_filelist_path: data/filelists/joe_spontonly_train.txt
valid_filelist_path: data/filelists/joe_spontonly_val.txt
data_statistics:
mel_mean: -5.882903
mel_std: 2.458284

10
configs/data/ryan.yaml Normal file
View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: ryan
train_filelist_path: data/filelists/ryan_train.csv
valid_filelist_path: data/filelists/ryan_val.csv
data_statistics:
mel_mean: -4.715779
mel_std: 2.124502

10
configs/data/tsg2.yaml Normal file
View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: tsg2
train_filelist_path: data/filelists/cormac_train.txt
valid_filelist_path: data/filelists/cormac_val.txt
data_statistics:
mel_mean: -5.536622
mel_std: 2.116101

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_det_dur

View File

@@ -0,0 +1,20 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_stoc_dur
model:
duration_predictor:
p_dropout: 0.2

View File

@@ -0,0 +1,16 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ljspeech.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ljspeech"]
run_name: ljspeech

View File

@@ -0,0 +1,18 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_det_dur
trainer:
max_epochs: 3000

View File

@@ -0,0 +1,24 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_stoc_dur
model:
duration_predictor:
p_dropout: 0.2
trainer:
max_epochs: 3000

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_det_dur

View File

@@ -0,0 +1,20 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_stoc_dur
model:
duration_predictor:
p_dropout: 0.5

View File

@@ -0,0 +1,7 @@
name: deterministic
n_spks: ${model.n_spks}
spk_emb_dim: ${model.spk_emb_dim}
filter_channels: 256
kernel_size: 3
n_channels: ${model.encoder.encoder_params.n_channels}
p_dropout: ${model.encoder.encoder_params.p_dropout}

View File

@@ -0,0 +1,7 @@
defaults:
- deterministic.yaml
- _self_
sigma_min: 1e-4
n_steps: 10
name: flow_matching

View File

@@ -3,16 +3,8 @@ encoder_params:
n_feats: ${model.n_feats}
n_channels: 192
filter_channels: 768
filter_channels_dp: 256
n_heads: 2
n_layers: 6
kernel_size: 3
p_dropout: 0.1
spk_emb_dim: 64
n_spks: 1
prenet: true
duration_predictor_params:
filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp}
kernel_size: 3
p_dropout: ${model.encoder.encoder_params.p_dropout}

View File

@@ -1,6 +1,7 @@
defaults:
- _self_
- encoder: default.yaml
- duration_predictor: deterministic.yaml
- decoder: default.yaml
- cfm: default.yaml
- optimizer: adam.yaml

View File

@@ -227,7 +227,7 @@ def cli():
parser.add_argument(
"--vocoder",
type=str,
default=None,
default="hifigan_univ_v1",
help="Vocoder to use (default: will use the one suggested with the pretrained model))",
choices=VOCODER_URLS.keys(),
)

View File

@@ -109,7 +109,7 @@ class TextMelDataModule(LightningDataModule):
"""Clean up after fit or test."""
pass # pylint: disable=unnecessary-pass
def state_dict(self): # pylint: disable=no-self-use
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
@@ -164,10 +164,10 @@ class TextMelDataset(torch.utils.data.Dataset):
filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None
text = self.get_text(text, add_blank=self.add_blank)
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk}
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
def get_mel(self, filepath):
audio, sr = ta.load(filepath)
@@ -187,11 +187,11 @@ class TextMelDataset(torch.utils.data.Dataset):
return mel
def get_text(self, text, add_blank=True):
text_norm = text_to_sequence(text, self.cleaners)
text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
if self.add_blank:
text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm)
return text_norm
return text_norm, cleaned_text
def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index])
@@ -207,15 +207,16 @@ class TextMelBatchCollate:
def __call__(self, batch):
B = len(batch)
y_max_length = max([item["y"].shape[-1] for item in batch])
y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
y_max_length = fix_len_compatibility(y_max_length)
x_max_length = max([item["x"].shape[-1] for item in batch])
x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
n_feats = batch[0]["y"].shape[-2]
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], []
spks = []
filepaths, x_texts = [], []
for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1])
@@ -223,9 +224,19 @@ class TextMelBatchCollate:
y[i, :, : y_.shape[-1]] = y_
x[i, : x_.shape[-1]] = x_
spks.append(item["spk"])
filepaths.append(item["filepath"])
x_texts.append(item["x_text"])
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks}
return {
"x": x,
"x_lengths": x_lengths,
"y": y,
"y_lengths": y_lengths,
"spks": spks,
"filepaths": filepaths,
"x_texts": x_texts,
}

View File

@@ -58,7 +58,7 @@ class BaseLightningClass(LightningModule, ABC):
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
dur_loss, prior_loss, diff_loss = self(
dur_loss, prior_loss, diff_loss, *_ = self(
x=x,
x_lengths=x_lengths,
y=y,

View File

@@ -0,0 +1,448 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack
from matcha.models.components.decoder import SinusoidalPosEmb, TimestepEmbedding
from matcha.models.components.text_encoder import LayerNorm
# Define available networks
class DurationPredictorNetwork(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class DurationPredictorNetworkWithTimeStep(nn.Module):
"""Similar architecture but with a time embedding support"""
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.time_embeddings = SinusoidalPosEmb(filter_channels)
self.time_mlp = TimestepEmbedding(
in_channels=filter_channels,
time_embed_dim=filter_channels,
act_fn="silu",
)
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask, enc_outputs, t):
t = self.time_embeddings(t)
t = self.time_mlp(t).unsqueeze(-1)
x = pack([x, enc_outputs], "b * t")[0]
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
# Define available methods to compute loss
# Simple MSE deterministic
class DeterministicDurationPredictor(nn.Module):
def __init__(self, params):
super().__init__()
self.estimator = DurationPredictorNetwork(
params.n_channels + (params.spk_emb_dim if params.n_spks > 1 else 0),
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
@torch.inference_mode()
def forward(self, x, x_mask):
return self.estimator(x, x_mask)
def compute_loss(self, durations, enc_outputs, x_mask):
return F.mse_loss(self.estimator(enc_outputs, x_mask), durations, reduction="sum") / torch.sum(x_mask)
# Flow Matching duration predictor
class FlowMatchingDurationPrediction(nn.Module):
def __init__(self, params) -> None:
super().__init__()
self.estimator = DurationPredictorNetworkWithTimeStep(
1
+ params.n_channels
+ (
params.spk_emb_dim if params.n_spks > 1 else 0
), # 1 for the durations and n_channels for encoder outputs
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
self.sigma_min = params.sigma_min
self.n_steps = params.n_steps
@torch.inference_mode()
def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
if n_timesteps is None:
n_timesteps = self.n_steps
b, _, t = enc_outputs.shape
z = torch.randn((b, 1, t), device=enc_outputs.device, dtype=enc_outputs.dtype) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=enc_outputs.device)
return self.solve_euler(z, t_span=t_span, enc_outputs=enc_outputs, mask=mask)
def solve_euler(self, x, t_span, enc_outputs, mask):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, enc_outputs, t)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, enc_outputs, mask):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
enc_outputs = enc_outputs.detach() # don't update encoder from the duration predictor
b, _, t = enc_outputs.shape
# random timestep
t = torch.rand([b, 1, 1], device=enc_outputs.device, dtype=enc_outputs.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, enc_outputs, t.squeeze()), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss
# VITS discrete normalising flow based duration predictor
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = Log()
self.flows = nn.ModuleList()
self.flows.append(ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
# Meta class to wrap all duration predictors
class DP(nn.Module):
def __init__(self, params):
super().__init__()
self.name = params.name
if params.name == "deterministic":
self.dp = DeterministicDurationPredictor(
params,
)
elif params.name == "flow_matching":
self.dp = FlowMatchingDurationPrediction(
params,
)
else:
raise ValueError(f"Invalid duration predictor configuration: {params.name}")
@torch.inference_mode()
def forward(self, enc_outputs, mask):
return self.dp(enc_outputs, mask)
def compute_loss(self, durations, enc_outputs, mask):
return self.dp.compute_loss(durations, enc_outputs, mask)

View File

@@ -67,33 +67,6 @@ class ConvReluNorm(nn.Module):
return x * x_mask
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class RotaryPositionalEmbeddings(nn.Module):
"""
## RoPE module
@@ -330,7 +303,6 @@ class TextEncoder(nn.Module):
self,
encoder_type,
encoder_params,
duration_predictor_params,
n_vocab,
n_spks=1,
spk_emb_dim=128,
@@ -368,12 +340,6 @@ class TextEncoder(nn.Module):
)
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
self.proj_w = DurationPredictor(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
duration_predictor_params.filter_channels_dp,
duration_predictor_params.kernel_size,
duration_predictor_params.p_dropout,
)
def forward(self, x, x_lengths, spks=None):
"""Run forward pass to the transformer based encoder and duration predictor
@@ -404,7 +370,7 @@ class TextEncoder(nn.Module):
x = self.encoder(x, x_mask)
mu = self.proj_m(x) * x_mask
x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask)
# x_dp = torch.detach(x)
# logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask
return mu, x, x_mask

View File

@@ -4,14 +4,14 @@ import random
import torch
import matcha.utils.monotonic_align as monotonic_align
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.duration_predictors import DP
from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder
from matcha.utils.model import (
denormalize,
duration_loss,
fix_len_compatibility,
generate_path,
sequence_mask,
@@ -28,6 +28,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
spk_emb_dim,
n_feats,
encoder,
duration_predictor,
decoder,
cfm,
data_statistics,
@@ -53,12 +54,13 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.encoder = TextEncoder(
encoder.encoder_type,
encoder.encoder_params,
encoder.duration_predictor_params,
n_vocab,
n_spks,
spk_emb_dim,
)
self.dp = DP(duration_predictor)
self.decoder = CFM(
in_channels=2 * encoder.encoder_params.n_feats,
out_channel=encoder.encoder_params.n_feats,
@@ -112,11 +114,15 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Get speaker embedding
spks = self.spk_emb(spks.long())
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
# Get encoder_outputs `mu_x` and encoded text `enc_output`
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
# Get log-scaled token durations `logw`
logw = self.dp(enc_output, x_mask)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
# print(w_ceil)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max()
y_max_length_ = fix_len_compatibility(y_max_length)
@@ -173,7 +179,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
spks = self.spk_emb(spks)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
y_max_length = y.shape[-1]
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
@@ -192,9 +198,8 @@ class MatchaTTS(BaseLightningClass): # 🍵
attn = attn.detach()
# Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
dur_loss = self.dp.compute_loss(logw_, enc_output, x_mask)
# Cut a small segment of mel-spectrogram in order to increase batch size
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
@@ -236,4 +241,4 @@ class MatchaTTS(BaseLightningClass): # 🍵
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss
return dur_loss, prior_loss, diff_loss, attn

View File

@@ -21,7 +21,7 @@ def text_to_sequence(text, cleaner_names):
for symbol in clean_text:
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]
return sequence
return sequence, clean_text
def cleaned_text_to_sequence(cleaned_text):

View File

@@ -0,0 +1,192 @@
r"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
import sys
from pathlib import Path
import lightning
import numpy as np
import rootutils
import torch
from hydra import compose, initialize
from omegaconf import open_dict
from torch import nn
from tqdm.auto import tqdm
from matcha.cli import get_device
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS
from matcha.utils.logging_utils import pylogger
from matcha.utils.utils import get_phoneme_durations
log = pylogger.get_pylogger(__name__)
def save_durations_to_folder(
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
):
durations = attn.squeeze().sum(1)[:x_length].numpy()
durations_json = get_phoneme_durations(durations, text)
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
json.dump(durations_json, f, indent=4, ensure_ascii=False)
np.save(output, durations)
@torch.inference_mode()
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
"""Generate durations from the model for each datapoint and save it in a folder
Args:
data_loader (torch.utils.data.DataLoader): Dataloader
model (nn.Module): MatchaTTS model
device (torch.device): GPU or CPU
"""
for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
x = x.to(device)
y = y.to(device)
x_lengths = x_lengths.to(device)
y_lengths = y_lengths.to(device)
spks = spks.to(device) if spks is not None else None
_, _, _, attn = model(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
)
attn = attn.cpu()
for i in range(attn.shape[0]):
save_durations_to_folder(
attn[i],
x_lengths[i].item(),
y_lengths[i].item(),
batch["filepaths"][i],
output_folder,
batch["x_texts"][i],
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-config",
type=str,
default="vctk.yaml",
help="The name of the yaml config file under configs/data",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default="32",
help="Can have increased batch size for faster computation",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
required=False,
help="force overwrite the file",
)
parser.add_argument(
"-c",
"--checkpoint_path",
type=str,
required=True,
help="Path to the checkpoint file to load the model from",
)
parser.add_argument(
"-o",
"--output-folder",
type=str,
default=None,
help="Output folder to save the data statistics",
)
parser.add_argument(
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
)
args = parser.parse_args()
with initialize(version_base="1.3", config_path="../../configs/data"):
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
with open_dict(cfg):
del cfg["hydra"]
del cfg["_target_"]
cfg["seed"] = 1234
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
if args.output_folder is not None:
output_folder = Path(args.output_folder)
else:
output_folder = Path("data") / "processed_data" / cfg["name"] / "durations"
if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite")
sys.exit(1)
output_folder.mkdir(parents=True, exist_ok=True)
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
print("Loading model...")
device = get_device(args)
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup()
try:
print("Computing stats for training set if exists...")
train_dataloader = text_mel_datamodule.train_dataloader()
compute_durations(train_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No training set found")
try:
print("Computing stats for validation set if exists...")
val_dataloader = text_mel_datamodule.val_dataloader()
compute_durations(val_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No validation set found")
try:
print("Computing stats for test set if exists...")
test_dataloader = text_mel_datamodule.test_dataloader()
compute_durations(test_dataloader, model, device, output_folder)
except lightning.fabric.utilities.exceptions.MisconfigurationException:
print("No test set found")
print(f"[+] Done! Data statistics saved to: {output_folder}")
if __name__ == "__main__":
# Helps with generating durations for the dataset to train other architectures
# that cannot learn to align due to limited size of dataset
# Example usage:
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
# This will create a folder in data/processed_data/durations/ljspeech with the durations
main()

View File

@@ -2,6 +2,7 @@ import os
import sys
import warnings
from importlib.util import find_spec
from math import ceil
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
@@ -217,3 +218,42 @@ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else:
wget.download(url=url, out=checkpoint_path)
def get_phoneme_durations(durations, phones):
prev = durations[0]
merged_durations = []
# Convolve with stride 2
for i in range(1, len(durations), 2):
if i == len(durations) - 2:
# if it is last take full value
next_half = durations[i + 1]
else:
next_half = ceil(durations[i + 1] / 2)
curr = prev + durations[i] + next_half
prev = durations[i + 1] - next_half
merged_durations.append(curr)
assert len(phones) == len(merged_durations)
assert len(merged_durations) == (len(durations) - 1) // 2
merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
start = torch.tensor(0)
duration_json = []
for i, duration in enumerate(merged_durations):
duration_json.append(
{
phones[i]: {
"starttime": start.item(),
"endtime": duration.item(),
"duration": duration.item() - start.item(),
}
}
)
start = duration
assert list(duration_json[-1].values())[0]["endtime"] == sum(
durations
), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
return duration_json

View File

@@ -35,7 +35,7 @@ torchaudio
matplotlib
pandas
conformer==0.3.2
diffusers==0.27.2
diffusers==0.25.0
notebook
ipywidgets
gradio

15
scripts/get_durations.sh Normal file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
echo "Starting script"
echo "Getting LJ Speech durations"
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c logs/train/lj_det/runs/2024-01-12_12-05-00/checkpoints/last.ckpt -f
echo "Getting TSG2 durations"
python matcha/utils/get_durations_from_trained_model.py -i tsg2.yaml -c logs/train/tsg2_det_dur/runs/2024-01-05_12-33-35/checkpoints/last.ckpt -f
echo "Getting Joe Spont durations"
python matcha/utils/get_durations_from_trained_model.py -i joe_spont_only.yaml -c logs/train/joe_det_dur/runs/2024-02-20_14-01-01/checkpoints/last.ckpt -f
echo "Getting Ryan durations"
python matcha/utils/get_durations_from_trained_model.py -i ryan.yaml -c logs/train/matcha_ryan_det/runs/2024-02-26_09-28-09/checkpoints/last.ckpt -f

7
scripts/transcribe.sh Normal file
View File

@@ -0,0 +1,7 @@
echo "Transcribing"
whispertranscriber -i lj_det_output -o lj_det_output_transcriptions -f
whispertranscriber -i lj_fm_output -o lj_fm_output_transcriptions -f
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_det_output_transcriptions
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_fm_output_transcriptions

30
scripts/wer_computer.sh Normal file
View File

@@ -0,0 +1,30 @@
#!/bin/bash
# Run from root folder with: bash scripts/wer_computer.sh
root_folder=${1:-"dur_wer_computation"}
echo "Running WER computation for Duration predictors"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_fm_output_transcriptions/"
# echo $cmd
echo "LJ"
echo "==================================="
echo "Flow Matching"
$cmd
echo "-----------------------------------"
echo "LJ Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_det_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac"
echo "==================================="
echo "Cormac Flow Matching"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/fm_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/det_output_transcriptions/"
$cmd
echo "-----------------------------------"

File diff suppressed because one or more lines are too long