diff --git a/configs/model/matcha.yaml b/configs/model/matcha.yaml index 36f6eaf..4700855 100644 --- a/configs/model/matcha.yaml +++ b/configs/model/matcha.yaml @@ -12,4 +12,3 @@ spk_emb_dim: 64 n_feats: 80 data_statistics: ${data.data_statistics} out_size: null # Must be divisible by 4 -prior_loss: true diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 64b2c07..6feb9e7 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -34,7 +34,6 @@ class MatchaTTS(BaseLightningClass): # 🍵 out_size, optimizer=None, scheduler=None, - prior_loss=True, ): super().__init__() @@ -45,7 +44,6 @@ class MatchaTTS(BaseLightningClass): # 🍵 self.spk_emb_dim = spk_emb_dim self.n_feats = n_feats self.out_size = out_size - self.prior_loss = prior_loss if n_spks > 1: self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) @@ -230,10 +228,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 # Compute loss of the decoder diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) - if self.prior_loss: - prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) - prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) - else: - prior_loss = 0 + prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) + prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) return dur_loss, prior_loss, diff_loss