mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 01:49:23 +08:00
feat: Initial commit
This commit is contained in:
135
engines/launch.py
Normal file
135
engines/launch.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Launcher
|
||||
|
||||
modified from detectron2(https://github.com/facebookresearch/detectron2)
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from utils import comm
|
||||
|
||||
__all__ = ["DEFAULT_TIMEOUT", "launch"]
|
||||
|
||||
DEFAULT_TIMEOUT = timedelta(minutes=30)
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
import socket
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# Binding to port 0 will cause the OS to find an available port for us
|
||||
sock.bind(("", 0))
|
||||
port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
# NOTE: there is still a chance the port could be taken by other processes.
|
||||
return port
|
||||
|
||||
|
||||
def launch(
|
||||
main_func,
|
||||
num_gpus_per_machine,
|
||||
num_machines=1,
|
||||
machine_rank=0,
|
||||
dist_url=None,
|
||||
cfg=(),
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
):
|
||||
"""
|
||||
Launch multi-gpu or distributed training.
|
||||
This function must be called on all machines involved in the training.
|
||||
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
|
||||
Args:
|
||||
main_func: a function that will be called by `main_func(*args)`
|
||||
num_gpus_per_machine (int): number of GPUs per machine
|
||||
num_machines (int): the total number of machines
|
||||
machine_rank (int): the rank of this machine
|
||||
dist_url (str): url to connect to for distributed jobs, including protocol
|
||||
e.g. "tcp://127.0.0.1:8686".
|
||||
Can be set to "auto" to automatically select a free port on localhost
|
||||
timeout (timedelta): timeout of the distributed workers
|
||||
args (tuple): arguments passed to main_func
|
||||
"""
|
||||
world_size = num_machines * num_gpus_per_machine
|
||||
if world_size > 1:
|
||||
if dist_url == "auto":
|
||||
assert (
|
||||
num_machines == 1
|
||||
), "dist_url=auto not supported in multi-machine jobs."
|
||||
port = _find_free_port()
|
||||
dist_url = f"tcp://127.0.0.1:{port}"
|
||||
if num_machines > 1 and dist_url.startswith("file://"):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
|
||||
)
|
||||
|
||||
mp.spawn(
|
||||
_distributed_worker,
|
||||
nprocs=num_gpus_per_machine,
|
||||
args=(
|
||||
main_func,
|
||||
world_size,
|
||||
num_gpus_per_machine,
|
||||
machine_rank,
|
||||
dist_url,
|
||||
cfg,
|
||||
timeout,
|
||||
),
|
||||
daemon=False,
|
||||
)
|
||||
else:
|
||||
main_func(*cfg)
|
||||
|
||||
|
||||
def _distributed_worker(
|
||||
local_rank,
|
||||
main_func,
|
||||
world_size,
|
||||
num_gpus_per_machine,
|
||||
machine_rank,
|
||||
dist_url,
|
||||
cfg,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
):
|
||||
assert (
|
||||
torch.cuda.is_available()
|
||||
), "cuda is not available. Please check your installation."
|
||||
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
||||
try:
|
||||
dist.init_process_group(
|
||||
backend="NCCL",
|
||||
init_method=dist_url,
|
||||
world_size=world_size,
|
||||
rank=global_rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error("Process group URL: {}".format(dist_url))
|
||||
raise e
|
||||
|
||||
# Setup the local process group (which contains ranks within the same machine)
|
||||
assert comm._LOCAL_PROCESS_GROUP is None
|
||||
num_machines = world_size // num_gpus_per_machine
|
||||
for i in range(num_machines):
|
||||
ranks_on_i = list(
|
||||
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
|
||||
)
|
||||
pg = dist.new_group(ranks_on_i)
|
||||
if i == machine_rank:
|
||||
comm._LOCAL_PROCESS_GROUP = pg
|
||||
|
||||
assert num_gpus_per_machine <= torch.cuda.device_count()
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# synchronize is needed here to prevent a possible timeout after calling init_process_group
|
||||
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
|
||||
comm.synchronize()
|
||||
|
||||
main_func(*cfg)
|
||||
Reference in New Issue
Block a user