Merge pull request #5 from snakers4/adamnsandle

new model
This commit is contained in:
Alexander Veysov
2020-12-22 15:24:11 +03:00
committed by GitHub
4 changed files with 10 additions and 10 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -233,7 +233,7 @@
" ort_inputs = {'input': inputs.cpu().numpy()}\n", " ort_inputs = {'input': inputs.cpu().numpy()}\n",
" outs = model.run(None, ort_inputs)\n", " outs = model.run(None, ort_inputs)\n",
" outs = [torch.Tensor(x) for x in outs]\n", " outs = [torch.Tensor(x) for x in outs]\n",
" return outs" " return outs[0]"
] ]
}, },
{ {
@@ -405,5 +405,5 @@
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 1
} }

View File

@@ -52,7 +52,7 @@ def init_jit_model(model_path: str,
def get_speech_ts(wav: torch.Tensor, def get_speech_ts(wav: torch.Tensor,
model, model,
trig_sum: float = 0.25, trig_sum: float = 0.25,
neg_trig_sum: float = 0.02, neg_trig_sum: float = 0.1,
num_steps: int = 8, num_steps: int = 8,
batch_size: int = 200, batch_size: int = 200,
run_function=validate): run_function=validate):
@@ -70,13 +70,13 @@ def get_speech_ts(wav: torch.Tensor,
to_concat.append(chunk.unsqueeze(0)) to_concat.append(chunk.unsqueeze(0))
if len(to_concat) >= batch_size: if len(to_concat) >= batch_size:
chunks = torch.Tensor(torch.cat(to_concat, dim=0)) chunks = torch.Tensor(torch.cat(to_concat, dim=0))
out = run_function(model, chunks)[-2] out = run_function(model, chunks)
outs.append(out) outs.append(out)
to_concat = [] to_concat = []
if to_concat: if to_concat:
chunks = torch.Tensor(torch.cat(to_concat, dim=0)) chunks = torch.Tensor(torch.cat(to_concat, dim=0))
out = run_function(model, chunks)[-2] out = run_function(model, chunks)
outs.append(out) outs.append(out)
outs = torch.cat(outs, dim=0) outs = torch.cat(outs, dim=0)
@@ -107,7 +107,7 @@ def get_speech_ts(wav: torch.Tensor,
class VADiterator: class VADiterator:
def __init__(self, def __init__(self,
trig_sum: float = 0.26, trig_sum: float = 0.26,
neg_trig_sum: float = 0.02, neg_trig_sum: float = 0.1,
num_steps: int = 8): num_steps: int = 8):
self.num_samples = 4000 self.num_samples = 4000
self.num_steps = num_steps self.num_steps = num_steps
@@ -168,7 +168,7 @@ def state_generator(model,
audios: List[str], audios: List[str],
onnx: bool = False, onnx: bool = False,
trig_sum: float = 0.26, trig_sum: float = 0.26,
neg_trig_sum: float = 0.02, neg_trig_sum: float = 0.1,
num_steps: int = 8, num_steps: int = 8,
audios_in_stream: int = 2, audios_in_stream: int = 2,
run_function=validate): run_function=validate):
@@ -178,7 +178,7 @@ def state_generator(model,
batch = torch.cat(for_batch) batch = torch.cat(for_batch)
outs = run_function(model, batch) outs = run_function(model, batch)
vad_outs = torch.split(outs[-2], num_steps) vad_outs = torch.split(outs, num_steps)
states = [] states = []
for x, y in zip(VADiters, vad_outs): for x, y in zip(VADiters, vad_outs):
@@ -227,7 +227,7 @@ def single_audio_stream(model,
audio: str, audio: str,
onnx: bool = False, onnx: bool = False,
trig_sum: float = 0.26, trig_sum: float = 0.26,
neg_trig_sum: float = 0.02, neg_trig_sum: float = 0.1,
num_steps: int = 8, num_steps: int = 8,
run_function=validate): run_function=validate):
num_samples = 4000 num_samples = 4000
@@ -238,7 +238,7 @@ def single_audio_stream(model,
batch = VADiter.prepare_batch(chunk) batch = VADiter.prepare_batch(chunk)
outs = run_function(model, batch) outs = run_function(model, batch)
vad_outs = outs[-2] # this is very misleading vad_outs = outs # this is very misleading
states = [] states = []
state = VADiter.state(vad_outs) state = VADiter.state(vad_outs)