From c079e5254a246a632729bb4ae66f245aa08e1021 Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Sun, 17 Sep 2023 06:49:12 +0000 Subject: [PATCH] Adding docstrings --- matcha/models/components/flow_matching.py | 48 ++++++++++++++++------- matcha/models/components/text_encoder.py | 18 +++++++++ matcha/models/matcha_tts.py | 41 +++++++++++++------ 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/matcha/models/components/flow_matching.py b/matcha/models/components/flow_matching.py index 1f53043..781deb0 100644 --- a/matcha/models/components/flow_matching.py +++ b/matcha/models/components/flow_matching.py @@ -34,15 +34,19 @@ class BASECFM(torch.nn.Module, ABC): """Forward diffusion Args: - z (_type_): mu + noise (we don't need this in this formulation), we will sample the noise again - mask (_type_): output_mask - mu (_type_): output of encoder - n_timesteps (_type_): number of diffusion steps - stoc (bool, optional): _description_. Defaults to False. - spks (_type_, optional): _description_. Defaults to None. + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes Returns: - sample: _description_ + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) """ z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) @@ -52,10 +56,21 @@ class BASECFM(torch.nn.Module, ABC): """ Fixed euler solver for ODEs. Args: - x (_type_): _description_ - t (_type_): _description_ + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag sol = [] steps = 1 @@ -75,14 +90,19 @@ class BASECFM(torch.nn.Module, ABC): """Computes diffusion loss Args: - x1 (_type_): Target - mask (_type_): target mask - mu (_type_): output of encoder - spks (_type_, optional): speaker embedding. Defaults to None. + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) Returns: - loss: diffusion loss + loss: conditional flow matching loss y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) """ b, _, t = mu.shape diff --git a/matcha/models/components/text_encoder.py b/matcha/models/components/text_encoder.py index 33cc099..a388d05 100644 --- a/matcha/models/components/text_encoder.py +++ b/matcha/models/components/text_encoder.py @@ -376,6 +376,24 @@ class TextEncoder(nn.Module): ) def forward(self, x, x_lengths, spks=None): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size,) + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ x = self.emb(x) * math.sqrt(self.n_channels) x = torch.transpose(x, 1, -1) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 41caf1b..c480db8 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -9,13 +9,9 @@ from matcha import utils from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.components.flow_matching import CFM from matcha.models.components.text_encoder import TextEncoder -from matcha.utils.model import ( - denormalize, - duration_loss, - fix_len_compatibility, - generate_path, - sequence_mask, -) +from matcha.utils.model import (denormalize, duration_loss, + fix_len_compatibility, generate_path, + sequence_mask) log = utils.get_pylogger(__name__) @@ -78,13 +74,30 @@ class MatchaTTS(BaseLightningClass): # 🍵 Args: x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) n_timesteps (int): number of steps to use for reverse diffusion in decoder. temperature (float, optional): controls variance of terminal distribution. - stoc (bool, optional): flag that adds stochastic term to the decoder sampler. - Usually, does not provide synthesis improvements. + spks (bool, optional): speaker ids. + shape: (batch_size,) length_scale (float, optional): controls speech pace. Increase value to slow down generated speech and vice versa. + + Returns: + dict: { + "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Average mel spectrogram generated by the encoder + "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Refined mel spectrogram improved by the CFM + "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), + # Alignment map between text and mel spectrogram + "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Denormalized mel spectrogram + "mel_lengths": torch.Tensor, shape: (batch_size,), + # Lengths of mel spectrograms + "rtf": float, + # Real-time factor """ # For RTF computation t = dt.datetime.now() @@ -112,7 +125,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 mu_y = mu_y.transpose(1, 2) encoder_outputs = mu_y[:, :, :y_max_length] - # Generate sample by performing reverse dynamics + # Generate sample tracing the probability flow decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks) decoder_outputs = decoder_outputs[:, :, :y_max_length] @@ -133,15 +146,21 @@ class MatchaTTS(BaseLightningClass): # 🍵 Computes 3 losses: 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). 2. prior loss: loss between mel-spectrogram and encoder outputs. - 3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. + 3. flow matching loss: loss between mel-spectrogram and decoder outputs. Args: x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) y (torch.Tensor): batch of corresponding mel-spectrograms. + shape: (batch_size, n_feats, max_mel_length) y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. + shape: (batch_size,) out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. + spks (torch.Tensor, optional): speaker ids. + shape: (batch_size,) """ if self.n_spks > 1: # Get speaker embedding