add llm bistream

This commit is contained in:
lyuxiang.lx
2025-01-23 10:12:06 +08:00
parent 0b75c3a03f
commit 07e477519b
10 changed files with 163 additions and 39 deletions

View File

@@ -20,6 +20,7 @@ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from cosyvoice.utils.common import IGNORE_ID
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging
class TransformerLM(torch.nn.Module):
@@ -144,10 +145,14 @@ class TransformerLM(torch.nn.Module):
sampling: int,
ignore_eos: bool = True,
):
num_trials, max_trials = 0, 100
while True:
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
break
num_trials += 1
if num_trials > max_trials:
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
return top_ids
@torch.inference_mode()
@@ -239,7 +244,7 @@ class Qwen2Encoder(torch.nn.Module):
return xs, new_cache
class Qwen2LM(torch.nn.Module):
class Qwen2LM(TransformerLM):
def __init__(
self,
llm_input_size: int,
@@ -249,8 +254,9 @@ class Qwen2LM(torch.nn.Module):
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
mix_ratio: List[int] = [5, 15],
):
super().__init__()
torch.nn.Module.__init__(self)
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
@@ -275,23 +281,7 @@ class Qwen2LM(torch.nn.Module):
# 4. sampling method
self.sampling = sampling
def sampling_ids(
self,
weighted_scores: torch.Tensor,
decoded_tokens: List,
sampling: int,
ignore_eos: bool = True,
):
num_trials, max_trials = 0, 100
while True:
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
break
num_trials += 1
if num_trials > max_trials:
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
return top_ids
self.mix_ratio = mix_ratio
@torch.inference_mode()
def inference(
@@ -312,9 +302,6 @@ class Qwen2LM(torch.nn.Module):
text_len += prompt_text_len
text = self.llm.model.model.embed_tokens(text)
# 2. encode embedding
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
@@ -322,7 +309,7 @@ class Qwen2LM(torch.nn.Module):
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -345,3 +332,100 @@ class Qwen2LM(torch.nn.Module):
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
@torch.inference_mode()
def inference_bistream(
self,
text: Generator,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]:
device = prompt_text.device
# 1. prepare input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb], dim=1)
# 2. iterate text
out_tokens = []
cache = None
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
text_cache = self.llm.model.model.embed_tokens(prompt_text)
next_fill_index = -1
for this_text in text:
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
# prompt_speech_token_emb not empty, try append to lm_input
while prompt_speech_token_emb.size(1) != 0:
if text_cache.size(1) >= self.mix_ratio[0]:
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
else:
logging.info('not enough text token to decode, wait for more')
break
# no prompt_speech_token_emb remain, can decode some speech token
if prompt_speech_token_emb.size(1) == 0:
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
logging.info('get fill token, need to append more text token')
if text_cache.size(1) >= self.mix_ratio[0]:
lm_input_text = text_cache[:, :self.mix_ratio[0]]
logging.info('append {} text token'.format(lm_input_text.size(1)))
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
text_cache = text_cache[:, self.mix_ratio[0]:]
else:
logging.info('not enough text token to decode, wait for more')
continue
while True:
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
top_ids = self.speech_token_size + 2
next_fill_index += (self.mix_ratio[1] + 1)
else:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
if top_ids == self.speech_token_size + 2:
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size + 2:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
yield top_ids
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
# 3. final decode
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
logging.info('no more text token, decode until met eos')
while True:
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
# in stream mode, yield token one by one
yield top_ids
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)