11 Commits

Author SHA1 Message Date
Shivam Mehta
4c35836fa5 minor fix 2024-03-02 12:48:54 +00:00
Shivam Mehta
294c6b1327 Adding saving phones while getting durations from matcha 2024-03-02 12:47:08 +00:00
Shivam Mehta
ad76016916 Fixing configs 2024-02-26 09:11:22 +00:00
Shivam Mehta
05c8f9b4a8 updating configs and experiments 2024-02-25 22:02:36 +00:00
Shivam Mehta
4d5b62cea9 Adding a bit of comments 2024-02-24 15:20:13 +00:00
Shivam Mehta
8e87111a98 Adding possibility of getting durations out 2024-02-24 15:10:19 +00:00
Shivam Mehta
def0855608 Adding other experiment configs 2024-01-22 11:46:08 +00:00
Shivam Mehta
6976a91348 Merge branch 'main' into stoc_dur 2024-01-12 11:58:41 +00:00
Shivam Mehta
458e9df236 Adding synthesis 2024-01-10 11:05:22 +00:00
Shivam Mehta
d03bba82bb In the middle of adding discrete nf based duration predictor 2024-01-10 11:04:46 +00:00
Shivam Mehta
a58bab5403 Adding option to do flow matching based duration prediction 2024-01-05 11:13:07 +00:00
53 changed files with 1989 additions and 589 deletions

1
.gitignore vendored
View File

@@ -161,4 +161,3 @@ generator_v1
g_02500000 g_02500000
gradio_cached_examples/ gradio_cached_examples/
synth_output/ synth_output/
/data

View File

@@ -1,9 +1,9 @@
default_language_version: default_language_version:
python: python3.11 python: python3.10
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0 rev: v4.5.0
hooks: hooks:
# list of supported hooks: https://pre-commit.com/hooks.html # list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace - id: trailing-whitespace
@@ -17,29 +17,29 @@ repos:
- id: check-added-large-files - id: check-added-large-files
# python code formatting # python code formatting
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black
rev: 26.1.0 rev: 23.12.1
hooks: hooks:
- id: black - id: black
args: [--line-length, "120"] args: [--line-length, "120"]
# python import sorting # python import sorting
- repo: https://github.com/PyCQA/isort - repo: https://github.com/PyCQA/isort
rev: 7.0.0 rev: 5.13.2
hooks: hooks:
- id: isort - id: isort
args: ["--profile", "black", "--filter-files"] args: ["--profile", "black", "--filter-files"]
# python upgrading syntax to newer version # python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.21.2 rev: v3.15.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade
args: [--py38-plus] args: [--py38-plus]
# python check (PEP8), programming errors and code complexity # python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 7.3.0 rev: 7.0.0
hooks: hooks:
- id: flake8 - id: flake8
args: args:
@@ -54,6 +54,6 @@ repos:
# pylint # pylint
- repo: https://github.com/pycqa/pylint - repo: https://github.com/pycqa/pylint
rev: v4.0.4 rev: v3.0.3
hooks: hooks:
- id: pylint - id: pylint

View File

