mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 17:59:19 +08:00
adding prior loss as a configuration
This commit is contained in:
@@ -12,3 +12,4 @@ spk_emb_dim: 64
|
|||||||
n_feats: 80
|
n_feats: 80
|
||||||
data_statistics: ${data.data_statistics}
|
data_statistics: ${data.data_statistics}
|
||||||
out_size: null # Must be divisible by 4
|
out_size: null # Must be divisible by 4
|
||||||
|
prior_loss: true
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
out_size,
|
out_size,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
|
prior_loss=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -44,6 +45,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
self.spk_emb_dim = spk_emb_dim
|
self.spk_emb_dim = spk_emb_dim
|
||||||
self.n_feats = n_feats
|
self.n_feats = n_feats
|
||||||
self.out_size = out_size
|
self.out_size = out_size
|
||||||
|
self.prior_loss = prior_loss
|
||||||
|
|
||||||
if n_spks > 1:
|
if n_spks > 1:
|
||||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||||
@@ -228,7 +230,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
# Compute loss of the decoder
|
# Compute loss of the decoder
|
||||||
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
|
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
|
||||||
|
|
||||||
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
if self.prior_loss:
|
||||||
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
||||||
|
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
||||||
|
else:
|
||||||
|
prior_loss = 0
|
||||||
|
|
||||||
return dur_loss, prior_loss, diff_loss
|
return dur_loss, prior_loss, diff_loss
|
||||||
|
|||||||
Reference in New Issue
Block a user