update dpo

This commit is contained in:
lyuxiang.lx
2025-06-13 16:14:05 +08:00
parent cc234bd322
commit 63856565f3
23 changed files with 345 additions and 2024 deletions

View File

@@ -27,6 +27,7 @@ from hyperpyyaml import load_hyperpyyaml
from torch.distributed.elastic.multiprocessing.errors import record
from cosyvoice.utils.losses import DPOLoss
from cosyvoice.utils.executor import Executor
from cosyvoice.utils.train_utils import (
init_distributed,
@@ -43,6 +44,7 @@ def get_args():
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
@@ -73,6 +75,10 @@ def get_args():
action='store_true',
default=False,
help='Use automatic mixed precision training')
parser.add_argument('--dpo',
action='store_true',
default=False,
help='Use Direct Preference Optimization')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
@@ -113,7 +119,7 @@ def main():
# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs, gan)
init_dataset_and_dataloader(args, configs, gan, args.dpo)
# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
@@ -122,6 +128,8 @@ def main():
writer = init_summarywriter(args)
# load checkpoint
if args.dpo is True:
configs[args.model].forward = configs[args.model].forward_dpo
model = configs[args.model]
start_step, start_epoch = 0, -1
if args.checkpoint is not None:
@@ -150,13 +158,25 @@ def main():
info_dict['epoch'] = start_epoch
save_model(model, 'init', info_dict)
# DPO related
if args.dpo is True:
ref_model = deepcopy(configs[args.model])
state_dict = torch.load(args.ref_model, map_location='cpu')
ref_model.load_state_dict(state_dict, strict=False)
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
ref_model = wrap_cuda_model(args, ref_model)
else:
ref_model, dpo_loss = None, None
# Get executor
executor = Executor(gan=gan)
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
executor.step = start_step
# Init scaler, used for pytorch amp mixed precision training
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
print('start step {} start epoch {}'.format(start_step, start_epoch))
# Start training loop
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
executor.epoch = epoch
@@ -167,7 +187,7 @@ def main():
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
writer, info_dict, scaler, group_join)
else:
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
dist.destroy_process_group(group_join)