Adding docstrings

This commit is contained in:
Shivam Mehta
2023-09-17 06:49:12 +00:00
parent 0554a5b87c
commit c079e5254a
3 changed files with 82 additions and 25 deletions

View File

@@ -9,13 +9,9 @@ from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder
from matcha.utils.model import (
denormalize,
duration_loss,
fix_len_compatibility,
generate_path,
sequence_mask,
)
from matcha.utils.model import (denormalize, duration_loss,
fix_len_compatibility, generate_path,
sequence_mask)
log = utils.get_pylogger(__name__)
@@ -78,13 +74,30 @@ class MatchaTTS(BaseLightningClass): # 🍵
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
temperature (float, optional): controls variance of terminal distribution.
stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
Usually, does not provide synthesis improvements.
spks (bool, optional): speaker ids.
shape: (batch_size,)
length_scale (float, optional): controls speech pace.
Increase value to slow down generated speech and vice versa.
Returns:
dict: {
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Average mel spectrogram generated by the encoder
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Refined mel spectrogram improved by the CFM
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
# Alignment map between text and mel spectrogram
"mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Denormalized mel spectrogram
"mel_lengths": torch.Tensor, shape: (batch_size,),
# Lengths of mel spectrograms
"rtf": float,
# Real-time factor
"""
# For RTF computation
t = dt.datetime.now()
@@ -112,7 +125,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length]
# Generate sample by performing reverse dynamics
# Generate sample tracing the probability flow
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
decoder_outputs = decoder_outputs[:, :, :y_max_length]
@@ -133,15 +146,21 @@ class MatchaTTS(BaseLightningClass): # 🍵
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs.
3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
y (torch.Tensor): batch of corresponding mel-spectrograms.
shape: (batch_size, n_feats, max_mel_length)
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
shape: (batch_size,)
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
spks (torch.Tensor, optional): speaker ids.
shape: (batch_size,)
"""
if self.n_spks > 1:
# Get speaker embedding