Initial commit

This commit is contained in:
Shivam Mehta
2023-09-16 17:51:36 +00:00
parent b189c1983a
commit f016784049
100 changed files with 6416 additions and 0 deletions

5
matcha/utils/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
from matcha.utils.logging_utils import log_hyperparameters
from matcha.utils.pylogger import get_pylogger
from matcha.utils.rich_utils import enforce_tags, print_config_tree
from matcha.utils.utils import extras, get_metric_value, task_wrapper

82
matcha/utils/audio.py Normal file
View File

@@ -0,0 +1,82 @@
import numpy as np
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from scipy.io.wavfile import read
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[str(y.device)],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec

View File

@@ -0,0 +1,111 @@
r"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import argparse
import json
import os
import sys
from pathlib import Path
import rootutils
import torch
from hydra import compose, initialize
from omegaconf import open_dict
from tqdm.auto import tqdm
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.utils.logging_utils import pylogger
log = pylogger.get_pylogger(__name__)
def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
"""Generate data mean and standard deviation helpful in data normalisation
Args:
data_loader (torch.utils.data.Dataloader): _description_
out_channels (int): mel spectrogram channels
"""
total_mel_sum = 0
total_mel_sq_sum = 0
total_mel_len = 0
for batch in tqdm(data_loader, leave=False):
mels = batch["y"]
mel_lengths = batch["y_lengths"]
total_mel_len += torch.sum(mel_lengths)
total_mel_sum += torch.sum(mels)
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
data_mean = total_mel_sum / (total_mel_len * out_channels)
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-config",
type=str,
default="vctk.yaml",
help="The name of the yaml config file under configs/data",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default="256",
help="Can have increased batch size for faster computation",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
required=False,
help="force overwrite the file",
)
args = parser.parse_args()
output_file = Path(args.input_config).with_suffix(".json")
if os.path.exists(output_file) and not args.force:
print("File already exists. Use -f to force overwrite")
sys.exit(1)
with initialize(version_base="1.3", config_path="../../configs/data"):
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
with open_dict(cfg):
del cfg["hydra"]
del cfg["_target_"]
cfg["data_statistics"] = None
cfg["seed"] = 1234
cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
text_mel_datamodule = TextMelDataModule(**cfg)
text_mel_datamodule.setup()
data_loader = text_mel_datamodule.train_dataloader()
log.info("Dataloader loaded! Now computing stats...")
params = compute_data_statistics(data_loader, cfg["n_feats"])
print(params)
json.dump(
params,
open(output_file, "w"),
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,56 @@
from typing import List
import hydra
from lightning import Callback
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from matcha.utils import pylogger
log = pylogger.get_pylogger(__name__)
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config.
:param callbacks_cfg: A DictConfig object containing callback configurations.
:return: A list of instantiated callbacks.
"""
callbacks: List[Callback] = []
if not callbacks_cfg:
log.warning("No callback configs found! Skipping..")
return callbacks
if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiates loggers from config.
:param logger_cfg: A DictConfig object containing logger configurations.
:return: A list of instantiated loggers.
"""
logger: List[Logger] = []
if not logger_cfg:
log.warning("No logger configs found! Skipping...")
return logger
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
logger.append(hydra.utils.instantiate(lg_conf))
return logger

View File

@@ -0,0 +1,53 @@
from typing import Any, Dict
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import OmegaConf
from matcha.utils import pylogger
log = pylogger.get_pylogger(__name__)
@rank_zero_only
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
"""Controls which config parts are saved by Lightning loggers.
Additionally saves:
- Number of model parameters
:param object_dict: A dictionary containing the following objects:
- `"cfg"`: A DictConfig object containing the main config.
- `"model"`: The Lightning model.
- `"trainer"`: The Lightning trainer.
"""
hparams = {}
cfg = OmegaConf.to_container(object_dict["cfg"])
model = object_dict["model"]
trainer = object_dict["trainer"]
if not trainer.logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
hparams["model"] = cfg["model"]
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
hparams["data"] = cfg["data"]
hparams["trainer"] = cfg["trainer"]
hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")
hparams["task_name"] = cfg.get("task_name")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")
# send hparams to all loggers
for logger in trainer.loggers:
logger.log_hyperparams(hparams)

88
matcha/utils/model.py Normal file
View File

@@ -0,0 +1,88 @@
""" from https://github.com/jaywalnut310/glow-tts """
import numpy as np
import torch
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
while True:
if length % (2**num_downsamplings_in_unet) == 0:
return length
length += 1
def convert_pad_shape(pad_shape):
inverted_shape = pad_shape[::-1]
pad_shape = [item for sublist in inverted_shape for item in sublist]
return pad_shape
def generate_path(duration, mask):
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path
def duration_loss(logw, logw_, lengths):
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
return loss
def normalize(data, mu, std):
if not isinstance(mu, (float, int)):
if isinstance(mu, list):
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
elif isinstance(mu, torch.Tensor):
mu = mu.to(data.device)
elif isinstance(mu, np.ndarray):
mu = torch.from_numpy(mu).to(data.device)
mu = mu.unsqueeze(-1)
if not isinstance(std, (float, int)):
if isinstance(std, list):
std = torch.tensor(std, dtype=data.dtype, device=data.device)
elif isinstance(std, torch.Tensor):
std = std.to(data.device)
elif isinstance(std, np.ndarray):
std = torch.from_numpy(std).to(data.device)
std = std.unsqueeze(-1)
return (data - mu) / std
def denormalize(data, mu, std):
if not isinstance(mu, float):
if isinstance(mu, list):
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
elif isinstance(mu, torch.Tensor):
mu = mu.to(data.device)
elif isinstance(mu, np.ndarray):
mu = torch.from_numpy(mu).to(data.device)
mu = mu.unsqueeze(-1)
if not isinstance(std, float):
if isinstance(std, list):
std = torch.tensor(std, dtype=data.dtype, device=data.device)
elif isinstance(std, torch.Tensor):
std = std.to(data.device)
elif isinstance(std, np.ndarray):
std = torch.from_numpy(std).to(data.device)
std = std.unsqueeze(-1)
return data * std + mu

View File

@@ -0,0 +1,22 @@
import numpy as np
import torch
from matcha.utils.monotonic_align.core import maximum_path_c
def maximum_path(value, mask):
"""Cython optimised version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
value = value * mask
device = value.device
dtype = value.dtype
value = value.data.cpu().numpy().astype(np.float32)
path = np.zeros_like(value).astype(np.int32)
mask = mask.data.cpu().numpy()
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
maximum_path_c(path, value, t_x_max, t_y_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)

View File

@@ -0,0 +1,47 @@
import numpy as np
cimport cython
cimport numpy as np
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
cdef int x
cdef int y
cdef float v_prev
cdef float v_cur
cdef float tmp
cdef int index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[x, y-1]
if x == 0:
if y == 0:
v_prev = 0.
else:
v_prev = max_neg_val
else:
v_prev = value[x-1, y-1]
value[x, y] = max(v_cur, v_prev) + value[x, y]
for y in range(t_y - 1, -1, -1):
path[index, y] = 1
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
cdef int b = values.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)

View File

@@ -0,0 +1,7 @@
# from distutils.core import setup
# from Cython.Build import cythonize
# import numpy
# setup(name='monotonic_align',
# ext_modules=cythonize("core.pyx"),
# include_dirs=[numpy.get_include()])

21
matcha/utils/pylogger.py Normal file
View File

@@ -0,0 +1,21 @@
import logging
from lightning.pytorch.utilities import rank_zero_only
def get_pylogger(name: str = __name__) -> logging.Logger:
"""Initializes a multi-GPU-friendly python command line logger.
:param name: The name of the logger, defaults to ``__name__``.
:return: A logger object.
"""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger

101
matcha/utils/rich_utils.py Normal file
View File

@@ -0,0 +1,101 @@
from pathlib import Path
from typing import Sequence
import rich
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt
from matcha.utils import pylogger
log = pylogger.get_pylogger(__name__)
@rank_zero_only
def print_config_tree(
cfg: DictConfig,
print_order: Sequence[str] = (
"data",
"model",
"callbacks",
"logger",
"trainer",
"paths",
"extras",
),
resolve: bool = False,
save_to_file: bool = False,
) -> None:
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
:param cfg: A DictConfig composed by Hydra.
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
"callbacks", "logger", "trainer", "paths", "extras")``.
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
queue = []
# add fields from `print_order` to queue
for field in print_order:
_ = (
queue.append(field)
if field in cfg
else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
)
# add all the other fields to queue (not specified in `print_order`)
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, style=style, guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
# print config tree
rich.print(tree)
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
rich.print(tree, file=file)
@rank_zero_only
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
"""Prompts user to input tags from command line if no tags are provided in config.
:param cfg: A DictConfig composed by Hydra.
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
"""
if not cfg.get("tags"):
if "id" in HydraConfig().cfg.hydra.job:
raise ValueError("Specify tags before launching a multirun!")
log.warning("No tags provided in config. Prompting user to input tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
tags = [t.strip() for t in tags.split(",") if t != ""]
with open_dict(cfg):
cfg.tags = tags
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
rich.print(cfg.tags, file=file)

214
matcha/utils/utils.py Normal file
View File

@@ -0,0 +1,214 @@
import os
import sys
import warnings
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import gdown
import matplotlib.pyplot as plt
import numpy as np
import torch
import wget
from omegaconf import DictConfig
from matcha.utils import pylogger, rich_utils
log = pylogger.get_pylogger(__name__)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
:param cfg: A DictConfig object containing the config tree.
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
rich_utils.enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return metric_dict, object_dict
```
:param task_func: The task function to be wrapped.
:return: The wrapped task function.
"""
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.output_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
"""Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
:param metric_name: The name of the metric to retrieve.
:return: The value of the metric.
"""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
def intersperse(lst, item):
# Adds blank symbol
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def save_figure_to_numpy(fig):
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def plot_tensor(tensor):
plt.style.use("default")
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
def save_plot(tensor, savepath):
plt.style.use("default")
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
plt.savefig(savepath)
plt.close()
def to_numpy(tensor):
if isinstance(tensor, np.ndarray):
return tensor
elif isinstance(tensor, torch.Tensor):
return tensor.detach().cpu().numpy()
elif isinstance(tensor, list):
return np.array(tensor)
else:
raise TypeError("Unsupported type for conversion to numpy array")
def get_user_data_dir(appname="matcha_tts"):
"""
Args:
appname (str): Name of application
Returns:
Path: path to user data directory
"""
MATCHA_HOME = os.environ.get("MATCHA_HOME")
if MATCHA_HOME is not None:
ans = Path(MATCHA_HOME).expanduser().resolve(strict=False)
elif sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel
key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == "darwin":
ans = Path("~/Library/Application Support/").expanduser()
else:
ans = Path.home().joinpath(".local/share")
return ans.joinpath(appname)
def assert_model_downloaded(checkpoint_path, url, use_wget=False):
if Path(checkpoint_path).exists():
log.debug(f"[+] Model already present at {checkpoint_path}!")
return
log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
checkpoint_path = str(checkpoint_path)
if not use_wget:
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
else:
wget.download(url=url, out=checkpoint_path)