In the middle of adding discrete nf based duration predictor

This commit is contained in:
Shivam Mehta
2024-01-10 11:04:46 +00:00
parent a58bab5403
commit d03bba82bb
2 changed files with 207 additions and 1 deletions

View File

@@ -121,7 +121,8 @@ class MatchaTTS(BaseLightningClass): # 🍵
logw = self.dp(enc_output, x_mask)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
w_ceil = torch.round(w) * length_scale
# print(w_ceil)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max()
y_max_length_ = fix_len_compatibility(y_max_length)