From 8b0566682b1f4a7c058dace11d801a2151a157b8 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 21 Jun 2024 21:10:48 +0800 Subject: [PATCH] use torch.no_grad() --- utils_vad.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils_vad.py b/utils_vad.py index 1532768..1ebf2d7 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -156,7 +156,6 @@ def save_audio(path: str, def init_jit_model(model_path: str, device=torch.device('cpu')): - torch.set_grad_enabled(False) model = torch.jit.load(model_path, map_location=device) model.eval() return model @@ -172,6 +171,7 @@ def make_visualization(probs, step): colormap='tab20') +@torch.no_grad() def get_speech_timestamps(audio: torch.Tensor, model, threshold: float = 0.5, @@ -484,6 +484,7 @@ class VADIterator: self.temp_end = 0 self.current_sample = 0 + @torch.no_grad() def __call__(self, x, return_seconds=False): """ x: torch.Tensor