mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add vllm inference
This commit is contained in:
@@ -169,17 +169,18 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1):
|
||||
self.trt_context_pool = queue.Queue()
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
|
||||
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||
self.trt_context_pool.put(trt_context)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context):
|
||||
self.trt_context_pool.put(context)
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
|
||||
@@ -58,7 +58,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
@@ -122,6 +122,7 @@ def export_cosyvoice2_vllm(model, model_path, device):
|
||||
model.llm.model.config.tie_word_embeddings = False
|
||||
model.llm.model.config.use_bias = True
|
||||
model.llm.model.save_pretrained(model_path)
|
||||
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
model.llm.model.set_input_embeddings(embed_tokens)
|
||||
|
||||
Reference in New Issue
Block a user