use torch.no_grad()

This commit is contained in:
Thien Tran
2024-06-21 21:10:48 +08:00
parent 82342b8a4c
commit 8b0566682b

View File

@@ -156,7 +156,6 @@ def save_audio(path: str,
def init_jit_model(model_path: str,
device=torch.device('cpu')):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
@@ -172,6 +171,7 @@ def make_visualization(probs, step):
colormap='tab20')
@torch.no_grad()
def get_speech_timestamps(audio: torch.Tensor,
model,
threshold: float = 0.5,
@@ -484,6 +484,7 @@ class VADIterator:
self.temp_end = 0
self.current_sample = 0
@torch.no_grad()
def __call__(self, x, return_seconds=False):
"""
x: torch.Tensor