diff --git a/configs/experiment/ljspeech_no_prior_loss.yaml b/configs/experiment/ljspeech_no_prior_loss.yaml new file mode 100644 index 0000000..6181950 --- /dev/null +++ b/configs/experiment/ljspeech_no_prior_loss.yaml @@ -0,0 +1,17 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech + +model: + prior_loss: false diff --git a/configs/model/matcha.yaml b/configs/model/matcha.yaml index 4700855..36f6eaf 100644 --- a/configs/model/matcha.yaml +++ b/configs/model/matcha.yaml @@ -12,3 +12,4 @@ spk_emb_dim: 64 n_feats: 80 data_statistics: ${data.data_statistics} out_size: null # Must be divisible by 4 +prior_loss: true diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 6feb9e7..64b2c07 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -34,6 +34,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 out_size, optimizer=None, scheduler=None, + prior_loss=True, ): super().__init__() @@ -44,6 +45,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 self.spk_emb_dim = spk_emb_dim self.n_feats = n_feats self.out_size = out_size + self.prior_loss = prior_loss if n_spks > 1: self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) @@ -228,7 +230,10 @@ class MatchaTTS(BaseLightningClass): # 🍵 # Compute loss of the decoder diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) - prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) - prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) + if self.prior_loss: + prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) + prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) + else: + prior_loss = 0 return dur_loss, prior_loss, diff_loss