diff --git a/examples/parallel_example.ipynb b/examples/parallel_example.ipynb index 162e4f2..3db2a1c 100644 --- a/examples/parallel_example.ipynb +++ b/examples/parallel_example.ipynb @@ -18,9 +18,12 @@ "SAMPLING_RATE = 16000\n", "import torch\n", "from pprint import pprint\n", + "\n", "torch.set_num_threads(1)\n", + "NUM_PROCESS=4 # set to the number of CPU cores in the machine\n", + "NUM_COPIES=8\n", "# download wav files, make multiple copies\n", - "for idx in range(4):\n", + "for idx in range(NUM_COPIES):\n", " torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n" ] }, @@ -39,15 +42,15 @@ "outputs": [], "source": [ "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", - " model='silero_vad',\n", - " force_reload=True,\n", - " onnx=False)\n", + " model='silero_vad',\n", + " force_reload=True,\n", + " onnx=False)\n", "\n", "(get_speech_timestamps,\n", - " save_audio,\n", - " read_audio,\n", - " VADIterator,\n", - " collect_chunks) = utils" + "save_audio,\n", + "read_audio,\n", + "VADIterator,\n", + "collect_chunks) = utils" ] }, { @@ -64,7 +67,12 @@ "metadata": {}, "outputs": [], "source": [ + "def init_model(model):\n", + " global vad_model\n", + " vad_model = model\n", + "\n", "def vad_process(audio_file: str):\n", + " global vad_model\n", " with torch.no_grad():\n", " wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n", " return get_speech_timestamps(\n", @@ -97,8 +105,10 @@ "from concurrent.futures import ProcessPoolExecutor, as_completed\n", "\n", "futures = []\n", - "with ProcessPoolExecutor(max_workers=4) as ex:\n", - " for i in range(4):\n", + "NUM_COPIES=20\n", + "\n", + "with ProcessPoolExecutor(max_workers=NUM_PROCESS) as ex:\n", + " for i in range(NUM_COPIES):\n", " futures.append(ex.submit(vad_process, f\"en_example{idx}.wav\"))\n", "\n", "for finished in as_completed(futures):\n",