diff --git a/examples/parallel_example.ipynb b/examples/parallel_example.ipynb new file mode 100644 index 0000000..9704291 --- /dev/null +++ b/examples/parallel_example.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install -q torchaudio\n", + "SAMPLING_RATE = 16000\n", + "import torch\n", + "from pprint import pprint\n", + "\n", + "torch.set_num_threads(1)\n", + "NUM_PROCESS=4 # set to the number of CPU cores in the machine\n", + "NUM_COPIES=8\n", + "# download wav files, make multiple copies\n", + "for idx in range(NUM_COPIES):\n", + " torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example{idx}.wav\")\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load VAD model from torch hub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", + " model='silero_vad',\n", + " force_reload=True,\n", + " onnx=False)\n", + "\n", + "(get_speech_timestamps,\n", + "save_audio,\n", + "read_audio,\n", + "VADIterator,\n", + "collect_chunks) = utils" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define a vad process function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import multiprocessing\n", + "\n", + "vad_models = dict()\n", + "\n", + "def init_model(model):\n", + " pid = multiprocessing.current_process().pid\n", + " model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", + " model='silero_vad',\n", + " force_reload=False,\n", + " onnx=False)\n", + " vad_models[pid] = model\n", + "\n", + "def vad_process(audio_file: str):\n", + " \n", + " pid = multiprocessing.current_process().pid\n", + " \n", + " with torch.no_grad():\n", + " wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n", + " return get_speech_timestamps(\n", + " wav,\n", + " vad_models[pid],\n", + " 0.46, # speech prob threshold\n", + " 16000, # sample rate\n", + " 300, # min speech duration in ms\n", + " 20, # max speech duration in seconds\n", + " 600, # min silence duration\n", + " 512, # window size\n", + " 200, # spech pad ms\n", + " )" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parallelization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from concurrent.futures import ProcessPoolExecutor, as_completed\n", + "\n", + "futures = []\n", + "\n", + "with ProcessPoolExecutor(max_workers=NUM_PROCESS, initializer=init_model, initargs=(model,)) as ex:\n", + " for i in range(NUM_COPIES):\n", + " futures.append(ex.submit(vad_process, f\"en_example{idx}.wav\"))\n", + "\n", + "for finished in as_completed(futures):\n", + " pprint(finished.result())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "diarization", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}