mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add func inference_bistream_vllm
This commit is contained in:
@@ -104,6 +104,7 @@ class CosyVoiceModel:
|
|||||||
with self.llm_context:
|
with self.llm_context:
|
||||||
if isinstance(text, Generator):
|
if isinstance(text, Generator):
|
||||||
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
||||||
|
if self.vllm_codec_engine is None:
|
||||||
for i in self.llm.inference_bistream(text=text,
|
for i in self.llm.inference_bistream(text=text,
|
||||||
prompt_text=prompt_text.to(self.device),
|
prompt_text=prompt_text.to(self.device),
|
||||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
@@ -111,6 +112,15 @@ class CosyVoiceModel:
|
|||||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
embedding=llm_embedding.to(self.device)):
|
embedding=llm_embedding.to(self.device)):
|
||||||
self.tts_speech_token_dict[uuid].append(i)
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
|
else:
|
||||||
|
for i in self.llm.inference_bistream_vllm(text=text,
|
||||||
|
prompt_text=prompt_text.to(self.device),
|
||||||
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=llm_embedding.to(self.device),
|
||||||
|
vllm_codec_engine=self.vllm_codec_engine):
|
||||||
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
else:
|
else:
|
||||||
for i in self.llm.inference(text=text.to(self.device),
|
for i in self.llm.inference(text=text.to(self.device),
|
||||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
|||||||
@@ -461,3 +461,130 @@ class Qwen2LM(TransformerLM):
|
|||||||
# in stream mode, yield token one by one
|
# in stream mode, yield token one by one
|
||||||
yield top_ids
|
yield top_ids
|
||||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference_bistream_vllm(
|
||||||
|
self,
|
||||||
|
text: Generator,
|
||||||
|
prompt_text: torch.Tensor,
|
||||||
|
prompt_text_len: torch.Tensor,
|
||||||
|
prompt_speech_token: torch.Tensor,
|
||||||
|
prompt_speech_token_len: torch.Tensor,
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
sampling: int = 25,
|
||||||
|
max_token_text_ratio: float = 20,
|
||||||
|
min_token_text_ratio: float = 2,
|
||||||
|
vllm_codec_engine=None,
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
|
||||||
|
device = prompt_text.device
|
||||||
|
# 1. prepare input
|
||||||
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
if prompt_speech_token_len != 0:
|
||||||
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||||
|
else:
|
||||||
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
||||||
|
lm_input = torch.concat([sos_eos_emb], dim=1)
|
||||||
|
|
||||||
|
# 2. iterate text
|
||||||
|
out_tokens = []
|
||||||
|
cache = None
|
||||||
|
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
||||||
|
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
||||||
|
next_fill_index = -1
|
||||||
|
from vllm import SamplingParams, RequestOutput
|
||||||
|
import uuid
|
||||||
|
sampling_params = SamplingParams(top_k=sampling,
|
||||||
|
stop_token_ids=[6561, 6563],
|
||||||
|
max_tokens=10000)
|
||||||
|
for this_text in text:
|
||||||
|
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
||||||
|
# prompt_speech_token_emb not empty, try append to lm_input
|
||||||
|
while prompt_speech_token_emb.size(1) != 0:
|
||||||
|
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||||
|
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
||||||
|
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
||||||
|
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
||||||
|
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
||||||
|
else:
|
||||||
|
logging.info('not enough text token to decode, wait for more')
|
||||||
|
break
|
||||||
|
# no prompt_speech_token_emb remain, can decode some speech token
|
||||||
|
if prompt_speech_token_emb.size(1) == 0:
|
||||||
|
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||||
|
logging.info('get fill token, need to append more text token')
|
||||||
|
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||||
|
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
||||||
|
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
||||||
|
if vllm_codec_engine is None and len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
|
||||||
|
lm_input = lm_input_text
|
||||||
|
else:
|
||||||
|
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
||||||
|
text_cache = text_cache[:, self.mix_ratio[0]:]
|
||||||
|
else:
|
||||||
|
logging.info('not enough text token to decode, wait for more')
|
||||||
|
continue
|
||||||
|
request_id = uuid.uuid4()
|
||||||
|
vllm_codec_engine.add_request(request_id,
|
||||||
|
{"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)},
|
||||||
|
sampling_params)
|
||||||
|
## generator
|
||||||
|
while True:
|
||||||
|
speech_token_break = False
|
||||||
|
request_outputs: List[RequestOutput] = vllm_codec_engine.step()
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if str(request_output.request_id) != str(request_id):
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"request output: {request_output}")
|
||||||
|
out_token = list(request_output.outputs[0].token_ids)[-1]
|
||||||
|
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
||||||
|
top_ids = self.speech_token_size + 2
|
||||||
|
next_fill_index += (self.mix_ratio[1] + 1)
|
||||||
|
else:
|
||||||
|
top_ids = out_token
|
||||||
|
if top_ids == self.speech_token_size + 2:
|
||||||
|
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||||
|
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
if top_ids >= self.speech_token_size:
|
||||||
|
if top_ids == self.speech_token_size + 2:
|
||||||
|
speech_token_break = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError('should not get token {}'.format(top_ids))
|
||||||
|
yield top_ids
|
||||||
|
token_embedding = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
lm_input = torch.concat([lm_input, token_embedding], dim=1)
|
||||||
|
|
||||||
|
if not vllm_codec_engine.has_unfinished_requests() or speech_token_break:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 3. final decode
|
||||||
|
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
||||||
|
logging.info('no more text token, decode until met eos')
|
||||||
|
request_id = uuid.uuid4()
|
||||||
|
vllm_codec_engine.add_request(request_id,
|
||||||
|
{"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)},
|
||||||
|
sampling_params)
|
||||||
|
## generator
|
||||||
|
while True:
|
||||||
|
speech_token_break = False
|
||||||
|
request_outputs: List[RequestOutput] = vllm_codec_engine.step()
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if str(request_output.request_id) != str(request_id):
|
||||||
|
continue
|
||||||
|
print(f"request output: {request_output}")
|
||||||
|
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
if top_ids >= self.speech_token_size:
|
||||||
|
if top_ids == self.speech_token_size:
|
||||||
|
speech_token_break = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError('should not get token {}'.format(top_ids))
|
||||||
|
# in stream mode, yield token one by one
|
||||||
|
yield top_ids
|
||||||
|
|
||||||
|
if not vllm_codec_engine.has_unfinished_requests() or speech_token_break:
|
||||||
|
break
|
||||||
|
|||||||
Reference in New Issue
Block a user