This commit is contained in:
root
2025-10-08 18:13:09 +08:00
parent 7cbd490253
commit aceede59ba
5 changed files with 20 additions and 29 deletions

View File

@@ -57,10 +57,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if dtype == torch.float16:
config.set_flag(trt.BuilderFlag.FP16)
elif dtype == torch.bfloat16:
config.set_flag(trt.BuilderFlag.BF16)
elif dtype == torch.float32:
config.set_flag(trt.BuilderFlag.FP32)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_model, "rb") as f:
@@ -199,7 +196,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())