mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 17:59:19 +08:00
Minor changes moving option to disable prior loss in config
This commit is contained in:
17
configs/experiment/ljspeech_no_prior_loss.yaml
Normal file
17
configs/experiment/ljspeech_no_prior_loss.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
# @package _global_
|
||||
|
||||
# to execute this experiment run:
|
||||
# python train.py experiment=multispeaker
|
||||
|
||||
defaults:
|
||||
- override /data: ljspeech.yaml
|
||||
|
||||
# all parameters below will be merged with parameters from default configurations set above
|
||||
# this allows you to overwrite only specified parameters
|
||||
|
||||
tags: ["ljspeech"]
|
||||
|
||||
run_name: ljspeech
|
||||
|
||||
model:
|
||||
prior_loss: false
|
||||
@@ -12,3 +12,4 @@ spk_emb_dim: 64
|
||||
n_feats: 80
|
||||
data_statistics: ${data.data_statistics}
|
||||
out_size: null # Must be divisible by 4
|
||||
prior_loss: true
|
||||
|
||||
@@ -34,6 +34,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
out_size,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
prior_loss=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -44,6 +45,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.n_feats = n_feats
|
||||
self.out_size = out_size
|
||||
self.prior_loss = prior_loss
|
||||
|
||||
if n_spks > 1:
|
||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||
@@ -228,7 +230,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
# Compute loss of the decoder
|
||||
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
|
||||
|
||||
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
||||
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
||||
if self.prior_loss:
|
||||
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
||||
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
||||
else:
|
||||
prior_loss = 0
|
||||
|
||||
return dur_loss, prior_loss, diff_loss
|
||||
|
||||
Reference in New Issue
Block a user