Adding possibility of getting durations out

This commit is contained in:
Shivam Mehta
2024-02-24 15:10:19 +00:00
parent def0855608
commit 8e87111a98
6 changed files with 516 additions and 25 deletions

View File

@@ -58,7 +58,7 @@ class BaseLightningClass(LightningModule, ABC):
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
dur_loss, prior_loss, diff_loss = self(
dur_loss, prior_loss, diff_loss, *_ = self(
x=x,
x_lengths=x_lengths,
y=y,

View File

@@ -4,7 +4,7 @@ import random
import torch
import matcha.utils.monotonic_align as monotonic_align
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.duration_predictors import DP
@@ -241,4 +241,4 @@ class MatchaTTS(BaseLightningClass): # 🍵
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss
return dur_loss, prior_loss, diff_loss, attn