diff --git a/examples/parallel_example.ipynb b/examples/parallel_example.ipynb index 5d1b365..9704291 100644 --- a/examples/parallel_example.ipynb +++ b/examples/parallel_example.ipynb @@ -14,7 +14,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q torchaudio\n", + "# !pip install -q torchaudio\n", "SAMPLING_RATE = 16000\n", "import torch\n", "from pprint import pprint\n", @@ -67,17 +67,27 @@ "metadata": {}, "outputs": [], "source": [ + "import multiprocessing\n", + "\n", + "vad_models = dict()\n", + "\n", "def init_model(model):\n", - " global vad_model\n", - " vad_model = model\n", + " pid = multiprocessing.current_process().pid\n", + " model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", + " model='silero_vad',\n", + " force_reload=False,\n", + " onnx=False)\n", + " vad_models[pid] = model\n", "\n", "def vad_process(audio_file: str):\n", - " global vad_model\n", + " \n", + " pid = multiprocessing.current_process().pid\n", + " \n", " with torch.no_grad():\n", " wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n", " return get_speech_timestamps(\n", " wav,\n", - " vad_model,\n", + " vad_models[pid],\n", " 0.46, # speech prob threshold\n", " 16000, # sample rate\n", " 300, # min speech duration in ms\n", @@ -132,8 +142,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2