feat: Initial commit

This commit is contained in:
fdyuandong
2025-04-17 23:14:24 +08:00
commit ca93dd0572
51 changed files with 7904 additions and 0 deletions

0
utils/__init__.py Normal file
View File

53
utils/cache.py Normal file
View File

@@ -0,0 +1,53 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import SharedArray
try:
from multiprocessing.shared_memory import ShareableList
except ImportError:
import warnings
warnings.warn("Please update python version >= 3.8 to enable shared_memory")
import numpy as np
def shared_array(name, var=None):
if var is not None:
# check exist
if os.path.exists(f"/dev/shm/{name}"):
return SharedArray.attach(f"shm://{name}")
# create shared_array
data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype)
data[...] = var[...]
data.flags.writeable = False
else:
data = SharedArray.attach(f"shm://{name}").copy()
return data
def shared_dict(name, var=None):
name = str(name)
assert "." not in name # '.' is used as sep flag
data = {}
if var is not None:
assert isinstance(var, dict)
keys = var.keys()
# current version only cache np.array
keys_valid = []
for key in keys:
if isinstance(var[key], np.ndarray):
keys_valid.append(key)
keys = keys_valid
ShareableList(sequence=keys, name=name + ".keys")
for key in keys:
if isinstance(var[key], np.ndarray):
data[key] = shared_array(name=f"{name}.{key}", var=var[key])
else:
keys = list(ShareableList(name=name + ".keys"))
for key in keys:
data[key] = shared_array(name=f"{name}.{key}")
return data

192
utils/comm.py Normal file
View File

