diff --git a/configs/data/hi-fi_en-US_female.yaml b/configs/data/hi-fi_en-US_female.yaml new file mode 100644 index 0000000..1269f9b --- /dev/null +++ b/configs/data/hi-fi_en-US_female.yaml @@ -0,0 +1,14 @@ +defaults: + - ljspeech + - _self_ + +# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/ +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: hi-fi_en-US_female +train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt +valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt +batch_size: 32 +cleaners: [english_cleaners_piper] +data_statistics: # Computed for this dataset + mel_mean: -6.38385 + mel_std: 2.541796 diff --git a/configs/experiment/hifi_dataset_piper_phonemizer.yaml b/configs/experiment/hifi_dataset_piper_phonemizer.yaml new file mode 100644 index 0000000..7e6c57a --- /dev/null +++ b/configs/experiment/hifi_dataset_piper_phonemizer.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: hi-fi_en-US_female.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"] + +run_name: hi-fi_en-US_female_piper_phonemizer diff --git a/configs/model/matcha.yaml b/configs/model/matcha.yaml index 4700855..36f6eaf 100644 --- a/configs/model/matcha.yaml +++ b/configs/model/matcha.yaml @@ -12,3 +12,4 @@ spk_emb_dim: 64 n_feats: 80 data_statistics: ${data.data_statistics} out_size: null # Must be divisible by 4 +prior_loss: true diff --git a/matcha/VERSION b/matcha/VERSION index 81340c7..bbdeab6 100644 --- a/matcha/VERSION +++ b/matcha/VERSION @@ -1 +1 @@ -0.0.4 +0.0.5 diff --git a/matcha/models/baselightningmodule.py b/matcha/models/baselightningmodule.py index 29f4927..3724888 100644 --- a/matcha/models/baselightningmodule.py +++ b/matcha/models/baselightningmodule.py @@ -81,7 +81,7 @@ class BaseLightningClass(LightningModule, ABC): "step", float(self.global_step), on_step=True, - on_epoch=True, + prog_bar=True, logger=True, sync_dist=True, ) diff --git a/matcha/models/components/flow_matching.py b/matcha/models/components/flow_matching.py index 4d77547..5cad743 100644 --- a/matcha/models/components/flow_matching.py +++ b/matcha/models/components/flow_matching.py @@ -73,16 +73,14 @@ class BASECFM(torch.nn.Module, ABC): # Or in future might add like a return_all_steps flag sol = [] - steps = 1 - while steps <= len(t_span) - 1: + for step in range(1, len(t_span)): dphi_dt = self.estimator(x, mask, mu, t, spks, cond) x = x + dt * dphi_dt t = t + dt sol.append(x) - if steps < len(t_span) - 1: - dt = t_span[steps + 1] - t - steps += 1 + if step < len(t_span) - 1: + dt = t_span[step + 1] - t return sol[-1] diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py index 6feb9e7..64b2c07 100644 --- a/matcha/models/matcha_tts.py +++ b/matcha/models/matcha_tts.py @@ -34,6 +34,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 out_size, optimizer=None, scheduler=None, + prior_loss=True, ): super().__init__() @@ -44,6 +45,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 self.spk_emb_dim = spk_emb_dim self.n_feats = n_feats self.out_size = out_size + self.prior_loss = prior_loss if n_spks > 1: self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) @@ -228,7 +230,10 @@ class MatchaTTS(BaseLightningClass): # 🍵 # Compute loss of the decoder diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) - prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) - prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) + if self.prior_loss: + prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) + prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) + else: + prior_loss = 0 return dur_loss, prior_loss, diff_loss diff --git a/matcha/text/cleaners.py b/matcha/text/cleaners.py index 26b91d7..5e8d96b 100644 --- a/matcha/text/cleaners.py +++ b/matcha/text/cleaners.py @@ -15,6 +15,7 @@ import logging import re import phonemizer +import piper_phonemize from unidecode import unidecode # To avoid excessive logging we set the log level of the phonemizer package to Critical @@ -103,3 +104,13 @@ def english_cleaners2(text): phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] phonemes = collapse_whitespace(phonemes) return phonemes + + +def english_cleaners_piper(text): + """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) + phonemes = collapse_whitespace(phonemes) + return phonemes diff --git a/requirements.txt b/requirements.txt index c1be781..0a7e14c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,10 +35,11 @@ torchaudio matplotlib pandas conformer==0.3.2 -diffusers==0.21.3 +diffusers==0.25.0 notebook ipywidgets gradio gdown wget seaborn +piper_phonemize