mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 17:39:24 +08:00
feat: Initial commit
This commit is contained in:
147
engines/defaults.py
Normal file
147
engines/defaults.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
|
||||
import utils.comm as comm
|
||||
from utils.env import get_random_seed, set_seed
|
||||
from utils.config import Config, DictAction
|
||||
|
||||
|
||||
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
|
||||
"""
|
||||
Create a DistributedDataParallel model if there are >1 processes.
|
||||
Args:
|
||||
model: a torch.nn.Module
|
||||
fp16_compression: add fp16 compression hooks to the ddp object.
|
||||
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
|
||||
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
|
||||
"""
|
||||
if comm.get_world_size() == 1:
|
||||
return model
|
||||
# kwargs['find_unused_parameters'] = True
|
||||
if "device_ids" not in kwargs:
|
||||
kwargs["device_ids"] = [comm.get_local_rank()]
|
||||
if "output_device" not in kwargs:
|
||||
kwargs["output_device"] = [comm.get_local_rank()]
|
||||
ddp = DistributedDataParallel(model, **kwargs)
|
||||
if fp16_compression:
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
||||
|
||||
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
|
||||
return ddp
|
||||
|
||||
|
||||
def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||
"""Worker init func for dataloader.
|
||||
|
||||
The seed of each worker equals to num_worker * rank + worker_id + user_seed
|
||||
|
||||
Args:
|
||||
worker_id (int): Worker id.
|
||||
num_workers (int): Number of workers.
|
||||
rank (int): The rank of current process.
|
||||
seed (int): The random seed to use.
|
||||
"""
|
||||
|
||||
worker_seed = num_workers * rank + worker_id + seed
|
||||
set_seed(worker_seed)
|
||||
|
||||
|
||||
def default_argument_parser(epilog=None):
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog=epilog
|
||||
or f"""
|
||||
Examples:
|
||||
Run on single machine:
|
||||
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
|
||||
Change some config options:
|
||||
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
|
||||
Run on multiple machines:
|
||||
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
|
||||
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
|
||||
""",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-file", default="", metavar="FILE", help="path to config file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-gpus", type=int, default=1, help="number of gpus *per machine*"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-machines", type=int, default=1, help="total number of machines"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--machine-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="the rank of this machine (unique per machine)",
|
||||
)
|
||||
# PyTorch still may leave orphan processes in multi-gpu training.
|
||||
# Therefore we use a deterministic way to obtain port,
|
||||
# so that users are aware of orphan processes by seeing the port occupied.
|
||||
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
|
||||
parser.add_argument(
|
||||
"--dist-url",
|
||||
# default="tcp://127.0.0.1:{}".format(port),
|
||||
default="auto",
|
||||
help="initialization URL for pytorch distributed backend. See "
|
||||
"https://pytorch.org/docs/stable/distributed.html for details.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--options", nargs="+", action=DictAction, help="custom options"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def default_config_parser(file_path, options):
|
||||
# config name protocol: dataset_name/model_name-exp_name
|
||||
if os.path.isfile(file_path):
|
||||
cfg = Config.fromfile(file_path)
|
||||
else:
|
||||
sep = file_path.find("-")
|
||||
cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :]))
|
||||
|
||||
if options is not None:
|
||||
cfg.merge_from_dict(options)
|
||||
|
||||
if cfg.seed is None:
|
||||
cfg.seed = get_random_seed()
|
||||
|
||||
cfg.data.train.loop = cfg.epoch // cfg.eval_epoch
|
||||
|
||||
os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True)
|
||||
if not cfg.resume:
|
||||
cfg.dump(os.path.join(cfg.save_path, "config.py"))
|
||||
return cfg
|
||||
|
||||
|
||||
def default_setup(cfg):
|
||||
# scalar by world size
|
||||
world_size = comm.get_world_size()
|
||||
cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count()
|
||||
cfg.num_worker_per_gpu = cfg.num_worker // world_size
|
||||
assert cfg.batch_size % world_size == 0
|
||||
assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0
|
||||
assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0
|
||||
cfg.batch_size_per_gpu = cfg.batch_size // world_size
|
||||
cfg.batch_size_val_per_gpu = (
|
||||
cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1
|
||||
)
|
||||
cfg.batch_size_test_per_gpu = (
|
||||
cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1
|
||||
)
|
||||
# update data loop
|
||||
assert cfg.epoch % cfg.eval_epoch == 0
|
||||
# settle random seed
|
||||
rank = comm.get_rank()
|
||||
seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank
|
||||
set_seed(seed)
|
||||
return cfg
|
||||
Reference in New Issue
Block a user