@@ -0,0 +1,192 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import functools
import numpy as np
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
"""
A torch process group which only includes processes that on the same machine as the current process.
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
"""
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert (
_LOCAL_PROCESS_GROUP is not None
), "Local process group is not created! Please use launch() to spawn processes!"
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def is_main_process() -> bool:
return get_rank() == 0
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
if dist.get_backend() == dist.Backend.NCCL:
# This argument is needed to avoid warnings.
# It's valid only for NCCL backend.
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
else:
return dist.group.WORLD
def all_gather(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if get_world_size() == 1:
return [data]
if group is None:
group = (
_get_global_gloo_group()
) # use CPU group by default, to reduce GPU RAM usage.
world_size = dist.get_world_size(group)
if world_size == 1:
return [data]
output = [None for _ in range(world_size)]
dist.all_gather_object(output, data, group=group)
return output
def gather(data, dst=0, group=None):
"""
Run gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
dst (int): destination rank
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: on dst, a list of data gathered from each rank. Otherwise,
an empty list.
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
world_size = dist.get_world_size(group=group)
if world_size == 1:
return [data]
rank = dist.get_rank(group=group)
if rank == dst:
output = [None for _ in range(world_size)]
dist.gather_object(data, output, dst=dst, group=group)
return output
else:
dist.gather_object(data, None, dst=dst, group=group)
return []
def shared_random_seed():
"""
Returns:
int: a random number that is the same across all workers.
If workers need a shared RNG, they can use this shared seed to
create one.
All workers must call this function, otherwise it will deadlock.
"""
ints = np.random.randint(2**31)
all_ints = all_gather(ints)
return all_ints[0]
def reduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict

696
utils/config.py Normal file
View File

@@ -0,0 +1,696 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import ast
import copy
import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
from .misc import import_modules_from_strings
from .path import check_file_exist
if platform.system() == "Windows":
import regex as re
else:
import re
BASE_KEY = "_base_"
DELETE_KEY = "_delete_"
DEPRECATION_KEY = "_deprecation_"
RESERVED_KEYS = ["filename", "text", "pretty_text"]
class ConfigDict(Dict):
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError(
f"'{self.__class__.__name__}' object has no " f"attribute '{name}'"
)
except Exception as e:
ex = e
else:
return value
raise ex
def add_args(parser, cfg, prefix=""):
for k, v in cfg.items():
if isinstance(v, str):
parser.add_argument("--" + prefix + k)
elif isinstance(v, int):
parser.add_argument("--" + prefix + k, type=int)
elif isinstance(v, float):
parser.add_argument("--" + prefix + k, type=float)
elif isinstance(v, bool):
parser.add_argument("--" + prefix + k, action="store_true")
elif isinstance(v, dict):
add_args(parser, v, prefix + k + ".")
elif isinstance(v, abc.Iterable):
parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+")
else:
print(f"cannot parse key {prefix + k} of type {type(v)}")
return parser
class Config:
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
def _validate_py_syntax(filename):
with open(filename, "r", encoding="utf-8") as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError(
"There are syntax errors in config " f"file {filename}: {e}"
)
@staticmethod
def _substitute_predefined_vars(filename, temp_config_name):
file_dirname = osp.dirname(filename)
file_basename = osp.basename(filename)
file_basename_no_extension = osp.splitext(file_basename)[0]
file_extname = osp.splitext(filename)[1]
support_templates = dict(
fileDirname=file_dirname,
fileBasename=file_basename,
fileBasenameNoExtension=file_basename_no_extension,
fileExtname=file_extname,
)
with open(filename, "r", encoding="utf-8") as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
for key, value in support_templates.items():
regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
value = value.replace("\\", "/")
config_file = re.sub(regexp, value, config_file)
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
tmp_config_file.write(config_file)
@staticmethod
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, "r", encoding="utf-8") as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
base_vars = set(re.findall(regexp, config_file))
for base_var in base_vars:
randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
base_var_dict[randstr] = base_var
regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
config_file = re.sub(regexp, f'"{randstr}"', config_file)
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
tmp_config_file.write(config_file)
return base_var_dict
@staticmethod
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
"""Substitute variable strings to their actual values."""
cfg = copy.deepcopy(cfg)
if isinstance(cfg, dict):
for k, v in cfg.items():
if isinstance(v, str) and v in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[v].split("."):
new_v = new_v[new_k]
cfg[k] = new_v
elif isinstance(v, (list, tuple, dict)):
cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
elif isinstance(cfg, tuple):
cfg = tuple(
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
)
elif isinstance(cfg, list):
cfg = [
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
]
elif isinstance(cfg, str) and cfg in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[cfg].split("."):
new_v = new_v[new_k]
cfg = new_v
return cfg
@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
fileExtname = osp.splitext(filename)[1]
if fileExtname not in [".py", ".json", ".yaml", ".yml"]:
raise IOError("Only py/yml/yaml/json type are supported now!")
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix=fileExtname
)
if platform.system() == "Windows":
temp_config_file.close()
temp_config_name = osp.basename(temp_config_file.name)
# Substitute predefined variables
if use_predefined_variables:
Config._substitute_predefined_vars(filename, temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
# Substitute base variables from placeholders to strings
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name
)
if filename.endswith(".py"):
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
Config._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith("__")
}
# delete imported module
del sys.modules[temp_module_name]
elif filename.endswith((".yml", ".yaml", ".json")):
raise NotImplementedError
# close temp file
temp_config_file.close()
# check deprecation information
if DEPRECATION_KEY in cfg_dict:
deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
warning_msg = (
f"The config file {filename} will be deprecated " "in the future."
)
if "expected" in deprecation_info:
warning_msg += f' Please use {deprecation_info["expected"]} ' "instead."
if "reference" in deprecation_info:
warning_msg += (
" More information can be found at "
f'{deprecation_info["reference"]}'
)
warnings.warn(warning_msg)
cfg_text = filename + "\n"
with open(filename, "r", encoding="utf-8") as f:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = (
base_filename if isinstance(base_filename, list) else [base_filename]
)
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
base_cfg_dict = dict()
for c in cfg_dict_list:
duplicate_keys = base_cfg_dict.keys() & c.keys()
if len(duplicate_keys) > 0:
raise KeyError(
"Duplicate key is not allowed among bases. "
f"Duplicate keys: {duplicate_keys}"
)
base_cfg_dict.update(c)
# Substitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(
cfg_dict, base_var_dict, base_cfg_dict
)
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = "\n".join(cfg_text_list)
return cfg_dict, cfg_text
@staticmethod
def _merge_a_into_b(a, b, allow_list_keys=False):
"""merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
Args:
a (dict): The source dict to be merged into ``b``.
b (dict): The origin dict to be fetch keys from ``a``.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in source ``a`` and will replace the element of the
corresponding index in b if b is a list. Default: False.
Returns:
dict: The modified dict of ``b`` using ``a``.
Examples:
# Normally merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# Delete b first and merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# b is a list
>>> Config._merge_a_into_b(
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
[{'a': 2}, {'b': 2}]
"""
b = b.copy()
for k, v in a.items():
if allow_list_keys and k.isdigit() and isinstance(b, list):
k = int(k)
if len(b) <= k:
raise KeyError(f"Index {k} exceeds the length of list {b}")
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
allowed_types = (dict, list) if allow_list_keys else dict
if not isinstance(b[k], allowed_types):
raise TypeError(
f"{k}={v} in child config cannot inherit from base "
f"because {k} is a dict in the child config but is of "
f"type {type(b[k])} in base config. You may set "
f"`{DELETE_KEY}=True` to ignore the base config"
)
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
else:
b[k] = v
return b
@staticmethod
def fromfile(filename, use_predefined_variables=True, import_custom_modules=True):
cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
if import_custom_modules and cfg_dict.get("custom_imports", None):
import_modules_from_strings(**cfg_dict["custom_imports"])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
@staticmethod
def fromstring(cfg_str, file_format):
"""Generate config from config str.
Args:
cfg_str (str): Config str.
file_format (str): Config file format corresponding to the
config str. Only py/yml/yaml/json type are supported now!
Returns:
obj:`Config`: Config obj.
"""
if file_format not in [".py", ".json", ".yaml", ".yml"]:
raise IOError("Only py/yml/yaml/json type are supported now!")
if file_format != ".py" and "dict(" in cfg_str:
# check if users specify a wrong suffix for python
warnings.warn('Please check "file_format", the file format may be .py')
with tempfile.NamedTemporaryFile(
"w", encoding="utf-8", suffix=file_format, delete=False
) as temp_file:
temp_file.write(cfg_str)
# on windows, previous implementation cause error
# see PR 1077 for details
cfg = Config.fromfile(temp_file.name)
os.remove(temp_file.name)
return cfg
@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)"""
partial_parser = ArgumentParser(description=description)
partial_parser.add_argument("config", help="config file path")
cfg_file = partial_parser.parse_known_args()[0].config
cfg = Config.fromfile(cfg_file)
parser = ArgumentParser(description=description)
parser.add_argument("config", help="config file path")
add_args(parser, cfg)
return parser, cfg
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
for key in cfg_dict:
if key in RESERVED_KEYS:
raise KeyError(f"{key} is reserved for config file")
super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
super(Config, self).__setattr__("_filename", filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, "r") as f:
text = f.read()
else:
text = ""
super(Config, self).__setattr__("_text", text)
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
@property
def pretty_text(self):
indent = 4
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = f"'{v}'"
else:
v_str = str(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: {v_str}"
else:
attr_str = f"{str(k)}={v_str}"
attr_str = _indent(attr_str, indent)
return attr_str
def _format_list(k, v, use_mapping=False):
# check if all items in the list are dict
if all(isinstance(_, dict) for _ in v):
v_str = "[\n"
v_str += "\n".join(
f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
).rstrip(",")
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: {v_str}"
else:
attr_str = f"{str(k)}={v_str}"
attr_str = _indent(attr_str, indent) + "]"
else:
attr_str = _format_basic_types(k, v, use_mapping)
return attr_str
def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= not str(key_name).isidentifier()
return contain_invalid_identifier
def _format_dict(input_dict, outest_level=False):
r = ""
s = []
use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += "{"
for idx, (k, v) in enumerate(input_dict.items()):
is_last = idx >= len(input_dict) - 1
end = "" if outest_level or is_last else ","
if isinstance(v, dict):
v_str = "\n" + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: dict({v_str}"
else:
attr_str = f"{str(k)}=dict({v_str}"
attr_str = _indent(attr_str, indent) + ")" + end
elif isinstance(v, list):
attr_str = _format_list(k, v, use_mapping) + end
else:
attr_str = _format_basic_types(k, v, use_mapping) + end
s.append(attr_str)
r += "\n".join(s)
if use_mapping:
r += "}"
return r
cfg_dict = self._cfg_dict.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
# copied from setup.cfg
yapf_style = dict(
based_on_style="pep8",
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True,
)
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
return text
def __repr__(self):
return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
def __getstate__(self):
return (self._cfg_dict, self._filename, self._text)
def __setstate__(self, state):
_cfg_dict, _filename, _text = state
super(Config, self).__setattr__("_cfg_dict", _cfg_dict)
super(Config, self).__setattr__("_filename", _filename)
super(Config, self).__setattr__("_text", _text)
def dump(self, file=None):
cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict()
if self.filename.endswith(".py"):
if file is None:
return self.pretty_text
else:
with open(file, "w", encoding="utf-8") as f:
f.write(self.pretty_text)
else:
import mmcv
if file is None:
file_format = self.filename.split(".")[-1]
return mmcv.dump(cfg_dict, file_format=file_format)
else:
mmcv.dump(cfg_dict, file)
def merge_from_dict(self, options, allow_list_keys=True):
"""Merge list into cfg_dict.
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'models.backbone.depth': 50,
... 'models.backbone.with_cp':True}
>>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... models=dict(backbone=dict(depth=50, with_cp=True)))
# Merge list element
>>> cfg = Config(dict(pipeline=[
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
>>> cfg.merge_from_dict(options, allow_list_keys=True)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(pipeline=[
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
Args:
options (dict): dict of configs to merge from.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in ``options`` and will replace the element of the
corresponding index in the config if the config is a list.
Default: True.
"""
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split(".")
for subkey in key_list[:-1]:
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__("_cfg_dict")
super(Config, self).__setattr__(
"_cfg_dict",
Config._merge_a_into_b(
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys
),
)
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ["true", "false"]:
return True if val.lower() == "true" else False
return val
@staticmethod
def _parse_iterable(val):
"""Parse iterable values in the string.
All elements inside '()' or '[]' are treated as iterable values.
Args:
val (str): Value string.
Returns:
list | tuple: The expanded list or tuple from the string.
Examples:
>>> DictAction._parse_iterable('1,2,3')
[1, 2, 3]
>>> DictAction._parse_iterable('[a, b, c]')
['a', 'b', 'c']
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
[(1, 2, 3), ['a', 'b'], 'c']
"""
def find_next_comma(string):
"""Find the position of next comma in the string.
If no ',' is found in the string, return the string length. All
chars inside '()' and '[]' are treated as one element and thus ','
inside these brackets are ignored.
"""
assert (string.count("(") == string.count(")")) and (
string.count("[") == string.count("]")
), f"Imbalanced brackets exist in {string}"
end = len(string)
for idx, char in enumerate(string):
pre = string[:idx]
# The string before this ',' is balanced
if (
(char == ",")
and (pre.count("(") == pre.count(")"))
and (pre.count("[") == pre.count("]"))
):
end = idx
break
return end
# Strip ' and " characters and replace whitespace.
val = val.strip("'\"").replace(" ", "")
is_tuple = False
if val.startswith("(") and val.endswith(")"):
is_tuple = True
val = val[1:-1]
elif val.startswith("[") and val.endswith("]"):
val = val[1:-1]
elif "," not in val:
# val is a single value
return DictAction._parse_int_float_bool(val)
values = []
while len(val) > 0:
comma_idx = find_next_comma(val)
element = DictAction._parse_iterable(val[:comma_idx])
values.append(element)
val = val[comma_idx + 1 :]
if is_tuple:
values = tuple(values)
return values
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split("=", maxsplit=1)
options[key] = self._parse_iterable(val)
setattr(namespace, self.dest, options)

33
utils/env.py Normal file
View File

@@ -0,0 +1,33 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from datetime import datetime
def get_random_seed():
seed = (
os.getpid()
+ int(datetime.now().strftime("%S%f"))
+ int.from_bytes(os.urandom(2), "big")
)
return seed
def set_seed(seed=None):
if seed is None:
seed = get_random_seed()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.benchmark = False
cudnn.deterministic = True
os.environ["PYTHONHASHSEED"] = str(seed)

585
utils/events.py Normal file
View File

@@ -0,0 +1,585 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import datetime
import json
import logging
import os
import time
import torch
import numpy as np
from typing import List, Optional, Tuple
from collections import defaultdict
from contextlib import contextmanager
__all__ = [
"get_event_storage",
"JSONWriter",
"TensorboardXWriter",
"CommonMetricPrinter",
"EventStorage",
]
_CURRENT_STORAGE_STACK = []
def get_event_storage():
"""
Returns:
The :class:`EventStorage` object that's currently being used.
Throws an error if no :class:`EventStorage` is currently enabled.
"""
assert len(
_CURRENT_STORAGE_STACK
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
return _CURRENT_STORAGE_STACK[-1]
class EventWriter:
"""
Base class for writers that obtain events from :class:`EventStorage` and process them.
"""
def write(self):
raise NotImplementedError
def close(self):
pass
class JSONWriter(EventWriter):
"""
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Examples parsing such a json file:
::
$ cat metrics.json | jq -s '.[0:2]'
[
{
"data_time": 0.008433341979980469,
"iteration": 19,
"loss": 1.9228371381759644,
"loss_box_reg": 0.050025828182697296,
"loss_classifier": 0.5316952466964722,
"loss_mask": 0.7236229181289673,
"loss_rpn_box": 0.0856662318110466,
"loss_rpn_cls": 0.48198649287223816,
"lr": 0.007173333333333333,
"time": 0.25401854515075684
},
{
"data_time": 0.007216215133666992,
"iteration": 39,
"loss": 1.282649278640747,
"loss_box_reg": 0.06222952902317047,
"loss_classifier": 0.30682939291000366,
"loss_mask": 0.6970193982124329,
"loss_rpn_box": 0.038663312792778015,
"loss_rpn_cls": 0.1471673548221588,
"lr": 0.007706666666666667,
"time": 0.2490077018737793
}
]
$ cat metrics.json | jq '.loss_mask'
0.7126231789588928
0.689423680305481
0.6776131987571716
...
"""
def __init__(self, json_file, window_size=20):
"""
Args:
json_file (str): path to the json file. New data will be appended if the file exists.
window_size (int): the window size of median smoothing for the scalars whose
`smoothing_hint` are True.
"""
self._file_handle = open(json_file, "a")
self._window_size = window_size
self._last_write = -1
def write(self):
storage = get_event_storage()
to_save = defaultdict(dict)
for k, (v, iter) in storage.latest_with_smoothing_hint(
self._window_size
).items():
# keep scalars that have not been written
if iter <= self._last_write:
continue
to_save[iter][k] = v
if len(to_save):
all_iters = sorted(to_save.keys())
self._last_write = max(all_iters)
for itr, scalars_per_iter in to_save.items():
scalars_per_iter["iteration"] = itr
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
self._file_handle.flush()
try:
os.fsync(self._file_handle.fileno())
except AttributeError:
pass
def close(self):
self._file_handle.close()
class TensorboardXWriter(EventWriter):
"""
Write all scalars to a tensorboard file.
"""
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
"""
Args:
log_dir (str): the directory to save the output events
window_size (int): the scalars will be median-smoothed by this window size
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
"""
self._window_size = window_size
from torch.utils.tensorboard import SummaryWriter
self._writer = SummaryWriter(log_dir, **kwargs)
self._last_write = -1
def write(self):
storage = get_event_storage()
new_last_write = self._last_write
for k, (v, iter) in storage.latest_with_smoothing_hint(
self._window_size
).items():
if iter > self._last_write:
self._writer.add_scalar(k, v, iter)
new_last_write = max(new_last_write, iter)
self._last_write = new_last_write
# storage.put_{image,histogram} is only meant to be used by
# tensorboard writer. So we access its internal fields directly from here.
if len(storage._vis_data) >= 1:
for img_name, img, step_num in storage._vis_data:
self._writer.add_image(img_name, img, step_num)
# Storage stores all image data and rely on this writer to clear them.
# As a result it assumes only one writer will use its image data.
# An alternative design is to let storage store limited recent
# data (e.g. only the most recent image) that all writers can access.
# In that case a writer may not see all image data if its period is long.
storage.clear_images()
if len(storage._histograms) >= 1:
for params in storage._histograms:
self._writer.add_histogram_raw(**params)
storage.clear_histograms()
def close(self):
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
self._writer.close()
class CommonMetricPrinter(EventWriter):
"""
Print **common** metrics to the terminal, including
iteration time, ETA, memory, all losses, and the learning rate.
It also applies smoothing using a window of 20 elements.
It's meant to print common metrics in common ways.
To print something in more customized ways, please implement a similar printer by yourself.
"""
def __init__(self, max_iter: Optional[int] = None, window_size: int = 20):
"""
Args:
max_iter: the maximum number of iterations to train.
Used to compute ETA. If not given, ETA will not be printed.
window_size (int): the losses will be median-smoothed by this window size
"""
self.logger = logging.getLogger(__name__)
self._max_iter = max_iter
self._window_size = window_size
self._last_write = (
None # (step, time) of last call to write(). Used to compute ETA
)
def _get_eta(self, storage) -> Optional[str]:
if self._max_iter is None:
return ""
iteration = storage.iter
try:
eta_seconds = storage.history("time").median(1000) * (
self._max_iter - iteration - 1
)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
return str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError:
# estimate eta on our own - more noisy
eta_string = None
if self._last_write is not None:
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
iteration - self._last_write[0]
)
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
self._last_write = (iteration, time.perf_counter())
return eta_string
def write(self):
storage = get_event_storage()
iteration = storage.iter
if iteration == self._max_iter:
# This hook only reports training progress (loss, ETA, etc) but not other data,
# therefore do not write anything after training succeeds, even if this method
# is called.
return
try:
data_time = storage.history("data_time").avg(20)
except KeyError:
# they may not exist in the first few iterations (due to warmup)
# or when SimpleTrainer is not used
data_time = None
try:
iter_time = storage.history("time").global_avg()
except KeyError:
iter_time = None
try:
lr = "{:.5g}".format(storage.history("lr").latest())
except KeyError:
lr = "N/A"
eta_string = self._get_eta(storage)
if torch.cuda.is_available():
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
else:
max_mem_mb = None
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
" {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
eta=f"eta: {eta_string} " if eta_string else "",
iter=iteration,
losses=" ".join(
[
"{}: {:.4g}".format(k, v.median(self._window_size))
for k, v in storage.histories().items()
if "loss" in k
]
),
time="time: {:.4f} ".format(iter_time)
if iter_time is not None
else "",
data_time="data_time: {:.4f} ".format(data_time)
if data_time is not None
else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb)
if max_mem_mb is not None
else "",
)
)
class EventStorage:
"""
The user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
"""
def __init__(self, start_iter=0):
"""
Args:
start_iter (int): the iteration number to start with
"""
self._history = defaultdict(AverageMeter)
self._smoothing_hints = {}
self._latest_scalars = {}
self._iter = start_iter
self._current_prefix = ""
self._vis_data = []
self._histograms = []
# def put_image(self, img_name, img_tensor):
# """
# Add an `img_tensor` associated with `img_name`, to be shown on
# tensorboard.
# Args:
# img_name (str): The name of the image to put into tensorboard.
# img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
# Tensor of shape `[channel, height, width]` where `channel` is
# 3. The image format should be RGB. The elements in img_tensor
# can either have values in [0, 1] (float32) or [0, 255] (uint8).
# The `img_tensor` will be visualized in tensorboard.
# """
# self._vis_data.append((img_name, img_tensor, self._iter))
def put_scalar(self, name, value, n=1, smoothing_hint=False):
"""
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
Args:
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
smoothed when logged. The hint will be accessible through
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
and apply custom smoothing rule.
It defaults to True because most scalars we save need to be smoothed to
provide any useful signal.
"""
name = self._current_prefix + name
history = self._history[name]
history.update(value, n)
self._latest_scalars[name] = (value, self._iter)
existing_hint = self._smoothing_hints.get(name)
if existing_hint is not None:
assert (
existing_hint == smoothing_hint
), "Scalar {} was put with a different smoothing_hint!".format(name)
else:
self._smoothing_hints[name] = smoothing_hint
# def put_scalars(self, *, smoothing_hint=True, **kwargs):
# """
# Put multiple scalars from keyword arguments.
# Examples:
# storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
# """
# for k, v in kwargs.items():
# self.put_scalar(k, v, smoothing_hint=smoothing_hint)
#
# def put_histogram(self, hist_name, hist_tensor, bins=1000):
# """
# Create a histogram from a tensor.
# Args:
# hist_name (str): The name of the histogram to put into tensorboard.
# hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
# into a histogram.
# bins (int): Number of histogram bins.
# """
# ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
#
# # Create a histogram with PyTorch
# hist_counts = torch.histc(hist_tensor, bins=bins)
# hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
#
# # Parameter for the add_histogram_raw function of SummaryWriter
# hist_params = dict(
# tag=hist_name,
# min=ht_min,
# max=ht_max,
# num=len(hist_tensor),
# sum=float(hist_tensor.sum()),
# sum_squares=float(torch.sum(hist_tensor**2)),
# bucket_limits=hist_edges[1:].tolist(),
# bucket_counts=hist_counts.tolist(),
# global_step=self._iter,
# )
# self._histograms.append(hist_params)
def history(self, name):
"""
Returns:
AverageMeter: the history for name
"""
ret = self._history.get(name, None)
if ret is None:
raise KeyError("No history metric available for {}!".format(name))
return ret
def histories(self):
"""
Returns:
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
"""
return self._history
def latest(self):
"""
Returns:
dict[str -> (float, int)]: mapping from the name of each scalar to the most
recent value and the iteration number its added.
"""
return self._latest_scalars
def latest_with_smoothing_hint(self, window_size=20):
"""
Similar to :meth:`latest`, but the returned values
are either the un-smoothed original latest value,
or a median of the given window_size,
depend on whether the smoothing_hint is True.
This provides a default behavior that other writers can use.
"""
result = {}
for k, (v, itr) in self._latest_scalars.items():
result[k] = (
self._history[k].median(window_size) if self._smoothing_hints[k] else v,
itr,
)
return result
def smoothing_hints(self):
"""
Returns:
dict[name -> bool]: the user-provided hint on whether the scalar
is noisy and needs smoothing.
"""
return self._smoothing_hints
def step(self):
"""
User should either: (1) Call this function to increment storage.iter when needed. Or
(2) Set `storage.iter` to the correct iteration number before each iteration.
The storage will then be able to associate the new data with an iteration number.
"""
self._iter += 1
@property
def iter(self):
"""
Returns:
int: The current iteration number. When used together with a trainer,
this is ensured to be the same as trainer.iter.
"""
return self._iter
@iter.setter
def iter(self, val):
self._iter = int(val)
@property
def iteration(self):
# for backward compatibility
return self._iter
def __enter__(self):
_CURRENT_STORAGE_STACK.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
assert _CURRENT_STORAGE_STACK[-1] == self
_CURRENT_STORAGE_STACK.pop()
@contextmanager
def name_scope(self, name):
"""
Yields:
A context within which all the events added to this storage
will be prefixed by the name scope.
"""
old_prefix = self._current_prefix
self._current_prefix = name.rstrip("/") + "/"
yield
self._current_prefix = old_prefix
def clear_images(self):
"""
Delete all the stored images for visualization. This should be called
after images are written to tensorboard.
"""
self._vis_data = []
def clear_histograms(self):
"""
Delete all the stored histograms for visualization.
This should be called after histograms are written to tensorboard.
"""
self._histograms = []
def reset_history(self, name):
ret = self._history.get(name, None)
if ret is None:
raise KeyError("No history metric available for {}!".format(name))
ret.reset()
def reset_histories(self):
for name in self._history.keys():
self._history[name].reset()
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.total = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.total = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.total += val * n
self.count += n
self.avg = self.total / self.count
class HistoryBuffer:
"""
Track a series of scalar values and provide access to smoothed values over a
window or the global average of the series.
"""
def __init__(self, max_length: int = 1000000) -> None:
"""
Args:
max_length: maximal number of values that can be stored in the
buffer. When the capacity of the buffer is exhausted, old
values will be removed.
"""
self._max_length: int = max_length
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
self._count: int = 0
self._global_avg: float = 0
def update(self, value: float, iteration: Optional[float] = None) -> None:
"""
Add a new scalar value produced at certain iteration. If the length
of the buffer exceeds self._max_length, the oldest element will be
removed from the buffer.
"""
if iteration is None:
iteration = self._count
if len(self._data) == self._max_length:
self._data.pop(0)
self._data.append((value, iteration))
self._count += 1
self._global_avg += (value - self._global_avg) / self._count
def latest(self) -> float:
"""
Return the latest scalar value added to the buffer.
"""
return self._data[-1][0]
def median(self, window_size: int) -> float:
"""
Return the median of the latest `window_size` values in the buffer.
"""
return np.median([x[0] for x in self._data[-window_size:]])
def avg(self, window_size: int) -> float:
"""
Return the mean of the latest `window_size` values in the buffer.
"""
return np.mean([x[0] for x in self._data[-window_size:]])
def global_avg(self) -> float:
"""
Return the mean of all the elements in the buffer. Note that this
includes those getting removed due to limited buffer storage.
"""
return self._global_avg
def values(self) -> List[Tuple[float, float]]:
"""
Returns:
list[(number, iteration)]: content of the current buffer.
"""
return self._data

167
utils/logger.py Normal file
View File

@@ -0,0 +1,167 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import logging
import torch
import torch.distributed as dist
from termcolor import colored
logger_initialized = {}
root_status = 0
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified and the process rank is 0, a FileHandler
will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
file_mode (str): The file mode used in opening log file.
Defaults to 'a'.
color (bool): Colorful log output. Defaults to True
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
logger.propagate = False
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# only rank 0 will add a FileHandler
if rank == 0 and log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
plain_formatter = logging.Formatter(
"[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
)
else:
formatter = plain_formatter
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
if rank == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
def print_log(msg, logger=None, level=logging.INFO):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (logging.Logger | str | None): The logger to be used.
Some special loggers are:
- "silent": no message will be printed.
- other str: the logger obtained with `get_root_logger(logger)`.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if logger is None:
print(msg)
elif isinstance(logger, logging.Logger):
logger.log(level, msg)
elif logger == "silent":
pass
elif isinstance(logger, str):
_logger = get_logger(logger)
_logger.log(level, msg)
else:
raise TypeError(
"logger should be either a logging.Logger object, str, "
f'"silent" or None, but got {type(logger)}'
)
def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name.
Args:
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
file_mode (str): File Mode of logger. (w or a)
Returns:
logging.Logger: The root logger.
"""
logger = get_logger(
name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode
)
return logger
def _log_api_usage(identifier: str):
"""
Internal function used to log the usage of different detectron2 components
inside facebook's infra.
"""
torch._C._log_api_usage_once("pointcept." + identifier)

156
utils/misc.py Normal file
View File

@@ -0,0 +1,156 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import warnings
from collections import abc
import numpy as np
import torch
from importlib import import_module
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def intersection_and_union(output, target, K, ignore_index=-1):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert output.ndim in [1, 2, 3]
assert output.shape == target.shape
output = output.reshape(output.size).copy()
target = target.reshape(target.size)
output[np.where(target == ignore_index)[0]] = ignore_index
intersection = output[np.where(output == target)[0]]
area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
area_output, _ = np.histogram(output, bins=np.arange(K + 1))
area_target, _ = np.histogram(target, bins=np.arange(K + 1))
area_union = area_output + area_target - area_intersection
return area_intersection, area_union, area_target
def intersection_and_union_gpu(output, target, k, ignore_index=-1):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert output.dim() in [1, 2, 3]
assert output.shape == target.shape
output = output.view(-1)
target = target.view(-1)
output[target == ignore_index] = ignore_index
intersection = output[output == target]
area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1)
area_output = torch.histc(output, bins=k, min=0, max=k - 1)
area_target = torch.histc(target, bins=k, min=0, max=k - 1)
area_union = area_output + area_target - area_intersection
return area_intersection, area_union, area_target
def make_dirs(dir_name):
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
def find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def is_seq_of(seq, expected_type, seq_type=None):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if seq_type is None:
exp_seq_type = abc.Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def is_str(x):
"""Whether the input is an string instance.
Note: This method is deprecated since python 2 is no longer supported.
"""
return isinstance(x, str)
def import_modules_from_strings(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.
Returns:
list[module] | module | None: The imported modules.
Examples:
>>> osp, sys = import_modules_from_strings(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(f"custom_imports must be a list but got type {type(imports)}")
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.")
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
warnings.warn(f"{imp} failed to import and is ignored.", UserWarning)
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported

52
utils/optimizer.py Normal file
View File

@@ -0,0 +1,52 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import torch
from utils.logger import get_root_logger
from utils.registry import Registry
OPTIMIZERS = Registry("optimizers")
OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD")
OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam")
OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW")
def build_optimizer(cfg, model, param_dicts=None):
if param_dicts is None:
cfg.params = model.parameters()
else:
cfg.params = [dict(names=[], params=[], lr=cfg.lr)]
for i in range(len(param_dicts)):
param_group = dict(names=[], params=[])
if "lr" in param_dicts[i].keys():
param_group["lr"] = param_dicts[i].lr
if "momentum" in param_dicts[i].keys():
param_group["momentum"] = param_dicts[i].momentum
if "weight_decay" in param_dicts[i].keys():
param_group["weight_decay"] = param_dicts[i].weight_decay
cfg.params.append(param_group)
for n, p in model.named_parameters():
flag = False
for i in range(len(param_dicts)):
if param_dicts[i].keyword in n:
cfg.params[i + 1]["names"].append(n)
cfg.params[i + 1]["params"].append(p)
flag = True
break
if not flag:
cfg.params[0]["names"].append(n)
cfg.params[0]["params"].append(p)
logger = get_root_logger()
for i in range(len(cfg.params)):
param_names = cfg.params[i].pop("names")
message = ""
for key in cfg.params[i].keys():
if key != "params":
message += f" {key}: {cfg.params[i][key]};"
logger.info(f"Params Group {i+1} -{message} Params: {param_names}.")
return OPTIMIZERS.build(cfg=cfg)

105
utils/path.py Normal file
View File

@@ -0,0 +1,105 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import os.path as osp
from pathlib import Path
from .misc import is_str
def is_filepath(x):
return is_str(x) or isinstance(x, Path)
def fopen(filepath, *args, **kwargs):
if is_str(filepath):
return open(filepath, *args, **kwargs)
elif isinstance(filepath, Path):
return filepath.open(*args, **kwargs)
raise ValueError("`filepath` should be a string or a Path")
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
if not osp.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))
def mkdir_or_exist(dir_name, mode=0o777):
if dir_name == "":
return
dir_name = osp.expanduser(dir_name)
os.makedirs(dir_name, mode=mode, exist_ok=True)
def symlink(src, dst, overwrite=True, **kwargs):
if os.path.lexists(dst) and overwrite:
os.remove(dst)
os.symlink(src, dst, **kwargs)
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
"""Scan a directory to find the interested files.
Args:
dir_path (str | obj:`Path`): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
case_sensitive (bool, optional) : If set to False, ignore the case of
suffix. Default: True.
Returns:
A generator for all the interested files with relative paths.
"""
if isinstance(dir_path, (str, Path)):
dir_path = str(dir_path)
else:
raise TypeError('"dir_path" must be a string or Path object')
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
if suffix is not None and not case_sensitive:
suffix = (
suffix.lower()
if isinstance(suffix, str)
else tuple(item.lower() for item in suffix)
)
root = dir_path
def _scandir(dir_path, suffix, recursive, case_sensitive):
for entry in os.scandir(dir_path):
if not entry.name.startswith(".") and entry.is_file():
rel_path = osp.relpath(entry.path, root)
_rel_path = rel_path if case_sensitive else rel_path.lower()
if suffix is None or _rel_path.endswith(suffix):
yield rel_path
elif recursive and os.path.isdir(entry.path):
# scan recursively if entry.path is a directory
yield from _scandir(entry.path, suffix, recursive, case_sensitive)
return _scandir(dir_path, suffix, recursive, case_sensitive)
def find_vcs_root(path, markers=(".git",)):
"""Finds the root directory (including itself) of specified markers.
Args:
path (str): Path of directory or file.
markers (list[str], optional): List of file or directory names.
Returns:
The directory contained one of the markers or None if not found.
"""
if osp.isfile(path):
path = osp.dirname(path)
prev, cur = None, osp.abspath(osp.expanduser(path))
while cur != prev:
if any(osp.exists(osp.join(cur, marker)) for marker in markers):
return cur
prev, cur = cur, osp.split(cur)[0]
return None

318
utils/registry.py Normal file
View File

@@ -0,0 +1,318 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import inspect
import warnings
from functools import partial
from .misc import is_seq_of
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from configs dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
if "type" not in cfg:
if default_args is None or "type" not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f"but got {cfg}\n{default_args}"
)
if not isinstance(registry, Registry):
raise TypeError(
"registry must be an mmcv.Registry object, " f"but got {type(registry)}"
)
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError(
"default_args must be a dict or None, " f"but got {type(default_args)}"
)
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop("type")
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(f"{obj_type} is not in the {registry.name} registry")
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f"{obj_cls.__name__}: {e}")
class Registry:
"""A registry to map strings to classes.
Registered object could be built from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(type='ResNet'))
Please refer to
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
advanced usage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = (
self.__class__.__name__ + f"(name={self._name}, "
f"items={self._module_dict})"
)
return format_str
@staticmethod
def infer_scope():
"""Infer the scope of registry.
The name of the package where registry is defined will be returned.
Example:
# in mmdet/models/backbone/resnet.py
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
The scope of ``ResNet`` will be ``mmdet``.
Returns:
scope (str): The inferred scope name.
"""
# inspect.stack() trace where this function is called, the index-2
# indicates the frame where `infer_scope()` is called
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
split_filename = filename.split(".")
return split_filename[0]
@staticmethod
def split_scope_key(key):
"""Split scope and key.
The first scope will be split from key.
Examples:
>>> Registry.split_scope_key('mmdet.ResNet')
'mmdet', 'ResNet'
>>> Registry.split_scope_key('ResNet')
None, 'ResNet'
Return:
scope (str, None): The first scope.
key (str): The remaining key.
"""
split_index = key.find(".")
if split_index != -1:
return key[:split_index], key[split_index + 1 :]
else:
return None, key
@property
def name(self):
return self._name
@property
def scope(self):
return self._scope
@property
def module_dict(self):
return self._module_dict
@property
def children(self):
return self._children
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
scope, real_key = self.split_scope_key(key)
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
# get from self._children
if scope in self._children:
return self._children[scope].get(real_key)
else:
# goto root
parent = self.parent
while parent.parent is not None:
parent = parent.parent
return parent.get(key)
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def _add_children(self, registry):
"""Add children for a registry.
The ``registry`` will be added as children based on its scope.
The parent registry could build objects from children registry.
Example:
>>> models = Registry('models')
>>> mmdet_models = Registry('models', parent=models)
>>> @mmdet_models.register_module()
>>> class ResNet:
>>> pass
>>> resnet = models.build(dict(type='mmdet.ResNet'))
"""
assert isinstance(registry, Registry)
assert registry.scope is not None
assert (
registry.scope not in self.children
), f"scope {registry.scope} exists in {self.name} registry"
self.children[registry.scope] = registry
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError("module must be a class, " f"but got {type(module_class)}")
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f"{name} is already registered " f"in {self.name}")
self._module_dict[name] = module_class
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
"The old API of register_module(module, force=False) "
"is deprecated and will be removed, please use the new API "
"register_module(name=None, force=False, module=None) instead."
)
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f"force must be a boolean, but got {type(force)}")
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
"name must be either of None, an instance of str or a sequence"
f" of str, but got {type(name)}"
)
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module_class=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(module_class=cls, module_name=name, force=force)
return cls
return _register

