mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 18:29:19 +08:00
Adding docstrings
This commit is contained in:
@@ -34,15 +34,19 @@ class BASECFM(torch.nn.Module, ABC):
|
|||||||
"""Forward diffusion
|
"""Forward diffusion
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
z (_type_): mu + noise (we don't need this in this formulation), we will sample the noise again
|
mu (torch.Tensor): output of encoder
|
||||||
mask (_type_): output_mask
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
mu (_type_): output of encoder
|
mask (torch.Tensor): output_mask
|
||||||
n_timesteps (_type_): number of diffusion steps
|
shape: (batch_size, 1, mel_timesteps)
|
||||||
stoc (bool, optional): _description_. Defaults to False.
|
n_timesteps (int): number of diffusion steps
|
||||||
spks (_type_, optional): _description_. Defaults to None.
|
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||||
|
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||||
|
shape: (batch_size, spk_emb_dim)
|
||||||
|
cond: Not used but kept for future purposes
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
sample: _description_
|
sample: generated mel-spectrogram
|
||||||
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
"""
|
"""
|
||||||
z = torch.randn_like(mu) * temperature
|
z = torch.randn_like(mu) * temperature
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||||
@@ -52,10 +56,21 @@ class BASECFM(torch.nn.Module, ABC):
|
|||||||
"""
|
"""
|
||||||
Fixed euler solver for ODEs.
|
Fixed euler solver for ODEs.
|
||||||
Args:
|
Args:
|
||||||
x (_type_): _description_
|
x (torch.Tensor): random noise
|
||||||
t (_type_): _description_
|
t_span (torch.Tensor): n_timesteps interpolated
|
||||||
|
shape: (n_timesteps + 1,)
|
||||||
|
mu (torch.Tensor): output of encoder
|
||||||
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
|
mask (torch.Tensor): output_mask
|
||||||
|
shape: (batch_size, 1, mel_timesteps)
|
||||||
|
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||||
|
shape: (batch_size, spk_emb_dim)
|
||||||
|
cond: Not used but kept for future purposes
|
||||||
"""
|
"""
|
||||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||||
|
|
||||||
|
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||||
|
# Or in future might add like a return_all_steps flag
|
||||||
sol = []
|
sol = []
|
||||||
|
|
||||||
steps = 1
|
steps = 1
|
||||||
@@ -75,14 +90,19 @@ class BASECFM(torch.nn.Module, ABC):
|
|||||||
"""Computes diffusion loss
|
"""Computes diffusion loss
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x1 (_type_): Target
|
x1 (torch.Tensor): Target
|
||||||
mask (_type_): target mask
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
mu (_type_): output of encoder
|
mask (torch.Tensor): target mask
|
||||||
spks (_type_, optional): speaker embedding. Defaults to None.
|
shape: (batch_size, 1, mel_timesteps)
|
||||||
|
mu (torch.Tensor): output of encoder
|
||||||
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
|
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
||||||
|
shape: (batch_size, spk_emb_dim)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
loss: diffusion loss
|
loss: conditional flow matching loss
|
||||||
y: conditional flow
|
y: conditional flow
|
||||||
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
"""
|
"""
|
||||||
b, _, t = mu.shape
|
b, _, t = mu.shape
|
||||||
|
|
||||||
|
|||||||
@@ -376,6 +376,24 @@ class TextEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, spks=None):
|
def forward(self, x, x_lengths, spks=None):
|
||||||
|
"""Run forward pass to the transformer based encoder and duration predictor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): text input
|
||||||
|
shape: (batch_size, max_text_length)
|
||||||
|
x_lengths (torch.Tensor): text input lengths
|
||||||
|
shape: (batch_size,)
|
||||||
|
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||||
|
shape: (batch_size,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mu (torch.Tensor): average output of the encoder
|
||||||
|
shape: (batch_size, n_feats, max_text_length)
|
||||||
|
logw (torch.Tensor): log duration predicted by the duration predictor
|
||||||
|
shape: (batch_size, 1, max_text_length)
|
||||||
|
x_mask (torch.Tensor): mask for the text input
|
||||||
|
shape: (batch_size, 1, max_text_length)
|
||||||
|
"""
|
||||||
x = self.emb(x) * math.sqrt(self.n_channels)
|
x = self.emb(x) * math.sqrt(self.n_channels)
|
||||||
x = torch.transpose(x, 1, -1)
|
x = torch.transpose(x, 1, -1)
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
|
|||||||
@@ -9,13 +9,9 @@ from matcha import utils
|
|||||||
from matcha.models.baselightningmodule import BaseLightningClass
|
from matcha.models.baselightningmodule import BaseLightningClass
|
||||||
from matcha.models.components.flow_matching import CFM
|
from matcha.models.components.flow_matching import CFM
|
||||||
from matcha.models.components.text_encoder import TextEncoder
|
from matcha.models.components.text_encoder import TextEncoder
|
||||||
from matcha.utils.model import (
|
from matcha.utils.model import (denormalize, duration_loss,
|
||||||
denormalize,
|
fix_len_compatibility, generate_path,
|
||||||
duration_loss,
|
sequence_mask)
|
||||||
fix_len_compatibility,
|
|
||||||
generate_path,
|
|
||||||
sequence_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
log = utils.get_pylogger(__name__)
|
log = utils.get_pylogger(__name__)
|
||||||
|
|
||||||
@@ -78,13 +74,30 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
||||||
|
shape: (batch_size, max_text_length)
|
||||||
x_lengths (torch.Tensor): lengths of texts in batch.
|
x_lengths (torch.Tensor): lengths of texts in batch.
|
||||||
|
shape: (batch_size,)
|
||||||
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
|
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
|
||||||
temperature (float, optional): controls variance of terminal distribution.
|
temperature (float, optional): controls variance of terminal distribution.
|
||||||
stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
|
spks (bool, optional): speaker ids.
|
||||||
Usually, does not provide synthesis improvements.
|
shape: (batch_size,)
|
||||||
length_scale (float, optional): controls speech pace.
|
length_scale (float, optional): controls speech pace.
|
||||||
Increase value to slow down generated speech and vice versa.
|
Increase value to slow down generated speech and vice versa.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {
|
||||||
|
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
||||||
|
# Average mel spectrogram generated by the encoder
|
||||||
|
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
||||||
|
# Refined mel spectrogram improved by the CFM
|
||||||
|
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
|
||||||
|
# Alignment map between text and mel spectrogram
|
||||||
|
"mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
||||||
|
# Denormalized mel spectrogram
|
||||||
|
"mel_lengths": torch.Tensor, shape: (batch_size,),
|
||||||
|
# Lengths of mel spectrograms
|
||||||
|
"rtf": float,
|
||||||
|
# Real-time factor
|
||||||
"""
|
"""
|
||||||
# For RTF computation
|
# For RTF computation
|
||||||
t = dt.datetime.now()
|
t = dt.datetime.now()
|
||||||
@@ -112,7 +125,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
mu_y = mu_y.transpose(1, 2)
|
mu_y = mu_y.transpose(1, 2)
|
||||||
encoder_outputs = mu_y[:, :, :y_max_length]
|
encoder_outputs = mu_y[:, :, :y_max_length]
|
||||||
|
|
||||||
# Generate sample by performing reverse dynamics
|
# Generate sample tracing the probability flow
|
||||||
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
|
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
|
||||||
decoder_outputs = decoder_outputs[:, :, :y_max_length]
|
decoder_outputs = decoder_outputs[:, :, :y_max_length]
|
||||||
|
|
||||||
@@ -133,15 +146,21 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
Computes 3 losses:
|
Computes 3 losses:
|
||||||
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
||||||
2. prior loss: loss between mel-spectrogram and encoder outputs.
|
2. prior loss: loss between mel-spectrogram and encoder outputs.
|
||||||
3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
|
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
||||||
|
shape: (batch_size, max_text_length)
|
||||||
x_lengths (torch.Tensor): lengths of texts in batch.
|
x_lengths (torch.Tensor): lengths of texts in batch.
|
||||||
|
shape: (batch_size,)
|
||||||
y (torch.Tensor): batch of corresponding mel-spectrograms.
|
y (torch.Tensor): batch of corresponding mel-spectrograms.
|
||||||
|
shape: (batch_size, n_feats, max_mel_length)
|
||||||
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
|
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
|
||||||
|
shape: (batch_size,)
|
||||||
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
|
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
|
||||||
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
|
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
|
||||||
|
spks (torch.Tensor, optional): speaker ids.
|
||||||
|
shape: (batch_size,)
|
||||||
"""
|
"""
|
||||||
if self.n_spks > 1:
|
if self.n_spks > 1:
|
||||||
# Get speaker embedding
|
# Get speaker embedding
|
||||||
|
|||||||
Reference in New Issue
Block a user