mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update dpo
This commit is contained in:
@@ -300,7 +300,6 @@ class Qwen2LM(TransformerLM):
|
||||
# 5. vllm related
|
||||
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
||||
self.vllm_output_queue = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
|
||||
lm_target, lm_input = [], []
|
||||
@@ -378,6 +377,52 @@ class Qwen2LM(TransformerLM):
|
||||
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
||||
return {'loss': loss, 'acc': acc}
|
||||
|
||||
def forward_dpo(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
text_token = batch['text_token'].to(device)
|
||||
text_token_len = batch['text_token_len'].to(device)
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
reject_speech_token = batch['reject_speech_token'].to(device)
|
||||
reject_speech_token_len = batch['reject_speech_token_len'].to(device)
|
||||
|
||||
# 1. encode text_token
|
||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||
|
||||
# 2. encode speech_token
|
||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
|
||||
speech_token_combined = speech_token + reject_speech_token
|
||||
speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
|
||||
speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
|
||||
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
|
||||
|
||||
# 3. prepare llm_input/target
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2), speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
|
||||
lm_target = lm_target.to(device)
|
||||
|
||||
# 4. run lm forward
|
||||
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||
logits = self.llm_decoder(lm_output)
|
||||
chosen_logits = logits[:text_token.shape[0]]
|
||||
rejected_logits = logits[text_token.shape[0]:]
|
||||
chosen_lm_target = lm_target[:text_token.shape[0]]
|
||||
rejected_lm_target = lm_target[text_token.shape[0]:]
|
||||
loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
|
||||
acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
|
||||
|
||||
# 5. calculate dpo logits
|
||||
chosen_lm_mask = chosen_lm_target == IGNORE_ID
|
||||
rejected_lm_mask = rejected_lm_target == IGNORE_ID
|
||||
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
|
||||
rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
|
||||
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user