From 11515d0d5a1348a1b24403b4d1481f005d0d4306 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Mon, 28 Jul 2025 11:55:38 +0800 Subject: [PATCH] use bf16 for amp --- cosyvoice/utils/executor.py | 2 +- cosyvoice/utils/train_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cosyvoice/utils/executor.py b/cosyvoice/utils/executor.py index f120cb5..f08fa09 100644 --- a/cosyvoice/utils/executor.py +++ b/cosyvoice/utils/executor.py @@ -166,7 +166,7 @@ class Executor: for k, v in info_dict['loss_dict'].items(): if k not in total_loss_dict: total_loss_dict[k] = [] - total_loss_dict[k].append(v.item() * num_utts) + total_loss_dict[k].append(v.mean().item() * num_utts) log_per_step(None, info_dict) for k, v in total_loss_dict.items(): total_loss_dict[k] = sum(v) / total_num_utts diff --git a/cosyvoice/utils/train_utils.py b/cosyvoice/utils/train_utils.py index 783a246..e2d2b09 100644 --- a/cosyvoice/utils/train_utils.py +++ b/cosyvoice/utils/train_utils.py @@ -71,7 +71,7 @@ def init_dataset_and_dataloader(args, configs, gan, dpo): def check_modify_and_save_config(args, configs): if args.train_engine == "torch_ddp": - configs['train_conf']["dtype"] = 'fp32' + configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32' else: with open(args.deepspeed_config, 'r') as fin: ds_configs = json.load(fin) @@ -247,7 +247,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None dtype = torch.float32 if info_dict['train_engine'] == 'torch_ddp': - autocast = torch.cuda.amp.autocast(enabled=scaler is not None) + autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype) else: autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)