add flow decoder tensorrt infer

This commit is contained in:
zhoubofan.zbf
2024-08-29 23:35:07 +08:00
parent 1d881df8b2
commit 5f21aef786
5 changed files with 149 additions and 19 deletions

View File

@@ -1,8 +1,103 @@
# TODO 跟export_jit一样的逻辑完成flow部分的estimator的onnx导出。
# tensorrt的安装方式再这里写一下步骤提示如下如果没有安装那么不要执行这个脚本提示用户先安装不给选择
import argparse
import logging
import os
import sys
logging.getLogger('matplotlib').setLevel(logging.WARNING)
try:
import tensorrt
except ImportError:
print('step1, 下载\n step2. 解压安装whl')
# 安装命令里tensosrt的根目录用环境变量导入比如os.environ['tensorrt_root_dir']/bin/exetrace然后python里subprocess里执行导出命令
# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
error_msg_zh = [
"step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x",
"step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl",
"step.3 将 tensorrt 的 lib 路径添加进环境变量中export LD_LIBRARY_PATH=${TensorRT-Path}/lib/"
]
print("\n".join(error_msg_zh))
sys.exit(1)
import torch
from cosyvoice.cli.cosyvoice import CosyVoice
def get_args():
parser = argparse.ArgumentParser(description='Export your model for deployment')
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/CosyVoice-300M',
help='Local path to the model directory')
parser.add_argument('--export_half',
action='store_true',
help='Export with half precision (FP16)')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
flow = cosyvoice.model.flow
estimator = cosyvoice.model.flow.decoder.estimator
dtype = torch.float32 if not args.export_half else torch.float16
device = torch.device("cuda")
batch_size = 1
seq_len = 1024
hidden_size = flow.output_size
x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
mask = torch.zeros((batch_size, 1, seq_len), dtype=dtype, device=device)
mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
t = torch.tensor([0.], dtype=dtype, device=device)
spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device)
cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
onnx_file_name = 'estimator_fp16.onnx' if args.export_half else 'estimator_fp32.onnx'
onnx_file_path = os.path.join(args.model_dir, onnx_file_name)
dummy_input = (x, mask, mu, t, spks, cond)
estimator = estimator.to(dtype)
torch.onnx.export(
estimator,
dummy_input,
onnx_file_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
output_names=['output'],
dynamic_axes={
'x': {2: 'seq_len'},
'mask': {2: 'seq_len'},
'mu': {2: 'seq_len'},
'cond': {2: 'seq_len'},
'output': {2: 'seq_len'},
}
)
tensorrt_path = os.environ.get('tensorrt_root_dir')
if not tensorrt_path:
raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.")
if not os.path.isdir(tensorrt_path):
raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.")
trt_lib_path = os.path.join(tensorrt_path, "lib")
if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''):
print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.")
os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}"
trt_file_name = 'estimator_fp16.plan' if args.export_half else 'estimator_fp32.plan'
trt_file_path = os.path.join(args.model_dir, trt_file_name)
trtexec_cmd = f"{tensorrt_path}/bin/trtexec --onnx={onnx_file_path} --saveEngine={trt_file_path} " \
"--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \
"--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose"
os.system(trtexec_cmd)
if __name__ == "__main__":
main()