add vllm inference

This commit is contained in:
lyuxiang.lx
2025-05-30 07:22:35 +00:00
parent 9f55c5af8f
commit 6dd68b9d5e
6 changed files with 105 additions and 64 deletions

View File

@@ -169,17 +169,18 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1):
self.trt_context_pool = queue.Queue()
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put(trt_context)
self.trt_context_pool.put([trt_context, trt_stream])
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context):
self.trt_context_pool.put(context)
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])

View File

@@ -58,7 +58,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
@@ -122,6 +122,7 @@ def export_cosyvoice2_vllm(model, model_path, device):
model.llm.model.config.tie_word_embeddings = False
model.llm.model.config.use_bias = True
model.llm.model.save_pretrained(model_path)
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
model.llm.model.config.vocab_size = tmp_vocab_size
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
model.llm.model.set_input_embeddings(embed_tokens)