@@ -10,7 +10,7 @@
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/) [![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
[![PyPI Downloads](https://static.pepy.tech/personalized-badge/matcha-tts?period=total&units=INTERNATIONAL_SYSTEM&left_color=BLACK&right_color=GREEN&left_text=downloads)](https://pepy.tech/projects/matcha-tts)
<p style="text-align: center;"> <p style="text-align: center;">
<img src="https://shivammehta25.github.io/Matcha-TTS/images/logo.png" height="128"/> <img src="https://shivammehta25.github.io/Matcha-TTS/images/logo.png" height="128"/>
</p> </p>
@@ -252,43 +252,6 @@ python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vo
This will write `.wav` audio files to the output directory. This will write `.wav` audio files to the output directory.
## Extract phoneme alignments from Matcha-TTS
If the dataset is structured as
```bash
data/
└── LJSpeech-1.1
├── metadata.csv
├── README
├── test.txt
├── train.txt
├── val.txt
└── wavs
```
Then you can extract the phoneme level alignments from a Trained Matcha-TTS model using:
```bash
python matcha/utils/get_durations_from_trained_model.py -i dataset_yaml -c <checkpoint>
```
Example:
```bash
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c matcha_ljspeech.ckpt
```
or simply:
```bash
matcha-tts-get-durations -i ljspeech.yaml -c matcha_ljspeech.ckpt
```
---
## Train using extracted alignments
In the datasetconfig turn on load duration.
Example: `ljspeech.yaml`
```
load_durations: True
```
or see an examples in configs/experiment/ljspeech_from_durations.yaml
## Citation information ## Citation information
If you use our code or otherwise find this work useful, please cite our paper: If you use our code or otherwise find this work useful, please cite our paper:

View File

@@ -5,8 +5,8 @@ defaults:
# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/ # Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/
_target_: matcha.data.text_mel_datamodule.TextMelDataModule _target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: hi-fi_en-US_female name: hi-fi_en-US_female
train_filelist_path: data/hi-fi_en-US_female/train.txt train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt
valid_filelist_path: data/hi-fi_en-US_female/val.txt valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt
batch_size: 32 batch_size: 32
cleaners: [english_cleaners_piper] cleaners: [english_cleaners_piper]
data_statistics: # Computed for this dataset data_statistics: # Computed for this dataset

View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: joe_spont_only
train_filelist_path: data/filelists/joe_spontonly_train.txt
valid_filelist_path: data/filelists/joe_spontonly_val.txt
data_statistics:
mel_mean: -5.882903
mel_std: 2.458284

View File

@@ -1,7 +1,7 @@
_target_: matcha.data.text_mel_datamodule.TextMelDataModule _target_: matcha.data.text_mel_datamodule.TextMelDataModule
name: ljspeech name: ljspeech
train_filelist_path: data/LJSpeech-1.1/train.txt train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt
valid_filelist_path: data/LJSpeech-1.1/val.txt valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt
batch_size: 32 batch_size: 32
num_workers: 20 num_workers: 20
pin_memory: True pin_memory: True
@@ -19,4 +19,3 @@ data_statistics: # Computed for ljspeech dataset
mel_mean: -5.536622 mel_mean: -5.536622
mel_std: 2.116101 mel_std: 2.116101
seed: ${seed} seed: ${seed}
load_durations: false

10
configs/data/ryan.yaml Normal file
View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: ryan
train_filelist_path: data/filelists/ryan_train.csv
valid_filelist_path: data/filelists/ryan_val.csv
data_statistics:
mel_mean: -4.715779
mel_std: 2.124502

10
configs/data/tsg2.yaml Normal file
View File

@@ -0,0 +1,10 @@
defaults:
- ljspeech
- _self_
name: tsg2
train_filelist_path: data/filelists/cormac_train.txt
valid_filelist_path: data/filelists/cormac_val.txt
data_statistics:
mel_mean: -5.536622
mel_std: 2.116101

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_det_dur

View File

@@ -0,0 +1,20 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: joe_spont_only.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["joe"]
run_name: joe_stoc_dur
model:
duration_predictor:
p_dropout: 0.2

View File

@@ -5,15 +5,12 @@
defaults: defaults:
- override /data: ljspeech.yaml - override /data: ljspeech.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above # all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters # this allows you to overwrite only specified parameters
tags: ["ljspeech"] tags: ["ljspeech"]
run_name: ljspeech run_name: ljspeech
data:
load_durations: True
batch_size: 64

View File

@@ -0,0 +1,18 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_det_dur
trainer:
max_epochs: 3000

View File

@@ -0,0 +1,24 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: ryan.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["ryan"]
run_name: ryan_stoc_dur
model:
duration_predictor:
p_dropout: 0.2
trainer:
max_epochs: 3000

View File

@@ -0,0 +1,14 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_det_dur

View File

@@ -0,0 +1,20 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=multispeaker
defaults:
- override /data: tsg2.yaml
- override /model/duration_predictor: flow_matching.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["tsg2"]
run_name: tsg2_stoc_dur
model:
duration_predictor:
p_dropout: 0.5

View File

@@ -0,0 +1,7 @@
name: deterministic
n_spks: ${model.n_spks}
spk_emb_dim: ${model.spk_emb_dim}
filter_channels: 256
kernel_size: 3
n_channels: ${model.encoder.encoder_params.n_channels}
p_dropout: ${model.encoder.encoder_params.p_dropout}

View File

@@ -0,0 +1,7 @@
defaults:
- deterministic.yaml
- _self_
sigma_min: 1e-4
n_steps: 10
name: flow_matching

View File

@@ -3,16 +3,8 @@ encoder_params:
n_feats: ${model.n_feats} n_feats: ${model.n_feats}
n_channels: 192 n_channels: 192
filter_channels: 768 filter_channels: 768
filter_channels_dp: 256
n_heads: 2 n_heads: 2
n_layers: 6 n_layers: 6
kernel_size: 3 kernel_size: 3
p_dropout: 0.1 p_dropout: 0.1
spk_emb_dim: 64
n_spks: 1
prenet: true prenet: true
duration_predictor_params:
filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp}
kernel_size: 3
p_dropout: ${model.encoder.encoder_params.p_dropout}

View File

@@ -1,6 +1,7 @@
defaults: defaults:
- _self_ - _self_
- encoder: default.yaml - encoder: default.yaml
- duration_predictor: deterministic.yaml
- decoder: default.yaml - decoder: default.yaml
- cfm: default.yaml - cfm: default.yaml
- optimizer: adam.yaml - optimizer: adam.yaml
@@ -13,4 +14,3 @@ n_feats: 80
data_statistics: ${data.data_statistics} data_statistics: ${data.data_statistics}
out_size: null # Must be divisible by 4 out_size: null # Must be divisible by 4
prior_loss: true prior_loss: true
use_precomputed_durations: ${data.load_durations}

1
data Symbolic link
View File

@@ -0,0 +1 @@
/home/smehta/Projects/Speech-Backbones/Grad-TTS/data

View File

@@ -1 +1 @@
0.0.7.2 0.0.5.1

View File

@@ -48,7 +48,7 @@ def plot_spectrogram_to_numpy(spectrogram, filename):
def process_text(i: int, text: str, device: torch.device): def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}") print(f"[{i}] - Input text: {text}")
x = torch.tensor( x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), intersperse(text_to_sequence(text, ["english_cleaners2"]), 0),
dtype=torch.long, dtype=torch.long,
device=device, device=device,
)[None] )[None]
@@ -114,10 +114,10 @@ def load_matcha(model_name, checkpoint_path, device):
return model return model
def to_waveform(mel, vocoder, denoiser=None, denoiser_strength=0.00025): def to_waveform(mel, vocoder, denoiser=None):
audio = vocoder(mel).clamp(-1, 1) audio = vocoder(mel).clamp(-1, 1)
if denoiser is not None: if denoiser is not None:
audio = denoiser(audio.squeeze(), strength=denoiser_strength).cpu().squeeze() audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze() return audio.cpu().squeeze()
@@ -227,7 +227,7 @@ def cli():
parser.add_argument( parser.add_argument(
"--vocoder", "--vocoder",
type=str, type=str,
default=None, default="hifigan_univ_v1",
help="Vocoder to use (default: will use the one suggested with the pretrained model))", help="Vocoder to use (default: will use the one suggested with the pretrained model))",
choices=VOCODER_URLS.keys(), choices=VOCODER_URLS.keys(),
) )
@@ -326,17 +326,16 @@ def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
for i, batch in enumerate(dataloader): for i, batch in enumerate(dataloader):
i = i + 1 i = i + 1
start_t = dt.datetime.now() start_t = dt.datetime.now()
b = batch["x"].shape[0]
output = model.synthesise( output = model.synthesise(
batch["x"].to(device), batch["x"].to(device),
batch["x_lengths"].to(device), batch["x_lengths"].to(device),
n_timesteps=args.steps, n_timesteps=args.steps,
temperature=args.temperature, temperature=args.temperature,
spks=spk.expand(b) if spk is not None else spk, spks=spk,
length_scale=args.speaking_rate, length_scale=args.speaking_rate,
) )
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
t = (dt.datetime.now() - start_t).total_seconds() t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1]) rtf_w = t * 22050 / (output["waveform"].shape[-1])
print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}") print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}")
@@ -377,7 +376,7 @@ def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
spks=spk, spks=spk,
length_scale=args.speaking_rate, length_scale=args.speaking_rate,
) )
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
# RTF with HiFiGAN # RTF with HiFiGAN
t = (dt.datetime.now() - start_t).total_seconds() t = (dt.datetime.now() - start_t).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1]) rtf_w = t * 22050 / (output["waveform"].shape[-1])

