mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 02:09:21 +08:00
Minor changes moving option to disable prior loss in config
This commit is contained in:
@@ -34,6 +34,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
out_size,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
prior_loss=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -44,6 +45,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.n_feats = n_feats
|
||||
self.out_size = out_size
|
||||
self.prior_loss = prior_loss
|
||||
|
||||
if n_spks > 1:
|
||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||
@@ -228,7 +230,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
# Compute loss of the decoder
|
||||
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
|
||||
|
||||
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
||||
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
||||
if self.prior_loss:
|
||||
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
||||
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
||||
else:
|
||||
prior_loss = 0
|
||||
|
||||
return dur_loss, prior_loss, diff_loss
|
||||
|
||||
Reference in New Issue
Block a user