mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update vllm_codec_engine
This commit is contained in:
@@ -158,7 +158,7 @@ class CosyVoice2(CosyVoice):
|
|||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
gpu_memory_utilization=0.1)
|
gpu_memory_utilization=0.1)
|
||||||
self.vllm_codec_engine = LLMEngine.from_engine_args(engine_args)
|
self.vllm_codec_engine = LLMEngine.from_engine_args(engine_args)
|
||||||
self.model.llm.vllm_codec_engine = self.vllm_codec_engine
|
self.model.vllm_codec_engine = self.vllm_codec_engine
|
||||||
|
|
||||||
if load_jit:
|
if load_jit:
|
||||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ class CosyVoiceModel:
|
|||||||
self.mel_overlap_dict = {}
|
self.mel_overlap_dict = {}
|
||||||
self.flow_cache_dict = {}
|
self.flow_cache_dict = {}
|
||||||
self.hift_cache_dict = {}
|
self.hift_cache_dict = {}
|
||||||
|
self.vllm_codec_engine = None
|
||||||
|
|
||||||
def load(self, llm_model, flow_model, hift_model):
|
def load(self, llm_model, flow_model, hift_model):
|
||||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||||
@@ -117,7 +118,8 @@ class CosyVoiceModel:
|
|||||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
embedding=llm_embedding.to(self.device)):
|
embedding=llm_embedding.to(self.device),
|
||||||
|
vllm_codec_engine=self.vllm_codec_engine):
|
||||||
self.tts_speech_token_dict[uuid].append(i)
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
self.llm_end_dict[uuid] = True
|
self.llm_end_dict[uuid] = True
|
||||||
|
|
||||||
@@ -314,6 +316,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
self.tts_speech_token_dict = {}
|
self.tts_speech_token_dict = {}
|
||||||
self.llm_end_dict = {}
|
self.llm_end_dict = {}
|
||||||
self.hift_cache_dict = {}
|
self.hift_cache_dict = {}
|
||||||
|
self.vllm_codec_engine = None
|
||||||
|
|
||||||
def load_jit(self, flow_encoder_model):
|
def load_jit(self, flow_encoder_model):
|
||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
|
|||||||
@@ -282,7 +282,6 @@ class Qwen2LM(TransformerLM):
|
|||||||
# 4. sampling method
|
# 4. sampling method
|
||||||
self.sampling = sampling
|
self.sampling = sampling
|
||||||
self.mix_ratio = mix_ratio
|
self.mix_ratio = mix_ratio
|
||||||
self.vllm_codec_engine = None
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(
|
def inference(
|
||||||
@@ -297,6 +296,7 @@ class Qwen2LM(TransformerLM):
|
|||||||
sampling: int = 25,
|
sampling: int = 25,
|
||||||
max_token_text_ratio: float = 20,
|
max_token_text_ratio: float = 20,
|
||||||
min_token_text_ratio: float = 2,
|
min_token_text_ratio: float = 2,
|
||||||
|
vllm_codec_engine=None,
|
||||||
) -> Generator[torch.Tensor, None, None]:
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
device = text.device
|
device = text.device
|
||||||
text = torch.concat([prompt_text, text], dim=1)
|
text = torch.concat([prompt_text, text], dim=1)
|
||||||
@@ -317,22 +317,48 @@ class Qwen2LM(TransformerLM):
|
|||||||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||||
|
|
||||||
# 5. step by step decode
|
# 5. step by step decode
|
||||||
out_tokens = []
|
if vllm_codec_engine is None:
|
||||||
cache = None
|
out_tokens = []
|
||||||
for i in range(max_len):
|
cache = None
|
||||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
for i in range(max_len):
|
||||||
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||||
cache=cache)
|
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
cache=cache)
|
||||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
if top_ids == self.speech_token_size:
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||||
break
|
if top_ids == self.speech_token_size:
|
||||||
if top_ids > self.speech_token_size:
|
break
|
||||||
continue
|
if top_ids > self.speech_token_size:
|
||||||
# in stream mode, yield token one by one
|
continue
|
||||||
yield top_ids
|
# in stream mode, yield token one by one
|
||||||
out_tokens.append(top_ids)
|
yield top_ids
|
||||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
out_tokens.append(top_ids)
|
||||||
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
else:
|
||||||
|
from vllm import SamplingParams, RequestOutput
|
||||||
|
import uuid
|
||||||
|
sampling_params = SamplingParams(top_k=sampling,
|
||||||
|
stop_token_ids=[6561, 6563],
|
||||||
|
min_tokens=min_len,
|
||||||
|
max_tokens=max_len)
|
||||||
|
request_id = uuid.uuid4()
|
||||||
|
vllm_codec_engine.add_request(request_id,
|
||||||
|
{"prompt_embeds": lm_input.to(torch.bfloat16).to(device)},
|
||||||
|
sampling_params)
|
||||||
|
## generator
|
||||||
|
out_token_ids = []
|
||||||
|
while True:
|
||||||
|
request_outputs: List[RequestOutput] = vllm_codec_engine.step()
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if str(request_output.request_id) != str(request_id):
|
||||||
|
continue
|
||||||
|
if not request_output.finished:
|
||||||
|
print(f"Partial request output: {request_output}")
|
||||||
|
out_token = list(request_output.outputs[0].token_ids)[-1]
|
||||||
|
yield out_token
|
||||||
|
out_token_ids.append(out_token)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference_bistream(
|
def inference_bistream(
|
||||||
|
|||||||
Reference in New Issue
Block a user