mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add qwen lm
This commit is contained in:
@@ -15,6 +15,7 @@ from typing import Dict, Optional, Callable, List, Generator
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from transformers import Qwen2ForCausalLM
|
||||||
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
||||||
from cosyvoice.utils.common import IGNORE_ID
|
from cosyvoice.utils.common import IGNORE_ID
|
||||||
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
||||||
@@ -213,3 +214,127 @@ class TransformerLM(torch.nn.Module):
|
|||||||
out_tokens.append(top_ids)
|
out_tokens.append(top_ids)
|
||||||
offset += lm_input.size(1)
|
offset += lm_input.size(1)
|
||||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2Encoder(torch.nn.Module):
|
||||||
|
def __init__(self, pretrain_path):
|
||||||
|
super().__init__()
|
||||||
|
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
|
||||||
|
|
||||||
|
def forward_one_step(self, xs, masks, cache=None):
|
||||||
|
input_masks = masks[:, -1, :]
|
||||||
|
outs = self.model(
|
||||||
|
inputs_embeds=xs,
|
||||||
|
attention_mask=input_masks,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=True,
|
||||||
|
use_cache=True,
|
||||||
|
past_key_values=cache,
|
||||||
|
)
|
||||||
|
xs = outs.hidden_states[-1]
|
||||||
|
new_cache = outs.past_key_values
|
||||||
|
return xs, new_cache
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2LM(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_input_size: int,
|
||||||
|
llm_output_size: int,
|
||||||
|
speech_token_size: int,
|
||||||
|
llm: torch.nn.Module,
|
||||||
|
sampling: Callable,
|
||||||
|
length_normalized_loss: bool = True,
|
||||||
|
lsm_weight: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.llm_input_size = llm_input_size
|
||||||
|
self.llm_output_size = llm_output_size
|
||||||
|
self.speech_token_size = speech_token_size
|
||||||
|
|
||||||
|
# 2. build speech token language model related modules
|
||||||
|
self.sos_eos = 0
|
||||||
|
self.task_id = 1
|
||||||
|
self.fill_token = 2
|
||||||
|
|
||||||
|
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
||||||
|
self.llm = llm
|
||||||
|
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
|
||||||
|
self.criterion_ce = LabelSmoothingLoss(
|
||||||
|
size=speech_token_size + 3,
|
||||||
|
padding_idx=IGNORE_ID,
|
||||||
|
smoothing=lsm_weight,
|
||||||
|
normalize_length=length_normalized_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. [Optional] build speech token related modules
|
||||||
|
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
||||||
|
|
||||||
|
# 4. sampling method
|
||||||
|
self.sampling = sampling
|
||||||
|
|
||||||
|
def sampling_ids(
|
||||||
|
self,
|
||||||
|
weighted_scores: torch.Tensor,
|
||||||
|
decoded_tokens: List,
|
||||||
|
sampling: int,
|
||||||
|
ignore_eos: bool = True,
|
||||||
|
):
|
||||||
|
while True:
|
||||||
|
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||||
|
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
||||||
|
break
|
||||||
|
return top_ids
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_len: torch.Tensor,
|
||||||
|
prompt_text: torch.Tensor,
|
||||||
|
prompt_text_len: torch.Tensor,
|
||||||
|
prompt_speech_token: torch.Tensor,
|
||||||
|
prompt_speech_token_len: torch.Tensor,
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
sampling: int = 25,
|
||||||
|
max_token_text_ratio: float = 20,
|
||||||
|
min_token_text_ratio: float = 2,
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
device = text.device
|
||||||
|
text = torch.concat([prompt_text, text], dim=1)
|
||||||
|
text_len += prompt_text_len
|
||||||
|
text = self.llm.model.model.embed_tokens(text)
|
||||||
|
|
||||||
|
# 2. encode embedding
|
||||||
|
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||||
|
|
||||||
|
# 3. concat llm_input
|
||||||
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
if prompt_speech_token_len != 0:
|
||||||
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||||
|
else:
|
||||||
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||||
|
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||||
|
|
||||||
|
# 4. cal min/max_length
|
||||||
|
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||||
|
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||||
|
|
||||||
|
# 5. step by step decode
|
||||||
|
out_tokens = []
|
||||||
|
cache = None
|
||||||
|
for i in range(max_len):
|
||||||
|
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||||
|
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||||
|
cache=cache)
|
||||||
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||||
|
if top_ids == self.speech_token_size:
|
||||||
|
break
|
||||||
|
if top_ids > self.speech_token_size:
|
||||||
|
continue
|
||||||
|
# in stream mode, yield token one by one
|
||||||
|
yield top_ids
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
Reference in New Issue
Block a user