mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-05 18:09:22 +08:00
Merge pull request #467 from gau-nernst/fix_grad
Replace `torch.set_grad_enabled(False)` with `torch.no_grad()`
This commit is contained in:
@@ -156,7 +156,6 @@ def save_audio(path: str,
|
|||||||
|
|
||||||
def init_jit_model(model_path: str,
|
def init_jit_model(model_path: str,
|
||||||
device=torch.device('cpu')):
|
device=torch.device('cpu')):
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
model = torch.jit.load(model_path, map_location=device)
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
@@ -172,6 +171,7 @@ def make_visualization(probs, step):
|
|||||||
colormap='tab20')
|
colormap='tab20')
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def get_speech_timestamps(audio: torch.Tensor,
|
def get_speech_timestamps(audio: torch.Tensor,
|
||||||
model,
|
model,
|
||||||
threshold: float = 0.5,
|
threshold: float = 0.5,
|
||||||
@@ -484,6 +484,7 @@ class VADIterator:
|
|||||||
self.temp_end = 0
|
self.temp_end = 0
|
||||||
self.current_sample = 0
|
self.current_sample = 0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def __call__(self, x, return_seconds=False):
|
def __call__(self, x, return_seconds=False):
|
||||||
"""
|
"""
|
||||||
x: torch.Tensor
|
x: torch.Tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user