mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 02:09:21 +08:00
Initial commit
This commit is contained in:
5
matcha/utils/__init__.py
Normal file
5
matcha/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
||||
from matcha.utils.logging_utils import log_hyperparameters
|
||||
from matcha.utils.pylogger import get_pylogger
|
||||
from matcha.utils.rich_utils import enforce_tags, print_config_tree
|
||||
from matcha.utils.utils import extras, get_metric_value, task_wrapper
|
||||
82
matcha/utils/audio.py
Normal file
82
matcha/utils/audio.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def load_wav(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return data, sampling_rate
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[str(y.device)],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
||||
|
||||
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
111
matcha/utils/generate_data_statistics.py
Normal file
111
matcha/utils/generate_data_statistics.py
Normal file
@@ -0,0 +1,111 @@
|
||||
r"""
|
||||
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
|
||||
when needed.
|
||||
|
||||
Parameters from hparam.py will be used
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import rootutils
|
||||
import torch
|
||||
from hydra import compose, initialize
|
||||
from omegaconf import open_dict
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||
from matcha.utils.logging_utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
|
||||
"""Generate data mean and standard deviation helpful in data normalisation
|
||||
|
||||
Args:
|
||||
data_loader (torch.utils.data.Dataloader): _description_
|
||||
out_channels (int): mel spectrogram channels
|
||||
"""
|
||||
total_mel_sum = 0
|
||||
total_mel_sq_sum = 0
|
||||
total_mel_len = 0
|
||||
|
||||
for batch in tqdm(data_loader, leave=False):
|
||||
mels = batch["y"]
|
||||
mel_lengths = batch["y_lengths"]
|
||||
|
||||
total_mel_len += torch.sum(mel_lengths)
|
||||
total_mel_sum += torch.sum(mels)
|
||||
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
|
||||
|
||||
data_mean = total_mel_sum / (total_mel_len * out_channels)
|
||||
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
|
||||
|
||||
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-config",
|
||||
type=str,
|
||||
default="vctk.yaml",
|
||||
help="The name of the yaml config file under configs/data",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default="256",
|
||||
help="Can have increased batch size for faster computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
action="store_true",
|
||||
default=False,
|
||||
required=False,
|
||||
help="force overwrite the file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
output_file = Path(args.input_config).with_suffix(".json")
|
||||
|
||||
if os.path.exists(output_file) and not args.force:
|
||||
print("File already exists. Use -f to force overwrite")
|
||||
sys.exit(1)
|
||||
|
||||
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
||||
|
||||
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||
|
||||
with open_dict(cfg):
|
||||
del cfg["hydra"]
|
||||
del cfg["_target_"]
|
||||
cfg["data_statistics"] = None
|
||||
cfg["seed"] = 1234
|
||||
cfg["batch_size"] = args.batch_size
|
||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||
|
||||
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||
text_mel_datamodule.setup()
|
||||
data_loader = text_mel_datamodule.train_dataloader()
|
||||
log.info("Dataloader loaded! Now computing stats...")
|
||||
params = compute_data_statistics(data_loader, cfg["n_feats"])
|
||||
print(params)
|
||||
json.dump(
|
||||
params,
|
||||
open(output_file, "w"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
56
matcha/utils/instantiators.py
Normal file
56
matcha/utils/instantiators.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import List
|
||||
|
||||
import hydra
|
||||
from lightning import Callback
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from matcha.utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
||||
"""Instantiates callbacks from config.
|
||||
|
||||
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
||||
:return: A list of instantiated callbacks.
|
||||
"""
|
||||
callbacks: List[Callback] = []
|
||||
|
||||
if not callbacks_cfg:
|
||||
log.warning("No callback configs found! Skipping..")
|
||||
return callbacks
|
||||
|
||||
if not isinstance(callbacks_cfg, DictConfig):
|
||||
raise TypeError("Callbacks config must be a DictConfig!")
|
||||
|
||||
for _, cb_conf in callbacks_cfg.items():
|
||||
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
||||
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
|
||||
callbacks.append(hydra.utils.instantiate(cb_conf))
|
||||
|
||||
return callbacks
|
||||
|
||||
|
||||
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
||||
"""Instantiates loggers from config.
|
||||
|
||||
:param logger_cfg: A DictConfig object containing logger configurations.
|
||||
:return: A list of instantiated loggers.
|
||||
"""
|
||||
logger: List[Logger] = []
|
||||
|
||||
if not logger_cfg:
|
||||
log.warning("No logger configs found! Skipping...")
|
||||
return logger
|
||||
|
||||
if not isinstance(logger_cfg, DictConfig):
|
||||
raise TypeError("Logger config must be a DictConfig!")
|
||||
|
||||
for _, lg_conf in logger_cfg.items():
|
||||
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
||||
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
|
||||
logger.append(hydra.utils.instantiate(lg_conf))
|
||||
|
||||
return logger
|
||||
53
matcha/utils/logging_utils.py
Normal file
53
matcha/utils/logging_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from lightning.pytorch.utilities import rank_zero_only
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from matcha.utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
||||
"""Controls which config parts are saved by Lightning loggers.
|
||||
|
||||
Additionally saves:
|
||||
- Number of model parameters
|
||||
|
||||
:param object_dict: A dictionary containing the following objects:
|
||||
- `"cfg"`: A DictConfig object containing the main config.
|
||||
- `"model"`: The Lightning model.
|
||||
- `"trainer"`: The Lightning trainer.
|
||||
"""
|
||||
hparams = {}
|
||||
|
||||
cfg = OmegaConf.to_container(object_dict["cfg"])
|
||||
model = object_dict["model"]
|
||||
trainer = object_dict["trainer"]
|
||||
|
||||
if not trainer.logger:
|
||||
log.warning("Logger not found! Skipping hyperparameter logging...")
|
||||
return
|
||||
|
||||
hparams["model"] = cfg["model"]
|
||||
|
||||
# save number of model parameters
|
||||
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
||||
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
||||
|
||||
hparams["data"] = cfg["data"]
|
||||
hparams["trainer"] = cfg["trainer"]
|
||||
|
||||
hparams["callbacks"] = cfg.get("callbacks")
|
||||
hparams["extras"] = cfg.get("extras")
|
||||
|
||||
hparams["task_name"] = cfg.get("task_name")
|
||||
hparams["tags"] = cfg.get("tags")
|
||||
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
||||
hparams["seed"] = cfg.get("seed")
|
||||
|
||||
# send hparams to all loggers
|
||||
for logger in trainer.loggers:
|
||||
logger.log_hyperparams(hparams)
|
||||
88
matcha/utils/model.py
Normal file
88
matcha/utils/model.py
Normal file
@@ -0,0 +1,88 @@
|
||||
""" from https://github.com/jaywalnut310/glow-tts """
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
||||
while True:
|
||||
if length % (2**num_downsamplings_in_unet) == 0:
|
||||
return length
|
||||
length += 1
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
inverted_shape = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in inverted_shape for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
device = duration.device
|
||||
|
||||
b, t_x, t_y = mask.shape
|
||||
cum_duration = torch.cumsum(duration, 1)
|
||||
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path * mask
|
||||
return path
|
||||
|
||||
|
||||
def duration_loss(logw, logw_, lengths):
|
||||
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
|
||||
return loss
|
||||
|
||||
|
||||
def normalize(data, mu, std):
|
||||
if not isinstance(mu, (float, int)):
|
||||
if isinstance(mu, list):
|
||||
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
|
||||
elif isinstance(mu, torch.Tensor):
|
||||
mu = mu.to(data.device)
|
||||
elif isinstance(mu, np.ndarray):
|
||||
mu = torch.from_numpy(mu).to(data.device)
|
||||
mu = mu.unsqueeze(-1)
|
||||
|
||||
if not isinstance(std, (float, int)):
|
||||
if isinstance(std, list):
|
||||
std = torch.tensor(std, dtype=data.dtype, device=data.device)
|
||||
elif isinstance(std, torch.Tensor):
|
||||
std = std.to(data.device)
|
||||
elif isinstance(std, np.ndarray):
|
||||
std = torch.from_numpy(std).to(data.device)
|
||||
std = std.unsqueeze(-1)
|
||||
|
||||
return (data - mu) / std
|
||||
|
||||
|
||||
def denormalize(data, mu, std):
|
||||
if not isinstance(mu, float):
|
||||
if isinstance(mu, list):
|
||||
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
|
||||
elif isinstance(mu, torch.Tensor):
|
||||
mu = mu.to(data.device)
|
||||
elif isinstance(mu, np.ndarray):
|
||||
mu = torch.from_numpy(mu).to(data.device)
|
||||
mu = mu.unsqueeze(-1)
|
||||
|
||||
if not isinstance(std, float):
|
||||
if isinstance(std, list):
|
||||
std = torch.tensor(std, dtype=data.dtype, device=data.device)
|
||||
elif isinstance(std, torch.Tensor):
|
||||
std = std.to(data.device)
|
||||
elif isinstance(std, np.ndarray):
|
||||
std = torch.from_numpy(std).to(data.device)
|
||||
std = std.unsqueeze(-1)
|
||||
|
||||
return data * std + mu
|
||||
22
matcha/utils/monotonic_align/__init__.py
Normal file
22
matcha/utils/monotonic_align/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from matcha.utils.monotonic_align.core import maximum_path_c
|
||||
|
||||
|
||||
def maximum_path(value, mask):
|
||||
"""Cython optimised version.
|
||||
value: [b, t_x, t_y]
|
||||
mask: [b, t_x, t_y]
|
||||
"""
|
||||
value = value * mask
|
||||
device = value.device
|
||||
dtype = value.dtype
|
||||
value = value.data.cpu().numpy().astype(np.float32)
|
||||
path = np.zeros_like(value).astype(np.int32)
|
||||
mask = mask.data.cpu().numpy()
|
||||
|
||||
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
||||
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
||||
maximum_path_c(path, value, t_x_max, t_y_max)
|
||||
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
||||
47
matcha/utils/monotonic_align/core.pyx
Normal file
47
matcha/utils/monotonic_align/core.pyx
Normal file
@@ -0,0 +1,47 @@
|
||||
import numpy as np
|
||||
|
||||
cimport cython
|
||||
cimport numpy as np
|
||||
|
||||
from cython.parallel import prange
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
|
||||
cdef int x
|
||||
cdef int y
|
||||
cdef float v_prev
|
||||
cdef float v_cur
|
||||
cdef float tmp
|
||||
cdef int index = t_x - 1
|
||||
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[x, y-1]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[x-1, y-1]
|
||||
value[x, y] = max(v_cur, v_prev) + value[x, y]
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[index, y] = 1
|
||||
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
|
||||
index = index - 1
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
|
||||
cdef int b = values.shape[0]
|
||||
|
||||
cdef int i
|
||||
for i in prange(b, nogil=True):
|
||||
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
|
||||
7
matcha/utils/monotonic_align/setup.py
Normal file
7
matcha/utils/monotonic_align/setup.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# from distutils.core import setup
|
||||
# from Cython.Build import cythonize
|
||||
# import numpy
|
||||
|
||||
# setup(name='monotonic_align',
|
||||
# ext_modules=cythonize("core.pyx"),
|
||||
# include_dirs=[numpy.get_include()])
|
||||
21
matcha/utils/pylogger.py
Normal file
21
matcha/utils/pylogger.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import logging
|
||||
|
||||
from lightning.pytorch.utilities import rank_zero_only
|
||||
|
||||
|
||||
def get_pylogger(name: str = __name__) -> logging.Logger:
|
||||
"""Initializes a multi-GPU-friendly python command line logger.
|
||||
|
||||
:param name: The name of the logger, defaults to ``__name__``.
|
||||
|
||||
:return: A logger object.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
# this ensures all logging levels get marked with the rank zero decorator
|
||||
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
||||
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
|
||||
for level in logging_levels:
|
||||
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
||||
|
||||
return logger
|
||||
101
matcha/utils/rich_utils.py
Normal file
101
matcha/utils/rich_utils.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import rich
|
||||
import rich.syntax
|
||||
import rich.tree
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from lightning.pytorch.utilities import rank_zero_only
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from matcha.utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def print_config_tree(
|
||||
cfg: DictConfig,
|
||||
print_order: Sequence[str] = (
|
||||
"data",
|
||||
"model",
|
||||
"callbacks",
|
||||
"logger",
|
||||
"trainer",
|
||||
"paths",
|
||||
"extras",
|
||||
),
|
||||
resolve: bool = False,
|
||||
save_to_file: bool = False,
|
||||
) -> None:
|
||||
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
||||
|
||||
:param cfg: A DictConfig composed by Hydra.
|
||||
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
|
||||
"callbacks", "logger", "trainer", "paths", "extras")``.
|
||||
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
|
||||
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
|
||||
"""
|
||||
style = "dim"
|
||||
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
||||
|
||||
queue = []
|
||||
|
||||
# add fields from `print_order` to queue
|
||||
for field in print_order:
|
||||
_ = (
|
||||
queue.append(field)
|
||||
if field in cfg
|
||||
else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
|
||||
)
|
||||
|
||||
# add all the other fields to queue (not specified in `print_order`)
|
||||
for field in cfg:
|
||||
if field not in queue:
|
||||
queue.append(field)
|
||||
|
||||
# generate config tree from queue
|
||||
for field in queue:
|
||||
branch = tree.add(field, style=style, guide_style=style)
|
||||
|
||||
config_group = cfg[field]
|
||||
if isinstance(config_group, DictConfig):
|
||||
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
||||
else:
|
||||
branch_content = str(config_group)
|
||||
|
||||
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
||||
|
||||
# print config tree
|
||||
rich.print(tree)
|
||||
|
||||
# save config tree to file
|
||||
if save_to_file:
|
||||
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
||||
rich.print(tree, file=file)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
||||
"""Prompts user to input tags from command line if no tags are provided in config.
|
||||
|
||||
:param cfg: A DictConfig composed by Hydra.
|
||||
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
|
||||
"""
|
||||
if not cfg.get("tags"):
|
||||
if "id" in HydraConfig().cfg.hydra.job:
|
||||
raise ValueError("Specify tags before launching a multirun!")
|
||||
|
||||
log.warning("No tags provided in config. Prompting user to input tags...")
|
||||
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
||||
tags = [t.strip() for t in tags.split(",") if t != ""]
|
||||
|
||||
with open_dict(cfg):
|
||||
cfg.tags = tags
|
||||
|
||||
log.info(f"Tags: {cfg.tags}")
|
||||
|
||||
if save_to_file:
|
||||
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
||||
rich.print(cfg.tags, file=file)
|
||||
214
matcha/utils/utils.py
Normal file
214
matcha/utils/utils.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import gdown
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import wget
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from matcha.utils import pylogger, rich_utils
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def extras(cfg: DictConfig) -> None:
|
||||
"""Applies optional utilities before the task is started.
|
||||
|
||||
Utilities:
|
||||
- Ignoring python warnings
|
||||
- Setting tags from command line
|
||||
- Rich config printing
|
||||
|
||||
:param cfg: A DictConfig object containing the config tree.
|
||||
"""
|
||||
# return if no `extras` config
|
||||
if not cfg.get("extras"):
|
||||
log.warning("Extras config not found! <cfg.extras=null>")
|
||||
return
|
||||
|
||||
# disable python warnings
|
||||
if cfg.extras.get("ignore_warnings"):
|
||||
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# prompt user to input tags from command line if none are provided in the config
|
||||
if cfg.extras.get("enforce_tags"):
|
||||
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
||||
rich_utils.enforce_tags(cfg, save_to_file=True)
|
||||
|
||||
# pretty print config tree using Rich library
|
||||
if cfg.extras.get("print_config"):
|
||||
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
||||
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
||||
|
||||
|
||||
def task_wrapper(task_func: Callable) -> Callable:
|
||||
"""Optional decorator that controls the failure behavior when executing the task function.
|
||||
|
||||
This wrapper can be used to:
|
||||
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
||||
- save the exception to a `.log` file
|
||||
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
||||
- etc. (adjust depending on your needs)
|
||||
|
||||
Example:
|
||||
```
|
||||
@utils.task_wrapper
|
||||
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
...
|
||||
return metric_dict, object_dict
|
||||
```
|
||||
|
||||
:param task_func: The task function to be wrapped.
|
||||
|
||||
:return: The wrapped task function.
|
||||
"""
|
||||
|
||||
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# execute the task
|
||||
try:
|
||||
metric_dict, object_dict = task_func(cfg=cfg)
|
||||
|
||||
# things to do if exception occurs
|
||||
except Exception as ex:
|
||||
# save exception to `.log` file
|
||||
log.exception("")
|
||||
|
||||
# some hyperparameter combinations might be invalid or cause out-of-memory errors
|
||||
# so when using hparam search plugins like Optuna, you might want to disable
|
||||
# raising the below exception to avoid multirun failure
|
||||
raise ex
|
||||
|
||||
# things to always do after either success or exception
|
||||
finally:
|
||||
# display output dir path in terminal
|
||||
log.info(f"Output dir: {cfg.paths.output_dir}")
|
||||
|
||||
# always close wandb run (even if exception occurs so multirun won't fail)
|
||||
if find_spec("wandb"): # check if wandb is installed
|
||||
import wandb
|
||||
|
||||
if wandb.run:
|
||||
log.info("Closing wandb!")
|
||||
wandb.finish()
|
||||
|
||||
return metric_dict, object_dict
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
|
||||
"""Safely retrieves value of the metric logged in LightningModule.
|
||||
|
||||
:param metric_dict: A dict containing metric values.
|
||||
:param metric_name: The name of the metric to retrieve.
|
||||
:return: The value of the metric.
|
||||
"""
|
||||
if not metric_name:
|
||||
log.info("Metric name is None! Skipping metric value retrieval...")
|
||||
return None
|
||||
|
||||
if metric_name not in metric_dict:
|
||||
raise Exception(
|
||||
f"Metric value not found! <metric_name={metric_name}>\n"
|
||||
"Make sure metric name logged in LightningModule is correct!\n"
|
||||
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
||||
)
|
||||
|
||||
metric_value = metric_dict[metric_name].item()
|
||||
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
||||
|
||||
return metric_value
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
# Adds blank symbol
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def save_figure_to_numpy(fig):
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
return data
|
||||
|
||||
|
||||
def plot_tensor(tensor):
|
||||
plt.style.use("default")
|
||||
fig, ax = plt.subplots(figsize=(12, 3))
|
||||
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = save_figure_to_numpy(fig)
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def save_plot(tensor, savepath):
|
||||
plt.style.use("default")
|
||||
fig, ax = plt.subplots(figsize=(12, 3))
|
||||
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
plt.savefig(savepath)
|
||||
plt.close()
|
||||
|
||||
|
||||
def to_numpy(tensor):
|
||||
if isinstance(tensor, np.ndarray):
|
||||
return tensor
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
return tensor.detach().cpu().numpy()
|
||||
elif isinstance(tensor, list):
|
||||
return np.array(tensor)
|
||||
else:
|
||||
raise TypeError("Unsupported type for conversion to numpy array")
|
||||
|
||||
|
||||
def get_user_data_dir(appname="matcha_tts"):
|
||||
"""
|
||||
Args:
|
||||
appname (str): Name of application
|
||||
|
||||
Returns:
|
||||
Path: path to user data directory
|
||||
"""
|
||||
|
||||
MATCHA_HOME = os.environ.get("MATCHA_HOME")
|
||||
if MATCHA_HOME is not None:
|
||||
ans = Path(MATCHA_HOME).expanduser().resolve(strict=False)
|
||||
elif sys.platform == "win32":
|
||||
import winreg # pylint: disable=import-outside-toplevel
|
||||
|
||||
key = winreg.OpenKey(
|
||||
winreg.HKEY_CURRENT_USER,
|
||||
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
|
||||
)
|
||||
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
|
||||
ans = Path(dir_).resolve(strict=False)
|
||||
elif sys.platform == "darwin":
|
||||
ans = Path("~/Library/Application Support/").expanduser()
|
||||
else:
|
||||
ans = Path.home().joinpath(".local/share")
|
||||
return ans.joinpath(appname)
|
||||
|
||||
|
||||
def assert_model_downloaded(checkpoint_path, url, use_wget=False):
|
||||
if Path(checkpoint_path).exists():
|
||||
log.debug(f"[+] Model already present at {checkpoint_path}!")
|
||||
return
|
||||
log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if not use_wget:
|
||||
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
|
||||
else:
|
||||
wget.download(url=url, out=checkpoint_path)
|
||||
Reference in New Issue
Block a user