Merge pull request #736 from snakers4/adamnsandle

add tinygrad model
This commit is contained in:
Dimitrii Voronin
2025-12-10 16:35:36 +03:00
committed by GitHub
2 changed files with 71 additions and 0 deletions

Binary file not shown.

View File

@@ -0,0 +1,71 @@
from tinygrad import nn
class TinySileroVAD:
def __init__(self):
"""
from tinygrad.nn.state import safe_load, load_state_dict
tiny_model = TinySileroVAD()
state_dict = safe_load('data/silero_vad_16k.safetensors')
load_state_dict(tiny_model, state_dict)
"""
self.n_fft = 256
self.stride = 128
self.pad = 64
self.cutoff = int(self.n_fft // 2) + 1
self.stft_conv = nn.Conv1d(1, 258, kernel_size=256, stride=self.stride, padding=0, bias=False)
self.conv1 = nn.Conv1d(129, 128, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv1d(128, 64, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1)
self.conv4 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1)
self.lstm_cell = nn.LSTMCell(128, 128)
self.final_conv = nn.Conv1d(128, 1, 1)
def __call__(self, x, state=None):
"""
# full audio example:
import torch
from tinygrad import Tensor
wav = read_audio(audio_path, sampling_rate=16000).unsqueeze(0)
num_samples = 512
context_size = 64
context = Tensor(np.zeros((1, context_size))).float()
outs = []
state = None
if wav.shape[1] % num_samples:
pad_num = num_samples - (wav.shape[1] % num_samples)
wav = torch.nn.functional.pad(wav, (0, pad_num), 'constant', value=0.0)
wav = torch.nn.functional.pad(wav, (context_size, 0))
wav = Tensor(wav.numpy()).float()
for i in tqdm(range(context_size, wav.shape[1], num_samples)):
wavs_batch = wav[:, i-context_size:i+num_samples]
out_chunk, state = tiny_model(wavs_batch, state)
#outs.append(out_chunk.numpy())
outs.append(out_chunk)
predict = outs[0].cat(*outs[1:], dim=1).numpy()
"""
if state is not None:
state = (state[0], state[1])
x = x.pad((0, self.pad), "reflect").unsqueeze(1)
x = self.stft_conv(x)
x = (x[:, :self.cutoff, :]**2 + x[:, self.cutoff:, :]**2).sqrt()
x = self.conv1(x).relu()
x = self.conv2(x).relu()
x = self.conv3(x).relu()
x = self.conv4(x).relu().squeeze(-1)
h, c = self.lstm_cell(x, state)
x = h.unsqueeze(-1)
state = h.stack(c, dim=0)
x = x.relu()
x = self.final_conv(x).sigmoid()
x = x.squeeze(1).mean(axis=1).unsqueeze(1)
return x, state