mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 09:49:21 +08:00
Adding docstrings
This commit is contained in:
@@ -69,6 +69,7 @@ class Downsample1D(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -115,6 +116,7 @@ class TimestepEmbedding(nn.Module):
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class BASECFM(torch.nn.Module, ABC):
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
z = torch.randn_like(mu) * temperature
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||
|
||||
@@ -2,8 +2,13 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm,
|
||||
AdaLayerNormZero, ApproximateGELU)
|
||||
from diffusers.models.attention import (
|
||||
GEGLU,
|
||||
GELU,
|
||||
AdaLayerNorm,
|
||||
AdaLayerNormZero,
|
||||
ApproximateGELU,
|
||||
)
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.lora import LoRACompatibleLinear
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
@@ -38,7 +43,7 @@ class SnakeBeta(nn.Module):
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
"""
|
||||
super(SnakeBeta, self).__init__()
|
||||
super().__init__()
|
||||
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
||||
self.proj = LoRACompatibleLinear(in_features, out_features)
|
||||
|
||||
@@ -73,8 +78,8 @@ class SnakeBeta(nn.Module):
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
@@ -127,8 +132,7 @@ class FeedForward(nn.Module):
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
@@ -217,7 +221,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
# scale_qk=False, # uncomment this to not to use flash attention
|
||||
# scale_qk=False, # uncomment this to not to use flash attention
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2 = None
|
||||
@@ -309,4 +313,4 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
return hidden_states
|
||||
|
||||
@@ -9,9 +9,13 @@ from matcha import utils
|
||||
from matcha.models.baselightningmodule import BaseLightningClass
|
||||
from matcha.models.components.flow_matching import CFM
|
||||
from matcha.models.components.text_encoder import TextEncoder
|
||||
from matcha.utils.model import (denormalize, duration_loss,
|
||||
fix_len_compatibility, generate_path,
|
||||
sequence_mask)
|
||||
from matcha.utils.model import (
|
||||
denormalize,
|
||||
duration_loss,
|
||||
fix_len_compatibility,
|
||||
generate_path,
|
||||
sequence_mask,
|
||||
)
|
||||
|
||||
log = utils.get_pylogger(__name__)
|
||||
|
||||
@@ -83,7 +87,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
shape: (batch_size,)
|
||||
length_scale (float, optional): controls speech pace.
|
||||
Increase value to slow down generated speech and vice versa.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
||||
|
||||
Reference in New Issue
Block a user