mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 18:29:19 +08:00
Adding option to do flow matching based duration prediction
This commit is contained in:
@@ -67,33 +67,6 @@ class ConvReluNorm(nn.Module):
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = LayerNorm(filter_channels)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = LayerNorm(filter_channels)
|
||||
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.drop(x)
|
||||
x = self.proj(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class RotaryPositionalEmbeddings(nn.Module):
|
||||
"""
|
||||
## RoPE module
|
||||
@@ -330,7 +303,6 @@ class TextEncoder(nn.Module):
|
||||
self,
|
||||
encoder_type,
|
||||
encoder_params,
|
||||
duration_predictor_params,
|
||||
n_vocab,
|
||||
n_spks=1,
|
||||
spk_emb_dim=128,
|
||||
@@ -368,12 +340,6 @@ class TextEncoder(nn.Module):
|
||||
)
|
||||
|
||||
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
|
||||
self.proj_w = DurationPredictor(
|
||||
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
||||
duration_predictor_params.filter_channels_dp,
|
||||
duration_predictor_params.kernel_size,
|
||||
duration_predictor_params.p_dropout,
|
||||
)
|
||||
|
||||
def forward(self, x, x_lengths, spks=None):
|
||||
"""Run forward pass to the transformer based encoder and duration predictor
|
||||
@@ -404,7 +370,7 @@ class TextEncoder(nn.Module):
|
||||
x = self.encoder(x, x_mask)
|
||||
mu = self.proj_m(x) * x_mask
|
||||
|
||||
x_dp = torch.detach(x)
|
||||
logw = self.proj_w(x_dp, x_mask)
|
||||
# x_dp = torch.detach(x)
|
||||
# logw = self.proj_w(x_dp, x_mask)
|
||||
|
||||
return mu, logw, x_mask
|
||||
return mu, x, x_mask
|
||||
|
||||
Reference in New Issue
Block a user