mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 17:39:24 +08:00
148 lines
5.2 KiB
Python
148 lines
5.2 KiB
Python
"""
|
|
The code is base on https://github.com/Pointcept/Pointcept
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import multiprocessing as mp
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
|
|
import utils.comm as comm
|
|
from utils.env import get_random_seed, set_seed
|
|
from utils.config import Config, DictAction
|
|
|
|
|
|
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
|
|
"""
|
|
Create a DistributedDataParallel model if there are >1 processes.
|
|
Args:
|
|
model: a torch.nn.Module
|
|
fp16_compression: add fp16 compression hooks to the ddp object.
|
|
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
|
|
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
|
|
"""
|
|
if comm.get_world_size() == 1:
|
|
return model
|
|
# kwargs['find_unused_parameters'] = True
|
|
if "device_ids" not in kwargs:
|
|
kwargs["device_ids"] = [comm.get_local_rank()]
|
|
if "output_device" not in kwargs:
|
|
kwargs["output_device"] = [comm.get_local_rank()]
|
|
ddp = DistributedDataParallel(model, **kwargs)
|
|
if fp16_compression:
|
|
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
|
|
|
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
|
|
return ddp
|
|
|
|
|
|
def worker_init_fn(worker_id, num_workers, rank, seed):
|
|
"""Worker init func for dataloader.
|
|
|
|
The seed of each worker equals to num_worker * rank + worker_id + user_seed
|
|
|
|
Args:
|
|
worker_id (int): Worker id.
|
|
num_workers (int): Number of workers.
|
|
rank (int): The rank of current process.
|
|
seed (int): The random seed to use.
|
|
"""
|
|
|
|
worker_seed = num_workers * rank + worker_id + seed
|
|
set_seed(worker_seed)
|
|
|
|
|
|
def default_argument_parser(epilog=None):
|
|
parser = argparse.ArgumentParser(
|
|
epilog=epilog
|
|
or f"""
|
|
Examples:
|
|
Run on single machine:
|
|
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
|
|
Change some config options:
|
|
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
|
|
Run on multiple machines:
|
|
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
|
|
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
|
|
""",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
parser.add_argument(
|
|
"--config-file", default="", metavar="FILE", help="path to config file"
|
|
)
|
|
parser.add_argument(
|
|
"--num-gpus", type=int, default=1, help="number of gpus *per machine*"
|
|
)
|
|
parser.add_argument(
|
|
"--num-machines", type=int, default=1, help="total number of machines"
|
|
)
|
|
parser.add_argument(
|
|
"--machine-rank",
|
|
type=int,
|
|
default=0,
|
|
help="the rank of this machine (unique per machine)",
|
|
)
|
|
# PyTorch still may leave orphan processes in multi-gpu training.
|
|
# Therefore we use a deterministic way to obtain port,
|
|
# so that users are aware of orphan processes by seeing the port occupied.
|
|
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
|
|
parser.add_argument(
|
|
"--dist-url",
|
|
# default="tcp://127.0.0.1:{}".format(port),
|
|
default="auto",
|
|
help="initialization URL for pytorch distributed backend. See "
|
|
"https://pytorch.org/docs/stable/distributed.html for details.",
|
|
)
|
|
parser.add_argument(
|
|
"--options", nargs="+", action=DictAction, help="custom options"
|
|
)
|
|
return parser
|
|
|
|
|
|
def default_config_parser(file_path, options):
|
|
# config name protocol: dataset_name/model_name-exp_name
|
|
if os.path.isfile(file_path):
|
|
cfg = Config.fromfile(file_path)
|
|
else:
|
|
sep = file_path.find("-")
|
|
cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :]))
|
|
|
|
if options is not None:
|
|
cfg.merge_from_dict(options)
|
|
|
|
if cfg.seed is None:
|
|
cfg.seed = get_random_seed()
|
|
|
|
cfg.data.train.loop = cfg.epoch // cfg.eval_epoch
|
|
|
|
os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True)
|
|
if not cfg.resume:
|
|
cfg.dump(os.path.join(cfg.save_path, "config.py"))
|
|
return cfg
|
|
|
|
|
|
def default_setup(cfg):
|
|
# scalar by world size
|
|
world_size = comm.get_world_size()
|
|
cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count()
|
|
cfg.num_worker_per_gpu = cfg.num_worker // world_size
|
|
assert cfg.batch_size % world_size == 0
|
|
assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0
|
|
assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0
|
|
cfg.batch_size_per_gpu = cfg.batch_size // world_size
|
|
cfg.batch_size_val_per_gpu = (
|
|
cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1
|
|
)
|
|
cfg.batch_size_test_per_gpu = (
|
|
cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1
|
|
)
|
|
# update data loop
|
|
assert cfg.epoch % cfg.eval_epoch == 0
|
|
# settle random seed
|
|
rank = comm.get_rank()
|
|
seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank
|
|
set_seed(seed)
|
|
return cfg
|