From 369f3c2c18c68867a06cd4c1dd8d58acd619a1b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A6=BE=E6=81=AF?= Date: Wed, 16 Apr 2025 14:39:06 +0800 Subject: [PATCH] Update estimator count retrieval and memory pool limit in CosyVoice - Simplified estimator count retrieval in CosyVoice and CosyVoice2 classes to directly access the configs dictionary. - Adjusted memory pool limit in the ONNX to TensorRT conversion function from 8GB to 1GB for optimized resource management. --- cosyvoice/cli/cosyvoice.py | 4 ++-- cosyvoice/utils/file_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 8606530..7f0211d 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -54,7 +54,7 @@ class CosyVoice: '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: - self.estimator_count = configs['flow']['decoder']['estimator'].get('estimator_count', 1) + self.estimator_count = configs.get('estimator_count', 1) self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), self.fp16, self.estimator_count) @@ -180,7 +180,7 @@ class CosyVoice2(CosyVoice): if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: - self.estimator_count = configs['flow']['decoder']['estimator'].get('estimator_count', 1) + self.estimator_count = configs.get('estimator_count', 1) self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), self.fp16, self.estimator_count) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index ac7fe93..cf8ad03 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -61,7 +61,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16): network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB if fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile()