use a process specific copy of model

This commit is contained in:
Ziyuan Wang
2023-03-10 16:09:07 +00:00
parent 17903cb41d
commit 55c41abf46

View File

@@ -14,7 +14,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install -q torchaudio\n", "# !pip install -q torchaudio\n",
"SAMPLING_RATE = 16000\n", "SAMPLING_RATE = 16000\n",
"import torch\n", "import torch\n",
"from pprint import pprint\n", "from pprint import pprint\n",
@@ -67,17 +67,27 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import multiprocessing\n",
"\n",
"vad_models = dict()\n",
"\n",
"def init_model(model):\n", "def init_model(model):\n",
" global vad_model\n", " pid = multiprocessing.current_process().pid\n",
" vad_model = model\n", " model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=False,\n",
" onnx=False)\n",
" vad_models[pid] = model\n",
"\n", "\n",
"def vad_process(audio_file: str):\n", "def vad_process(audio_file: str):\n",
" global vad_model\n", " \n",
" pid = multiprocessing.current_process().pid\n",
" \n",
" with torch.no_grad():\n", " with torch.no_grad():\n",
" wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n", " wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n",
" return get_speech_timestamps(\n", " return get_speech_timestamps(\n",
" wav,\n", " wav,\n",
" vad_model,\n", " vad_models[pid],\n",
" 0.46, # speech prob threshold\n", " 0.46, # speech prob threshold\n",
" 16000, # sample rate\n", " 16000, # sample rate\n",
" 300, # min speech duration in ms\n", " 300, # min speech duration in ms\n",
@@ -132,8 +142,7 @@
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.15" "version": "3.9.15"
}, }
"orig_nbformat": 4
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2