This commit is contained in:
snakers41
2020-12-11 13:57:22 +00:00
parent cef1644a5f
commit 5412e5a769
2 changed files with 24 additions and 19 deletions

View File

@@ -301,7 +301,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.3" "version": "3.7.7"
}, },
"toc": { "toc": {
"base_numbering": 1, "base_numbering": 1,

View File

@@ -73,11 +73,15 @@ def init_jit_model(model_path,
return model return model
def get_speech_ts(wav, model, extractor, trig_sum=0.25, neg_trig_sum=0.01, num_steps=8, batch_size=200): def get_speech_ts(wav, model, extractor,
trig_sum=0.25, neg_trig_sum=0.01,
num_steps=8, batch_size=200):
assert 4000 % num_steps == 0 assert 4000 % num_steps == 0
step = int(4000 / num_steps) step = int(4000 / num_steps) # stride / hop
outs = [] outs = []
to_concat = [] to_concat = []
for i in range(0, len(wav), step): for i in range(0, len(wav), step):
chunk = wav[i: i+4000] chunk = wav[i: i+4000]
if len(chunk) < 4000: if len(chunk) < 4000:
@@ -89,20 +93,21 @@ def get_speech_ts(wav, model, extractor, trig_sum=0.25, neg_trig_sum=0.01, num_s
out = model(extractor(chunks))[-2] out = model(extractor(chunks))[-2]
outs.append(out) outs.append(out)
to_concat = [] to_concat = []
if to_concat: if to_concat:
chunks = torch.Tensor(torch.vstack(to_concat)) chunks = torch.Tensor(torch.vstack(to_concat))
with torch.no_grad(): with torch.no_grad():
out = model(extractor(chunks))[-2] out = model(extractor(chunks))[-2]
outs.append(out) outs.append(out)
outs = torch.cat(outs, dim=0) outs = torch.cat(outs, dim=0)
buffer = deque(maxlen=num_steps) buffer = deque(maxlen=num_steps) # when max queue len is reach, first element is dropped
triggered = False triggered = False
speeches = [] speeches = []
current_speech = {} current_speech = {}
for i, predict in enumerate(outs[:, 1]):
for i, predict in enumerate(outs[:, 1]): # add name
buffer.append(predict) buffer.append(predict)
if (np.mean(buffer) >= trig_sum) and not triggered: if (np.mean(buffer) >= trig_sum) and not triggered:
triggered = True triggered = True
@@ -150,7 +155,9 @@ class STFTExtractor(nn.Module):
class VADiterator: class VADiterator:
def __init__(self, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8): def __init__(self,
trig_sum=0.26, neg_trig_sum=0.01,
num_steps=8):
self.num_steps = num_steps self.num_steps = num_steps
assert 4000 % num_steps == 0 assert 4000 % num_steps == 0
self.step = int(4000 / num_steps) self.step = int(4000 / num_steps)
@@ -162,14 +169,14 @@ class VADiterator:
self.trig_sum = trig_sum self.trig_sum = trig_sum
self.neg_trig_sum = neg_trig_sum self.neg_trig_sum = neg_trig_sum
self.current_name = '' self.current_name = ''
def refresh(self): def refresh(self):
self.prev = torch.zeros(4000) self.prev = torch.zeros(4000)
self.last = False self.last = False
self.triggered = False self.triggered = False
self.buffer = deque(maxlen=8) self.buffer = deque(maxlen=8)
self.num_frames = 0 self.num_frames = 0
def prepare_batch(self, wav_chunk, name=None): def prepare_batch(self, wav_chunk, name=None):
if (name is not None) and (name != self.current_name): if (name is not None) and (name != self.current_name):
self.refresh() self.refresh()
@@ -177,15 +184,15 @@ class VADiterator:
assert len(wav_chunk) <= 4000 assert len(wav_chunk) <= 4000
self.num_frames += len(wav_chunk) self.num_frames += len(wav_chunk)
if len(wav_chunk) < 4000: if len(wav_chunk) < 4000:
wav_chunk = F.pad(wav_chunk, (0, 4000 - len(wav_chunk))) # assume that short chunk means end of the audio wav_chunk = F.pad(wav_chunk, (0, 4000 - len(wav_chunk))) # assume that short chunk means end of the audio
self.last = True self.last = True
stacked = torch.hstack([self.prev, wav_chunk]) stacked = torch.hstack([self.prev, wav_chunk])
self.prev = wav_chunk self.prev = wav_chunk
overlap_chunks = [stacked[i:i+4000] for i in range(500, 4001, self.step)] # 500 step is good enough overlap_chunks = [stacked[i:i+4000] for i in range(500, 4001, self.step)] # 500 step is good enough
return torch.vstack(overlap_chunks) return torch.vstack(overlap_chunks)
def state(self, model_out): def state(self, model_out):
current_speech = {} current_speech = {}
for i, predict in enumerate(model_out[:, 1]): for i, predict in enumerate(model_out[:, 1]):
@@ -203,7 +210,6 @@ class VADiterator:
return current_speech, self.current_name return current_speech, self.current_name
def state_generator(model, audios, extractor, onnx=False, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8, audios_in_stream=5): def state_generator(model, audios, extractor, onnx=False, trig_sum=0.26, neg_trig_sum=0.01, num_steps=8, audios_in_stream=5):
VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)] VADiters = [VADiterator(trig_sum, neg_trig_sum, num_steps) for i in range(audios_in_stream)]
for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)): for i, current_pieces in enumerate(stream_imitator(audios, audios_in_stream)):
@@ -218,7 +224,7 @@ def state_generator(model, audios, extractor, onnx=False, trig_sum=0.26, neg_tri
else: else:
outs = model(extractor(batch)) outs = model(extractor(batch))
vad_outs = np.split(outs[-2].numpy(), audios_in_stream) vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
states = [] states = []
for x, y in zip(VADiters, vad_outs): for x, y in zip(VADiters, vad_outs):
cur_st = x.state(y) cur_st = x.state(y)
@@ -259,4 +265,3 @@ def stream_imitator(stereo, audios_in_stream):
values.append((out, wav_name)) values.append((out, wav_name))
yield values yield values