View File

@@ -1,8 +1,6 @@
import random import random
from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import numpy as np
import torch import torch
import torchaudio as ta import torchaudio as ta
from lightning import LightningDataModule from lightning import LightningDataModule
@@ -41,7 +39,6 @@ class TextMelDataModule(LightningDataModule):
f_max, f_max,
data_statistics, data_statistics,
seed, seed,
load_durations,
): ):
super().__init__() super().__init__()
@@ -71,7 +68,6 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max, self.hparams.f_max,
self.hparams.data_statistics, self.hparams.data_statistics,
self.hparams.seed, self.hparams.seed,
self.hparams.load_durations,
) )
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.valid_filelist_path, self.hparams.valid_filelist_path,
@@ -87,7 +83,6 @@ class TextMelDataModule(LightningDataModule):
self.hparams.f_max, self.hparams.f_max,
self.hparams.data_statistics, self.hparams.data_statistics,
self.hparams.seed, self.hparams.seed,
self.hparams.load_durations,
) )
def train_dataloader(self): def train_dataloader(self):
@@ -139,7 +134,6 @@ class TextMelDataset(torch.utils.data.Dataset):
f_max=8000, f_max=8000,
data_parameters=None, data_parameters=None,
seed=None, seed=None,
load_durations=False,
): ):
self.filepaths_and_text = parse_filelist(filelist_path) self.filepaths_and_text = parse_filelist(filelist_path)
self.n_spks = n_spks self.n_spks = n_spks
@@ -152,8 +146,6 @@ class TextMelDataset(torch.utils.data.Dataset):
self.win_length = win_length self.win_length = win_length
self.f_min = f_min self.f_min = f_min
self.f_max = f_max self.f_max = f_max
self.load_durations = load_durations
if data_parameters is not None: if data_parameters is not None:
self.data_parameters = data_parameters self.data_parameters = data_parameters
else: else:
@@ -175,26 +167,7 @@ class TextMelDataset(torch.utils.data.Dataset):
text, cleaned_text = self.get_text(text, add_blank=self.add_blank) text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath) mel = self.get_mel(filepath)
durations = self.get_durations(filepath, text) if self.load_durations else None return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text}
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
def get_durations(self, filepath, text):
filepath = Path(filepath)
data_dir, name = filepath.parent.parent, filepath.stem
try:
dur_loc = data_dir / "durations" / f"{name}.npy"
durs = torch.from_numpy(np.load(dur_loc).astype(int))
except FileNotFoundError as e:
raise FileNotFoundError(
f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
) from e
assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
return durs
def get_mel(self, filepath): def get_mel(self, filepath):
audio, sr = ta.load(filepath) audio, sr = ta.load(filepath)
@@ -241,8 +214,6 @@ class TextMelBatchCollate:
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long) x = torch.zeros((B, x_max_length), dtype=torch.long)
durations = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], [] y_lengths, x_lengths = [], []
spks = [] spks = []
filepaths, x_texts = [], [] filepaths, x_texts = [], []
@@ -255,8 +226,6 @@ class TextMelBatchCollate:
spks.append(item["spk"]) spks.append(item["spk"])
filepaths.append(item["filepath"]) filepaths.append(item["filepath"])
x_texts.append(item["x_text"]) x_texts.append(item["x_text"])
if item["durations"] is not None:
durations[i, : item["durations"].shape[-1]] = item["durations"]
y_lengths = torch.tensor(y_lengths, dtype=torch.long) y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long) x_lengths = torch.tensor(x_lengths, dtype=torch.long)
@@ -270,5 +239,4 @@ class TextMelBatchCollate:
"spks": spks, "spks": spks,
"filepaths": filepaths, "filepaths": filepaths,
"x_texts": x_texts, "x_texts": x_texts,
"durations": durations if not torch.eq(durations, 0).all() else None,
} }

View File

