use bf16 for amp

This commit is contained in:
lyuxiang.lx
2025-07-28 11:55:38 +08:00
parent b048a2d6db
commit 11515d0d5a
2 changed files with 3 additions and 3 deletions

View File

@@ -166,7 +166,7 @@ class Executor:
for k, v in info_dict['loss_dict'].items():
if k not in total_loss_dict:
total_loss_dict[k] = []
total_loss_dict[k].append(v.item() * num_utts)
total_loss_dict[k].append(v.mean().item() * num_utts)
log_per_step(None, info_dict)
for k, v in total_loss_dict.items():
total_loss_dict[k] = sum(v) / total_num_utts

View File

@@ -71,7 +71,7 @@ def init_dataset_and_dataloader(args, configs, gan, dpo):
def check_modify_and_save_config(args, configs):
if args.train_engine == "torch_ddp":
configs['train_conf']["dtype"] = 'fp32'
configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
else:
with open(args.deepspeed_config, 'r') as fin:
ds_configs = json.load(fin)
@@ -247,7 +247,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
dtype = torch.float32
if info_dict['train_engine'] == 'torch_ddp':
autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
else:
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)