mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
BIN
src/silero_vad/data/silero_vad_16k.safetensors
Executable file
BIN
src/silero_vad/data/silero_vad_16k.safetensors
Executable file
Binary file not shown.
71
src/silero_vad/tinygrad_model.py
Normal file
71
src/silero_vad/tinygrad_model.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from tinygrad import nn
|
||||
|
||||
|
||||
class TinySileroVAD:
|
||||
def __init__(self):
|
||||
"""
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
|
||||
tiny_model = TinySileroVAD()
|
||||
state_dict = safe_load('data/silero_vad_16k.safetensors')
|
||||
load_state_dict(tiny_model, state_dict)
|
||||
"""
|
||||
self.n_fft = 256
|
||||
self.stride = 128
|
||||
self.pad = 64
|
||||
self.cutoff = int(self.n_fft // 2) + 1
|
||||
|
||||
self.stft_conv = nn.Conv1d(1, 258, kernel_size=256, stride=self.stride, padding=0, bias=False)
|
||||
self.conv1 = nn.Conv1d(129, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv1d(128, 64, kernel_size=3, stride=2, padding=1)
|
||||
self.conv3 = nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1)
|
||||
self.conv4 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.lstm_cell = nn.LSTMCell(128, 128)
|
||||
self.final_conv = nn.Conv1d(128, 1, 1)
|
||||
|
||||
def __call__(self, x, state=None):
|
||||
"""
|
||||
# full audio example:
|
||||
import torch
|
||||
from tinygrad import Tensor
|
||||
|
||||
wav = read_audio(audio_path, sampling_rate=16000).unsqueeze(0)
|
||||
num_samples = 512
|
||||
context_size = 64
|
||||
context = Tensor(np.zeros((1, context_size))).float()
|
||||
outs = []
|
||||
state = None
|
||||
if wav.shape[1] % num_samples:
|
||||
pad_num = num_samples - (wav.shape[1] % num_samples)
|
||||
wav = torch.nn.functional.pad(wav, (0, pad_num), 'constant', value=0.0)
|
||||
|
||||
wav = torch.nn.functional.pad(wav, (context_size, 0))
|
||||
|
||||
wav = Tensor(wav.numpy()).float()
|
||||
|
||||
for i in tqdm(range(context_size, wav.shape[1], num_samples)):
|
||||
wavs_batch = wav[:, i-context_size:i+num_samples]
|
||||
out_chunk, state = tiny_model(wavs_batch, state)
|
||||
#outs.append(out_chunk.numpy())
|
||||
outs.append(out_chunk)
|
||||
|
||||
predict = outs[0].cat(*outs[1:], dim=1).numpy()
|
||||
|
||||
"""
|
||||
if state is not None:
|
||||
state = (state[0], state[1])
|
||||
x = x.pad((0, self.pad), "reflect").unsqueeze(1)
|
||||
x = self.stft_conv(x)
|
||||
x = (x[:, :self.cutoff, :]**2 + x[:, self.cutoff:, :]**2).sqrt()
|
||||
x = self.conv1(x).relu()
|
||||
x = self.conv2(x).relu()
|
||||
x = self.conv3(x).relu()
|
||||
x = self.conv4(x).relu().squeeze(-1)
|
||||
h, c = self.lstm_cell(x, state)
|
||||
x = h.unsqueeze(-1)
|
||||
state = h.stack(c, dim=0)
|
||||
x = x.relu()
|
||||
x = self.final_conv(x).sigmoid()
|
||||
x = x.squeeze(1).mean(axis=1).unsqueeze(1)
|
||||
return x, state
|
||||
Reference in New Issue
Block a user