mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 18:09:22 +08:00
feat: Initial commit
This commit is contained in:
144
utils/scheduler.py
Normal file
144
utils/scheduler.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
from .registry import Registry
|
||||
|
||||
SCHEDULERS = Registry("schedulers")
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class MultiStepLR(lr_scheduler.MultiStepLR):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
milestones,
|
||||
total_steps,
|
||||
gamma=0.1,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
milestones=[rate * total_steps for rate in milestones],
|
||||
gamma=gamma,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
milestones,
|
||||
total_steps,
|
||||
gamma=0.1,
|
||||
warmup_rate=0.05,
|
||||
warmup_scale=1e-6,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
milestones = [rate * total_steps for rate in milestones]
|
||||
|
||||
def multi_step_with_warmup(s):
|
||||
factor = 1.0
|
||||
for i in range(len(milestones)):
|
||||
if s < milestones[i]:
|
||||
break
|
||||
factor *= gamma
|
||||
|
||||
if s <= warmup_rate * total_steps:
|
||||
warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (
|
||||
1 - warmup_scale
|
||||
)
|
||||
else:
|
||||
warmup_coefficient = 1.0
|
||||
return warmup_coefficient * factor
|
||||
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
lr_lambda=multi_step_with_warmup,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class PolyLR(lr_scheduler.LambdaLR):
|
||||
def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class ExpLR(lr_scheduler.LambdaLR):
|
||||
def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
lr_lambda=lambda s: gamma ** (s / total_steps),
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
|
||||
def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
T_max=total_steps,
|
||||
eta_min=eta_min,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class OneCycleLR(lr_scheduler.OneCycleLR):
|
||||
r"""
|
||||
torch.optim.lr_scheduler.OneCycleLR, Block total_steps
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
max_lr,
|
||||
total_steps=None,
|
||||
pct_start=0.3,
|
||||
anneal_strategy="cos",
|
||||
cycle_momentum=True,
|
||||
base_momentum=0.85,
|
||||
max_momentum=0.95,
|
||||
div_factor=25.0,
|
||||
final_div_factor=1e4,
|
||||
three_phase=False,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
max_lr=max_lr,
|
||||
total_steps=total_steps,
|
||||
pct_start=pct_start,
|
||||
anneal_strategy=anneal_strategy,
|
||||
cycle_momentum=cycle_momentum,
|
||||
base_momentum=base_momentum,
|
||||
max_momentum=max_momentum,
|
||||
div_factor=div_factor,
|
||||
final_div_factor=final_div_factor,
|
||||
three_phase=three_phase,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
def build_scheduler(cfg, optimizer):
|
||||
cfg.optimizer = optimizer
|
||||
return SCHEDULERS.build(cfg=cfg)
|
||||
Reference in New Issue
Block a user