mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
32
funasr_local/optimizers/sgd.py
Normal file
32
funasr_local/optimizers/sgd.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
class SGD(torch.optim.SGD):
|
||||
"""Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr'
|
||||
|
||||
Note that
|
||||
the arguments of the optimizer invoked by AbsTask.main()
|
||||
must have default value except for 'param'.
|
||||
|
||||
I can't understand why only SGD.lr doesn't have the default value.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr: float = 0.1,
|
||||
momentum: float = 0.0,
|
||||
dampening: float = 0.0,
|
||||
weight_decay: float = 0.0,
|
||||
nesterov: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
params,
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
dampening=dampening,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov,
|
||||
)
|
||||
Reference in New Issue
Block a user