add vllm inference

This commit is contained in:
lyuxiang.lx
2025-05-30 07:22:35 +00:00
parent 9f55c5af8f
commit 6dd68b9d5e
6 changed files with 105 additions and 64 deletions

View File

@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import random
import time
import threading
from typing import Dict, Optional, Callable, List, Generator
import torch
from torch import nn
@@ -170,6 +173,7 @@ class TransformerLM(torch.nn.Module):
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
@@ -270,7 +274,6 @@ class Qwen2LM(TransformerLM):
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos_eos = 0
self.task_id = 1
@@ -292,6 +295,11 @@ class Qwen2LM(TransformerLM):
# 4. sampling method
self.sampling = sampling
self.mix_ratio = mix_ratio
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {}
self.lock = threading.Lock()
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
lm_target, lm_input = [], []
@@ -382,6 +390,7 @@ class Qwen2LM(TransformerLM):
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
@@ -402,22 +411,55 @@ class Qwen2LM(TransformerLM):
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
out_tokens = []
cache = None
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
yield token
@torch.inference_mode()
def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
if hasattr(self, 'vllm'):
from vllm import SamplingParams, RequestOutput
sampling_params = SamplingParams(top_k=sampling,
stop_token_ids=self.stop_token_ids,
min_tokens=min_len,
max_tokens=max_len)
with self.lock:
self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
self.vllm_output_queue[uuid] = queue.Queue()
out_tokens = []
while True:
with self.lock:
if self.vllm_output_queue[uuid].empty() is True:
request_outputs: List[RequestOutput] = self.vllm.step()
for request_output in request_outputs:
top_ids = list(request_output.outputs[0].token_ids)[-1]
self.vllm_output_queue[request_output.request_id].put(top_ids)
if self.vllm_output_queue[uuid].empty() is False:
top_ids = self.vllm_output_queue[uuid].get()
if top_ids in self.stop_token_ids:
break
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
time.sleep(0.001)
with self.lock:
self.vllm_output_queue.pop(uuid)
else:
out_tokens = []
cache = None
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
@torch.inference_mode()
def inference_bistream(