From 793a24862caed695a375661ae57fa99b5fc7793f Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 10 Jul 2024 16:37:25 +0800 Subject: [PATCH] add constant lr scheduler --- cosyvoice/utils/scheduler.py | 22 +++++++++++++++++++ cosyvoice/utils/train_utils.py | 5 ++++- .../libritts/cosyvoice/conf/cosyvoice.yaml | 4 ++-- tools/extract_embedding.py | 2 +- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/cosyvoice/utils/scheduler.py b/cosyvoice/utils/scheduler.py index eed1ea0..fbf4803 100644 --- a/cosyvoice/utils/scheduler.py +++ b/cosyvoice/utils/scheduler.py @@ -715,3 +715,25 @@ class NoamHoldAnnealing(WarmupHoldPolicy): def set_step(self, step: int): self.last_epoch = step + + +class ConstantLR(_LRScheduler): + """The ConstantLR scheduler + + This scheduler keeps a constant lr + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + ): + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer) + + def get_lr(self): + return self.base_lrs + + def set_step(self, step: int): + self.last_epoch = step diff --git a/cosyvoice/utils/train_utils.py b/cosyvoice/utils/train_utils.py index df3a321..f8d7b45 100644 --- a/cosyvoice/utils/train_utils.py +++ b/cosyvoice/utils/train_utils.py @@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_ from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live from cosyvoice.dataset.dataset import Dataset -from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing +from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR def init_distributed(args): @@ -122,6 +122,9 @@ def init_optimizer_and_scheduler(args, configs, model): elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': scheduler_type = NoamHoldAnnealing scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) else: raise ValueError("unknown scheduler: " + configs['train_conf']) diff --git a/examples/libritts/cosyvoice/conf/cosyvoice.yaml b/examples/libritts/cosyvoice/conf/cosyvoice.yaml index cc5eee0..c791c76 100644 --- a/examples/libritts/cosyvoice/conf/cosyvoice.yaml +++ b/examples/libritts/cosyvoice/conf/cosyvoice.yaml @@ -186,8 +186,8 @@ data_pipeline: [ train_conf: optim: adam optim_conf: - lr: 0.001 - scheduler: warmuplr + lr: 0.001 # change to 1e-5 during sft + scheduler: warmuplr # change to constantlr during sft scheduler_conf: warmup_steps: 2500 max_epoch: 200 diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 9c6f568..96a043c 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -54,7 +54,7 @@ def main(args): spk2embedding[spk] = [] spk2embedding[spk].append(embedding) for k, v in spk2embedding.items(): - spk2embedding[k] = torch.tensor(v).mean(dim=0, keepdim=True).tolist() + spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir)) torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))