mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
Integration with silero vad added
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
import time, logging
|
||||
from datetime import datetime
|
||||
import threading, collections, queue, os, os.path
|
||||
import deepspeech
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import wave
|
||||
import webrtcvad
|
||||
from halo import Halo
|
||||
from scipy import signal
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
logging.basicConfig(level=20)
|
||||
|
||||
@@ -152,18 +153,9 @@ class VADAudio(Audio):
|
||||
ring_buffer.clear()
|
||||
|
||||
def main(ARGS):
|
||||
# Load DeepSpeech model
|
||||
if os.path.isdir(ARGS.model):
|
||||
model_dir = ARGS.model
|
||||
ARGS.model = os.path.join(model_dir, 'output_graph.pb')
|
||||
ARGS.scorer = os.path.join(model_dir, ARGS.scorer)
|
||||
|
||||
print('Initializing model...')
|
||||
logging.info("ARGS.model: %s", ARGS.model)
|
||||
model = deepspeech.Model(ARGS.model)
|
||||
if ARGS.scorer:
|
||||
logging.info("ARGS.scorer: %s", ARGS.scorer)
|
||||
model.enableExternalScorer(ARGS.scorer)
|
||||
|
||||
|
||||
|
||||
# Start audio with VAD
|
||||
vad_audio = VADAudio(aggressiveness=ARGS.vad_aggressiveness,
|
||||
@@ -173,36 +165,56 @@ def main(ARGS):
|
||||
print("Listening (ctrl-C to exit)...")
|
||||
frames = vad_audio.vad_collector()
|
||||
|
||||
# load silero VAD
|
||||
torchaudio.set_audio_backend("soundfile")
|
||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
force_reload=True)
|
||||
(get_speech_ts,get_speech_ts_adaptive,_, read_audio,_, _, _) = utils
|
||||
|
||||
|
||||
# Stream from microphone to DeepSpeech using VAD
|
||||
spinner = None
|
||||
if not ARGS.nospinner:
|
||||
spinner = Halo(spinner='line')
|
||||
stream_context = model.createStream()
|
||||
wav_data = bytearray()
|
||||
for frame in frames:
|
||||
if frame is not None:
|
||||
if spinner: spinner.start()
|
||||
|
||||
logging.debug("streaming frame")
|
||||
stream_context.feedAudioContent(np.frombuffer(frame, np.int16))
|
||||
if ARGS.savewav: wav_data.extend(frame)
|
||||
wav_data.extend(frame)
|
||||
else:
|
||||
if spinner: spinner.stop()
|
||||
logging.debug("end utterence")
|
||||
if ARGS.savewav:
|
||||
vad_audio.write_wav(os.path.join(ARGS.savewav, datetime.now().strftime("savewav_%Y-%m-%d_%H-%M-%S_%f.wav")), wav_data)
|
||||
wav_data = bytearray()
|
||||
text = stream_context.finishStream()
|
||||
print("Recognized: %s" % text)
|
||||
if ARGS.keyboard:
|
||||
from pyautogui import typewrite
|
||||
typewrite(text)
|
||||
stream_context = model.createStream()
|
||||
print("webRTC has detected a possible speech")
|
||||
|
||||
newsound= np.frombuffer(wav_data,np.int16)
|
||||
audio_float32=Int2Float(newsound)
|
||||
time_stamps =get_speech_ts(audio_float32, model,num_steps=4)
|
||||
if(len(time_stamps)>0):
|
||||
print("silero VAD has detected a possible speech")
|
||||
if ARGS.savewav:
|
||||
vad_audio.write_wav(os.path.join(ARGS.savewav, datetime.now().strftime("savewav_%Y-%m-%d_%H-%M-%S_%f.wav")), wav_data)
|
||||
else:
|
||||
print("silero VAD has detected a noise")
|
||||
print()
|
||||
wav_data = bytearray()
|
||||
|
||||
|
||||
def Int2Float(sound):
|
||||
_sound = np.copy(sound) #
|
||||
abs_max = np.abs(_sound).max()
|
||||
_sound = _sound.astype('float32')
|
||||
if abs_max > 0:
|
||||
_sound *= 1/abs_max
|
||||
audio_float32 = torch.from_numpy(_sound.squeeze())
|
||||
return audio_float32
|
||||
|
||||
if __name__ == '__main__':
|
||||
DEFAULT_SAMPLE_RATE = 16000
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Stream from microphone to DeepSpeech using VAD")
|
||||
parser = argparse.ArgumentParser(description="Stream from microphone to webRTC and silero VAD")
|
||||
|
||||
parser.add_argument('-v', '--vad_aggressiveness', type=int, default=3,
|
||||
help="Set aggressiveness of VAD: an integer between 0 and 3, 0 being the least aggressive about filtering out non-speech, 3 the most aggressive. Default: 3")
|
||||
@@ -212,17 +224,10 @@ if __name__ == '__main__':
|
||||
help="Save .wav files of utterences to given directory")
|
||||
parser.add_argument('-f', '--file',
|
||||
help="Read from .wav file instead of microphone")
|
||||
|
||||
parser.add_argument('-m', '--model', required=True,
|
||||
help="Path to the model (protocol buffer binary file, or entire directory containing all standard-named files for model)")
|
||||
parser.add_argument('-s', '--scorer',
|
||||
help="Path to the external scorer file.")
|
||||
parser.add_argument('-d', '--device', type=int, default=None,
|
||||
help="Device input index (Int) as listed by pyaudio.PyAudio.get_device_info_by_index(). If not provided, falls back to PyAudio.get_default_device().")
|
||||
parser.add_argument('-r', '--rate', type=int, default=DEFAULT_SAMPLE_RATE,
|
||||
help=f"Input device sample rate. Default: {DEFAULT_SAMPLE_RATE}. Your device may require 44100.")
|
||||
parser.add_argument('-k', '--keyboard', action='store_true',
|
||||
help="Type output through system keyboard")
|
||||
ARGS = parser.parse_args()
|
||||
if ARGS.savewav: os.makedirs(ARGS.savewav, exist_ok=True)
|
||||
main(ARGS)
|
||||
Reference in New Issue
Block a user