mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 18:29:19 +08:00
Merge pull request #42 from shivammehta25/dev
Merging dev adding another dataset, piper phonemizer and refractoring
This commit is contained in:
14
configs/data/hi-fi_en-US_female.yaml
Normal file
14
configs/data/hi-fi_en-US_female.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
defaults:
|
||||||
|
- ljspeech
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/
|
||||||
|
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
|
||||||
|
name: hi-fi_en-US_female
|
||||||
|
train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt
|
||||||
|
valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt
|
||||||
|
batch_size: 32
|
||||||
|
cleaners: [english_cleaners_piper]
|
||||||
|
data_statistics: # Computed for this dataset
|
||||||
|
mel_mean: -6.38385
|
||||||
|
mel_std: 2.541796
|
||||||
14
configs/experiment/hifi_dataset_piper_phonemizer.yaml
Normal file
14
configs/experiment/hifi_dataset_piper_phonemizer.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# to execute this experiment run:
|
||||||
|
# python train.py experiment=multispeaker
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- override /data: hi-fi_en-US_female.yaml
|
||||||
|
|
||||||
|
# all parameters below will be merged with parameters from default configurations set above
|
||||||
|
# this allows you to overwrite only specified parameters
|
||||||
|
|
||||||
|
tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"]
|
||||||
|
|
||||||
|
run_name: hi-fi_en-US_female_piper_phonemizer
|
||||||
@@ -12,3 +12,4 @@ spk_emb_dim: 64
|
|||||||
n_feats: 80
|
n_feats: 80
|
||||||
data_statistics: ${data.data_statistics}
|
data_statistics: ${data.data_statistics}
|
||||||
out_size: null # Must be divisible by 4
|
out_size: null # Must be divisible by 4
|
||||||
|
prior_loss: true
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
0.0.4
|
0.0.5
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class BaseLightningClass(LightningModule, ABC):
|
|||||||
"step",
|
"step",
|
||||||
float(self.global_step),
|
float(self.global_step),
|
||||||
on_step=True,
|
on_step=True,
|
||||||
on_epoch=True,
|
prog_bar=True,
|
||||||
logger=True,
|
logger=True,
|
||||||
sync_dist=True,
|
sync_dist=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -73,16 +73,14 @@ class BASECFM(torch.nn.Module, ABC):
|
|||||||
# Or in future might add like a return_all_steps flag
|
# Or in future might add like a return_all_steps flag
|
||||||
sol = []
|
sol = []
|
||||||
|
|
||||||
steps = 1
|
for step in range(1, len(t_span)):
|
||||||
while steps <= len(t_span) - 1:
|
|
||||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
||||||
|
|
||||||
x = x + dt * dphi_dt
|
x = x + dt * dphi_dt
|
||||||
t = t + dt
|
t = t + dt
|
||||||
sol.append(x)
|
sol.append(x)
|
||||||
if steps < len(t_span) - 1:
|
if step < len(t_span) - 1:
|
||||||
dt = t_span[steps + 1] - t
|
dt = t_span[step + 1] - t
|
||||||
steps += 1
|
|
||||||
|
|
||||||
return sol[-1]
|
return sol[-1]
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
out_size,
|
out_size,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
|
prior_loss=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -44,6 +45,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
self.spk_emb_dim = spk_emb_dim
|
self.spk_emb_dim = spk_emb_dim
|
||||||
self.n_feats = n_feats
|
self.n_feats = n_feats
|
||||||
self.out_size = out_size
|
self.out_size = out_size
|
||||||
|
self.prior_loss = prior_loss
|
||||||
|
|
||||||
if n_spks > 1:
|
if n_spks > 1:
|
||||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||||
@@ -228,7 +230,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
# Compute loss of the decoder
|
# Compute loss of the decoder
|
||||||
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
|
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
|
||||||
|
|
||||||
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
if self.prior_loss:
|
||||||
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
||||||
|
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
||||||
|
else:
|
||||||
|
prior_loss = 0
|
||||||
|
|
||||||
return dur_loss, prior_loss, diff_loss
|
return dur_loss, prior_loss, diff_loss
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import phonemizer
|
import phonemizer
|
||||||
|
import piper_phonemize
|
||||||
from unidecode import unidecode
|
from unidecode import unidecode
|
||||||
|
|
||||||
# To avoid excessive logging we set the log level of the phonemizer package to Critical
|
# To avoid excessive logging we set the log level of the phonemizer package to Critical
|
||||||
@@ -103,3 +104,13 @@ def english_cleaners2(text):
|
|||||||
phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
|
phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
|
||||||
phonemes = collapse_whitespace(phonemes)
|
phonemes = collapse_whitespace(phonemes)
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
|
|
||||||
|
def english_cleaners_piper(text):
|
||||||
|
"""Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
|
||||||
|
text = convert_to_ascii(text)
|
||||||
|
text = lowercase(text)
|
||||||
|
text = expand_abbreviations(text)
|
||||||
|
phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
|
||||||
|
phonemes = collapse_whitespace(phonemes)
|
||||||
|
return phonemes
|
||||||
|
|||||||
@@ -35,10 +35,11 @@ torchaudio
|
|||||||
matplotlib
|
matplotlib
|
||||||
pandas
|
pandas
|
||||||
conformer==0.3.2
|
conformer==0.3.2
|
||||||
diffusers==0.21.3
|
diffusers==0.25.0
|
||||||
notebook
|
notebook
|
||||||
ipywidgets
|
ipywidgets
|
||||||
gradio
|
gradio
|
||||||
gdown
|
gdown
|
||||||
wget
|
wget
|
||||||
seaborn
|
seaborn
|
||||||
|
piper_phonemize
|
||||||
|
|||||||
Reference in New Issue
Block a user