add a initializer

This commit is contained in:
Ziyuan Wang
2023-03-09 16:16:51 +00:00
parent 9865b3cb93
commit a6a067de44

View File

@@ -18,9 +18,12 @@
"SAMPLING_RATE = 16000\n", "SAMPLING_RATE = 16000\n",
"import torch\n", "import torch\n",
"from pprint import pprint\n", "from pprint import pprint\n",
"\n",
"torch.set_num_threads(1)\n", "torch.set_num_threads(1)\n",
"NUM_PROCESS=4 # set to the number of CPU cores in the machine\n",
"NUM_COPIES=8\n",
"# download wav files, make multiple copies\n", "# download wav files, make multiple copies\n",
"for idx in range(4):\n", "for idx in range(NUM_COPIES):\n",
" torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n" " torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n"
] ]
}, },
@@ -39,15 +42,15 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n", " model='silero_vad',\n",
" force_reload=True,\n", " force_reload=True,\n",
" onnx=False)\n", " onnx=False)\n",
"\n", "\n",
"(get_speech_timestamps,\n", "(get_speech_timestamps,\n",
" save_audio,\n", "save_audio,\n",
" read_audio,\n", "read_audio,\n",
" VADIterator,\n", "VADIterator,\n",
" collect_chunks) = utils" "collect_chunks) = utils"
] ]
}, },
{ {
@@ -64,7 +67,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def init_model(model):\n",
" global vad_model\n",
" vad_model = model\n",
"\n",
"def vad_process(audio_file: str):\n", "def vad_process(audio_file: str):\n",
" global vad_model\n",
" with torch.no_grad():\n", " with torch.no_grad():\n",
" wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n", " wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n",
" return get_speech_timestamps(\n", " return get_speech_timestamps(\n",
@@ -97,8 +105,10 @@
"from concurrent.futures import ProcessPoolExecutor, as_completed\n", "from concurrent.futures import ProcessPoolExecutor, as_completed\n",
"\n", "\n",
"futures = []\n", "futures = []\n",
"with ProcessPoolExecutor(max_workers=4) as ex:\n", "NUM_COPIES=20\n",
" for i in range(4):\n", "\n",
"with ProcessPoolExecutor(max_workers=NUM_PROCESS) as ex:\n",
" for i in range(NUM_COPIES):\n",
" futures.append(ex.submit(vad_process, f\"en_example{idx}.wav\"))\n", " futures.append(ex.submit(vad_process, f\"en_example{idx}.wav\"))\n",
"\n", "\n",
"for finished in as_completed(futures):\n", "for finished in as_completed(futures):\n",