mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 01:49:23 +08:00
feat: Initial commit
This commit is contained in:
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
53
utils/cache.py
Normal file
53
utils/cache.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import os
|
||||
import SharedArray
|
||||
|
||||
try:
|
||||
from multiprocessing.shared_memory import ShareableList
|
||||
except ImportError:
|
||||
import warnings
|
||||
|
||||
warnings.warn("Please update python version >= 3.8 to enable shared_memory")
|
||||
import numpy as np
|
||||
|
||||
|
||||
def shared_array(name, var=None):
|
||||
if var is not None:
|
||||
# check exist
|
||||
if os.path.exists(f"/dev/shm/{name}"):
|
||||
return SharedArray.attach(f"shm://{name}")
|
||||
# create shared_array
|
||||
data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype)
|
||||
data[...] = var[...]
|
||||
data.flags.writeable = False
|
||||
else:
|
||||
data = SharedArray.attach(f"shm://{name}").copy()
|
||||
return data
|
||||
|
||||
|
||||
def shared_dict(name, var=None):
|
||||
name = str(name)
|
||||
assert "." not in name # '.' is used as sep flag
|
||||
data = {}
|
||||
if var is not None:
|
||||
assert isinstance(var, dict)
|
||||
keys = var.keys()
|
||||
# current version only cache np.array
|
||||
keys_valid = []
|
||||
for key in keys:
|
||||
if isinstance(var[key], np.ndarray):
|
||||
keys_valid.append(key)
|
||||
keys = keys_valid
|
||||
|
||||
ShareableList(sequence=keys, name=name + ".keys")
|
||||
for key in keys:
|
||||
if isinstance(var[key], np.ndarray):
|
||||
data[key] = shared_array(name=f"{name}.{key}", var=var[key])
|
||||
else:
|
||||
keys = list(ShareableList(name=name + ".keys"))
|
||||
for key in keys:
|
||||
data[key] = shared_array(name=f"{name}.{key}")
|
||||
return data
|
||||
192
utils/comm.py
Normal file
192
utils/comm.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import functools
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
_LOCAL_PROCESS_GROUP = None
|
||||
"""
|
||||
A torch process group which only includes processes that on the same machine as the current process.
|
||||
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
||||
"""
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The rank of the current process within the local (per-machine) process group.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
assert (
|
||||
_LOCAL_PROCESS_GROUP is not None
|
||||
), "Local process group is not created! Please use launch() to spawn processes!"
|
||||
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
||||
|
||||
|
||||
def get_local_size() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The size of the per-machine process group,
|
||||
i.e. the number of processes per machine.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def synchronize():
|
||||
"""
|
||||
Helper function to synchronize (barrier) among all processes when
|
||||
using distributed training
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
world_size = dist.get_world_size()
|
||||
if world_size == 1:
|
||||
return
|
||||
if dist.get_backend() == dist.Backend.NCCL:
|
||||
# This argument is needed to avoid warnings.
|
||||
# It's valid only for NCCL backend.
|
||||
dist.barrier(device_ids=[torch.cuda.current_device()])
|
||||
else:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_global_gloo_group():
|
||||
"""
|
||||
Return a process group based on gloo backend, containing all the ranks
|
||||
The result is cached.
|
||||
"""
|
||||
if dist.get_backend() == "nccl":
|
||||
return dist.new_group(backend="gloo")
|
||||
else:
|
||||
return dist.group.WORLD
|
||||
|
||||
|
||||
def all_gather(data, group=None):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
||||
Args:
|
||||
data: any picklable object
|
||||
group: a torch process group. By default, will use a group which
|
||||
contains all ranks on gloo backend.
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
if get_world_size() == 1:
|
||||
return [data]
|
||||
if group is None:
|
||||
group = (
|
||||
_get_global_gloo_group()
|
||||
) # use CPU group by default, to reduce GPU RAM usage.
|
||||
world_size = dist.get_world_size(group)
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
output = [None for _ in range(world_size)]
|
||||
dist.all_gather_object(output, data, group=group)
|
||||
return output
|
||||
|
||||
|
||||
def gather(data, dst=0, group=None):
|
||||
"""
|
||||
Run gather on arbitrary picklable data (not necessarily tensors).
|
||||
Args:
|
||||
data: any picklable object
|
||||
dst (int): destination rank
|
||||
group: a torch process group. By default, will use a group which
|
||||
contains all ranks on gloo backend.
|
||||
Returns:
|
||||
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
||||
an empty list.
|
||||
"""
|
||||
if get_world_size() == 1:
|
||||
return [data]
|
||||
if group is None:
|
||||
group = _get_global_gloo_group()
|
||||
world_size = dist.get_world_size(group=group)
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
rank = dist.get_rank(group=group)
|
||||
|
||||
if rank == dst:
|
||||
output = [None for _ in range(world_size)]
|
||||
dist.gather_object(data, output, dst=dst, group=group)
|
||||
return output
|
||||
else:
|
||||
dist.gather_object(data, None, dst=dst, group=group)
|
||||
return []
|
||||
|
||||
|
||||
def shared_random_seed():
|
||||
"""
|
||||
Returns:
|
||||
int: a random number that is the same across all workers.
|
||||
If workers need a shared RNG, they can use this shared seed to
|
||||
create one.
|
||||
All workers must call this function, otherwise it will deadlock.
|
||||
"""
|
||||
ints = np.random.randint(2**31)
|
||||
all_ints = all_gather(ints)
|
||||
return all_ints[0]
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Reduce the values in the dictionary from all processes so that process with rank
|
||||
0 has the reduced results.
|
||||
Args:
|
||||
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
||||
average (bool): whether to do average or sum
|
||||
Returns:
|
||||
a dict with the same keys as input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.reduce(values, dst=0)
|
||||
if dist.get_rank() == 0 and average:
|
||||
# only main process gets accumulated, so only divide by
|
||||
# world_size in this case
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
696
utils/config.py
Normal file
696
utils/config.py
Normal file
@@ -0,0 +1,696 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
import ast
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
import warnings
|
||||
from argparse import Action, ArgumentParser
|
||||
from collections import abc
|
||||
from importlib import import_module
|
||||
|
||||
from addict import Dict
|
||||
from yapf.yapflib.yapf_api import FormatCode
|
||||
|
||||
from .misc import import_modules_from_strings
|
||||
from .path import check_file_exist
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import regex as re
|
||||
else:
|
||||
import re
|
||||
|
||||
BASE_KEY = "_base_"
|
||||
DELETE_KEY = "_delete_"
|
||||
DEPRECATION_KEY = "_deprecation_"
|
||||
RESERVED_KEYS = ["filename", "text", "pretty_text"]
|
||||
|
||||
|
||||
class ConfigDict(Dict):
|
||||
def __missing__(self, name):
|
||||
raise KeyError(name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
value = super(ConfigDict, self).__getattr__(name)
|
||||
except KeyError:
|
||||
ex = AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no " f"attribute '{name}'"
|
||||
)
|
||||
except Exception as e:
|
||||
ex = e
|
||||
else:
|
||||
return value
|
||||
raise ex
|
||||
|
||||
|
||||
def add_args(parser, cfg, prefix=""):
|
||||
for k, v in cfg.items():
|
||||
if isinstance(v, str):
|
||||
parser.add_argument("--" + prefix + k)
|
||||
elif isinstance(v, int):
|
||||
parser.add_argument("--" + prefix + k, type=int)
|
||||
elif isinstance(v, float):
|
||||
parser.add_argument("--" + prefix + k, type=float)
|
||||
elif isinstance(v, bool):
|
||||
parser.add_argument("--" + prefix + k, action="store_true")
|
||||
elif isinstance(v, dict):
|
||||
add_args(parser, v, prefix + k + ".")
|
||||
elif isinstance(v, abc.Iterable):
|
||||
parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+")
|
||||
else:
|
||||
print(f"cannot parse key {prefix + k} of type {type(v)}")
|
||||
return parser
|
||||
|
||||
|
||||
class Config:
|
||||
"""A facility for config and config files.
|
||||
|
||||
It supports common file formats as configs: python/json/yaml. The interface
|
||||
is the same as a dict object and also allows access config values as
|
||||
attributes.
|
||||
|
||||
Example:
|
||||
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||
>>> cfg.a
|
||||
1
|
||||
>>> cfg.b
|
||||
{'b1': [0, 1]}
|
||||
>>> cfg.b.b1
|
||||
[0, 1]
|
||||
>>> cfg = Config.fromfile('tests/data/config/a.py')
|
||||
>>> cfg.filename
|
||||
"/home/kchen/projects/mmcv/tests/data/config/a.py"
|
||||
>>> cfg.item4
|
||||
'test'
|
||||
>>> cfg
|
||||
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
|
||||
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _validate_py_syntax(filename):
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
content = f.read()
|
||||
try:
|
||||
ast.parse(content)
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(
|
||||
"There are syntax errors in config " f"file {filename}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _substitute_predefined_vars(filename, temp_config_name):
|
||||
file_dirname = osp.dirname(filename)
|
||||
file_basename = osp.basename(filename)
|
||||
file_basename_no_extension = osp.splitext(file_basename)[0]
|
||||
file_extname = osp.splitext(filename)[1]
|
||||
support_templates = dict(
|
||||
fileDirname=file_dirname,
|
||||
fileBasename=file_basename,
|
||||
fileBasenameNoExtension=file_basename_no_extension,
|
||||
fileExtname=file_extname,
|
||||
)
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
config_file = f.read()
|
||||
for key, value in support_templates.items():
|
||||
regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
|
||||
value = value.replace("\\", "/")
|
||||
config_file = re.sub(regexp, value, config_file)
|
||||
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
|
||||
tmp_config_file.write(config_file)
|
||||
|
||||
@staticmethod
|
||||
def _pre_substitute_base_vars(filename, temp_config_name):
|
||||
"""Substitute base variable placehoders to string, so that parsing
|
||||
would work."""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
config_file = f.read()
|
||||
base_var_dict = {}
|
||||
regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
|
||||
base_vars = set(re.findall(regexp, config_file))
|
||||
for base_var in base_vars:
|
||||
randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
|
||||
base_var_dict[randstr] = base_var
|
||||
regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
|
||||
config_file = re.sub(regexp, f'"{randstr}"', config_file)
|
||||
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
|
||||
tmp_config_file.write(config_file)
|
||||
return base_var_dict
|
||||
|
||||
@staticmethod
|
||||
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
|
||||
"""Substitute variable strings to their actual values."""
|
||||
cfg = copy.deepcopy(cfg)
|
||||
|
||||
if isinstance(cfg, dict):
|
||||
for k, v in cfg.items():
|
||||
if isinstance(v, str) and v in base_var_dict:
|
||||
new_v = base_cfg
|
||||
for new_k in base_var_dict[v].split("."):
|
||||
new_v = new_v[new_k]
|
||||
cfg[k] = new_v
|
||||
elif isinstance(v, (list, tuple, dict)):
|
||||
cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
|
||||
elif isinstance(cfg, tuple):
|
||||
cfg = tuple(
|
||||
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
|
||||
)
|
||||
elif isinstance(cfg, list):
|
||||
cfg = [
|
||||
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
|
||||
]
|
||||
elif isinstance(cfg, str) and cfg in base_var_dict:
|
||||
new_v = base_cfg
|
||||
for new_k in base_var_dict[cfg].split("."):
|
||||
new_v = new_v[new_k]
|
||||
cfg = new_v
|
||||
|
||||
return cfg
|
||||
|
||||
@staticmethod
|
||||
def _file2dict(filename, use_predefined_variables=True):
|
||||
filename = osp.abspath(osp.expanduser(filename))
|
||||
check_file_exist(filename)
|
||||
fileExtname = osp.splitext(filename)[1]
|
||||
if fileExtname not in [".py", ".json", ".yaml", ".yml"]:
|
||||
raise IOError("Only py/yml/yaml/json type are supported now!")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||
temp_config_file = tempfile.NamedTemporaryFile(
|
||||
dir=temp_config_dir, suffix=fileExtname
|
||||
)
|
||||
if platform.system() == "Windows":
|
||||
temp_config_file.close()
|
||||
temp_config_name = osp.basename(temp_config_file.name)
|
||||
# Substitute predefined variables
|
||||
if use_predefined_variables:
|
||||
Config._substitute_predefined_vars(filename, temp_config_file.name)
|
||||
else:
|
||||
shutil.copyfile(filename, temp_config_file.name)
|
||||
# Substitute base variables from placeholders to strings
|
||||
base_var_dict = Config._pre_substitute_base_vars(
|
||||
temp_config_file.name, temp_config_file.name
|
||||
)
|
||||
|
||||
if filename.endswith(".py"):
|
||||
temp_module_name = osp.splitext(temp_config_name)[0]
|
||||
sys.path.insert(0, temp_config_dir)
|
||||
Config._validate_py_syntax(filename)
|
||||
mod = import_module(temp_module_name)
|
||||
sys.path.pop(0)
|
||||
cfg_dict = {
|
||||
name: value
|
||||
for name, value in mod.__dict__.items()
|
||||
if not name.startswith("__")
|
||||
}
|
||||
# delete imported module
|
||||
del sys.modules[temp_module_name]
|
||||
elif filename.endswith((".yml", ".yaml", ".json")):
|
||||
raise NotImplementedError
|
||||
# close temp file
|
||||
temp_config_file.close()
|
||||
|
||||
# check deprecation information
|
||||
if DEPRECATION_KEY in cfg_dict:
|
||||
deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
|
||||
warning_msg = (
|
||||
f"The config file {filename} will be deprecated " "in the future."
|
||||
)
|
||||
if "expected" in deprecation_info:
|
||||
warning_msg += f' Please use {deprecation_info["expected"]} ' "instead."
|
||||
if "reference" in deprecation_info:
|
||||
warning_msg += (
|
||||
" More information can be found at "
|
||||
f'{deprecation_info["reference"]}'
|
||||
)
|
||||
warnings.warn(warning_msg)
|
||||
|
||||
cfg_text = filename + "\n"
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
cfg_text += f.read()
|
||||
|
||||
if BASE_KEY in cfg_dict:
|
||||
cfg_dir = osp.dirname(filename)
|
||||
base_filename = cfg_dict.pop(BASE_KEY)
|
||||
base_filename = (
|
||||
base_filename if isinstance(base_filename, list) else [base_filename]
|
||||
)
|
||||
|
||||
cfg_dict_list = list()
|
||||
cfg_text_list = list()
|
||||
for f in base_filename:
|
||||
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
|
||||
cfg_dict_list.append(_cfg_dict)
|
||||
cfg_text_list.append(_cfg_text)
|
||||
|
||||
base_cfg_dict = dict()
|
||||
for c in cfg_dict_list:
|
||||
duplicate_keys = base_cfg_dict.keys() & c.keys()
|
||||
if len(duplicate_keys) > 0:
|
||||
raise KeyError(
|
||||
"Duplicate key is not allowed among bases. "
|
||||
f"Duplicate keys: {duplicate_keys}"
|
||||
)
|
||||
base_cfg_dict.update(c)
|
||||
|
||||
# Substitute base variables from strings to their actual values
|
||||
cfg_dict = Config._substitute_base_vars(
|
||||
cfg_dict, base_var_dict, base_cfg_dict
|
||||
)
|
||||
|
||||
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
|
||||
cfg_dict = base_cfg_dict
|
||||
|
||||
# merge cfg_text
|
||||
cfg_text_list.append(cfg_text)
|
||||
cfg_text = "\n".join(cfg_text_list)
|
||||
|
||||
return cfg_dict, cfg_text
|
||||
|
||||
@staticmethod
|
||||
def _merge_a_into_b(a, b, allow_list_keys=False):
|
||||
"""merge dict ``a`` into dict ``b`` (non-inplace).
|
||||
|
||||
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
|
||||
in-place modifications.
|
||||
|
||||
Args:
|
||||
a (dict): The source dict to be merged into ``b``.
|
||||
b (dict): The origin dict to be fetch keys from ``a``.
|
||||
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
||||
are allowed in source ``a`` and will replace the element of the
|
||||
corresponding index in b if b is a list. Default: False.
|
||||
|
||||
Returns:
|
||||
dict: The modified dict of ``b`` using ``a``.
|
||||
|
||||
Examples:
|
||||
# Normally merge a into b.
|
||||
>>> Config._merge_a_into_b(
|
||||
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
|
||||
{'obj': {'a': 2}}
|
||||
|
||||
# Delete b first and merge a into b.
|
||||
>>> Config._merge_a_into_b(
|
||||
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
|
||||
{'obj': {'a': 2}}
|
||||
|
||||
# b is a list
|
||||
>>> Config._merge_a_into_b(
|
||||
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
|
||||
[{'a': 2}, {'b': 2}]
|
||||
"""
|
||||
b = b.copy()
|
||||
for k, v in a.items():
|
||||
if allow_list_keys and k.isdigit() and isinstance(b, list):
|
||||
k = int(k)
|
||||
if len(b) <= k:
|
||||
raise KeyError(f"Index {k} exceeds the length of list {b}")
|
||||
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
||||
elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
|
||||
allowed_types = (dict, list) if allow_list_keys else dict
|
||||
if not isinstance(b[k], allowed_types):
|
||||
raise TypeError(
|
||||
f"{k}={v} in child config cannot inherit from base "
|
||||
f"because {k} is a dict in the child config but is of "
|
||||
f"type {type(b[k])} in base config. You may set "
|
||||
f"`{DELETE_KEY}=True` to ignore the base config"
|
||||
)
|
||||
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
||||
else:
|
||||
b[k] = v
|
||||
return b
|
||||
|
||||
@staticmethod
|
||||
def fromfile(filename, use_predefined_variables=True, import_custom_modules=True):
|
||||
cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
|
||||
if import_custom_modules and cfg_dict.get("custom_imports", None):
|
||||
import_modules_from_strings(**cfg_dict["custom_imports"])
|
||||
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def fromstring(cfg_str, file_format):
|
||||
"""Generate config from config str.
|
||||
|
||||
Args:
|
||||
cfg_str (str): Config str.
|
||||
file_format (str): Config file format corresponding to the
|
||||
config str. Only py/yml/yaml/json type are supported now!
|
||||
|
||||
Returns:
|
||||
obj:`Config`: Config obj.
|
||||
"""
|
||||
if file_format not in [".py", ".json", ".yaml", ".yml"]:
|
||||
raise IOError("Only py/yml/yaml/json type are supported now!")
|
||||
if file_format != ".py" and "dict(" in cfg_str:
|
||||
# check if users specify a wrong suffix for python
|
||||
warnings.warn('Please check "file_format", the file format may be .py')
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w", encoding="utf-8", suffix=file_format, delete=False
|
||||
) as temp_file:
|
||||
temp_file.write(cfg_str)
|
||||
# on windows, previous implementation cause error
|
||||
# see PR 1077 for details
|
||||
cfg = Config.fromfile(temp_file.name)
|
||||
os.remove(temp_file.name)
|
||||
return cfg
|
||||
|
||||
@staticmethod
|
||||
def auto_argparser(description=None):
|
||||
"""Generate argparser from config file automatically (experimental)"""
|
||||
partial_parser = ArgumentParser(description=description)
|
||||
partial_parser.add_argument("config", help="config file path")
|
||||
cfg_file = partial_parser.parse_known_args()[0].config
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
parser = ArgumentParser(description=description)
|
||||
parser.add_argument("config", help="config file path")
|
||||
add_args(parser, cfg)
|
||||
return parser, cfg
|
||||
|
||||
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
||||
if cfg_dict is None:
|
||||
cfg_dict = dict()
|
||||
elif not isinstance(cfg_dict, dict):
|
||||
raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
|
||||
for key in cfg_dict:
|
||||
if key in RESERVED_KEYS:
|
||||
raise KeyError(f"{key} is reserved for config file")
|
||||
|
||||
super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
|
||||
super(Config, self).__setattr__("_filename", filename)
|
||||
if cfg_text:
|
||||
text = cfg_text
|
||||
elif filename:
|
||||
with open(filename, "r") as f:
|
||||
text = f.read()
|
||||
else:
|
||||
text = ""
|
||||
super(Config, self).__setattr__("_text", text)
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._filename
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return self._text
|
||||
|
||||
@property
|
||||
def pretty_text(self):
|
||||
indent = 4
|
||||
|
||||
def _indent(s_, num_spaces):
|
||||
s = s_.split("\n")
|
||||
if len(s) == 1:
|
||||
return s_
|
||||
first = s.pop(0)
|
||||
s = [(num_spaces * " ") + line for line in s]
|
||||
s = "\n".join(s)
|
||||
s = first + "\n" + s
|
||||
return s
|
||||
|
||||
def _format_basic_types(k, v, use_mapping=False):
|
||||
if isinstance(v, str):
|
||||
v_str = f"'{v}'"
|
||||
else:
|
||||
v_str = str(v)
|
||||
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f"{k_str}: {v_str}"
|
||||
else:
|
||||
attr_str = f"{str(k)}={v_str}"
|
||||
attr_str = _indent(attr_str, indent)
|
||||
|
||||
return attr_str
|
||||
|
||||
def _format_list(k, v, use_mapping=False):
|
||||
# check if all items in the list are dict
|
||||
if all(isinstance(_, dict) for _ in v):
|
||||
v_str = "[\n"
|
||||
v_str += "\n".join(
|
||||
f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
|
||||
).rstrip(",")
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f"{k_str}: {v_str}"
|
||||
else:
|
||||
attr_str = f"{str(k)}={v_str}"
|
||||
attr_str = _indent(attr_str, indent) + "]"
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v, use_mapping)
|
||||
return attr_str
|
||||
|
||||
def _contain_invalid_identifier(dict_str):
|
||||
contain_invalid_identifier = False
|
||||
for key_name in dict_str:
|
||||
contain_invalid_identifier |= not str(key_name).isidentifier()
|
||||
return contain_invalid_identifier
|
||||
|
||||
def _format_dict(input_dict, outest_level=False):
|
||||
r = ""
|
||||
s = []
|
||||
|
||||
use_mapping = _contain_invalid_identifier(input_dict)
|
||||
if use_mapping:
|
||||
r += "{"
|
||||
for idx, (k, v) in enumerate(input_dict.items()):
|
||||
is_last = idx >= len(input_dict) - 1
|
||||
end = "" if outest_level or is_last else ","
|
||||
if isinstance(v, dict):
|
||||
v_str = "\n" + _format_dict(v)
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f"{k_str}: dict({v_str}"
|
||||
else:
|
||||
attr_str = f"{str(k)}=dict({v_str}"
|
||||
attr_str = _indent(attr_str, indent) + ")" + end
|
||||
elif isinstance(v, list):
|
||||
attr_str = _format_list(k, v, use_mapping) + end
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v, use_mapping) + end
|
||||
|
||||
s.append(attr_str)
|
||||
r += "\n".join(s)
|
||||
if use_mapping:
|
||||
r += "}"
|
||||
return r
|
||||
|
||||
cfg_dict = self._cfg_dict.to_dict()
|
||||
text = _format_dict(cfg_dict, outest_level=True)
|
||||
# copied from setup.cfg
|
||||
yapf_style = dict(
|
||||
based_on_style="pep8",
|
||||
blank_line_before_nested_class_or_def=True,
|
||||
split_before_expression_after_opening_paren=True,
|
||||
)
|
||||
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
|
||||
|
||||
return text
|
||||
|
||||
def __repr__(self):
|
||||
return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cfg_dict)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._cfg_dict, name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self._cfg_dict.__getitem__(name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if isinstance(value, dict):
|
||||
value = ConfigDict(value)
|
||||
self._cfg_dict.__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
if isinstance(value, dict):
|
||||
value = ConfigDict(value)
|
||||
self._cfg_dict.__setitem__(name, value)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._cfg_dict)
|
||||
|
||||
def __getstate__(self):
|
||||
return (self._cfg_dict, self._filename, self._text)
|
||||
|
||||
def __setstate__(self, state):
|
||||
_cfg_dict, _filename, _text = state
|
||||
super(Config, self).__setattr__("_cfg_dict", _cfg_dict)
|
||||
super(Config, self).__setattr__("_filename", _filename)
|
||||
super(Config, self).__setattr__("_text", _text)
|
||||
|
||||
def dump(self, file=None):
|
||||
cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict()
|
||||
if self.filename.endswith(".py"):
|
||||
if file is None:
|
||||
return self.pretty_text
|
||||
else:
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
f.write(self.pretty_text)
|
||||
else:
|
||||
import mmcv
|
||||
|
||||
if file is None:
|
||||
file_format = self.filename.split(".")[-1]
|
||||
return mmcv.dump(cfg_dict, file_format=file_format)
|
||||
else:
|
||||
mmcv.dump(cfg_dict, file)
|
||||
|
||||
def merge_from_dict(self, options, allow_list_keys=True):
|
||||
"""Merge list into cfg_dict.
|
||||
|
||||
Merge the dict parsed by MultipleKVAction into this cfg.
|
||||
|
||||
Examples:
|
||||
>>> options = {'models.backbone.depth': 50,
|
||||
... 'models.backbone.with_cp':True}
|
||||
>>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet'))))
|
||||
>>> cfg.merge_from_dict(options)
|
||||
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
>>> assert cfg_dict == dict(
|
||||
... models=dict(backbone=dict(depth=50, with_cp=True)))
|
||||
|
||||
# Merge list element
|
||||
>>> cfg = Config(dict(pipeline=[
|
||||
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
|
||||
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
|
||||
>>> cfg.merge_from_dict(options, allow_list_keys=True)
|
||||
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
>>> assert cfg_dict == dict(pipeline=[
|
||||
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
|
||||
|
||||
Args:
|
||||
options (dict): dict of configs to merge from.
|
||||
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
||||
are allowed in ``options`` and will replace the element of the
|
||||
corresponding index in the config if the config is a list.
|
||||
Default: True.
|
||||
"""
|
||||
option_cfg_dict = {}
|
||||
for full_key, v in options.items():
|
||||
d = option_cfg_dict
|
||||
key_list = full_key.split(".")
|
||||
for subkey in key_list[:-1]:
|
||||
d.setdefault(subkey, ConfigDict())
|
||||
d = d[subkey]
|
||||
subkey = key_list[-1]
|
||||
d[subkey] = v
|
||||
|
||||
cfg_dict = super(Config, self).__getattribute__("_cfg_dict")
|
||||
super(Config, self).__setattr__(
|
||||
"_cfg_dict",
|
||||
Config._merge_a_into_b(
|
||||
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DictAction(Action):
|
||||
"""
|
||||
argparse action to split an argument into KEY=VALUE form
|
||||
on the first = and append to a dictionary. List options can
|
||||
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
|
||||
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
|
||||
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_int_float_bool(val):
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
pass
|
||||
if val.lower() in ["true", "false"]:
|
||||
return True if val.lower() == "true" else False
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def _parse_iterable(val):
|
||||
"""Parse iterable values in the string.
|
||||
|
||||
All elements inside '()' or '[]' are treated as iterable values.
|
||||
|
||||
Args:
|
||||
val (str): Value string.
|
||||
|
||||
Returns:
|
||||
list | tuple: The expanded list or tuple from the string.
|
||||
|
||||
Examples:
|
||||
>>> DictAction._parse_iterable('1,2,3')
|
||||
[1, 2, 3]
|
||||
>>> DictAction._parse_iterable('[a, b, c]')
|
||||
['a', 'b', 'c']
|
||||
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
|
||||
[(1, 2, 3), ['a', 'b'], 'c']
|
||||
"""
|
||||
|
||||
def find_next_comma(string):
|
||||
"""Find the position of next comma in the string.
|
||||
|
||||
If no ',' is found in the string, return the string length. All
|
||||
chars inside '()' and '[]' are treated as one element and thus ','
|
||||
inside these brackets are ignored.
|
||||
"""
|
||||
assert (string.count("(") == string.count(")")) and (
|
||||
string.count("[") == string.count("]")
|
||||
), f"Imbalanced brackets exist in {string}"
|
||||
end = len(string)
|
||||
for idx, char in enumerate(string):
|
||||
pre = string[:idx]
|
||||
# The string before this ',' is balanced
|
||||
if (
|
||||
(char == ",")
|
||||
and (pre.count("(") == pre.count(")"))
|
||||
and (pre.count("[") == pre.count("]"))
|
||||
):
|
||||
end = idx
|
||||
break
|
||||
return end
|
||||
|
||||
# Strip ' and " characters and replace whitespace.
|
||||
val = val.strip("'\"").replace(" ", "")
|
||||
is_tuple = False
|
||||
if val.startswith("(") and val.endswith(")"):
|
||||
is_tuple = True
|
||||
val = val[1:-1]
|
||||
elif val.startswith("[") and val.endswith("]"):
|
||||
val = val[1:-1]
|
||||
elif "," not in val:
|
||||
# val is a single value
|
||||
return DictAction._parse_int_float_bool(val)
|
||||
|
||||
values = []
|
||||
while len(val) > 0:
|
||||
comma_idx = find_next_comma(val)
|
||||
element = DictAction._parse_iterable(val[:comma_idx])
|
||||
values.append(element)
|
||||
val = val[comma_idx + 1 :]
|
||||
if is_tuple:
|
||||
values = tuple(values)
|
||||
return values
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
options = {}
|
||||
for kv in values:
|
||||
key, val = kv.split("=", maxsplit=1)
|
||||
options[key] = self._parse_iterable(val)
|
||||
setattr(namespace, self.dest, options)
|
||||
33
utils/env.py
Normal file
33
utils/env.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
seed = (
|
||||
os.getpid()
|
||||
+ int(datetime.now().strftime("%S%f"))
|
||||
+ int.from_bytes(os.urandom(2), "big")
|
||||
)
|
||||
return seed
|
||||
|
||||
|
||||
def set_seed(seed=None):
|
||||
if seed is None:
|
||||
seed = get_random_seed()
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
585
utils/events.py
Normal file
585
utils/events.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
|
||||
__all__ = [
|
||||
"get_event_storage",
|
||||
"JSONWriter",
|
||||
"TensorboardXWriter",
|
||||
"CommonMetricPrinter",
|
||||
"EventStorage",
|
||||
]
|
||||
|
||||
_CURRENT_STORAGE_STACK = []
|
||||
|
||||
|
||||
def get_event_storage():
|
||||
"""
|
||||
Returns:
|
||||
The :class:`EventStorage` object that's currently being used.
|
||||
Throws an error if no :class:`EventStorage` is currently enabled.
|
||||
"""
|
||||
assert len(
|
||||
_CURRENT_STORAGE_STACK
|
||||
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
|
||||
return _CURRENT_STORAGE_STACK[-1]
|
||||
|
||||
|
||||
class EventWriter:
|
||||
"""
|
||||
Base class for writers that obtain events from :class:`EventStorage` and process them.
|
||||
"""
|
||||
|
||||
def write(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class JSONWriter(EventWriter):
|
||||
"""
|
||||
Write scalars to a json file.
|
||||
It saves scalars as one json per line (instead of a big json) for easy parsing.
|
||||
Examples parsing such a json file:
|
||||
::
|
||||
$ cat metrics.json | jq -s '.[0:2]'
|
||||
[
|
||||
{
|
||||
"data_time": 0.008433341979980469,
|
||||
"iteration": 19,
|
||||
"loss": 1.9228371381759644,
|
||||
"loss_box_reg": 0.050025828182697296,
|
||||
"loss_classifier": 0.5316952466964722,
|
||||
"loss_mask": 0.7236229181289673,
|
||||
"loss_rpn_box": 0.0856662318110466,
|
||||
"loss_rpn_cls": 0.48198649287223816,
|
||||
"lr": 0.007173333333333333,
|
||||
"time": 0.25401854515075684
|
||||
},
|
||||
{
|
||||
"data_time": 0.007216215133666992,
|
||||
"iteration": 39,
|
||||
"loss": 1.282649278640747,
|
||||
"loss_box_reg": 0.06222952902317047,
|
||||
"loss_classifier": 0.30682939291000366,
|
||||
"loss_mask": 0.6970193982124329,
|
||||
"loss_rpn_box": 0.038663312792778015,
|
||||
"loss_rpn_cls": 0.1471673548221588,
|
||||
"lr": 0.007706666666666667,
|
||||
"time": 0.2490077018737793
|
||||
}
|
||||
]
|
||||
$ cat metrics.json | jq '.loss_mask'
|
||||
0.7126231789588928
|
||||
0.689423680305481
|
||||
0.6776131987571716
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, json_file, window_size=20):
|
||||
"""
|
||||
Args:
|
||||
json_file (str): path to the json file. New data will be appended if the file exists.
|
||||
window_size (int): the window size of median smoothing for the scalars whose
|
||||
`smoothing_hint` are True.
|
||||
"""
|
||||
self._file_handle = open(json_file, "a")
|
||||
self._window_size = window_size
|
||||
self._last_write = -1
|
||||
|
||||
def write(self):
|
||||
storage = get_event_storage()
|
||||
to_save = defaultdict(dict)
|
||||
|
||||
for k, (v, iter) in storage.latest_with_smoothing_hint(
|
||||
self._window_size
|
||||
).items():
|
||||
# keep scalars that have not been written
|
||||
if iter <= self._last_write:
|
||||
continue
|
||||
to_save[iter][k] = v
|
||||
if len(to_save):
|
||||
all_iters = sorted(to_save.keys())
|
||||
self._last_write = max(all_iters)
|
||||
|
||||
for itr, scalars_per_iter in to_save.items():
|
||||
scalars_per_iter["iteration"] = itr
|
||||
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
|
||||
self._file_handle.flush()
|
||||
try:
|
||||
os.fsync(self._file_handle.fileno())
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self._file_handle.close()
|
||||
|
||||
|
||||
class TensorboardXWriter(EventWriter):
|
||||
"""
|
||||
Write all scalars to a tensorboard file.
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
log_dir (str): the directory to save the output events
|
||||
window_size (int): the scalars will be median-smoothed by this window size
|
||||
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
|
||||
"""
|
||||
self._window_size = window_size
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
self._writer = SummaryWriter(log_dir, **kwargs)
|
||||
self._last_write = -1
|
||||
|
||||
def write(self):
|
||||
storage = get_event_storage()
|
||||
new_last_write = self._last_write
|
||||
for k, (v, iter) in storage.latest_with_smoothing_hint(
|
||||
self._window_size
|
||||
).items():
|
||||
if iter > self._last_write:
|
||||
self._writer.add_scalar(k, v, iter)
|
||||
new_last_write = max(new_last_write, iter)
|
||||
self._last_write = new_last_write
|
||||
|
||||
# storage.put_{image,histogram} is only meant to be used by
|
||||
# tensorboard writer. So we access its internal fields directly from here.
|
||||
if len(storage._vis_data) >= 1:
|
||||
for img_name, img, step_num in storage._vis_data:
|
||||
self._writer.add_image(img_name, img, step_num)
|
||||
# Storage stores all image data and rely on this writer to clear them.
|
||||
# As a result it assumes only one writer will use its image data.
|
||||
# An alternative design is to let storage store limited recent
|
||||
# data (e.g. only the most recent image) that all writers can access.
|
||||
# In that case a writer may not see all image data if its period is long.
|
||||
storage.clear_images()
|
||||
|
||||
if len(storage._histograms) >= 1:
|
||||
for params in storage._histograms:
|
||||
self._writer.add_histogram_raw(**params)
|
||||
storage.clear_histograms()
|
||||
|
||||
def close(self):
|
||||
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
|
||||
self._writer.close()
|
||||
|
||||
|
||||
class CommonMetricPrinter(EventWriter):
|
||||
"""
|
||||
Print **common** metrics to the terminal, including
|
||||
iteration time, ETA, memory, all losses, and the learning rate.
|
||||
It also applies smoothing using a window of 20 elements.
|
||||
It's meant to print common metrics in common ways.
|
||||
To print something in more customized ways, please implement a similar printer by yourself.
|
||||
"""
|
||||
|
||||
def __init__(self, max_iter: Optional[int] = None, window_size: int = 20):
|
||||
"""
|
||||
Args:
|
||||
max_iter: the maximum number of iterations to train.
|
||||
Used to compute ETA. If not given, ETA will not be printed.
|
||||
window_size (int): the losses will be median-smoothed by this window size
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._max_iter = max_iter
|
||||
self._window_size = window_size
|
||||
self._last_write = (
|
||||
None # (step, time) of last call to write(). Used to compute ETA
|
||||
)
|
||||
|
||||
def _get_eta(self, storage) -> Optional[str]:
|
||||
if self._max_iter is None:
|
||||
return ""
|
||||
iteration = storage.iter
|
||||
try:
|
||||
eta_seconds = storage.history("time").median(1000) * (
|
||||
self._max_iter - iteration - 1
|
||||
)
|
||||
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
|
||||
return str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
except KeyError:
|
||||
# estimate eta on our own - more noisy
|
||||
eta_string = None
|
||||
if self._last_write is not None:
|
||||
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
|
||||
iteration - self._last_write[0]
|
||||
)
|
||||
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
self._last_write = (iteration, time.perf_counter())
|
||||
return eta_string
|
||||
|
||||
def write(self):
|
||||
storage = get_event_storage()
|
||||
iteration = storage.iter
|
||||
if iteration == self._max_iter:
|
||||
# This hook only reports training progress (loss, ETA, etc) but not other data,
|
||||
# therefore do not write anything after training succeeds, even if this method
|
||||
# is called.
|
||||
return
|
||||
|
||||
try:
|
||||
data_time = storage.history("data_time").avg(20)
|
||||
except KeyError:
|
||||
# they may not exist in the first few iterations (due to warmup)
|
||||
# or when SimpleTrainer is not used
|
||||
data_time = None
|
||||
try:
|
||||
iter_time = storage.history("time").global_avg()
|
||||
except KeyError:
|
||||
iter_time = None
|
||||
try:
|
||||
lr = "{:.5g}".format(storage.history("lr").latest())
|
||||
except KeyError:
|
||||
lr = "N/A"
|
||||
|
||||
eta_string = self._get_eta(storage)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
||||
else:
|
||||
max_mem_mb = None
|
||||
|
||||
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
||||
self.logger.info(
|
||||
" {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
|
||||
eta=f"eta: {eta_string} " if eta_string else "",
|
||||
iter=iteration,
|
||||
losses=" ".join(
|
||||
[
|
||||
"{}: {:.4g}".format(k, v.median(self._window_size))
|
||||
for k, v in storage.histories().items()
|
||||
if "loss" in k
|
||||
]
|
||||
),
|
||||
time="time: {:.4f} ".format(iter_time)
|
||||
if iter_time is not None
|
||||
else "",
|
||||
data_time="data_time: {:.4f} ".format(data_time)
|
||||
if data_time is not None
|
||||
else "",
|
||||
lr=lr,
|
||||
memory="max_mem: {:.0f}M".format(max_mem_mb)
|
||||
if max_mem_mb is not None
|
||||
else "",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class EventStorage:
|
||||
"""
|
||||
The user-facing class that provides metric storage functionalities.
|
||||
In the future we may add support for storing / logging other types of data if needed.
|
||||
"""
|
||||
|
||||
def __init__(self, start_iter=0):
|
||||
"""
|
||||
Args:
|
||||
start_iter (int): the iteration number to start with
|
||||
"""
|
||||
self._history = defaultdict(AverageMeter)
|
||||
self._smoothing_hints = {}
|
||||
self._latest_scalars = {}
|
||||
self._iter = start_iter
|
||||
self._current_prefix = ""
|
||||
self._vis_data = []
|
||||
self._histograms = []
|
||||
|
||||
# def put_image(self, img_name, img_tensor):
|
||||
# """
|
||||
# Add an `img_tensor` associated with `img_name`, to be shown on
|
||||
# tensorboard.
|
||||
# Args:
|
||||
# img_name (str): The name of the image to put into tensorboard.
|
||||
# img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
|
||||
# Tensor of shape `[channel, height, width]` where `channel` is
|
||||
# 3. The image format should be RGB. The elements in img_tensor
|
||||
# can either have values in [0, 1] (float32) or [0, 255] (uint8).
|
||||
# The `img_tensor` will be visualized in tensorboard.
|
||||
# """
|
||||
# self._vis_data.append((img_name, img_tensor, self._iter))
|
||||
|
||||
def put_scalar(self, name, value, n=1, smoothing_hint=False):
|
||||
"""
|
||||
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
|
||||
Args:
|
||||
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
|
||||
smoothed when logged. The hint will be accessible through
|
||||
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
|
||||
and apply custom smoothing rule.
|
||||
It defaults to True because most scalars we save need to be smoothed to
|
||||
provide any useful signal.
|
||||
"""
|
||||
name = self._current_prefix + name
|
||||
history = self._history[name]
|
||||
history.update(value, n)
|
||||
self._latest_scalars[name] = (value, self._iter)
|
||||
|
||||
existing_hint = self._smoothing_hints.get(name)
|
||||
if existing_hint is not None:
|
||||
assert (
|
||||
existing_hint == smoothing_hint
|
||||
), "Scalar {} was put with a different smoothing_hint!".format(name)
|
||||
else:
|
||||
self._smoothing_hints[name] = smoothing_hint
|
||||
|
||||
# def put_scalars(self, *, smoothing_hint=True, **kwargs):
|
||||
# """
|
||||
# Put multiple scalars from keyword arguments.
|
||||
# Examples:
|
||||
# storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
|
||||
# """
|
||||
# for k, v in kwargs.items():
|
||||
# self.put_scalar(k, v, smoothing_hint=smoothing_hint)
|
||||
#
|
||||
# def put_histogram(self, hist_name, hist_tensor, bins=1000):
|
||||
# """
|
||||
# Create a histogram from a tensor.
|
||||
# Args:
|
||||
# hist_name (str): The name of the histogram to put into tensorboard.
|
||||
# hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
|
||||
# into a histogram.
|
||||
# bins (int): Number of histogram bins.
|
||||
# """
|
||||
# ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
|
||||
#
|
||||
# # Create a histogram with PyTorch
|
||||
# hist_counts = torch.histc(hist_tensor, bins=bins)
|
||||
# hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
|
||||
#
|
||||
# # Parameter for the add_histogram_raw function of SummaryWriter
|
||||
# hist_params = dict(
|
||||
# tag=hist_name,
|
||||
# min=ht_min,
|
||||
# max=ht_max,
|
||||
# num=len(hist_tensor),
|
||||
# sum=float(hist_tensor.sum()),
|
||||
# sum_squares=float(torch.sum(hist_tensor**2)),
|
||||
# bucket_limits=hist_edges[1:].tolist(),
|
||||
# bucket_counts=hist_counts.tolist(),
|
||||
# global_step=self._iter,
|
||||
# )
|
||||
# self._histograms.append(hist_params)
|
||||
|
||||
def history(self, name):
|
||||
"""
|
||||
Returns:
|
||||
AverageMeter: the history for name
|
||||
"""
|
||||
ret = self._history.get(name, None)
|
||||
if ret is None:
|
||||
raise KeyError("No history metric available for {}!".format(name))
|
||||
return ret
|
||||
|
||||
def histories(self):
|
||||
"""
|
||||
Returns:
|
||||
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
|
||||
"""
|
||||
return self._history
|
||||
|
||||
def latest(self):
|
||||
"""
|
||||
Returns:
|
||||
dict[str -> (float, int)]: mapping from the name of each scalar to the most
|
||||
recent value and the iteration number its added.
|
||||
"""
|
||||
return self._latest_scalars
|
||||
|
||||
def latest_with_smoothing_hint(self, window_size=20):
|
||||
"""
|
||||
Similar to :meth:`latest`, but the returned values
|
||||
are either the un-smoothed original latest value,
|
||||
or a median of the given window_size,
|
||||
depend on whether the smoothing_hint is True.
|
||||
This provides a default behavior that other writers can use.
|
||||
"""
|
||||
result = {}
|
||||
for k, (v, itr) in self._latest_scalars.items():
|
||||
result[k] = (
|
||||
self._history[k].median(window_size) if self._smoothing_hints[k] else v,
|
||||
itr,
|
||||
)
|
||||
return result
|
||||
|
||||
def smoothing_hints(self):
|
||||
"""
|
||||
Returns:
|
||||
dict[name -> bool]: the user-provided hint on whether the scalar
|
||||
is noisy and needs smoothing.
|
||||
"""
|
||||
return self._smoothing_hints
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
User should either: (1) Call this function to increment storage.iter when needed. Or
|
||||
(2) Set `storage.iter` to the correct iteration number before each iteration.
|
||||
The storage will then be able to associate the new data with an iteration number.
|
||||
"""
|
||||
self._iter += 1
|
||||
|
||||
@property
|
||||
def iter(self):
|
||||
"""
|
||||
Returns:
|
||||
int: The current iteration number. When used together with a trainer,
|
||||
this is ensured to be the same as trainer.iter.
|
||||
"""
|
||||
return self._iter
|
||||
|
||||
@iter.setter
|
||||
def iter(self, val):
|
||||
self._iter = int(val)
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
# for backward compatibility
|
||||
return self._iter
|
||||
|
||||
def __enter__(self):
|
||||
_CURRENT_STORAGE_STACK.append(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
assert _CURRENT_STORAGE_STACK[-1] == self
|
||||
_CURRENT_STORAGE_STACK.pop()
|
||||
|
||||
@contextmanager
|
||||
def name_scope(self, name):
|
||||
"""
|
||||
Yields:
|
||||
A context within which all the events added to this storage
|
||||
will be prefixed by the name scope.
|
||||
"""
|
||||
old_prefix = self._current_prefix
|
||||
self._current_prefix = name.rstrip("/") + "/"
|
||||
yield
|
||||
self._current_prefix = old_prefix
|
||||
|
||||
def clear_images(self):
|
||||
"""
|
||||
Delete all the stored images for visualization. This should be called
|
||||
after images are written to tensorboard.
|
||||
"""
|
||||
self._vis_data = []
|
||||
|
||||
def clear_histograms(self):
|
||||
"""
|
||||
Delete all the stored histograms for visualization.
|
||||
This should be called after histograms are written to tensorboard.
|
||||
"""
|
||||
self._histograms = []
|
||||
|
||||
def reset_history(self, name):
|
||||
ret = self._history.get(name, None)
|
||||
if ret is None:
|
||||
raise KeyError("No history metric available for {}!".format(name))
|
||||
ret.reset()
|
||||
|
||||
def reset_histories(self):
|
||||
for name in self._history.keys():
|
||||
self._history[name].reset()
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.total = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.total = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.total += val * n
|
||||
self.count += n
|
||||
self.avg = self.total / self.count
|
||||
|
||||
|
||||
class HistoryBuffer:
|
||||
"""
|
||||
Track a series of scalar values and provide access to smoothed values over a
|
||||
window or the global average of the series.
|
||||
"""
|
||||
|
||||
def __init__(self, max_length: int = 1000000) -> None:
|
||||
"""
|
||||
Args:
|
||||
max_length: maximal number of values that can be stored in the
|
||||
buffer. When the capacity of the buffer is exhausted, old
|
||||
values will be removed.
|
||||
"""
|
||||
self._max_length: int = max_length
|
||||
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
|
||||
self._count: int = 0
|
||||
self._global_avg: float = 0
|
||||
|
||||
def update(self, value: float, iteration: Optional[float] = None) -> None:
|
||||
"""
|
||||
Add a new scalar value produced at certain iteration. If the length
|
||||
of the buffer exceeds self._max_length, the oldest element will be
|
||||
removed from the buffer.
|
||||
"""
|
||||
if iteration is None:
|
||||
iteration = self._count
|
||||
if len(self._data) == self._max_length:
|
||||
self._data.pop(0)
|
||||
self._data.append((value, iteration))
|
||||
|
||||
self._count += 1
|
||||
self._global_avg += (value - self._global_avg) / self._count
|
||||
|
||||
def latest(self) -> float:
|
||||
"""
|
||||
Return the latest scalar value added to the buffer.
|
||||
"""
|
||||
return self._data[-1][0]
|
||||
|
||||
def median(self, window_size: int) -> float:
|
||||
"""
|
||||
Return the median of the latest `window_size` values in the buffer.
|
||||
"""
|
||||
return np.median([x[0] for x in self._data[-window_size:]])
|
||||
|
||||
def avg(self, window_size: int) -> float:
|
||||
"""
|
||||
Return the mean of the latest `window_size` values in the buffer.
|
||||
"""
|
||||
return np.mean([x[0] for x in self._data[-window_size:]])
|
||||
|
||||
def global_avg(self) -> float:
|
||||
"""
|
||||
Return the mean of all the elements in the buffer. Note that this
|
||||
includes those getting removed due to limited buffer storage.
|
||||
"""
|
||||
return self._global_avg
|
||||
|
||||
def values(self) -> List[Tuple[float, float]]:
|
||||
"""
|
||||
Returns:
|
||||
list[(number, iteration)]: content of the current buffer.
|
||||
"""
|
||||
return self._data
|
||||
167
utils/logger.py
Normal file
167
utils/logger.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
logger_initialized = {}
|
||||
root_status = 0
|
||||
|
||||
|
||||
class _ColorfulFormatter(logging.Formatter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._root_name = kwargs.pop("root_name") + "."
|
||||
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
||||
|
||||
def formatMessage(self, record):
|
||||
log = super(_ColorfulFormatter, self).formatMessage(record)
|
||||
if record.levelno == logging.WARNING:
|
||||
prefix = colored("WARNING", "red", attrs=["blink"])
|
||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
||||
else:
|
||||
return log
|
||||
return prefix + " " + log
|
||||
|
||||
|
||||
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False):
|
||||
"""Initialize and get a logger by name.
|
||||
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
||||
will also be added.
|
||||
|
||||
Args:
|
||||
name (str): Logger name.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the logger.
|
||||
log_level (int): The logger level. Note that only the process of
|
||||
rank 0 is affected, and other processes will set the level to
|
||||
"Error" thus be silent most of the time.
|
||||
file_mode (str): The file mode used in opening log file.
|
||||
Defaults to 'a'.
|
||||
color (bool): Colorful log output. Defaults to True
|
||||
|
||||
Returns:
|
||||
logging.Logger: The expected logger.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
if name in logger_initialized:
|
||||
return logger
|
||||
# handle hierarchical names
|
||||
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
||||
# initialization since it is a child of "a".
|
||||
for logger_name in logger_initialized:
|
||||
if name.startswith(logger_name):
|
||||
return logger
|
||||
|
||||
logger.propagate = False
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
handlers = [stream_handler]
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
# only rank 0 will add a FileHandler
|
||||
if rank == 0 and log_file is not None:
|
||||
# Here, the default behaviour of the official logger is 'a'. Thus, we
|
||||
# provide an interface to change the file mode to the default
|
||||
# behaviour.
|
||||
file_handler = logging.FileHandler(log_file, file_mode)
|
||||
handlers.append(file_handler)
|
||||
|
||||
plain_formatter = logging.Formatter(
|
||||
"[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
|
||||
)
|
||||
if color:
|
||||
formatter = _ColorfulFormatter(
|
||||
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
||||
datefmt="%m/%d %H:%M:%S",
|
||||
root_name=name,
|
||||
)
|
||||
else:
|
||||
formatter = plain_formatter
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
handler.setLevel(log_level)
|
||||
logger.addHandler(handler)
|
||||
|
||||
if rank == 0:
|
||||
logger.setLevel(log_level)
|
||||
else:
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
logger_initialized[name] = True
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def print_log(msg, logger=None, level=logging.INFO):
|
||||
"""Print a log message.
|
||||
|
||||
Args:
|
||||
msg (str): The message to be logged.
|
||||
logger (logging.Logger | str | None): The logger to be used.
|
||||
Some special loggers are:
|
||||
- "silent": no message will be printed.
|
||||
- other str: the logger obtained with `get_root_logger(logger)`.
|
||||
- None: The `print()` method will be used to print log messages.
|
||||
level (int): Logging level. Only available when `logger` is a Logger
|
||||
object or "root".
|
||||
"""
|
||||
if logger is None:
|
||||
print(msg)
|
||||
elif isinstance(logger, logging.Logger):
|
||||
logger.log(level, msg)
|
||||
elif logger == "silent":
|
||||
pass
|
||||
elif isinstance(logger, str):
|
||||
_logger = get_logger(logger)
|
||||
_logger.log(level, msg)
|
||||
else:
|
||||
raise TypeError(
|
||||
"logger should be either a logging.Logger object, str, "
|
||||
f'"silent" or None, but got {type(logger)}'
|
||||
)
|
||||
|
||||
|
||||
def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"):
|
||||
"""Get the root logger.
|
||||
|
||||
The logger will be initialized if it has not been initialized. By default a
|
||||
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
||||
also be added. The name of the root logger is the top-level package name.
|
||||
|
||||
Args:
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the root logger.
|
||||
log_level (int): The root logger level. Note that only the process of
|
||||
rank 0 is affected, while other processes will set the level to
|
||||
"Error" and be silent most of the time.
|
||||
file_mode (str): File Mode of logger. (w or a)
|
||||
|
||||
Returns:
|
||||
logging.Logger: The root logger.
|
||||
"""
|
||||
logger = get_logger(
|
||||
name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode
|
||||
)
|
||||
return logger
|
||||
|
||||
|
||||
def _log_api_usage(identifier: str):
|
||||
"""
|
||||
Internal function used to log the usage of different detectron2 components
|
||||
inside facebook's infra.
|
||||
"""
|
||||
torch._C._log_api_usage_once("pointcept." + identifier)
|
||||
156
utils/misc.py
Normal file
156
utils/misc.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections import abc
|
||||
import numpy as np
|
||||
import torch
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def intersection_and_union(output, target, K, ignore_index=-1):
|
||||
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
|
||||
assert output.ndim in [1, 2, 3]
|
||||
assert output.shape == target.shape
|
||||
output = output.reshape(output.size).copy()
|
||||
target = target.reshape(target.size)
|
||||
output[np.where(target == ignore_index)[0]] = ignore_index
|
||||
intersection = output[np.where(output == target)[0]]
|
||||
area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
|
||||
area_output, _ = np.histogram(output, bins=np.arange(K + 1))
|
||||
area_target, _ = np.histogram(target, bins=np.arange(K + 1))
|
||||
area_union = area_output + area_target - area_intersection
|
||||
return area_intersection, area_union, area_target
|
||||
|
||||
|
||||
def intersection_and_union_gpu(output, target, k, ignore_index=-1):
|
||||
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
|
||||
assert output.dim() in [1, 2, 3]
|
||||
assert output.shape == target.shape
|
||||
output = output.view(-1)
|
||||
target = target.view(-1)
|
||||
output[target == ignore_index] = ignore_index
|
||||
intersection = output[output == target]
|
||||
area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1)
|
||||
area_output = torch.histc(output, bins=k, min=0, max=k - 1)
|
||||
area_target = torch.histc(target, bins=k, min=0, max=k - 1)
|
||||
area_union = area_output + area_target - area_intersection
|
||||
return area_intersection, area_union, area_target
|
||||
|
||||
|
||||
def make_dirs(dir_name):
|
||||
if not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
|
||||
def find_free_port():
|
||||
import socket
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# Binding to port 0 will cause the OS to find an available port for us
|
||||
sock.bind(("", 0))
|
||||
port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
# NOTE: there is still a chance the port could be taken by other processes.
|
||||
return port
|
||||
|
||||
|
||||
def is_seq_of(seq, expected_type, seq_type=None):
|
||||
"""Check whether it is a sequence of some type.
|
||||
|
||||
Args:
|
||||
seq (Sequence): The sequence to be checked.
|
||||
expected_type (type): Expected type of sequence items.
|
||||
seq_type (type, optional): Expected sequence type.
|
||||
|
||||
Returns:
|
||||
bool: Whether the sequence is valid.
|
||||
"""
|
||||
if seq_type is None:
|
||||
exp_seq_type = abc.Sequence
|
||||
else:
|
||||
assert isinstance(seq_type, type)
|
||||
exp_seq_type = seq_type
|
||||
if not isinstance(seq, exp_seq_type):
|
||||
return False
|
||||
for item in seq:
|
||||
if not isinstance(item, expected_type):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_str(x):
|
||||
"""Whether the input is an string instance.
|
||||
|
||||
Note: This method is deprecated since python 2 is no longer supported.
|
||||
"""
|
||||
return isinstance(x, str)
|
||||
|
||||
|
||||
def import_modules_from_strings(imports, allow_failed_imports=False):
|
||||
"""Import modules from the given list of strings.
|
||||
|
||||
Args:
|
||||
imports (list | str | None): The given module names to be imported.
|
||||
allow_failed_imports (bool): If True, the failed imports will return
|
||||
None. Otherwise, an ImportError is raise. Default: False.
|
||||
|
||||
Returns:
|
||||
list[module] | module | None: The imported modules.
|
||||
|
||||
Examples:
|
||||
>>> osp, sys = import_modules_from_strings(
|
||||
... ['os.path', 'sys'])
|
||||
>>> import os.path as osp_
|
||||
>>> import sys as sys_
|
||||
>>> assert osp == osp_
|
||||
>>> assert sys == sys_
|
||||
"""
|
||||
if not imports:
|
||||
return
|
||||
single_import = False
|
||||
if isinstance(imports, str):
|
||||
single_import = True
|
||||
imports = [imports]
|
||||
if not isinstance(imports, list):
|
||||
raise TypeError(f"custom_imports must be a list but got type {type(imports)}")
|
||||
imported = []
|
||||
for imp in imports:
|
||||
if not isinstance(imp, str):
|
||||
raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.")
|
||||
try:
|
||||
imported_tmp = import_module(imp)
|
||||
except ImportError:
|
||||
if allow_failed_imports:
|
||||
warnings.warn(f"{imp} failed to import and is ignored.", UserWarning)
|
||||
imported_tmp = None
|
||||
else:
|
||||
raise ImportError
|
||||
imported.append(imported_tmp)
|
||||
if single_import:
|
||||
imported = imported[0]
|
||||
return imported
|
||||
52
utils/optimizer.py
Normal file
52
utils/optimizer.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import torch
|
||||
from utils.logger import get_root_logger
|
||||
from utils.registry import Registry
|
||||
|
||||
OPTIMIZERS = Registry("optimizers")
|
||||
|
||||
|
||||
OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD")
|
||||
OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam")
|
||||
OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW")
|
||||
|
||||
|
||||
def build_optimizer(cfg, model, param_dicts=None):
|
||||
if param_dicts is None:
|
||||
cfg.params = model.parameters()
|
||||
else:
|
||||
cfg.params = [dict(names=[], params=[], lr=cfg.lr)]
|
||||
for i in range(len(param_dicts)):
|
||||
param_group = dict(names=[], params=[])
|
||||
if "lr" in param_dicts[i].keys():
|
||||
param_group["lr"] = param_dicts[i].lr
|
||||
if "momentum" in param_dicts[i].keys():
|
||||
param_group["momentum"] = param_dicts[i].momentum
|
||||
if "weight_decay" in param_dicts[i].keys():
|
||||
param_group["weight_decay"] = param_dicts[i].weight_decay
|
||||
cfg.params.append(param_group)
|
||||
|
||||
for n, p in model.named_parameters():
|
||||
flag = False
|
||||
for i in range(len(param_dicts)):
|
||||
if param_dicts[i].keyword in n:
|
||||
cfg.params[i + 1]["names"].append(n)
|
||||
cfg.params[i + 1]["params"].append(p)
|
||||
flag = True
|
||||
break
|
||||
if not flag:
|
||||
cfg.params[0]["names"].append(n)
|
||||
cfg.params[0]["params"].append(p)
|
||||
|
||||
logger = get_root_logger()
|
||||
for i in range(len(cfg.params)):
|
||||
param_names = cfg.params[i].pop("names")
|
||||
message = ""
|
||||
for key in cfg.params[i].keys():
|
||||
if key != "params":
|
||||
message += f" {key}: {cfg.params[i][key]};"
|
||||
logger.info(f"Params Group {i+1} -{message} Params: {param_names}.")
|
||||
return OPTIMIZERS.build(cfg=cfg)
|
||||
105
utils/path.py
Normal file
105
utils/path.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
|
||||
from .misc import is_str
|
||||
|
||||
|
||||
def is_filepath(x):
|
||||
return is_str(x) or isinstance(x, Path)
|
||||
|
||||
|
||||
def fopen(filepath, *args, **kwargs):
|
||||
if is_str(filepath):
|
||||
return open(filepath, *args, **kwargs)
|
||||
elif isinstance(filepath, Path):
|
||||
return filepath.open(*args, **kwargs)
|
||||
raise ValueError("`filepath` should be a string or a Path")
|
||||
|
||||
|
||||
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
||||
if not osp.isfile(filename):
|
||||
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||
|
||||
|
||||
def mkdir_or_exist(dir_name, mode=0o777):
|
||||
if dir_name == "":
|
||||
return
|
||||
dir_name = osp.expanduser(dir_name)
|
||||
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
||||
|
||||
|
||||
def symlink(src, dst, overwrite=True, **kwargs):
|
||||
if os.path.lexists(dst) and overwrite:
|
||||
os.remove(dst)
|
||||
os.symlink(src, dst, **kwargs)
|
||||
|
||||
|
||||
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
|
||||
"""Scan a directory to find the interested files.
|
||||
|
||||
Args:
|
||||
dir_path (str | obj:`Path`): Path of the directory.
|
||||
suffix (str | tuple(str), optional): File suffix that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
case_sensitive (bool, optional) : If set to False, ignore the case of
|
||||
suffix. Default: True.
|
||||
|
||||
Returns:
|
||||
A generator for all the interested files with relative paths.
|
||||
"""
|
||||
if isinstance(dir_path, (str, Path)):
|
||||
dir_path = str(dir_path)
|
||||
else:
|
||||
raise TypeError('"dir_path" must be a string or Path object')
|
||||
|
||||
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||
|
||||
if suffix is not None and not case_sensitive:
|
||||
suffix = (
|
||||
suffix.lower()
|
||||
if isinstance(suffix, str)
|
||||
else tuple(item.lower() for item in suffix)
|
||||
)
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, suffix, recursive, case_sensitive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith(".") and entry.is_file():
|
||||
rel_path = osp.relpath(entry.path, root)
|
||||
_rel_path = rel_path if case_sensitive else rel_path.lower()
|
||||
if suffix is None or _rel_path.endswith(suffix):
|
||||
yield rel_path
|
||||
elif recursive and os.path.isdir(entry.path):
|
||||
# scan recursively if entry.path is a directory
|
||||
yield from _scandir(entry.path, suffix, recursive, case_sensitive)
|
||||
|
||||
return _scandir(dir_path, suffix, recursive, case_sensitive)
|
||||
|
||||
|
||||
def find_vcs_root(path, markers=(".git",)):
|
||||
"""Finds the root directory (including itself) of specified markers.
|
||||
|
||||
Args:
|
||||
path (str): Path of directory or file.
|
||||
markers (list[str], optional): List of file or directory names.
|
||||
|
||||
Returns:
|
||||
The directory contained one of the markers or None if not found.
|
||||
"""
|
||||
if osp.isfile(path):
|
||||
path = osp.dirname(path)
|
||||
|
||||
prev, cur = None, osp.abspath(osp.expanduser(path))
|
||||
while cur != prev:
|
||||
if any(osp.exists(osp.join(cur, marker)) for marker in markers):
|
||||
return cur
|
||||
prev, cur = cur, osp.split(cur)[0]
|
||||
return None
|
||||
318
utils/registry.py
Normal file
318
utils/registry.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
import inspect
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
from .misc import is_seq_of
|
||||
|
||||
|
||||
def build_from_cfg(cfg, registry, default_args=None):
|
||||
"""Build a module from configs dict.
|
||||
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
registry (:obj:`Registry`): The registry to search the type from.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
|
||||
Returns:
|
||||
object: The constructed object.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
|
||||
if "type" not in cfg:
|
||||
if default_args is None or "type" not in default_args:
|
||||
raise KeyError(
|
||||
'`cfg` or `default_args` must contain the key "type", '
|
||||
f"but got {cfg}\n{default_args}"
|
||||
)
|
||||
if not isinstance(registry, Registry):
|
||||
raise TypeError(
|
||||
"registry must be an mmcv.Registry object, " f"but got {type(registry)}"
|
||||
)
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError(
|
||||
"default_args must be a dict or None, " f"but got {type(default_args)}"
|
||||
)
|
||||
|
||||
args = cfg.copy()
|
||||
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
|
||||
obj_type = args.pop("type")
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = registry.get(obj_type)
|
||||
if obj_cls is None:
|
||||
raise KeyError(f"{obj_type} is not in the {registry.name} registry")
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
|
||||
try:
|
||||
return obj_cls(**args)
|
||||
except Exception as e:
|
||||
# Normal TypeError does not print class name.
|
||||
raise type(e)(f"{obj_cls.__name__}: {e}")
|
||||
|
||||
|
||||
class Registry:
|
||||
"""A registry to map strings to classes.
|
||||
|
||||
Registered object could be built from registry.
|
||||
Example:
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = MODELS.build(dict(type='ResNet'))
|
||||
|
||||
Please refer to
|
||||
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
||||
advanced usage.
|
||||
|
||||
Args:
|
||||
name (str): Registry name.
|
||||
build_func(func, optional): Build function to construct instance from
|
||||
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
||||
``build_func`` is specified. If ``parent`` is specified and
|
||||
``build_func`` is not given, ``build_func`` will be inherited
|
||||
from ``parent``. Default: None.
|
||||
parent (Registry, optional): Parent registry. The class registered in
|
||||
children registry could be built from parent. Default: None.
|
||||
scope (str, optional): The scope of registry. It is the key to search
|
||||
for children registry. If not specified, scope will be the name of
|
||||
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, name, build_func=None, parent=None, scope=None):
|
||||
self._name = name
|
||||
self._module_dict = dict()
|
||||
self._children = dict()
|
||||
self._scope = self.infer_scope() if scope is None else scope
|
||||
|
||||
# self.build_func will be set with the following priority:
|
||||
# 1. build_func
|
||||
# 2. parent.build_func
|
||||
# 3. build_from_cfg
|
||||
if build_func is None:
|
||||
if parent is not None:
|
||||
self.build_func = parent.build_func
|
||||
else:
|
||||
self.build_func = build_from_cfg
|
||||
else:
|
||||
self.build_func = build_func
|
||||
if parent is not None:
|
||||
assert isinstance(parent, Registry)
|
||||
parent._add_children(self)
|
||||
self.parent = parent
|
||||
else:
|
||||
self.parent = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self._module_dict)
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.get(key) is not None
|
||||
|
||||
def __repr__(self):
|
||||
format_str = (
|
||||
self.__class__.__name__ + f"(name={self._name}, "
|
||||
f"items={self._module_dict})"
|
||||
)
|
||||
return format_str
|
||||
|
||||
@staticmethod
|
||||
def infer_scope():
|
||||
"""Infer the scope of registry.
|
||||
|
||||
The name of the package where registry is defined will be returned.
|
||||
|
||||
Example:
|
||||
# in mmdet/models/backbone/resnet.py
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
The scope of ``ResNet`` will be ``mmdet``.
|
||||
|
||||
|
||||
Returns:
|
||||
scope (str): The inferred scope name.
|
||||
"""
|
||||
# inspect.stack() trace where this function is called, the index-2
|
||||
# indicates the frame where `infer_scope()` is called
|
||||
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
|
||||
split_filename = filename.split(".")
|
||||
return split_filename[0]
|
||||
|
||||
@staticmethod
|
||||
def split_scope_key(key):
|
||||
"""Split scope and key.
|
||||
|
||||
The first scope will be split from key.
|
||||
|
||||
Examples:
|
||||
>>> Registry.split_scope_key('mmdet.ResNet')
|
||||
'mmdet', 'ResNet'
|
||||
>>> Registry.split_scope_key('ResNet')
|
||||
None, 'ResNet'
|
||||
|
||||
Return:
|
||||
scope (str, None): The first scope.
|
||||
key (str): The remaining key.
|
||||
"""
|
||||
split_index = key.find(".")
|
||||
if split_index != -1:
|
||||
return key[:split_index], key[split_index + 1 :]
|
||||
else:
|
||||
return None, key
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def scope(self):
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def module_dict(self):
|
||||
return self._module_dict
|
||||
|
||||
@property
|
||||
def children(self):
|
||||
return self._children
|
||||
|
||||
def get(self, key):
|
||||
"""Get the registry record.
|
||||
|
||||
Args:
|
||||
key (str): The class name in string format.
|
||||
|
||||
Returns:
|
||||
class: The corresponding class.
|
||||
"""
|
||||
scope, real_key = self.split_scope_key(key)
|
||||
if scope is None or scope == self._scope:
|
||||
# get from self
|
||||
if real_key in self._module_dict:
|
||||
return self._module_dict[real_key]
|
||||
else:
|
||||
# get from self._children
|
||||
if scope in self._children:
|
||||
return self._children[scope].get(real_key)
|
||||
else:
|
||||
# goto root
|
||||
parent = self.parent
|
||||
while parent.parent is not None:
|
||||
parent = parent.parent
|
||||
return parent.get(key)
|
||||
|
||||
def build(self, *args, **kwargs):
|
||||
return self.build_func(*args, **kwargs, registry=self)
|
||||
|
||||
def _add_children(self, registry):
|
||||
"""Add children for a registry.
|
||||
|
||||
The ``registry`` will be added as children based on its scope.
|
||||
The parent registry could build objects from children registry.
|
||||
|
||||
Example:
|
||||
>>> models = Registry('models')
|
||||
>>> mmdet_models = Registry('models', parent=models)
|
||||
>>> @mmdet_models.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
||||
"""
|
||||
|
||||
assert isinstance(registry, Registry)
|
||||
assert registry.scope is not None
|
||||
assert (
|
||||
registry.scope not in self.children
|
||||
), f"scope {registry.scope} exists in {self.name} registry"
|
||||
self.children[registry.scope] = registry
|
||||
|
||||
def _register_module(self, module_class, module_name=None, force=False):
|
||||
if not inspect.isclass(module_class):
|
||||
raise TypeError("module must be a class, " f"but got {type(module_class)}")
|
||||
|
||||
if module_name is None:
|
||||
module_name = module_class.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in self._module_dict:
|
||||
raise KeyError(f"{name} is already registered " f"in {self.name}")
|
||||
self._module_dict[name] = module_class
|
||||
|
||||
def deprecated_register_module(self, cls=None, force=False):
|
||||
warnings.warn(
|
||||
"The old API of register_module(module, force=False) "
|
||||
"is deprecated and will be removed, please use the new API "
|
||||
"register_module(name=None, force=False, module=None) instead."
|
||||
)
|
||||
if cls is None:
|
||||
return partial(self.deprecated_register_module, force=force)
|
||||
self._register_module(cls, force=force)
|
||||
return cls
|
||||
|
||||
def register_module(self, name=None, force=False, module=None):
|
||||
"""Register a module.
|
||||
|
||||
A record will be added to `self._module_dict`, whose key is the class
|
||||
name or the specified name, and value is the class itself.
|
||||
It can be used as a decorator or a normal function.
|
||||
|
||||
Example:
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> @backbones.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> @backbones.register_module(name='mnet')
|
||||
>>> class MobileNet:
|
||||
>>> pass
|
||||
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> backbones.register_module(ResNet)
|
||||
|
||||
Args:
|
||||
name (str | None): The module name to be registered. If not
|
||||
specified, the class name will be used.
|
||||
force (bool, optional): Whether to override an existing class with
|
||||
the same name. Default: False.
|
||||
module (type): Module class to be registered.
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f"force must be a boolean, but got {type(force)}")
|
||||
# NOTE: This is a walkaround to be compatible with the old api,
|
||||
# while it may introduce unexpected bugs.
|
||||
if isinstance(name, type):
|
||||
return self.deprecated_register_module(name, force=force)
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
||||
raise TypeError(
|
||||
"name must be either of None, an instance of str or a sequence"
|
||||
f" of str, but got {type(name)}"
|
||||
)
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
self._register_module(module_class=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(cls):
|
||||
self._register_module(module_class=cls, module_name=name, force=force)
|
||||
return cls
|
||||
|
||||
return _register
|
||||
144
utils/scheduler.py
Normal file
144
utils/scheduler.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
from .registry import Registry
|
||||
|
||||
SCHEDULERS = Registry("schedulers")
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class MultiStepLR(lr_scheduler.MultiStepLR):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
milestones,
|
||||
total_steps,
|
||||
gamma=0.1,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
milestones=[rate * total_steps for rate in milestones],
|
||||
gamma=gamma,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
milestones,
|
||||
total_steps,
|
||||
gamma=0.1,
|
||||
warmup_rate=0.05,
|
||||
warmup_scale=1e-6,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
milestones = [rate * total_steps for rate in milestones]
|
||||
|
||||
def multi_step_with_warmup(s):
|
||||
factor = 1.0
|
||||
for i in range(len(milestones)):
|
||||
if s < milestones[i]:
|
||||
break
|
||||
factor *= gamma
|
||||
|
||||
if s <= warmup_rate * total_steps:
|
||||
warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (
|
||||
1 - warmup_scale
|
||||
)
|
||||
else:
|
||||
warmup_coefficient = 1.0
|
||||
return warmup_coefficient * factor
|
||||
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
lr_lambda=multi_step_with_warmup,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class PolyLR(lr_scheduler.LambdaLR):
|
||||
def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class ExpLR(lr_scheduler.LambdaLR):
|
||||
def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
lr_lambda=lambda s: gamma ** (s / total_steps),
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
|
||||
def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
T_max=total_steps,
|
||||
eta_min=eta_min,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@SCHEDULERS.register_module()
|
||||
class OneCycleLR(lr_scheduler.OneCycleLR):
|
||||
r"""
|
||||
torch.optim.lr_scheduler.OneCycleLR, Block total_steps
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
max_lr,
|
||||
total_steps=None,
|
||||
pct_start=0.3,
|
||||
anneal_strategy="cos",
|
||||
cycle_momentum=True,
|
||||
base_momentum=0.85,
|
||||
max_momentum=0.95,
|
||||
div_factor=25.0,
|
||||
final_div_factor=1e4,
|
||||
three_phase=False,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
max_lr=max_lr,
|
||||
total_steps=total_steps,
|
||||
pct_start=pct_start,
|
||||
anneal_strategy=anneal_strategy,
|
||||
cycle_momentum=cycle_momentum,
|
||||
base_momentum=base_momentum,
|
||||
max_momentum=max_momentum,
|
||||
div_factor=div_factor,
|
||||
final_div_factor=final_div_factor,
|
||||
three_phase=three_phase,
|
||||
last_epoch=last_epoch,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
def build_scheduler(cfg, optimizer):
|
||||
cfg.optimizer = optimizer
|
||||
return SCHEDULERS.build(cfg=cfg)
|
||||
71
utils/timer.py
Normal file
71
utils/timer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
from time import perf_counter
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Timer:
|
||||
"""
|
||||
A timer which computes the time elapsed since the start/reset of the timer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the timer.
|
||||
"""
|
||||
self._start = perf_counter()
|
||||
self._paused: Optional[float] = None
|
||||
self._total_paused = 0
|
||||
self._count_start = 1
|
||||
|
||||
def pause(self) -> None:
|
||||
"""
|
||||
Pause the timer.
|
||||
"""
|
||||
if self._paused is not None:
|
||||
raise ValueError("Trying to pause a Timer that is already paused!")
|
||||
self._paused = perf_counter()
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
bool: whether the timer is currently paused
|
||||
"""
|
||||
return self._paused is not None
|
||||
|
||||
def resume(self) -> None:
|
||||
"""
|
||||
Resume the timer.
|
||||
"""
|
||||
if self._paused is None:
|
||||
raise ValueError("Trying to resume a Timer that is not paused!")
|
||||
# pyre-fixme[58]: `-` is not supported for operand types `float` and
|
||||
# `Optional[float]`.
|
||||
self._total_paused += perf_counter() - self._paused
|
||||
self._paused = None
|
||||
self._count_start += 1
|
||||
|
||||
def seconds(self) -> float:
|
||||
"""
|
||||
Returns:
|
||||
(float): the total number of seconds since the start/reset of the
|
||||
timer, excluding the time when the timer is paused.
|
||||
"""
|
||||
if self._paused is not None:
|
||||
end_time: float = self._paused # type: ignore
|
||||
else:
|
||||
end_time = perf_counter()
|
||||
return end_time - self._start - self._total_paused
|
||||
|
||||
def avg_seconds(self) -> float:
|
||||
"""
|
||||
Returns:
|
||||
(float): the average number of seconds between every start/reset and
|
||||
pause.
|
||||
"""
|
||||
return self.seconds() / self._count_start
|
||||
86
utils/visualization.py
Normal file
86
utils/visualization.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import os
|
||||
import open3d as o3d
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.clone().detach().cpu().numpy()
|
||||
assert isinstance(x, np.ndarray)
|
||||
return x
|
||||
|
||||
|
||||
def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None):
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
coord = to_numpy(coord)
|
||||
if color is not None:
|
||||
color = to_numpy(color)
|
||||
pcd = o3d.geometry.PointCloud()
|
||||
pcd.points = o3d.utility.Vector3dVector(coord)
|
||||
pcd.colors = o3d.utility.Vector3dVector(
|
||||
np.ones_like(coord) if color is None else color
|
||||
)
|
||||
o3d.io.write_point_cloud(file_path, pcd)
|
||||
if logger is not None:
|
||||
logger.info(f"Save Point Cloud to: {file_path}")
|
||||
|
||||
|
||||
def save_bounding_boxes(
|
||||
bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None
|
||||
):
|
||||
bboxes_corners = to_numpy(bboxes_corners)
|
||||
# point list
|
||||
points = bboxes_corners.reshape(-1, 3)
|
||||
# line list
|
||||
box_lines = np.array(
|
||||
[
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[3, 0],
|
||||
[4, 5],
|
||||
[5, 6],
|
||||
[6, 7],
|
||||
[7, 0],
|
||||
[0, 4],
|
||||
[1, 5],
|
||||
[2, 6],
|
||||
[3, 7],
|
||||
]
|
||||
)
|
||||
lines = []
|
||||
for i, _ in enumerate(bboxes_corners):
|
||||
lines.append(box_lines + i * 8)
|
||||
lines = np.concatenate(lines)
|
||||
# color list
|
||||
color = np.array([color for _ in range(len(lines))])
|
||||
# generate line set
|
||||
line_set = o3d.geometry.LineSet()
|
||||
line_set.points = o3d.utility.Vector3dVector(points)
|
||||
line_set.lines = o3d.utility.Vector2iVector(lines)
|
||||
line_set.colors = o3d.utility.Vector3dVector(color)
|
||||
o3d.io.write_line_set(file_path, line_set)
|
||||
|
||||
if logger is not None:
|
||||
logger.info(f"Save Boxes to: {file_path}")
|
||||
|
||||
|
||||
def save_lines(
|
||||
points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None
|
||||
):
|
||||
points = to_numpy(points)
|
||||
lines = to_numpy(lines)
|
||||
colors = np.array([color for _ in range(len(lines))])
|
||||
line_set = o3d.geometry.LineSet()
|
||||
line_set.points = o3d.utility.Vector3dVector(points)
|
||||
line_set.lines = o3d.utility.Vector2iVector(lines)
|
||||
line_set.colors = o3d.utility.Vector3dVector(colors)
|
||||
o3d.io.write_line_set(file_path, line_set)
|
||||
|
||||
if logger is not None:
|
||||
logger.info(f"Save Lines to: {file_path}")
|
||||
Reference in New Issue
Block a user