fx bugs, del extractor

This commit is contained in:
adamnsandle
2020-12-14 13:47:57 +00:00
parent 23bcad96e5
commit bb02e92ff9
4 changed files with 113 additions and 346 deletions

Binary file not shown.

File diff suppressed because one or more lines are too long

109
utils.py
View File

@@ -7,23 +7,10 @@ import torch.nn.functional as F
from collections import deque from collections import deque
import numpy as np import numpy as np
from itertools import repeat from itertools import repeat
import onnxruntime
torchaudio.set_audio_backend("soundfile") # switch backend torchaudio.set_audio_backend("soundfile") # switch backend
def read_batch(audio_paths: List[str]):
return [read_audio(audio_path)
for audio_path
in audio_paths]
def split_into_batches(lst: List[str],
batch_size: int = 10):
return [lst[i:i + batch_size]
for i in
range(0, len(lst), batch_size)]
def read_audio(path: str, def read_audio(path: str,
target_sr: int = 16000): target_sr: int = 16000):
@@ -42,15 +29,10 @@ def read_audio(path: str,
assert sr == target_sr assert sr == target_sr
return wav.squeeze(0) return wav.squeeze(0)
def save_audio(path: str,
def prepare_model_input(batch: List[torch.Tensor], tensor: torch.Tensor,
device=torch.device('cpu')): sr: int):
max_seqlength = max(max([len(_) for _ in batch]), 12800) torchaudio.save(path, tensor, sr)
inputs = torch.zeros(len(batch), max_seqlength)
for i, wav in enumerate(batch):
inputs[i, :len(wav)].copy_(wav)
inputs = inputs.to(device)
return inputs
#def init_jit_model(model_url: str, #def init_jit_model(model_url: str,
@@ -72,8 +54,11 @@ def init_jit_model(model_path,
model.eval() model.eval()
return model return model
def init_onnx_model(model_path):
return onnxruntime.InferenceSession(model_path)
def get_speech_ts(wav, model, extractor,
def get_speech_ts(wav, model,
trig_sum=0.25, neg_trig_sum=0.01, trig_sum=0.25, neg_trig_sum=0.01,
num_steps=8, batch_size=200): num_steps=8, batch_size=200):
@@ -90,15 +75,13 @@ def get_speech_ts(wav, model, extractor,
to_concat.append(chunk) to_concat.append(chunk)
if len(to_concat) >= batch_size: if len(to_concat) >= batch_size:
chunks = torch.Tensor(torch.vstack(to_concat)) chunks = torch.Tensor(torch.vstack(to_concat))
with torch.no_grad(): out = validate(model, chunks)[-2]
out = model(extractor(chunks))[-2]
outs.append(out) outs.append(out)
to_concat = [] to_concat = []
if to_concat: if to_concat:
chunks = torch.Tensor(torch.vstack(to_concat)) chunks = torch.Tensor(torch.vstack(to_concat))
with torch.no_grad(): out = validate(model, chunks)[-2]
out = model(extractor(chunks))[-2]
outs.append(out) outs.append(out)
outs = torch.cat(outs, dim=0) outs = torch.cat(outs, dim=0)
@@ -125,37 +108,6 @@ def get_speech_ts(wav, model, extractor,
speeches.append(current_speech) speeches.append(current_speech)
return speeches return speeches
class STFTExtractor(nn.Module):
def __init__(self, sr=16000, win_size=0.02, mode='mag'):
super(STFTExtractor, self).__init__()
self.sr = sr
self.n_fft = int(sr * (win_size + 1e-8))
self.win_length = self.n_fft
self.hop_length = self.win_length // 2
self.mode = 'mag' if mode == '' else mode
def forward(self, wav):
# center==True because other frame-level features are centered by default in torch/librosa and we can't change this.
stft_sample = torch.stft(wav,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=True)
mag, phase = torchaudio.functional.magphase(stft_sample)
# It seems it is not a "mag", it is "power" (exp == 1).
# Also there is "energy" (exp == 2).
if self.mode == 'mag':
return mag
if self.mode == 'phase':
return phase
elif self.mode == 'magphase':
return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=1)
else:
raise NotImplementedError()
class VADiterator: class VADiterator:
def __init__(self, def __init__(self,
trig_sum=0.26, neg_trig_sum=0.01, trig_sum=0.26, neg_trig_sum=0.01,
@@ -214,7 +166,7 @@ class VADiterator:
return current_speech, self.current_name return current_speech, self.current_name
def state_generator(model, audios, extractor, def state_generator(model, audios,
onnx=False, onnx=False,
trig_sum=0.26, neg_trig_sum=0.01, trig_sum=0.26, neg_trig_sum=0.01,
num_steps=8, audios_in_stream=5): num_steps=8, audios_in_stream=5):
@@ -223,14 +175,8 @@ def state_generator(model, audios, extractor,
for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)] for_batch = [x.prepare_batch(*y) for x, y in zip(VADiters, current_pieces)]
batch = torch.cat(for_batch) batch = torch.cat(for_batch)
with torch.no_grad(): outs = validate(model, batch)
if onnx: vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
ort_inputs = {'input': to_numpy(extractor(batch))}
ort_outs = model.run(None, ort_inputs)
vad_outs = np.split(ort_outs[-2], audios_in_stream)
else:
outs = model(extractor(batch))
vad_outs = np.split(outs[-2].numpy(), audios_in_stream)
states = [] states = []
for x, y in zip(VADiters, vad_outs): for x, y in zip(VADiters, vad_outs):
@@ -273,7 +219,7 @@ def stream_imitator(audios, audios_in_stream):
values.append((out, wav_name)) values.append((out, wav_name))
yield values yield values
def single_audio_stream(model, audio, extractor, onnx=False, trig_sum=0.26, def single_audio_stream(model, audio, onnx=False, trig_sum=0.26,
neg_trig_sum=0.01, num_steps=8): neg_trig_sum=0.01, num_steps=8):
num_samples = 4000 num_samples = 4000
VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps) VADiter = VADiterator(trig_sum, neg_trig_sum, num_steps)
@@ -281,18 +227,25 @@ def single_audio_stream(model, audio, extractor, onnx=False, trig_sum=0.26,
wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)]) wav_chunks = iter([wav[i:i+num_samples] for i in range(0, len(wav), num_samples)])
for chunk in wav_chunks: for chunk in wav_chunks:
batch = VADiter.prepare_batch(chunk) batch = VADiter.prepare_batch(chunk)
with torch.no_grad(): outs = validate(model, batch)
if onnx: vad_outs = outs[-2]
ort_inputs = {'input': to_numpy(extractor(batch))}
ort_outs = model.run(None, ort_inputs)
vad_outs = ort_outs[-2]
else:
outs = model(extractor(batch))
vad_outs = outs[-2]
states = [] states = []
state = VADiter.state(vad_outs) state = VADiter.state(vad_outs)
if state[0]: if state[0]:
states.append(state[0]) states.append(state[0])
yield states yield states
def validate(model, inputs):
onnx = False
if type(model) == onnxruntime.capi.session.InferenceSession:
onnx = True
with torch.no_grad():
if onnx:
ort_inputs = {'input': inputs.cpu().numpy()}
outs = model.run(None, ort_inputs)
outs = [torch.Tensor(x) for x in outs]
else:
outs = model(inputs)
return outs