mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update dpo
This commit is contained in:
@@ -27,6 +27,7 @@ from hyperpyyaml import load_hyperpyyaml
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
|
||||
from cosyvoice.utils.losses import DPOLoss
|
||||
from cosyvoice.utils.executor import Executor
|
||||
from cosyvoice.utils.train_utils import (
|
||||
init_distributed,
|
||||
@@ -43,6 +44,7 @@ def get_args():
|
||||
choices=['torch_ddp', 'deepspeed'],
|
||||
help='Engine for paralleled training')
|
||||
parser.add_argument('--model', required=True, help='model which will be trained')
|
||||
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--train_data', required=True, help='train data file')
|
||||
parser.add_argument('--cv_data', required=True, help='cv data file')
|
||||
@@ -73,6 +75,10 @@ def get_args():
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use automatic mixed precision training')
|
||||
parser.add_argument('--dpo',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use Direct Preference Optimization')
|
||||
parser.add_argument('--deepspeed.save_states',
|
||||
dest='save_states',
|
||||
default='model_only',
|
||||
@@ -113,7 +119,7 @@ def main():
|
||||
|
||||
# Get dataset & dataloader
|
||||
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
||||
init_dataset_and_dataloader(args, configs, gan)
|
||||
init_dataset_and_dataloader(args, configs, gan, args.dpo)
|
||||
|
||||
# Do some sanity checks and save config to arsg.model_dir
|
||||
configs = check_modify_and_save_config(args, configs)
|
||||
@@ -122,6 +128,8 @@ def main():
|
||||
writer = init_summarywriter(args)
|
||||
|
||||
# load checkpoint
|
||||
if args.dpo is True:
|
||||
configs[args.model].forward = configs[args.model].forward_dpo
|
||||
model = configs[args.model]
|
||||
start_step, start_epoch = 0, -1
|
||||
if args.checkpoint is not None:
|
||||
@@ -150,13 +158,25 @@ def main():
|
||||
info_dict['epoch'] = start_epoch
|
||||
save_model(model, 'init', info_dict)
|
||||
|
||||
# DPO related
|
||||
if args.dpo is True:
|
||||
ref_model = deepcopy(configs[args.model])
|
||||
state_dict = torch.load(args.ref_model, map_location='cpu')
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
|
||||
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
|
||||
ref_model = wrap_cuda_model(args, ref_model)
|
||||
else:
|
||||
ref_model, dpo_loss = None, None
|
||||
|
||||
# Get executor
|
||||
executor = Executor(gan=gan)
|
||||
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
|
||||
executor.step = start_step
|
||||
|
||||
# Init scaler, used for pytorch amp mixed precision training
|
||||
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
||||
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
||||
|
||||
# Start training loop
|
||||
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
||||
executor.epoch = epoch
|
||||
@@ -167,7 +187,7 @@ def main():
|
||||
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, scaler, group_join)
|
||||
else:
|
||||
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
||||
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
|
||||
dist.destroy_process_group(group_join)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user