Merge pull request #42 from shivammehta25/dev

Merging dev adding another dataset, piper phonemizer and refractoring
This commit is contained in:
Shivam Mehta
2024-01-12 11:49:53 +01:00
committed by GitHub
9 changed files with 54 additions and 10 deletions

View File

@@ -0,0 +1,14 @@
defaults:
- ljspeech
- _self_
# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: hi-fi_en-US_female
train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt
valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt
batch_size: 32
cleaners: [english_cleaners_piper]
data_statistics: # Computed for this dataset
mel_mean: -6.38385
mel_std: 2.541796

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: hi-fi_en-US_female.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"]
run_name: hi-fi_en-US_female_piper_phonemizer

View File

@@ -12,3 +12,4 @@ spk_emb_dim: 64
n_feats: 80
data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4
prior_loss: true

View File

@@ -1 +1 @@
0.0.4
0.0.5

View File

@@ -81,7 +81,7 @@ class BaseLightningClass(LightningModule, ABC):
"step",
float(self.global_step),
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)

View File

@@ -73,16 +73,14 @@ class BASECFM(torch.nn.Module, ABC):
# Or in future might add like a return_all_steps flag
sol = []
steps = 1
while steps <= len(t_span) - 1:
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if steps < len(t_span) - 1:
dt = t_span[steps + 1] - t
steps += 1
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]

View File

@@ -34,6 +34,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
out_size,
optimizer=None,
scheduler=None,
prior_loss=True,
):
super().__init__()
@@ -44,6 +45,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.spk_emb_dim = spk_emb_dim
self.n_feats = n_feats
self.out_size = out_size
self.prior_loss = prior_loss
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
@@ -228,7 +230,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
if self.prior_loss:
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss

View File

@@ -15,6 +15,7 @@ import logging
import re
import phonemizer
import piper_phonemize
from unidecode import unidecode
# To avoid excessive logging we set the log level of the phonemizer package to Critical
@@ -103,3 +104,13 @@ def english_cleaners2(text):
phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
phonemes = collapse_whitespace(phonemes)
return phonemes
def english_cleaners_piper(text):
"""Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
phonemes = collapse_whitespace(phonemes)
return phonemes

View File

@@ -35,10 +35,11 @@ torchaudio
matplotlib
pandas
conformer==0.3.2
diffusers==0.21.3
diffusers==0.25.0
notebook
ipywidgets
gradio
gdown
wget
seaborn
piper_phonemize