mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 18:29:19 +08:00
Initial commit
This commit is contained in:
56
matcha/utils/instantiators.py
Normal file
56
matcha/utils/instantiators.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import List
|
||||
|
||||
import hydra
|
||||
from lightning import Callback
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from matcha.utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
||||
"""Instantiates callbacks from config.
|
||||
|
||||
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
||||
:return: A list of instantiated callbacks.
|
||||
"""
|
||||
callbacks: List[Callback] = []
|
||||
|
||||
if not callbacks_cfg:
|
||||
log.warning("No callback configs found! Skipping..")
|
||||
return callbacks
|
||||
|
||||
if not isinstance(callbacks_cfg, DictConfig):
|
||||
raise TypeError("Callbacks config must be a DictConfig!")
|
||||
|
||||
for _, cb_conf in callbacks_cfg.items():
|
||||
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
||||
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
|
||||
callbacks.append(hydra.utils.instantiate(cb_conf))
|
||||
|
||||
return callbacks
|
||||
|
||||
|
||||
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
||||
"""Instantiates loggers from config.
|
||||
|
||||
:param logger_cfg: A DictConfig object containing logger configurations.
|
||||
:return: A list of instantiated loggers.
|
||||
"""
|
||||
logger: List[Logger] = []
|
||||
|
||||
if not logger_cfg:
|
||||
log.warning("No logger configs found! Skipping...")
|
||||
return logger
|
||||
|
||||
if not isinstance(logger_cfg, DictConfig):
|
||||
raise TypeError("Logger config must be a DictConfig!")
|
||||
|
||||
for _, lg_conf in logger_cfg.items():
|
||||
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
||||
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
|
||||
logger.append(hydra.utils.instantiate(lg_conf))
|
||||
|
||||
return logger
|
||||
Reference in New Issue
Block a user