mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 01:49:23 +08:00
feat: Initial commit
This commit is contained in:
52
utils/optimizer.py
Normal file
52
utils/optimizer.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import torch
|
||||
from utils.logger import get_root_logger
|
||||
from utils.registry import Registry
|
||||
|
||||
OPTIMIZERS = Registry("optimizers")
|
||||
|
||||
|
||||
OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD")
|
||||
OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam")
|
||||
OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW")
|
||||
|
||||
|
||||
def build_optimizer(cfg, model, param_dicts=None):
|
||||
if param_dicts is None:
|
||||
cfg.params = model.parameters()
|
||||
else:
|
||||
cfg.params = [dict(names=[], params=[], lr=cfg.lr)]
|
||||
for i in range(len(param_dicts)):
|
||||
param_group = dict(names=[], params=[])
|
||||
if "lr" in param_dicts[i].keys():
|
||||
param_group["lr"] = param_dicts[i].lr
|
||||
if "momentum" in param_dicts[i].keys():
|
||||
param_group["momentum"] = param_dicts[i].momentum
|
||||
if "weight_decay" in param_dicts[i].keys():
|
||||
param_group["weight_decay"] = param_dicts[i].weight_decay
|
||||
cfg.params.append(param_group)
|
||||
|
||||
for n, p in model.named_parameters():
|
||||
flag = False
|
||||
for i in range(len(param_dicts)):
|
||||
if param_dicts[i].keyword in n:
|
||||
cfg.params[i + 1]["names"].append(n)
|
||||
cfg.params[i + 1]["params"].append(p)
|
||||
flag = True
|
||||
break
|
||||
if not flag:
|
||||
cfg.params[0]["names"].append(n)
|
||||
cfg.params[0]["params"].append(p)
|
||||
|
||||
logger = get_root_logger()
|
||||
for i in range(len(cfg.params)):
|
||||
param_names = cfg.params[i].pop("names")
|
||||
message = ""
|
||||
for key in cfg.params[i].keys():
|
||||
if key != "params":
|
||||
message += f" {key}: {cfg.params[i][key]};"
|
||||
logger.info(f"Params Group {i+1} -{message} Params: {param_names}.")
|
||||
return OPTIMIZERS.build(cfg=cfg)
|
||||
Reference in New Issue
Block a user