Adding the possibility to train with durations

This commit is contained in:
Shivam Mehta
2024-05-27 13:24:21 +02:00
parent e658aee6a5
commit aa496aa13f
7 changed files with 54 additions and 12 deletions

View File

@@ -65,6 +65,7 @@ class BaseLightningClass(LightningModule, ABC):
y_lengths=y_lengths,
spks=spks,
out_size=self.out_size,
durations=batch["durations"],
)
return {
"dur_loss": dur_loss,

View File

@@ -35,6 +35,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
optimizer=None,
scheduler=None,
prior_loss=True,
use_precomputed_durations=False,
):
super().__init__()
@@ -46,6 +47,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.n_feats = n_feats
self.out_size = out_size
self.prior_loss = prior_loss
self.use_precomputed_durations = use_precomputed_durations
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
@@ -147,7 +149,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
"rtf": rtf,
}
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
@@ -179,17 +181,20 @@ class MatchaTTS(BaseLightningClass): # 🍵
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
if self.use_precomputed_durations:
attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
else:
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach()
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach() # b, t_text, T_mel
# Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper