diff --git a/audio2mouth_cpu.py b/audio2mouth_cpu.py index c5ce12d..bec0e47 100644 --- a/audio2mouth_cpu.py +++ b/audio2mouth_cpu.py @@ -8,12 +8,13 @@ from extract_paraformer_feature import extract_para_feature from scipy import signal class Audio2Mouth(object): - def __init__(self): + def __init__(self. use_gpu): self.p_list = [str(ii) for ii in range(32)] model_path = './weights/model_1.onnx' - self.audio2mouth_model=onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider']) + provider = "CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider" + self.audio2mouth_model=onnxruntime.InferenceSession(model_path, providers=[provider]) self.w = np.array([1.0]).astype(np.float32) self.sp = np.array([2]).astype(np.int64)