mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
add constant lr scheduler
This commit is contained in:
@@ -715,3 +715,25 @@ class NoamHoldAnnealing(WarmupHoldPolicy):
|
|||||||
|
|
||||||
def set_step(self, step: int):
|
def set_step(self, step: int):
|
||||||
self.last_epoch = step
|
self.last_epoch = step
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantLR(_LRScheduler):
|
||||||
|
"""The ConstantLR scheduler
|
||||||
|
|
||||||
|
This scheduler keeps a constant lr
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
):
|
||||||
|
# __init__() must be invoked before setting field
|
||||||
|
# because step() is also invoked in __init__()
|
||||||
|
super().__init__(optimizer)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
return self.base_lrs
|
||||||
|
|
||||||
|
def set_step(self, step: int):
|
||||||
|
self.last_epoch = step
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_
|
|||||||
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
|
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
|
||||||
|
|
||||||
from cosyvoice.dataset.dataset import Dataset
|
from cosyvoice.dataset.dataset import Dataset
|
||||||
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing
|
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
|
||||||
|
|
||||||
|
|
||||||
def init_distributed(args):
|
def init_distributed(args):
|
||||||
@@ -122,6 +122,9 @@ def init_optimizer_and_scheduler(args, configs, model):
|
|||||||
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
||||||
scheduler_type = NoamHoldAnnealing
|
scheduler_type = NoamHoldAnnealing
|
||||||
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||||
|
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||||
|
scheduler_type = ConstantLR
|
||||||
|
scheduler = ConstantLR(optimizer)
|
||||||
else:
|
else:
|
||||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||||
|
|
||||||
|
|||||||
@@ -186,8 +186,8 @@ data_pipeline: [
|
|||||||
train_conf:
|
train_conf:
|
||||||
optim: adam
|
optim: adam
|
||||||
optim_conf:
|
optim_conf:
|
||||||
lr: 0.001
|
lr: 0.001 # change to 1e-5 during sft
|
||||||
scheduler: warmuplr
|
scheduler: warmuplr # change to constantlr during sft
|
||||||
scheduler_conf:
|
scheduler_conf:
|
||||||
warmup_steps: 2500
|
warmup_steps: 2500
|
||||||
max_epoch: 200
|
max_epoch: 200
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ def main(args):
|
|||||||
spk2embedding[spk] = []
|
spk2embedding[spk] = []
|
||||||
spk2embedding[spk].append(embedding)
|
spk2embedding[spk].append(embedding)
|
||||||
for k, v in spk2embedding.items():
|
for k, v in spk2embedding.items():
|
||||||
spk2embedding[k] = torch.tensor(v).mean(dim=0, keepdim=True).tolist()
|
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
|
||||||
|
|
||||||
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
|
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
|
||||||
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
|
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
|
||||||
|
|||||||
Reference in New Issue
Block a user