diff --git a/src/silero_vad/data/silero_vad_16k.safetensors b/src/silero_vad/data/silero_vad_16k.safetensors new file mode 100755 index 0000000..367975d Binary files /dev/null and b/src/silero_vad/data/silero_vad_16k.safetensors differ diff --git a/src/silero_vad/tinygrad_model.py b/src/silero_vad/tinygrad_model.py new file mode 100644 index 0000000..38108c1 --- /dev/null +++ b/src/silero_vad/tinygrad_model.py @@ -0,0 +1,71 @@ +from tinygrad import nn + + +class TinySileroVAD: + def __init__(self): + """ + from tinygrad.nn.state import safe_load, load_state_dict + + tiny_model = TinySileroVAD() + state_dict = safe_load('data/silero_vad_16k.safetensors') + load_state_dict(tiny_model, state_dict) + """ + self.n_fft = 256 + self.stride = 128 + self.pad = 64 + self.cutoff = int(self.n_fft // 2) + 1 + + self.stft_conv = nn.Conv1d(1, 258, kernel_size=256, stride=self.stride, padding=0, bias=False) + self.conv1 = nn.Conv1d(129, 128, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv1d(128, 64, kernel_size=3, stride=2, padding=1) + self.conv3 = nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1) + self.conv4 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1) + + self.lstm_cell = nn.LSTMCell(128, 128) + self.final_conv = nn.Conv1d(128, 1, 1) + + def __call__(self, x, state=None): + """ + # full audio example: + import torch + from tinygrad import Tensor + + wav = read_audio(audio_path, sampling_rate=16000).unsqueeze(0) + num_samples = 512 + context_size = 64 + context = Tensor(np.zeros((1, context_size))).float() + outs = [] + state = None + if wav.shape[1] % num_samples: + pad_num = num_samples - (wav.shape[1] % num_samples) + wav = torch.nn.functional.pad(wav, (0, pad_num), 'constant', value=0.0) + + wav = torch.nn.functional.pad(wav, (context_size, 0)) + + wav = Tensor(wav.numpy()).float() + + for i in tqdm(range(context_size, wav.shape[1], num_samples)): + wavs_batch = wav[:, i-context_size:i+num_samples] + out_chunk, state = tiny_model(wavs_batch, state) + #outs.append(out_chunk.numpy()) + outs.append(out_chunk) + + predict = outs[0].cat(*outs[1:], dim=1).numpy() + + """ + if state is not None: + state = (state[0], state[1]) + x = x.pad((0, self.pad), "reflect").unsqueeze(1) + x = self.stft_conv(x) + x = (x[:, :self.cutoff, :]**2 + x[:, self.cutoff:, :]**2).sqrt() + x = self.conv1(x).relu() + x = self.conv2(x).relu() + x = self.conv3(x).relu() + x = self.conv4(x).relu().squeeze(-1) + h, c = self.lstm_cell(x, state) + x = h.unsqueeze(-1) + state = h.stack(c, dim=0) + x = x.relu() + x = self.final_conv(x).sigmoid() + x = x.squeeze(1).mean(axis=1).unsqueeze(1) + return x, state