From f6a18ee07ae2cb09d19f239151543329c8b95326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E8=81=AA?= Date: Tue, 25 Feb 2025 19:40:30 +0800 Subject: [PATCH] update vllm_codec_engine --- cosyvoice/cli/cosyvoice.py | 2 +- cosyvoice/cli/model.py | 5 +++- cosyvoice/llm/llm.py | 60 +++++++++++++++++++++++++++----------- 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index f67c6d7..00a56a1 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -158,7 +158,7 @@ class CosyVoice2(CosyVoice): skip_tokenizer_init=True, gpu_memory_utilization=0.1) self.vllm_codec_engine = LLMEngine.from_engine_args(engine_args) - self.model.llm.vllm_codec_engine = self.vllm_codec_engine + self.model.vllm_codec_engine = self.vllm_codec_engine if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 115d7e1..e6ecd19 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -66,6 +66,7 @@ class CosyVoiceModel: self.mel_overlap_dict = {} self.flow_cache_dict = {} self.hift_cache_dict = {} + self.vllm_codec_engine = None def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) @@ -117,7 +118,8 @@ class CosyVoiceModel: prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), - embedding=llm_embedding.to(self.device)): + embedding=llm_embedding.to(self.device), + vllm_codec_engine=self.vllm_codec_engine): self.tts_speech_token_dict[uuid].append(i) self.llm_end_dict[uuid] = True @@ -314,6 +316,7 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict = {} self.llm_end_dict = {} self.hift_cache_dict = {} + self.vllm_codec_engine = None def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index a7f12a5..ac746f9 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -282,7 +282,6 @@ class Qwen2LM(TransformerLM): # 4. sampling method self.sampling = sampling self.mix_ratio = mix_ratio - self.vllm_codec_engine = None @torch.inference_mode() def inference( @@ -297,6 +296,7 @@ class Qwen2LM(TransformerLM): sampling: int = 25, max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, + vllm_codec_engine=None, ) -> Generator[torch.Tensor, None, None]: device = text.device text = torch.concat([prompt_text, text], dim=1) @@ -317,22 +317,48 @@ class Qwen2LM(TransformerLM): max_len = int((text_len - prompt_text_len) * max_token_text_ratio) # 5. step by step decode - out_tokens = [] - cache = None - for i in range(max_len): - y_pred, cache = self.llm.forward_one_step(lm_input, - masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), - cache=cache) - logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) - top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() - if top_ids == self.speech_token_size: - break - if top_ids > self.speech_token_size: - continue - # in stream mode, yield token one by one - yield top_ids - out_tokens.append(top_ids) - lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + if vllm_codec_engine is None: + out_tokens = [] + cache = None + for i in range(max_len): + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + if top_ids > self.speech_token_size: + continue + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + else: + from vllm import SamplingParams, RequestOutput + import uuid + sampling_params = SamplingParams(top_k=sampling, + stop_token_ids=[6561, 6563], + min_tokens=min_len, + max_tokens=max_len) + request_id = uuid.uuid4() + vllm_codec_engine.add_request(request_id, + {"prompt_embeds": lm_input.to(torch.bfloat16).to(device)}, + sampling_params) + ## generator + out_token_ids = [] + while True: + request_outputs: List[RequestOutput] = vllm_codec_engine.step() + for request_output in request_outputs: + if str(request_output.request_id) != str(request_id): + continue + if not request_output.finished: + print(f"Partial request output: {request_output}") + out_token = list(request_output.outputs[0].token_ids)[-1] + yield out_token + out_token_ids.append(out_token) + else: + break @torch.inference_mode() def inference_bistream(