mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 10:19:19 +08:00
Adding docstrings
This commit is contained in:
@@ -34,15 +34,19 @@ class BASECFM(torch.nn.Module, ABC):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
z (_type_): mu + noise (we don't need this in this formulation), we will sample the noise again
|
||||
mask (_type_): output_mask
|
||||
mu (_type_): output of encoder
|
||||
n_timesteps (_type_): number of diffusion steps
|
||||
stoc (bool, optional): _description_. Defaults to False.
|
||||
spks (_type_, optional): _description_. Defaults to None.
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
Returns:
|
||||
sample: _description_
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
z = torch.randn_like(mu) * temperature
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||
@@ -52,10 +56,21 @@ class BASECFM(torch.nn.Module, ABC):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (_type_): _description_
|
||||
t (_type_): _description_
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
|
||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
|
||||
steps = 1
|
||||
@@ -75,14 +90,19 @@ class BASECFM(torch.nn.Module, ABC):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Args:
|
||||
x1 (_type_): Target
|
||||
mask (_type_): target mask
|
||||
mu (_type_): output of encoder
|
||||
spks (_type_, optional): speaker embedding. Defaults to None.
|
||||
x1 (torch.Tensor): Target
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): target mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
|
||||
Returns:
|
||||
loss: diffusion loss
|
||||
loss: conditional flow matching loss
|
||||
y: conditional flow
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, _, t = mu.shape
|
||||
|
||||
|
||||
@@ -376,6 +376,24 @@ class TextEncoder(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, x_lengths, spks=None):
|
||||
"""Run forward pass to the transformer based encoder and duration predictor
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): text input
|
||||
shape: (batch_size, max_text_length)
|
||||
x_lengths (torch.Tensor): text input lengths
|
||||
shape: (batch_size,)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
mu (torch.Tensor): average output of the encoder
|
||||
shape: (batch_size, n_feats, max_text_length)
|
||||
logw (torch.Tensor): log duration predicted by the duration predictor
|
||||
shape: (batch_size, 1, max_text_length)
|
||||
x_mask (torch.Tensor): mask for the text input
|
||||
shape: (batch_size, 1, max_text_length)
|
||||
"""
|
||||
x = self.emb(x) * math.sqrt(self.n_channels)
|
||||
x = torch.transpose(x, 1, -1)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
Reference in New Issue
Block a user