From 9a4aebb0ea59b53bfae6ae717dda988d72f7d2b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E8=81=AA?= Date: Sat, 1 Mar 2025 18:50:19 +0800 Subject: [PATCH] add func inference_bistream_vllm --- cosyvoice/cli/model.py | 24 +++++--- cosyvoice/llm/llm.py | 127 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 7 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index b5ea3af..8c5b6ba 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -104,13 +104,23 @@ class CosyVoiceModel: with self.llm_context: if isinstance(text, Generator): assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' - for i in self.llm.inference_bistream(text=text, - prompt_text=prompt_text.to(self.device), - prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), - prompt_speech_token=llm_prompt_speech_token.to(self.device), - prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), - embedding=llm_embedding.to(self.device)): - self.tts_speech_token_dict[uuid].append(i) + if self.vllm_codec_engine is None: + for i in self.llm.inference_bistream(text=text, + prompt_text=prompt_text.to(self.device), + prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + embedding=llm_embedding.to(self.device)): + self.tts_speech_token_dict[uuid].append(i) + else: + for i in self.llm.inference_bistream_vllm(text=text, + prompt_text=prompt_text.to(self.device), + prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + embedding=llm_embedding.to(self.device), + vllm_codec_engine=self.vllm_codec_engine): + self.tts_speech_token_dict[uuid].append(i) else: for i in self.llm.inference(text=text.to(self.device), text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 331881f..5d6e4db 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -461,3 +461,130 @@ class Qwen2LM(TransformerLM): # in stream mode, yield token one by one yield top_ids lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + @torch.inference_mode() + def inference_bistream_vllm( + self, + text: Generator, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + vllm_codec_engine=None, + ) -> Generator[torch.Tensor, None, None]: + + device = prompt_text.device + # 1. prepare input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb], dim=1) + + # 2. iterate text + out_tokens = [] + cache = None + # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5 + text_cache = self.llm.model.model.embed_tokens(prompt_text) + next_fill_index = -1 + from vllm import SamplingParams, RequestOutput + import uuid + sampling_params = SamplingParams(top_k=sampling, + stop_token_ids=[6561, 6563], + max_tokens=10000) + for this_text in text: + text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1) + # prompt_speech_token_emb not empty, try append to lm_input + while prompt_speech_token_emb.size(1) != 0: + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]] + logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) + lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1) + text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:] + else: + logging.info('not enough text token to decode, wait for more') + break + # no prompt_speech_token_emb remain, can decode some speech token + if prompt_speech_token_emb.size(1) == 0: + if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1): + logging.info('get fill token, need to append more text token') + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text = text_cache[:, :self.mix_ratio[0]] + logging.info('append {} text token'.format(lm_input_text.size(1))) + if vllm_codec_engine is None and len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2: + lm_input = lm_input_text + else: + lm_input = torch.concat([lm_input, lm_input_text], dim=1) + text_cache = text_cache[:, self.mix_ratio[0]:] + else: + logging.info('not enough text token to decode, wait for more') + continue + request_id = uuid.uuid4() + vllm_codec_engine.add_request(request_id, + {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)}, + sampling_params) + ## generator + while True: + speech_token_break = False + request_outputs: List[RequestOutput] = vllm_codec_engine.step() + for request_output in request_outputs: + if str(request_output.request_id) != str(request_id): + continue + + print(f"request output: {request_output}") + out_token = list(request_output.outputs[0].token_ids)[-1] + if next_fill_index != -1 and len(out_tokens) == next_fill_index: + top_ids = self.speech_token_size + 2 + next_fill_index += (self.mix_ratio[1] + 1) + else: + top_ids = out_token + if top_ids == self.speech_token_size + 2: + next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1 + logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index)) + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size + 2: + speech_token_break = True + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + yield top_ids + token_embedding = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + lm_input = torch.concat([lm_input, token_embedding], dim=1) + + if not vllm_codec_engine.has_unfinished_requests() or speech_token_break: + break + + # 3. final decode + lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1) + logging.info('no more text token, decode until met eos') + request_id = uuid.uuid4() + vllm_codec_engine.add_request(request_id, + {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)}, + sampling_params) + ## generator + while True: + speech_token_break = False + request_outputs: List[RequestOutput] = vllm_codec_engine.step() + for request_output in request_outputs: + if str(request_output.request_id) != str(request_id): + continue + print(f"request output: {request_output}") + top_ids = list(request_output.outputs[0].token_ids)[-1] + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size: + speech_token_break = True + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + # in stream mode, yield token one by one + yield top_ids + + if not vllm_codec_engine.has_unfinished_requests() or speech_token_break: + break