mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 09:49:21 +08:00
Initial commit
This commit is contained in:
53
matcha/utils/logging_utils.py
Normal file
53
matcha/utils/logging_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from lightning.pytorch.utilities import rank_zero_only
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from matcha.utils import pylogger
|
||||
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
||||
"""Controls which config parts are saved by Lightning loggers.
|
||||
|
||||
Additionally saves:
|
||||
- Number of model parameters
|
||||
|
||||
:param object_dict: A dictionary containing the following objects:
|
||||
- `"cfg"`: A DictConfig object containing the main config.
|
||||
- `"model"`: The Lightning model.
|
||||
- `"trainer"`: The Lightning trainer.
|
||||
"""
|
||||
hparams = {}
|
||||
|
||||
cfg = OmegaConf.to_container(object_dict["cfg"])
|
||||
model = object_dict["model"]
|
||||
trainer = object_dict["trainer"]
|
||||
|
||||
if not trainer.logger:
|
||||
log.warning("Logger not found! Skipping hyperparameter logging...")
|
||||
return
|
||||
|
||||
hparams["model"] = cfg["model"]
|
||||
|
||||
# save number of model parameters
|
||||
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
||||
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
||||
|
||||
hparams["data"] = cfg["data"]
|
||||
hparams["trainer"] = cfg["trainer"]
|
||||
|
||||
hparams["callbacks"] = cfg.get("callbacks")
|
||||
hparams["extras"] = cfg.get("extras")
|
||||
|
||||
hparams["task_name"] = cfg.get("task_name")
|
||||
hparams["tags"] = cfg.get("tags")
|
||||
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
||||
hparams["seed"] = cfg.get("seed")
|
||||
|
||||
# send hparams to all loggers
|
||||
for logger in trainer.loggers:
|
||||
logger.log_hyperparams(hparams)
|
||||
Reference in New Issue
Block a user