mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update
This commit is contained in:
@@ -38,23 +38,21 @@ def main():
|
||||
args = get_args()
|
||||
|
||||
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
|
||||
|
||||
flow = cosyvoice.model.flow
|
||||
estimator = cosyvoice.model.flow.decoder.estimator
|
||||
|
||||
dtype = torch.float32 if not args.export_half else torch.float16
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seq_len = 1024
|
||||
hidden_size = flow.output_size
|
||||
seq_len = 256
|
||||
hidden_size = cosyvoice.model.flow.output_size
|
||||
x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
|
||||
mask = torch.zeros((batch_size, 1, seq_len), dtype=dtype, device=device)
|
||||
mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device)
|
||||
mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
|
||||
t = torch.tensor([0.], dtype=dtype, device=device)
|
||||
t = torch.rand((batch_size, ), dtype=dtype, device=device)
|
||||
spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device)
|
||||
cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
|
||||
|
||||
onnx_file_name = 'estimator_fp16.onnx' if args.export_half else 'estimator_fp32.onnx'
|
||||
onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx'
|
||||
onnx_file_path = os.path.join(args.model_dir, onnx_file_name)
|
||||
dummy_input = (x, mask, mu, t, spks, cond)
|
||||
|
||||
@@ -90,14 +88,24 @@ def main():
|
||||
print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.")
|
||||
os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}"
|
||||
|
||||
trt_file_name = 'estimator_fp16.plan' if args.export_half else 'estimator_fp32.plan'
|
||||
trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan'
|
||||
trt_file_path = os.path.join(args.model_dir, trt_file_name)
|
||||
|
||||
trtexec_cmd = f"{tensorrt_path}/bin/trtexec --onnx={onnx_file_path} --saveEngine={trt_file_path} " \
|
||||
"--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \
|
||||
"--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose"
|
||||
"--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \
|
||||
("--fp16" if args.export_half else "")
|
||||
# /ossfs/workspace/TensorRT-10.2.0.19/bin/trtexec --onnx=estimator_fp32.onnx --saveEngine=estimator_fp32.plan --minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 --maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose
|
||||
print("execute ", trtexec_cmd)
|
||||
|
||||
os.system(trtexec_cmd)
|
||||
|
||||
print("x.shape", x.shape)
|
||||
print("mask.shape", mask.shape)
|
||||
print("mu.shape", mu.shape)
|
||||
print("t.shape", t.shape)
|
||||
print("spks.shape", spks.shape)
|
||||
print("cond.shape", cond.shape)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user