144
utils/scheduler.py Normal file
View File

@@ -0,0 +1,144 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import torch.optim.lr_scheduler as lr_scheduler
from .registry import Registry
SCHEDULERS = Registry("schedulers")
@SCHEDULERS.register_module()
class MultiStepLR(lr_scheduler.MultiStepLR):
def __init__(
self,
optimizer,
milestones,
total_steps,
gamma=0.1,
last_epoch=-1,
verbose=False,
):
super().__init__(
optimizer=optimizer,
milestones=[rate * total_steps for rate in milestones],
gamma=gamma,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
def __init__(
self,
optimizer,
milestones,
total_steps,
gamma=0.1,
warmup_rate=0.05,
warmup_scale=1e-6,
last_epoch=-1,
verbose=False,
):
milestones = [rate * total_steps for rate in milestones]
def multi_step_with_warmup(s):
factor = 1.0
for i in range(len(milestones)):
if s < milestones[i]:
break
factor *= gamma
if s <= warmup_rate * total_steps:
warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (
1 - warmup_scale
)
else:
warmup_coefficient = 1.0
return warmup_coefficient * factor
super().__init__(
optimizer=optimizer,
lr_lambda=multi_step_with_warmup,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class PolyLR(lr_scheduler.LambdaLR):
def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
super().__init__(
optimizer=optimizer,
lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class ExpLR(lr_scheduler.LambdaLR):
def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
super().__init__(
optimizer=optimizer,
lr_lambda=lambda s: gamma ** (s / total_steps),
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
super().__init__(
optimizer=optimizer,
T_max=total_steps,
eta_min=eta_min,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class OneCycleLR(lr_scheduler.OneCycleLR):
r"""
torch.optim.lr_scheduler.OneCycleLR, Block total_steps
"""
def __init__(
self,
optimizer,
max_lr,
total_steps=None,
pct_start=0.3,
anneal_strategy="cos",
cycle_momentum=True,
base_momentum=0.85,
max_momentum=0.95,
div_factor=25.0,
final_div_factor=1e4,
three_phase=False,
last_epoch=-1,
verbose=False,
):
super().__init__(
optimizer=optimizer,
max_lr=max_lr,
total_steps=total_steps,
pct_start=pct_start,
anneal_strategy=anneal_strategy,
cycle_momentum=cycle_momentum,
base_momentum=base_momentum,
max_momentum=max_momentum,
div_factor=div_factor,
final_div_factor=final_div_factor,
three_phase=three_phase,
last_epoch=last_epoch,
verbose=verbose,
)
def build_scheduler(cfg, optimizer):
cfg.optimizer = optimizer
return SCHEDULERS.build(cfg=cfg)

71
utils/timer.py Normal file
View File

@@ -0,0 +1,71 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
from time import perf_counter
from typing import Optional
class Timer:
"""
A timer which computes the time elapsed since the start/reset of the timer.
"""
def __init__(self) -> None:
self.reset()
def reset(self) -> None:
"""
Reset the timer.
"""
self._start = perf_counter()
self._paused: Optional[float] = None
self._total_paused = 0
self._count_start = 1
def pause(self) -> None:
"""
Pause the timer.
"""
if self._paused is not None:
raise ValueError("Trying to pause a Timer that is already paused!")
self._paused = perf_counter()
def is_paused(self) -> bool:
"""
Returns:
bool: whether the timer is currently paused
"""
return self._paused is not None
def resume(self) -> None:
"""
Resume the timer.
"""
if self._paused is None:
raise ValueError("Trying to resume a Timer that is not paused!")
# pyre-fixme[58]: `-` is not supported for operand types `float` and
# `Optional[float]`.
self._total_paused += perf_counter() - self._paused
self._paused = None
self._count_start += 1
def seconds(self) -> float:
"""
Returns:
(float): the total number of seconds since the start/reset of the
timer, excluding the time when the timer is paused.
"""
if self._paused is not None:
end_time: float = self._paused # type: ignore
else:
end_time = perf_counter()
return end_time - self._start - self._total_paused
def avg_seconds(self) -> float:
"""
Returns:
(float): the average number of seconds between every start/reset and
pause.
"""
return self.seconds() / self._count_start

86
utils/visualization.py Normal file
View File

@@ -0,0 +1,86 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import open3d as o3d
import numpy as np
import torch
def to_numpy(x):
if isinstance(x, torch.Tensor):
x = x.clone().detach().cpu().numpy()
assert isinstance(x, np.ndarray)
return x
def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
coord = to_numpy(coord)
if color is not None:
color = to_numpy(color)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(coord)
pcd.colors = o3d.utility.Vector3dVector(
np.ones_like(coord) if color is None else color
)
o3d.io.write_point_cloud(file_path, pcd)
if logger is not None:
logger.info(f"Save Point Cloud to: {file_path}")
def save_bounding_boxes(
bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None
):
bboxes_corners = to_numpy(bboxes_corners)
# point list
points = bboxes_corners.reshape(-1, 3)
# line list
box_lines = np.array(
[
[0, 1],
[1, 2],
[2, 3],
[3, 0],
[4, 5],
[5, 6],
[6, 7],
[7, 0],
[0, 4],
[1, 5],
[2, 6],
[3, 7],
]
)
lines = []
for i, _ in enumerate(bboxes_corners):
lines.append(box_lines + i * 8)
lines = np.concatenate(lines)
# color list
color = np.array([color for _ in range(len(lines))])
# generate line set
line_set = o3d.geometry.LineSet()
line_set.points = o3d.utility.Vector3dVector(points)
line_set.lines = o3d.utility.Vector2iVector(lines)
line_set.colors = o3d.utility.Vector3dVector(color)
o3d.io.write_line_set(file_path, line_set)
if logger is not None:
logger.info(f"Save Boxes to: {file_path}")
def save_lines(
points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None
):
points = to_numpy(points)
lines = to_numpy(lines)
colors = np.array([color for _ in range(len(lines))])
line_set = o3d.geometry.LineSet()
line_set.points = o3d.utility.Vector3dVector(points)
line_set.lines = o3d.utility.Vector2iVector(lines)
line_set.colors = o3d.utility.Vector3dVector(colors)
o3d.io.write_line_set(file_path, line_set)
if logger is not None:
logger.info(f"Save Lines to: {file_path}")