add constant lr scheduler

This commit is contained in:
lyuxiang.lx
2024-07-10 16:37:25 +08:00
parent 6a3e44242a
commit 793a24862c
4 changed files with 29 additions and 4 deletions

View File

@@ -715,3 +715,25 @@ class NoamHoldAnnealing(WarmupHoldPolicy):
def set_step(self, step: int):
self.last_epoch = step
class ConstantLR(_LRScheduler):
"""The ConstantLR scheduler
This scheduler keeps a constant lr
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
):
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
super().__init__(optimizer)
def get_lr(self):
return self.base_lrs
def set_step(self, step: int):
self.last_epoch = step

View File

@@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
from cosyvoice.dataset.dataset import Dataset
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
def init_distributed(args):
@@ -122,6 +122,9 @@ def init_optimizer_and_scheduler(args, configs, model):
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler = ConstantLR(optimizer)
else:
raise ValueError("unknown scheduler: " + configs['train_conf'])

View File

@@ -186,8 +186,8 @@ data_pipeline: [
train_conf:
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
lr: 0.001 # change to 1e-5 during sft
scheduler: warmuplr # change to constantlr during sft
scheduler_conf:
warmup_steps: 2500
max_epoch: 200

View File

@@ -54,7 +54,7 @@ def main(args):
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0, keepdim=True).tolist()
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))