diff --git a/configs/experiment/ljspeech_stoc_dur.yaml b/configs/experiment/ljspeech_stoc_dur.yaml new file mode 100644 index 0000000..89b7e59 --- /dev/null +++ b/configs/experiment/ljspeech_stoc_dur.yaml @@ -0,0 +1,16 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + - override /model/duration_predictor: flow_matching.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + + +run_name: ljspeech diff --git a/configs/model/duration_predictor/deterministic.yaml b/configs/model/duration_predictor/deterministic.yaml new file mode 100644 index 0000000..1eaedf5 --- /dev/null +++ b/configs/model/duration_predictor/deterministic.yaml @@ -0,0 +1,7 @@ +name: deterministic +n_spks: ${model.n_spks} +spk_emb_dim: ${model.spk_emb_dim} +filter_channels: 256 +kernel_size: 3 +n_channels: ${model.encoder.encoder_params.n_channels} +p_dropout: ${model.encoder.encoder_params.p_dropout} diff --git a/configs/model/duration_predictor/flow_matching.yaml b/configs/model/duration_predictor/flow_matching.yaml new file mode 100644 index 0000000..2a394a8 --- /dev/null +++ b/configs/model/duration_predictor/flow_matching.yaml @@ -0,0 +1,7 @@ +defaults: + - deterministic.yaml + - _self_ + +sigma_min: 1e-4 +n_steps: 10 +name: flow_matching diff --git a/configs/model/encoder/default.yaml b/configs/model/encoder/default.yaml index d4d5e5a..6a14ec8 100644 --- a/configs/model/encoder/default.yaml +++ b/configs/model/encoder/default.yaml @@ -3,16 +3,8 @@ encoder_params: n_feats: ${model.n_feats} n_channels: 192 filter_channels: 768 - filter_channels_dp: 256 n_heads: 2 n_layers: 6 kernel_size: 3 p_dropout: 0.1 - spk_emb_dim: 64 - n_spks: 1 prenet: true - -duration_predictor_params: - filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp} - kernel_size: 3 - p_dropout: ${model.encoder.encoder_params.p_dropout} diff --git a/configs/model/matcha.yaml b/configs/model/matcha.yaml index 36f6eaf..fd6d348 100644 --- a/configs/model/matcha.yaml +++ b/configs/model/matcha.yaml @@ -1,6 +1,7 @@ defaults: - _self_ - encoder: default.yaml + - duration_predictor: deterministic.yaml - decoder: default.yaml - cfm: default.yaml - optimizer: adam.yaml diff --git a/matcha/models/components/duration_predictors.py b/matcha/models/components/duration_predictors.py new file mode 100644 index 0000000..787057b --- /dev/null +++ b/matcha/models/components/duration_predictors.py @@ -0,0 +1,243 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import pack + +from matcha.models.components.decoder import SinusoidalPosEmb, TimestepEmbedding +from matcha.models.components.text_encoder import LayerNorm + +# Define available networks + + +class DurationPredictorNetwork(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class DurationPredictorNetworkWithTimeStep(nn.Module): + """Similar architecture but with a time embedding support""" + + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.time_embeddings = SinusoidalPosEmb(filter_channels) + self.time_mlp = TimestepEmbedding( + in_channels=filter_channels, + time_embed_dim=filter_channels, + act_fn="silu", + ) + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask, enc_outputs, t): + t = self.time_embeddings(t) + t = self.time_mlp(t).unsqueeze(-1) + + x = pack([x, enc_outputs], "b * t")[0] + + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = x + t + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = x + t + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +# Define available methods to compute loss + +# Simple MSE deterministic + + +class DeterministicDurationPredictor(nn.Module): + def __init__(self, params): + super().__init__() + self.estimator = DurationPredictorNetwork( + params.n_channels + (params.spk_emb_dim if params.n_spks > 1 else 0), + params.filter_channels, + params.kernel_size, + params.p_dropout, + ) + + @torch.inference_mode() + def forward(self, x, x_mask): + return self.estimator(x, x_mask) + + def compute_loss(self, durations, enc_outputs, x_mask): + return F.mse_loss(self.estimator(enc_outputs, x_mask), durations, reduction="sum") / torch.sum(x_mask) + + +# Flow Matching duration predictor + + +class FlowMatchingDurationPrediction(nn.Module): + def __init__(self, params) -> None: + super().__init__() + + self.estimator = DurationPredictorNetworkWithTimeStep( + 1 + + params.n_channels + + ( + params.spk_emb_dim if params.n_spks > 1 else 0 + ), # 1 for the durations and n_channels for encoder outputs + params.filter_channels, + params.kernel_size, + params.p_dropout, + ) + self.sigma_min = params.sigma_min + self.n_steps = params.n_steps + + @torch.inference_mode() + def forward(self, enc_outputs, mask, n_timesteps=None, temperature=1): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + if n_timesteps is None: + n_timesteps = self.n_steps + + b, _, t = enc_outputs.shape + z = torch.randn((b, 1, t), device=enc_outputs.device, dtype=enc_outputs.dtype) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=enc_outputs.device) + return self.solve_euler(z, t_span=t_span, enc_outputs=enc_outputs, mask=mask) + + def solve_euler(self, x, t_span, enc_outputs, mask): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, enc_outputs, t) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def compute_loss(self, x1, enc_outputs, mask): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + enc_outputs = enc_outputs.detach() # don't update encoder from the duration predictor + b, _, t = enc_outputs.shape + + # random timestep + t = torch.rand([b, 1, 1], device=enc_outputs.device, dtype=enc_outputs.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + loss = F.mse_loss(self.estimator(y, mask, enc_outputs, t.squeeze()), u, reduction="sum") / ( + torch.sum(mask) * u.shape[1] + ) + return loss + + +# Meta class to wrap all duration predictors + + +class DP(nn.Module): + def __init__(self, params): + super().__init__() + self.name = params.name + + if params.name == "deterministic": + self.dp = DeterministicDurationPredictor( + params, + ) + elif params.name == "flow_matching": + self.dp = FlowMatchingDurationPrediction( + params, + ) + else: + raise ValueError(f"Invalid duration predictor configuration: {params.name}") + + @torch.inference_mode() + def forward(self, enc_outputs, mask): + return self.dp(enc_outputs, mask) + + def compute_loss(self, durations, enc_outputs, mask): + return self.dp.compute_loss(durations, enc_outputs, mask) diff --git a/matcha/models/components/text_encoder.py b/matcha/models/components/text_encoder.py index a388d05..13a6b3d 100644 --- a/matcha/models/components/text_encoder.py +++ b/matcha/models/components/text_encoder.py @@ -67,33 +67,6 @@ class ConvReluNorm(nn.Module): return x * x_mask -class DurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): - super().__init__() - self.in_channels = in_channels - self.filter_channels = filter_channels - self.p_dropout = p_dropout - - self.drop = torch.nn.Dropout(p_dropout) - self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) - self.norm_1 = LayerNorm(filter_channels) - self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) - self.norm_2 = LayerNorm(filter_channels) - self.proj = torch.nn.Conv1d(filter_channels, 1, 1) - - def forward(self, x, x_mask): - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.norm_1(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - x = torch.relu(x) - x = self.norm_2(x) - x = self.drop(x) - x = self.proj(x * x_mask) - return x * x_mask - - class RotaryPositionalEmbeddings(nn.Module): """ ## RoPE module @@ -330,7 +303,6 @@ class TextEncoder(nn.Module): self, encoder_type, encoder_params, - duration_predictor_params, n_vocab, n_spks=1, spk_emb_dim=128, @@ -368,12 +340,6 @@ class TextEncoder(nn.Module): ) self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) - self.proj_w = DurationPredictor( - self.n_channels + (spk_emb_dim if n_spks > 1 else 0), - duration_predictor_params.filter_channels_dp, - duration_predictor_params.kernel_size, - duration_predictor_params.p_dropout, - ) def forward(self, x, x_lengths, spks=None): """Run forward pass to the transformer based encoder and duration predictor @@ -404,7 +370,7 @@ class TextEncoder(nn.Module): x = self.encoder(x, x_mask) mu = self.proj_m(x) * x_mask - x_dp = torch.detach(x) - logw = self.proj_w(x_dp, x_mask) + # x_dp = torch.detach(x) + # logw = self.proj_w(x_dp, x_mask) - return mu, logw, x_mask + return mu, x, x_mask diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 64b2c07..325667e 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -7,11 +7,11 @@ import torch import matcha.utils.monotonic_align as monotonic_align from matcha import utils from matcha.models.baselightningmodule import BaseLightningClass +from matcha.models.components.duration_predictors import DP from matcha.models.components.flow_matching import CFM from matcha.models.components.text_encoder import TextEncoder from matcha.utils.model import ( denormalize, - duration_loss, fix_len_compatibility, generate_path, sequence_mask, @@ -28,6 +28,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 spk_emb_dim, n_feats, encoder, + duration_predictor, decoder, cfm, data_statistics, @@ -53,12 +54,13 @@ class MatchaTTS(BaseLightningClass): # 🍵 self.encoder = TextEncoder( encoder.encoder_type, encoder.encoder_params, - encoder.duration_predictor_params, n_vocab, n_spks, spk_emb_dim, ) + self.dp = DP(duration_predictor) + self.decoder = CFM( in_channels=2 * encoder.encoder_params.n_feats, out_channel=encoder.encoder_params.n_feats, @@ -112,8 +114,11 @@ class MatchaTTS(BaseLightningClass): # 🍵 # Get speaker embedding spks = self.spk_emb(spks.long()) - # Get encoder_outputs `mu_x` and log-scaled token durations `logw` - mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + # Get encoder_outputs `mu_x` and encoded text `enc_output` + mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks) + + # Get log-scaled token durations `logw` + logw = self.dp(enc_output, x_mask) w = torch.exp(logw) * x_mask w_ceil = torch.ceil(w) * length_scale @@ -173,7 +178,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 spks = self.spk_emb(spks) # Get encoder_outputs `mu_x` and log-scaled token durations `logw` - mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks) y_max_length = y.shape[-1] y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) @@ -192,9 +197,8 @@ class MatchaTTS(BaseLightningClass): # 🍵 attn = attn.detach() # Compute loss between predicted log-scaled durations and those obtained from MAS - # refered to as prior loss in the paper logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask - dur_loss = duration_loss(logw, logw_, x_lengths) + dur_loss = self.dp.compute_loss(logw_, enc_output, x_mask) # Cut a small segment of mel-spectrogram in order to increase batch size # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it