mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 17:39:24 +08:00
53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
"""
|
|
The code is base on https://github.com/Pointcept/Pointcept
|
|
"""
|
|
|
|
import torch
|
|
from utils.logger import get_root_logger
|
|
from utils.registry import Registry
|
|
|
|
OPTIMIZERS = Registry("optimizers")
|
|
|
|
|
|
OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD")
|
|
OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam")
|
|
OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW")
|
|
|
|
|
|
def build_optimizer(cfg, model, param_dicts=None):
|
|
if param_dicts is None:
|
|
cfg.params = model.parameters()
|
|
else:
|
|
cfg.params = [dict(names=[], params=[], lr=cfg.lr)]
|
|
for i in range(len(param_dicts)):
|
|
param_group = dict(names=[], params=[])
|
|
if "lr" in param_dicts[i].keys():
|
|
param_group["lr"] = param_dicts[i].lr
|
|
if "momentum" in param_dicts[i].keys():
|
|
param_group["momentum"] = param_dicts[i].momentum
|
|
if "weight_decay" in param_dicts[i].keys():
|
|
param_group["weight_decay"] = param_dicts[i].weight_decay
|
|
cfg.params.append(param_group)
|
|
|
|
for n, p in model.named_parameters():
|
|
flag = False
|
|
for i in range(len(param_dicts)):
|
|
if param_dicts[i].keyword in n:
|
|
cfg.params[i + 1]["names"].append(n)
|
|
cfg.params[i + 1]["params"].append(p)
|
|
flag = True
|
|
break
|
|
if not flag:
|
|
cfg.params[0]["names"].append(n)
|
|
cfg.params[0]["params"].append(p)
|
|
|
|
logger = get_root_logger()
|
|
for i in range(len(cfg.params)):
|
|
param_names = cfg.params[i].pop("names")
|
|
message = ""
|
|
for key in cfg.params[i].keys():
|
|
if key != "params":
|
|
message += f" {key}: {cfg.params[i][key]};"
|
|
logger.info(f"Params Group {i+1} -{message} Params: {param_names}.")
|
|
return OPTIMIZERS.build(cfg=cfg)
|