mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update dpo
This commit is contained in:
@@ -25,14 +25,16 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
|
||||
|
||||
class Executor:
|
||||
|
||||
def __init__(self, gan: bool = False):
|
||||
def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
|
||||
self.gan = gan
|
||||
self.ref_model = ref_model
|
||||
self.dpo_loss = dpo_loss
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.rank = int(os.environ.get('RANK', 0))
|
||||
self.device = torch.device('cuda:{}'.format(self.rank))
|
||||
|
||||
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
|
||||
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
@@ -44,6 +46,8 @@ class Executor:
|
||||
# torch.nn.parallel.DistributedDataParallel to be able to train
|
||||
# with uneven inputs across participating processes.
|
||||
model.train()
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
||||
with model_context():
|
||||
for batch_idx, batch_dict in enumerate(train_data_loader):
|
||||
@@ -65,7 +69,7 @@ class Executor:
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||
|
||||
Reference in New Issue
Block a user