mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix dpo bug
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -420,8 +420,8 @@ class Qwen2LM(TransformerLM):
|
||||
rejected_lm_mask = rejected_lm_target == IGNORE_ID
|
||||
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
|
||||
rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
|
||||
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
|
||||
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
|
||||
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user