mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add llm bistream
This commit is contained in:
@@ -20,6 +20,7 @@ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
||||
from cosyvoice.utils.common import IGNORE_ID
|
||||
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
||||
from cosyvoice.utils.common import th_accuracy
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
class TransformerLM(torch.nn.Module):
|
||||
@@ -144,10 +145,14 @@ class TransformerLM(torch.nn.Module):
|
||||
sampling: int,
|
||||
ignore_eos: bool = True,
|
||||
):
|
||||
num_trials, max_trials = 0, 100
|
||||
while True:
|
||||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
||||
break
|
||||
num_trials += 1
|
||||
if num_trials > max_trials:
|
||||
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
||||
return top_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -239,7 +244,7 @@ class Qwen2Encoder(torch.nn.Module):
|
||||
return xs, new_cache
|
||||
|
||||
|
||||
class Qwen2LM(torch.nn.Module):
|
||||
class Qwen2LM(TransformerLM):
|
||||
def __init__(
|
||||
self,
|
||||
llm_input_size: int,
|
||||
@@ -249,8 +254,9 @@ class Qwen2LM(torch.nn.Module):
|
||||
sampling: Callable,
|
||||
length_normalized_loss: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
mix_ratio: List[int] = [5, 15],
|
||||
):
|
||||
super().__init__()
|
||||
torch.nn.Module.__init__(self)
|
||||
self.llm_input_size = llm_input_size
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
@@ -275,23 +281,7 @@ class Qwen2LM(torch.nn.Module):
|
||||
|
||||
# 4. sampling method
|
||||
self.sampling = sampling
|
||||
|
||||
def sampling_ids(
|
||||
self,
|
||||
weighted_scores: torch.Tensor,
|
||||
decoded_tokens: List,
|
||||
sampling: int,
|
||||
ignore_eos: bool = True,
|
||||
):
|
||||
num_trials, max_trials = 0, 100
|
||||
while True:
|
||||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
||||
break
|
||||
num_trials += 1
|
||||
if num_trials > max_trials:
|
||||
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
||||
return top_ids
|
||||
self.mix_ratio = mix_ratio
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
@@ -312,9 +302,6 @@ class Qwen2LM(torch.nn.Module):
|
||||
text_len += prompt_text_len
|
||||
text = self.llm.model.model.embed_tokens(text)
|
||||
|
||||
# 2. encode embedding
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
@@ -322,7 +309,7 @@ class Qwen2LM(torch.nn.Module):
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
@@ -345,3 +332,100 @@ class Qwen2LM(torch.nn.Module):
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_bistream(
|
||||
self,
|
||||
text: Generator,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
|
||||
device = prompt_text.device
|
||||
# 1. prepare input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb], dim=1)
|
||||
|
||||
# 2. iterate text
|
||||
out_tokens = []
|
||||
cache = None
|
||||
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
||||
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
||||
next_fill_index = -1
|
||||
for this_text in text:
|
||||
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
||||
# prompt_speech_token_emb not empty, try append to lm_input
|
||||
while prompt_speech_token_emb.size(1) != 0:
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
||||
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
||||
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
||||
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
||||
else:
|
||||
logging.info('not enough text token to decode, wait for more')
|
||||
break
|
||||
# no prompt_speech_token_emb remain, can decode some speech token
|
||||
if prompt_speech_token_emb.size(1) == 0:
|
||||
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||
logging.info('get fill token, need to append more text token')
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
||||
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
||||
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
||||
text_cache = text_cache[:, self.mix_ratio[0]:]
|
||||
else:
|
||||
logging.info('not enough text token to decode, wait for more')
|
||||
continue
|
||||
while True:
|
||||
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
||||
top_ids = self.speech_token_size + 2
|
||||
next_fill_index += (self.mix_ratio[1] + 1)
|
||||
else:
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
yield top_ids
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
# 3. final decode
|
||||
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
||||
logging.info('no more text token, decode until met eos')
|
||||
while True:
|
||||
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
Reference in New Issue
Block a user