adding prior loss as a configuration

This commit is contained in:
Shivam Mehta
2023-12-05 09:57:37 +00:00
parent ae2417c175
commit 6e71dc8b8f
2 changed files with 8 additions and 2 deletions

View File

@@ -12,3 +12,4 @@ spk_emb_dim: 64
n_feats: 80 n_feats: 80
data_statistics: ${data.data_statistics} data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4 out_size: null # Must be divisible by 4
prior_loss: true

View File

@@ -34,6 +34,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
out_size, out_size,
optimizer=None, optimizer=None,
scheduler=None, scheduler=None,
prior_loss=True,
): ):
super().__init__() super().__init__()
@@ -44,6 +45,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.spk_emb_dim = spk_emb_dim self.spk_emb_dim = spk_emb_dim
self.n_feats = n_feats self.n_feats = n_feats
self.out_size = out_size self.out_size = out_size
self.prior_loss = prior_loss
if n_spks > 1: if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
@@ -228,7 +230,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Compute loss of the decoder # Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) if self.prior_loss:
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss return dur_loss, prior_loss, diff_loss