From cde3cec6fa2eeda06b4ee1132d9342fa19e348e7 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Tue, 10 Dec 2024 14:50:20 +0800 Subject: [PATCH] add qwen lm --- cosyvoice/llm/llm.py | 125 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index cf9c231..f06b23a 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -15,6 +15,7 @@ from typing import Dict, Optional, Callable, List, Generator import torch from torch import nn import torch.nn.functional as F +from transformers import Qwen2ForCausalLM from torch.nn.utils.rnn import pad_sequence, unpad_sequence from cosyvoice.utils.common import IGNORE_ID from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss @@ -213,3 +214,127 @@ class TransformerLM(torch.nn.Module): out_tokens.append(top_ids) offset += lm_input.size(1) lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + +class Qwen2Encoder(torch.nn.Module): + def __init__(self, pretrain_path): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + + def forward_one_step(self, xs, masks, cache=None): + input_masks = masks[:, -1, :] + outs = self.model( + inputs_embeds=xs, + attention_mask=input_masks, + output_hidden_states=True, + return_dict=True, + use_cache=True, + past_key_values=cache, + ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + return xs, new_cache + + +class Qwen2LM(torch.nn.Module): + def __init__( + self, + llm_input_size: int, + llm_output_size: int, + speech_token_size: int, + llm: torch.nn.Module, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.llm_output_size = llm_output_size + self.speech_token_size = speech_token_size + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.fill_token = 2 + + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3) + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 3, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size) + + # 4. sampling method + self.sampling = sampling + + def sampling_ids( + self, + weighted_scores: torch.Tensor, + decoded_tokens: List, + sampling: int, + ignore_eos: bool = True, + ): + while True: + top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) + if (not ignore_eos) or (self.speech_token_size not in top_ids): + break + return top_ids + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + device = text.device + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.llm.model.model.embed_tokens(text) + + # 2. encode embedding + embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + out_tokens = [] + cache = None + for i in range(max_len): + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + if top_ids > self.speech_token_size: + continue + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) \ No newline at end of file