This commit is contained in:
lyuxiang.lx
2025-01-13 10:30:13 +08:00
parent 59fa786769
commit 43f9e9ab20

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import json import json
import tensorrt as trt
import torchaudio import torchaudio
import logging import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING)
@@ -49,6 +48,7 @@ def load_wav(wav, target_sr):
def convert_onnx_to_trt(trt_model, onnx_model, fp16): def convert_onnx_to_trt(trt_model, onnx_model, fp16):
import tensorrt as trt
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)] _min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)] _opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)] _max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]