mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 18:29:19 +08:00
Adding option to do flow matching based duration prediction
This commit is contained in:
@@ -7,11 +7,11 @@ import torch
|
||||
import matcha.utils.monotonic_align as monotonic_align
|
||||
from matcha import utils
|
||||
from matcha.models.baselightningmodule import BaseLightningClass
|
||||
from matcha.models.components.duration_predictors import DP
|
||||
from matcha.models.components.flow_matching import CFM
|
||||
from matcha.models.components.text_encoder import TextEncoder
|
||||
from matcha.utils.model import (
|
||||
denormalize,
|
||||
duration_loss,
|
||||
fix_len_compatibility,
|
||||
generate_path,
|
||||
sequence_mask,
|
||||
@@ -28,6 +28,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
spk_emb_dim,
|
||||
n_feats,
|
||||
encoder,
|
||||
duration_predictor,
|
||||
decoder,
|
||||
cfm,
|
||||
data_statistics,
|
||||
@@ -53,12 +54,13 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
self.encoder = TextEncoder(
|
||||
encoder.encoder_type,
|
||||
encoder.encoder_params,
|
||||
encoder.duration_predictor_params,
|
||||
n_vocab,
|
||||
n_spks,
|
||||
spk_emb_dim,
|
||||
)
|
||||
|
||||
self.dp = DP(duration_predictor)
|
||||
|
||||
self.decoder = CFM(
|
||||
in_channels=2 * encoder.encoder_params.n_feats,
|
||||
out_channel=encoder.encoder_params.n_feats,
|
||||
@@ -112,8 +114,11 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
# Get speaker embedding
|
||||
spks = self.spk_emb(spks.long())
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
|
||||
# Get encoder_outputs `mu_x` and encoded text `enc_output`
|
||||
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
|
||||
|
||||
# Get log-scaled token durations `logw`
|
||||
logw = self.dp(enc_output, x_mask)
|
||||
|
||||
w = torch.exp(logw) * x_mask
|
||||
w_ceil = torch.ceil(w) * length_scale
|
||||
@@ -173,7 +178,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
spks = self.spk_emb(spks)
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
|
||||
mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
|
||||
y_max_length = y.shape[-1]
|
||||
|
||||
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
|
||||
@@ -192,9 +197,8 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
attn = attn.detach()
|
||||
|
||||
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
||||
# refered to as prior loss in the paper
|
||||
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
|
||||
dur_loss = duration_loss(logw, logw_, x_lengths)
|
||||
dur_loss = self.dp.compute_loss(logw_, enc_output, x_mask)
|
||||
|
||||
# Cut a small segment of mel-spectrogram in order to increase batch size
|
||||
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
|
||||
|
||||
Reference in New Issue
Block a user