@@ -1,14 +1,9 @@
# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
import torch import torch
class ModeException(Exception):
pass
class Denoiser(torch.nn.Module): class Denoiser(torch.nn.Module):
"""Removes model bias from audio produced with waveglow""" """Removes model bias from audio produced with waveglow"""
@@ -25,7 +20,7 @@ class Denoiser(torch.nn.Module):
elif mode == "normal": elif mode == "normal":
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
else: else:
raise ModeException(f"Mode {mode} if not supported") raise Exception(f"Mode {mode} if not supported")
def stft_fn(audio, n_fft, hop_length, win_length, window): def stft_fn(audio, n_fft, hop_length, win_length, window):
spec = torch.stft( spec = torch.stft(

View File

@@ -55,7 +55,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned global mel_basis, hann_window # pylint: disable=global-statement
if fmax not in mel_basis: if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

@@ -1,7 +1,7 @@
""" from https://github.com/jik876/hifi-gan """ """ from https://github.com/jik876/hifi-gan """
import torch import torch
import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm

View File

@@ -2,7 +2,6 @@
This is a base lightning module that can be used to train a model. This is a base lightning module that can be used to train a model.
The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
""" """
import inspect import inspect
from abc import ABC from abc import ABC
from typing import Any, Dict from typing import Any, Dict
@@ -66,7 +65,6 @@ class BaseLightningClass(LightningModule, ABC):
y_lengths=y_lengths, y_lengths=y_lengths,
spks=spks, spks=spks,
out_size=self.out_size, out_size=self.out_size,
durations=batch["durations"],
) )
return { return {
"dur_loss": dur_loss, "dur_loss": dur_loss,

View File

@@ -2,7 +2,7 @@ import math
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from conformer import ConformerBlock from conformer import ConformerBlock
from diffusers.models.activations import get_activation from diffusers.models.activations import get_activation

View File

@@ -0,0 +1,448 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack
from matcha.models.components.decoder import SinusoidalPosEmb, TimestepEmbedding
from matcha.models.components.text_encoder import LayerNorm
# Define available networks
class DurationPredictorNetwork(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class DurationPredictorNetworkWithTimeStep(nn.Module):
"""Similar architecture but with a time embedding support"""
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.time_embeddings = SinusoidalPosEmb(filter_channels)
self.time_mlp = TimestepEmbedding(
in_channels=filter_channels,
time_embed_dim=filter_channels,
act_fn="silu",
)
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask, enc_outputs, t):
t = self.time_embeddings(t)
t = self.time_mlp(t).unsqueeze(-1)
x = pack([x, enc_outputs], "b * t")[0]
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = x + t
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
# Define available methods to compute loss
# Simple MSE deterministic
class DeterministicDurationPredictor(nn.Module):
def __init__(self, params):
super().__init__()
self.estimator = DurationPredictorNetwork(
params.n_channels + (params.spk_emb_dim if params.n_spks > 1 else 0),
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
@torch.inference_mode()
def forward(self, x, x_mask):
return self.estimator(x, x_mask)
def compute_loss(self, durations, enc_outputs, x_mask):
return F.mse_loss(self.estimator(enc_outputs, x_mask), durations, reduction="sum") / torch.sum(x_mask)
# Flow Matching duration predictor
class FlowMatchingDurationPrediction(nn.Module):
def __init__(self, params) -> None:
super().__init__()
self.estimator = DurationPredictorNetworkWithTimeStep(
1
+ params.n_channels
+ (
params.spk_emb_dim if params.n_spks > 1 else 0
), # 1 for the durations and n_channels for encoder outputs
params.filter_channels,
params.kernel_size,
params.p_dropout,
)
self.sigma_min = params.sigma_min
self.n_steps = params.n_steps
@torch.inference_mode()
def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
if n_timesteps is None:
n_timesteps = self.n_steps
b, _, t = enc_outputs.shape
z = torch.randn((b, 1, t), device=enc_outputs.device, dtype=enc_outputs.dtype) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=enc_outputs.device)
return self.solve_euler(z, t_span=t_span, enc_outputs=enc_outputs, mask=mask)
def solve_euler(self, x, t_span, enc_outputs, mask):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, enc_outputs, t)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, enc_outputs, mask):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
enc_outputs = enc_outputs.detach() # don't update encoder from the duration predictor
b, _, t = enc_outputs.shape
# random timestep
t = torch.rand([b, 1, 1], device=enc_outputs.device, dtype=enc_outputs.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, enc_outputs, t.squeeze()), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss
# VITS discrete normalising flow based duration predictor
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = Log()
self.flows = nn.ModuleList()
self.flows.append(ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
# Meta class to wrap all duration predictors
class DP(nn.Module):
def __init__(self, params):
super().__init__()
self.name = params.name
if params.name == "deterministic":
self.dp = DeterministicDurationPredictor(
params,
)
elif params.name == "flow_matching":
self.dp = FlowMatchingDurationPrediction(
params,
)
else:
raise ValueError(f"Invalid duration predictor configuration: {params.name}")
@torch.inference_mode()
def forward(self, enc_outputs, mask):
return self.dp(enc_outputs, mask)
def compute_loss(self, durations, enc_outputs, mask):
return self.dp.compute_loss(durations, enc_outputs, mask)

View File

@@ -3,10 +3,10 @@
import math import math
import torch import torch
import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn as nn
from einops import rearrange from einops import rearrange
import matcha.utils as utils # pylint: disable=consider-using-from-import import matcha.utils as utils
from matcha.utils.model import sequence_mask from matcha.utils.model import sequence_mask
log = utils.get_pylogger(__name__) log = utils.get_pylogger(__name__)
@@ -67,33 +67,6 @@ class ConvReluNorm(nn.Module):
return x * x_mask return x * x_mask
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class RotaryPositionalEmbeddings(nn.Module): class RotaryPositionalEmbeddings(nn.Module):
""" """
## RoPE module ## RoPE module
@@ -330,7 +303,6 @@ class TextEncoder(nn.Module):
self, self,
encoder_type, encoder_type,
encoder_params, encoder_params,
duration_predictor_params,
n_vocab, n_vocab,
n_spks=1, n_spks=1,
spk_emb_dim=128, spk_emb_dim=128,
@@ -368,12 +340,6 @@ class TextEncoder(nn.Module):
) )
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
self.proj_w = DurationPredictor(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
duration_predictor_params.filter_channels_dp,
duration_predictor_params.kernel_size,
duration_predictor_params.p_dropout,
)
def forward(self, x, x_lengths, spks=None): def forward(self, x, x_lengths, spks=None):
"""Run forward pass to the transformer based encoder and duration predictor """Run forward pass to the transformer based encoder and duration predictor
@@ -404,7 +370,7 @@ class TextEncoder(nn.Module):
x = self.encoder(x, x_mask) x = self.encoder(x, x_mask)
mu = self.proj_m(x) * x_mask mu = self.proj_m(x) * x_mask
x_dp = torch.detach(x) # x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask) # logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask return mu, x, x_mask

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn as nn
from diffusers.models.attention import ( from diffusers.models.attention import (
GEGLU, GEGLU,
GELU, GELU,

View File

@@ -7,11 +7,11 @@ import torch
import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
from matcha import utils from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.duration_predictors import DP
from matcha.models.components.flow_matching import CFM from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder from matcha.models.components.text_encoder import TextEncoder
from matcha.utils.model import ( from matcha.utils.model import (
denormalize, denormalize,
duration_loss,
fix_len_compatibility, fix_len_compatibility,
generate_path, generate_path,
sequence_mask, sequence_mask,
@@ -28,6 +28,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
spk_emb_dim, spk_emb_dim,
n_feats, n_feats,
encoder, encoder,
duration_predictor,
decoder, decoder,
cfm, cfm,
data_statistics, data_statistics,
@@ -35,7 +36,6 @@ class MatchaTTS(BaseLightningClass): # 🍵
optimizer=None, optimizer=None,
scheduler=None, scheduler=None,
prior_loss=True, prior_loss=True,
use_precomputed_durations=False,
): ):
super().__init__() super().__init__()
@@ -47,7 +47,6 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.n_feats = n_feats self.n_feats = n_feats
self.out_size = out_size self.out_size = out_size
self.prior_loss = prior_loss self.prior_loss = prior_loss
self.use_precomputed_durations = use_precomputed_durations
if n_spks > 1: if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
@@ -55,12 +54,13 @@ class MatchaTTS(BaseLightningClass): # 🍵
self.encoder = TextEncoder( self.encoder = TextEncoder(
encoder.encoder_type, encoder.encoder_type,
encoder.encoder_params, encoder.encoder_params,
encoder.duration_predictor_params,
n_vocab, n_vocab,
n_spks, n_spks,
spk_emb_dim, spk_emb_dim,
) )
self.dp = DP(duration_predictor)
self.decoder = CFM( self.decoder = CFM(
in_channels=2 * encoder.encoder_params.n_feats, in_channels=2 * encoder.encoder_params.n_feats,
out_channel=encoder.encoder_params.n_feats, out_channel=encoder.encoder_params.n_feats,
@@ -106,7 +106,6 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Lengths of mel spectrograms # Lengths of mel spectrograms
"rtf": float, "rtf": float,
# Real-time factor # Real-time factor
}
""" """
# For RTF computation # For RTF computation
t = dt.datetime.now() t = dt.datetime.now()
@@ -115,11 +114,15 @@ class MatchaTTS(BaseLightningClass): # 🍵
# Get speaker embedding # Get speaker embedding
spks = self.spk_emb(spks.long()) spks = self.spk_emb(spks.long())
# Get encoder_outputs `mu_x` and log-scaled token durations `logw` # Get encoder_outputs `mu_x` and encoded text `enc_output`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
# Get log-scaled token durations `logw`
logw = self.dp(enc_output, x_mask)
w = torch.exp(logw) * x_mask w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale w_ceil = torch.ceil(w) * length_scale
# print(w_ceil)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max() y_max_length = y_lengths.max()
y_max_length_ = fix_len_compatibility(y_max_length) y_max_length_ = fix_len_compatibility(y_max_length)
@@ -150,10 +153,10 @@ class MatchaTTS(BaseLightningClass): # 🍵
"rtf": rtf, "rtf": rtf,
} }
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None): def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
""" """
Computes 3 losses: Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotonic Alignment Search (MAS). 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs. 2. prior loss: loss between mel-spectrogram and encoder outputs.
3. flow matching loss: loss between mel-spectrogram and decoder outputs. 3. flow matching loss: loss between mel-spectrogram and decoder outputs.
@@ -176,15 +179,12 @@ class MatchaTTS(BaseLightningClass): # 🍵
spks = self.spk_emb(spks) spks = self.spk_emb(spks)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw` # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) mu_x, enc_output, x_mask = self.encoder(x, x_lengths, spks)
y_max_length = y.shape[-1] y_max_length = y.shape[-1]
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
if self.use_precomputed_durations:
attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
else:
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad(): with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats const = -0.5 * math.log(2 * math.pi) * self.n_feats
@@ -195,12 +195,11 @@ class MatchaTTS(BaseLightningClass): # 🍵
log_prior = y_square - y_mu_double + mu_square + const log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach() # b, t_text, T_mel attn = attn.detach()
# Compute loss between predicted log-scaled durations and those obtained from MAS # Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths) dur_loss = self.dp.compute_loss(logw_, enc_output, x_mask)
# Cut a small segment of mel-spectrogram in order to increase batch size # Cut a small segment of mel-spectrogram in order to increase batch size
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it

View File

@@ -1,5 +1,4 @@
""" from https://github.com/keithito/tacotron """ """ from https://github.com/keithito/tacotron """
from matcha.text import cleaners from matcha.text import cleaners
from matcha.text.symbols import symbols from matcha.text.symbols import symbols
@@ -8,10 +7,6 @@ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
class UnknownCleanerException(Exception):
pass
def text_to_sequence(text, cleaner_names): def text_to_sequence(text, cleaner_names):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text. """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args: Args:
@@ -53,6 +48,6 @@ def _clean_text(text, cleaner_names):
for name in cleaner_names: for name in cleaner_names:
cleaner = getattr(cleaners, name) cleaner = getattr(cleaners, name)
if not cleaner: if not cleaner:
raise UnknownCleanerException(f"Unknown cleaner: {name}") raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text) text = cleaner(text)
return text return text

View File

@@ -15,6 +15,7 @@ import logging
import re import re
import phonemizer import phonemizer
import piper_phonemize
from unidecode import unidecode from unidecode import unidecode
# To avoid excessive logging we set the log level of the phonemizer package to Critical # To avoid excessive logging we set the log level of the phonemizer package to Critical
@@ -36,12 +37,9 @@ global_phonemizer = phonemizer.backend.EspeakBackend(
# Regular expression matching whitespace: # Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+") _whitespace_re = re.compile(r"\s+")
# Remove brackets
_brackets_re = re.compile(r"[\[\]\(\)\{\}]")
# List of (regular expression, replacement) pairs for abbreviations: # List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [ _abbreviations = [
(re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1]) (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [ for x in [
("mrs", "misess"), ("mrs", "misess"),
("mr", "mister"), ("mr", "mister"),
@@ -75,10 +73,6 @@ def lowercase(text):
return text.lower() return text.lower()
def remove_brackets(text):
return re.sub(_brackets_re, "", text)
def collapse_whitespace(text): def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text) return re.sub(_whitespace_re, " ", text)
@@ -108,37 +102,15 @@ def english_cleaners2(text):
text = lowercase(text) text = lowercase(text)
text = expand_abbreviations(text) text = expand_abbreviations(text)
phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
# Added in some cases espeak is not removing brackets
phonemes = remove_brackets(phonemes)
phonemes = collapse_whitespace(phonemes) phonemes = collapse_whitespace(phonemes)
return phonemes return phonemes
def ipa_simplifier(text): def english_cleaners_piper(text):
replacements = [ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
("ɐ", "ə"), text = convert_to_ascii(text)
("ˈə", "ə"), text = lowercase(text)
("ʤ", ""), text = expand_abbreviations(text)
("ʧ", ""), phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
("", "ɪ"), phonemes = collapse_whitespace(phonemes)
]
for replacement in replacements:
text = text.replace(replacement[0], replacement[1])
phonemes = collapse_whitespace(text)
return phonemes return phonemes
# I am removing this due to incompatibility with several version of python
# However, if you want to use it, you can uncomment it
# and install piper-phonemize with the following command:
# pip install piper-phonemize
# import piper_phonemize
# def english_cleaners_piper(text):
# """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
# text = convert_to_ascii(text)
# text = lowercase(text)
# text = expand_abbreviations(text)
# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
# phonemes = collapse_whitespace(phonemes)
# return phonemes

View File

@@ -2,7 +2,6 @@
Defines the set of symbols used in text input to the model. Defines the set of symbols used in text input to the model.
""" """
_pad = "_" _pad = "_"
_punctuation = ';:,.!?¡¿—…"«»“” ' _punctuation = ';:,.!?¡¿—…"«»“” '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

View File

@@ -48,7 +48,7 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned global mel_basis, hann_window # pylint: disable=global-statement
if f"{str(fmax)}_{str(y.device)}" not in mel_basis: if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)

View File

@@ -1,148 +0,0 @@
#!/usr/bin/env python
import argparse
import os
import sys
import tempfile
from pathlib import Path
import torchaudio
from torch.hub import download_url_to_file
from tqdm import tqdm
from matcha.utils.data.utils import _extract_zip
URLS = {
"en-US": {
"female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_F.zip",
"male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_M.zip",
},
"ja-JP": {
"female": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_F.zip",
"male": "https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_M.zip",
},
}
INFO_PAGE = "https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/"
# On their website they say "We NICT open-sourced Hi-Fi-CAPTAIN",
# but they use this very-much-not-open-source licence.
# Dunno if this is open washing or stupidity.
LICENCE = "CC BY-NC-SA 4.0"
# I'd normally put the citation here. It's on their website.
# Boo to non-open-source stuff.
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files")
parser.add_argument(
"-r",
"--skip-resampling",
action="store_true",
default=False,
help="Skip resampling the data (from 48 to 22.05)",
)
parser.add_argument(
"-l", "--language", type=str, choices=["en-US", "ja-JP"], default="en-US", help="The language to download"
)
parser.add_argument(
"-g",
"--gender",
type=str,
choices=["male", "female"],
default="female",
help="The gender of the speaker to download",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
default="data",
help="Place to store the converted data. Top-level only, the subdirectory will be created",
)
return parser.parse_args()
def process_text(infile, outpath: Path):
outmode = "w"
if infile.endswith("dev.txt"):
outfile = outpath / "valid.txt"
elif infile.endswith("eval.txt"):
outfile = outpath / "test.txt"
else:
outfile = outpath / "train.txt"
if outfile.exists():
outmode = "a"
with (
open(infile, encoding="utf-8") as inf,
open(outfile, outmode, encoding="utf-8") as of,
):
for line in inf.readlines():
line = line.strip()
fileid, rest = line.split(" ", maxsplit=1)
outfile = str(outpath / f"{fileid}.wav")
of.write(f"{outfile}|{rest}\n")
def process_files(zipfile, outpath, resample=True):
with tempfile.TemporaryDirectory() as tmpdirname:
for filename in tqdm(_extract_zip(zipfile, tmpdirname)):
if not filename.startswith(tmpdirname):
filename = os.path.join(tmpdirname, filename)
if filename.endswith(".txt"):
process_text(filename, outpath)
elif filename.endswith(".wav"):
filepart = filename.rsplit("/", maxsplit=1)[-1]
outfile = str(outpath / filepart)
arr, sr = torchaudio.load(filename)
if resample:
arr = torchaudio.functional.resample(arr, orig_freq=sr, new_freq=22050)
torchaudio.save(outfile, arr, 22050)
else:
continue
def main():
args = get_args()
save_dir = None
if args.save_dir:
save_dir = Path(args.save_dir)
if not save_dir.is_dir():
save_dir.mkdir()
if not args.output_dir:
print("output directory not specified, exiting")
sys.exit(1)
URL = URLS[args.language][args.gender]
dirname = f"hi-fi_{args.language}_{args.gender}"
outbasepath = Path(args.output_dir)
if not outbasepath.is_dir():
outbasepath.mkdir()
outpath = outbasepath / dirname
if not outpath.is_dir():
outpath.mkdir()
resample = True
if args.skip_resampling:
resample = False
if save_dir:
zipname = URL.rsplit("/", maxsplit=1)[-1]
zipfile = save_dir / zipname
if not zipfile.exists():
download_url_to_file(URL, zipfile, progress=True)
process_files(zipfile, outpath, resample)
else:
with tempfile.NamedTemporaryFile(suffix=".zip", delete=True) as zf:
download_url_to_file(URL, zf.name, progress=True)
process_files(zf.name, outpath, resample)
if __name__ == "__main__":
main()

View File

@@ -1,97 +0,0 @@
#!/usr/bin/env python
import argparse
import random
import tempfile
from pathlib import Path
from torch.hub import download_url_to_file
from matcha.utils.data.utils import _extract_tar
URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
INFO_PAGE = "https://keithito.com/LJ-Speech-Dataset/"
LICENCE = "Public domain (LibriVox copyright disclaimer)"
CITATION = """
@misc{ljspeech17,
author = {Keith Ito and Linda Johnson},
title = {The LJ Speech Dataset},
howpublished = {\\url{https://keithito.com/LJ-Speech-Dataset/}},
year = 2017
}
"""
def decision():
return random.random() < 0.98
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save-dir", type=str, default=None, help="Place to store the downloaded zip files")
parser.add_argument(
"output_dir",
type=str,
nargs="?",
default="data",
help="Place to store the converted data (subdirectory LJSpeech-1.1 will be created)",
)
return parser.parse_args()
def process_csv(ljpath: Path):
if (ljpath / "metadata.csv").exists():
basepath = ljpath
elif (ljpath / "LJSpeech-1.1" / "metadata.csv").exists():
basepath = ljpath / "LJSpeech-1.1"
csvpath = basepath / "metadata.csv"
wavpath = basepath / "wavs"
with (
open(csvpath, encoding="utf-8") as csvf,
open(basepath / "train.txt", "w", encoding="utf-8") as tf,
open(basepath / "val.txt", "w", encoding="utf-8") as vf,
):
for line in csvf.readlines():
line = line.strip()
parts = line.split("|")
wavfile = str(wavpath / f"{parts[0]}.wav")
if decision():
tf.write(f"{wavfile}|{parts[1]}\n")
else:
vf.write(f"{wavfile}|{parts[1]}\n")
def main():
args = get_args()
save_dir = None
if args.save_dir:
save_dir = Path(args.save_dir)
if not save_dir.is_dir():
save_dir.mkdir()
outpath = Path(args.output_dir)
if not outpath.is_dir():
outpath.mkdir()
if save_dir:
tarname = URL.rsplit("/", maxsplit=1)[-1]
tarfile = save_dir / tarname
if not tarfile.exists():
download_url_to_file(URL, str(tarfile), progress=True)
_extract_tar(tarfile, outpath)
process_csv(outpath)
else:
with tempfile.NamedTemporaryFile(suffix=".tar.bz2", delete=True) as zf:
download_url_to_file(URL, zf.name, progress=True)
_extract_tar(zf.name, outpath)
process_csv(outpath)
if __name__ == "__main__":
main()

View File

@@ -1,53 +0,0 @@
# taken from https://github.com/pytorch/audio/blob/main/src/torchaudio/datasets/utils.py
# Copyright (c) 2017 Facebook Inc. (Soumith Chintala)
# Licence: BSD 2-Clause
# pylint: disable=C0123
import logging
import os
import tarfile
import zipfile
from pathlib import Path
from typing import Any, List, Optional, Union
_LG = logging.getLogger(__name__)
def _extract_tar(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if type(from_path) is Path:
from_path = str(Path)
if to_path is None:
to_path = os.path.dirname(from_path)
with tarfile.open(from_path, "r") as tar:
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
return files
def _extract_zip(from_path: Union[str, Path], to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if type(from_path) is Path:
from_path = str(Path)
if to_path is None:
to_path = os.path.dirname(from_path)
with zipfile.ZipFile(from_path, "r") as zfile:
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
return files

View File

@@ -4,7 +4,6 @@ when needed.
Parameters from hparam.py will be used Parameters from hparam.py will be used
""" """
import argparse import argparse
import json import json
import os import os
@@ -95,7 +94,6 @@ def main():
cfg["batch_size"] = args.batch_size cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
text_mel_datamodule = TextMelDataModule(**cfg) text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup() text_mel_datamodule.setup()
@@ -103,8 +101,10 @@ def main():
log.info("Dataloader loaded! Now computing stats...") log.info("Dataloader loaded! Now computing stats...")
params = compute_data_statistics(data_loader, cfg["n_feats"]) params = compute_data_statistics(data_loader, cfg["n_feats"])
print(params) print(params)
with open(output_file, "w", encoding="utf-8") as dumpfile: json.dump(
json.dump(params, dumpfile) params,
open(output_file, "w"),
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -4,7 +4,6 @@ when needed.
Parameters from hparam.py will be used Parameters from hparam.py will be used
""" """
import argparse import argparse
import json import json
import os import os
@@ -87,7 +86,7 @@ def main():
"-i", "-i",
"--input-config", "--input-config",
type=str, type=str,
default="ljspeech.yaml", default="vctk.yaml",
help="The name of the yaml config file under configs/data", help="The name of the yaml config file under configs/data",
) )
@@ -141,14 +140,11 @@ def main():
cfg["batch_size"] = args.batch_size cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
cfg["load_durations"] = False
if args.output_folder is not None: if args.output_folder is not None:
output_folder = Path(args.output_folder) output_folder = Path(args.output_folder)
else: else:
output_folder = Path(cfg["train_filelist_path"]).parent / "durations" output_folder = Path("data") / "processed_data" / cfg["name"] / "durations"
print(f"Output folder set to: {output_folder}")
if os.path.exists(output_folder) and not args.force: if os.path.exists(output_folder) and not args.force:
print("Folder already exists. Use -f to force overwrite") print("Folder already exists. Use -f to force overwrite")

View File

@@ -72,7 +72,7 @@ def print_config_tree(
# save config tree to file # save config tree to file
if save_to_file: if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w", encoding="utf-8") as file: with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
rich.print(tree, file=file) rich.print(tree, file=file)
@@ -97,5 +97,5 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
log.info(f"Tags: {cfg.tags}") log.info(f"Tags: {cfg.tags}")
if save_to_file: if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file: with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
rich.print(cfg.tags, file=file) rich.print(cfg.tags, file=file)

View File

@@ -35,10 +35,11 @@ torchaudio
matplotlib matplotlib
pandas pandas
conformer==0.3.2 conformer==0.3.2
diffusers # developed using version ==0.25.0 diffusers==0.25.0
notebook notebook
ipywidgets ipywidgets
gradio==3.43.2 gradio
gdown gdown
wget wget
seaborn seaborn
piper_phonemize

15
scripts/get_durations.sh Normal file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
echo "Starting script"
echo "Getting LJ Speech durations"
python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c logs/train/lj_det/runs/2024-01-12_12-05-00/checkpoints/last.ckpt -f
echo "Getting TSG2 durations"
python matcha/utils/get_durations_from_trained_model.py -i tsg2.yaml -c logs/train/tsg2_det_dur/runs/2024-01-05_12-33-35/checkpoints/last.ckpt -f
echo "Getting Joe Spont durations"
python matcha/utils/get_durations_from_trained_model.py -i joe_spont_only.yaml -c logs/train/joe_det_dur/runs/2024-02-20_14-01-01/checkpoints/last.ckpt -f
echo "Getting Ryan durations"
python matcha/utils/get_durations_from_trained_model.py -i ryan.yaml -c logs/train/matcha_ryan_det/runs/2024-02-26_09-28-09/checkpoints/last.ckpt -f

7
scripts/transcribe.sh Normal file
View File

@@ -0,0 +1,7 @@
echo "Transcribing"
whispertranscriber -i lj_det_output -o lj_det_output_transcriptions -f
whispertranscriber -i lj_fm_output -o lj_fm_output_transcriptions -f
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_det_output_transcriptions
wercompute -r dur_wer_computation/reference_transcripts/ -i lj_fm_output_transcriptions

30
scripts/wer_computer.sh Normal file
View File

@@ -0,0 +1,30 @@
#!/bin/bash
# Run from root folder with: bash scripts/wer_computer.sh
root_folder=${1:-"dur_wer_computation"}
echo "Running WER computation for Duration predictors"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_fm_output_transcriptions/"
# echo $cmd
echo "LJ"
echo "==================================="
echo "Flow Matching"
$cmd
echo "-----------------------------------"
echo "LJ Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/lj_det_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac"
echo "==================================="
echo "Cormac Flow Matching"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/fm_output_transcriptions/"
$cmd
echo "-----------------------------------"
echo "Cormac Determinstic"
cmd="wercompute -r ${root_folder}/reference_transcripts/ -i ${root_folder}/det_output_transcriptions/"
$cmd
echo "-----------------------------------"

View File

@@ -16,16 +16,9 @@ with open("README.md", encoding="utf-8") as readme_file:
README = readme_file.read() README = readme_file.read()
cwd = os.path.dirname(os.path.abspath(__file__)) cwd = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(cwd, "matcha", "VERSION"), encoding="utf-8") as fin: with open(os.path.join(cwd, "matcha", "VERSION")) as fin:
version = fin.read().strip() version = fin.read().strip()
def get_requires():
requirements = os.path.join(os.path.dirname(__file__), "requirements.txt")
with open(requirements, encoding="utf-8") as reqfile:
return [str(r).strip() for r in reqfile]
setup( setup(
name="matcha-tts", name="matcha-tts",
version=version, version=version,
@@ -35,7 +28,7 @@ setup(
author="Shivam Mehta", author="Shivam Mehta",
author_email="shivam.mehta25@gmail.com", author_email="shivam.mehta25@gmail.com",
url="https://shivammehta25.github.io/Matcha-TTS", url="https://shivammehta25.github.io/Matcha-TTS",
install_requires=get_requires(), install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))],
include_dirs=[numpy.get_include()], include_dirs=[numpy.get_include()],
include_package_data=True, include_package_data=True,
packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]),
@@ -45,7 +38,6 @@ setup(
"matcha-data-stats=matcha.utils.generate_data_statistics:main", "matcha-data-stats=matcha.utils.generate_data_statistics:main",
"matcha-tts=matcha.cli:cli", "matcha-tts=matcha.cli:cli",
"matcha-tts-app=matcha.app:main", "matcha-tts-app=matcha.app:main",
"matcha-tts-get-durations=matcha.utils.get_durations_from_trained_model:main",
] ]
}, },
ext_modules=cythonize(exts, language_level=3), ext_modules=cythonize(exts, language_level=3),

File diff suppressed because one or more lines are too long