add flow trt wrapper

This commit is contained in:
lyuxiang.lx
2025-04-16 17:57:02 +08:00
parent 7f8bea2669
commit a442317d17
8 changed files with 615 additions and 56 deletions

View File

@@ -1,5 +1,6 @@
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Unility functions for Transformer."""
import queue
import random
from typing import List
@@ -164,3 +166,20 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask = (1.0 - mask) * -1.0e+10
return mask
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1):
self.trt_context_pool = queue.Queue()
self.trt_engine = trt_engine
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put(trt_context)
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context):
self.trt_context_pool.put(context)