mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 01:49:23 +08:00
feat: Initial commit
This commit is contained in:
696
utils/config.py
Normal file
696
utils/config.py
Normal file
@@ -0,0 +1,696 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
import ast
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
import warnings
|
||||
from argparse import Action, ArgumentParser
|
||||
from collections import abc
|
||||
from importlib import import_module
|
||||
|
||||
from addict import Dict
|
||||
from yapf.yapflib.yapf_api import FormatCode
|
||||
|
||||
from .misc import import_modules_from_strings
|
||||
from .path import check_file_exist
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import regex as re
|
||||
else:
|
||||
import re
|
||||
|
||||
BASE_KEY = "_base_"
|
||||
DELETE_KEY = "_delete_"
|
||||
DEPRECATION_KEY = "_deprecation_"
|
||||
RESERVED_KEYS = ["filename", "text", "pretty_text"]
|
||||
|
||||
|
||||
class ConfigDict(Dict):
|
||||
def __missing__(self, name):
|
||||
raise KeyError(name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
value = super(ConfigDict, self).__getattr__(name)
|
||||
except KeyError:
|
||||
ex = AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no " f"attribute '{name}'"
|
||||
)
|
||||
except Exception as e:
|
||||
ex = e
|
||||
else:
|
||||
return value
|
||||
raise ex
|
||||
|
||||
|
||||
def add_args(parser, cfg, prefix=""):
|
||||
for k, v in cfg.items():
|
||||
if isinstance(v, str):
|
||||
parser.add_argument("--" + prefix + k)
|
||||
elif isinstance(v, int):
|
||||
parser.add_argument("--" + prefix + k, type=int)
|
||||
elif isinstance(v, float):
|
||||
parser.add_argument("--" + prefix + k, type=float)
|
||||
elif isinstance(v, bool):
|
||||
parser.add_argument("--" + prefix + k, action="store_true")
|
||||
elif isinstance(v, dict):
|
||||
add_args(parser, v, prefix + k + ".")
|
||||
elif isinstance(v, abc.Iterable):
|
||||
parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+")
|
||||
else:
|
||||
print(f"cannot parse key {prefix + k} of type {type(v)}")
|
||||
return parser
|
||||
|
||||
|
||||
class Config:
|
||||
"""A facility for config and config files.
|
||||
|
||||
It supports common file formats as configs: python/json/yaml. The interface
|
||||
is the same as a dict object and also allows access config values as
|
||||
attributes.
|
||||
|
||||
Example:
|
||||
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||
>>> cfg.a
|
||||
1
|
||||
>>> cfg.b
|
||||
{'b1': [0, 1]}
|
||||
>>> cfg.b.b1
|
||||
[0, 1]
|
||||
>>> cfg = Config.fromfile('tests/data/config/a.py')
|
||||
>>> cfg.filename
|
||||
"/home/kchen/projects/mmcv/tests/data/config/a.py"
|
||||
>>> cfg.item4
|
||||
'test'
|
||||
>>> cfg
|
||||
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
|
||||
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _validate_py_syntax(filename):
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
content = f.read()
|
||||
try:
|
||||
ast.parse(content)
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(
|
||||
"There are syntax errors in config " f"file {filename}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _substitute_predefined_vars(filename, temp_config_name):
|
||||
file_dirname = osp.dirname(filename)
|
||||
file_basename = osp.basename(filename)
|
||||
file_basename_no_extension = osp.splitext(file_basename)[0]
|
||||
file_extname = osp.splitext(filename)[1]
|
||||
support_templates = dict(
|
||||
fileDirname=file_dirname,
|
||||
fileBasename=file_basename,
|
||||
fileBasenameNoExtension=file_basename_no_extension,
|
||||
fileExtname=file_extname,
|
||||
)
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
config_file = f.read()
|
||||
for key, value in support_templates.items():
|
||||
regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
|
||||
value = value.replace("\\", "/")
|
||||
config_file = re.sub(regexp, value, config_file)
|
||||
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
|
||||
tmp_config_file.write(config_file)
|
||||
|
||||
@staticmethod
|
||||
def _pre_substitute_base_vars(filename, temp_config_name):
|
||||
"""Substitute base variable placehoders to string, so that parsing
|
||||
would work."""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
config_file = f.read()
|
||||
base_var_dict = {}
|
||||
regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
|
||||
base_vars = set(re.findall(regexp, config_file))
|
||||
for base_var in base_vars:
|
||||
randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
|
||||
base_var_dict[randstr] = base_var
|
||||
regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
|
||||
config_file = re.sub(regexp, f'"{randstr}"', config_file)
|
||||
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
|
||||
tmp_config_file.write(config_file)
|
||||
return base_var_dict
|
||||
|
||||
@staticmethod
|
||||
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
|
||||
"""Substitute variable strings to their actual values."""
|
||||
cfg = copy.deepcopy(cfg)
|
||||
|
||||
if isinstance(cfg, dict):
|
||||
for k, v in cfg.items():
|
||||
if isinstance(v, str) and v in base_var_dict:
|
||||
new_v = base_cfg
|
||||
for new_k in base_var_dict[v].split("."):
|
||||
new_v = new_v[new_k]
|
||||
cfg[k] = new_v
|
||||
elif isinstance(v, (list, tuple, dict)):
|
||||
cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
|
||||
elif isinstance(cfg, tuple):
|
||||
cfg = tuple(
|
||||
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
|
||||
)
|
||||
elif isinstance(cfg, list):
|
||||
cfg = [
|
||||
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
|
||||
]
|
||||
elif isinstance(cfg, str) and cfg in base_var_dict:
|
||||
new_v = base_cfg
|
||||
for new_k in base_var_dict[cfg].split("."):
|
||||
new_v = new_v[new_k]
|
||||
cfg = new_v
|
||||
|
||||
return cfg
|
||||
|
||||
@staticmethod
|
||||
def _file2dict(filename, use_predefined_variables=True):
|
||||
filename = osp.abspath(osp.expanduser(filename))
|
||||
check_file_exist(filename)
|
||||
fileExtname = osp.splitext(filename)[1]
|
||||
if fileExtname not in [".py", ".json", ".yaml", ".yml"]:
|
||||
raise IOError("Only py/yml/yaml/json type are supported now!")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||
temp_config_file = tempfile.NamedTemporaryFile(
|
||||
dir=temp_config_dir, suffix=fileExtname
|
||||
)
|
||||
if platform.system() == "Windows":
|
||||
temp_config_file.close()
|
||||
temp_config_name = osp.basename(temp_config_file.name)
|
||||
# Substitute predefined variables
|
||||
if use_predefined_variables:
|
||||
Config._substitute_predefined_vars(filename, temp_config_file.name)
|
||||
else:
|
||||
shutil.copyfile(filename, temp_config_file.name)
|
||||
# Substitute base variables from placeholders to strings
|
||||
base_var_dict = Config._pre_substitute_base_vars(
|
||||
temp_config_file.name, temp_config_file.name
|
||||
)
|
||||
|
||||
if filename.endswith(".py"):
|
||||
temp_module_name = osp.splitext(temp_config_name)[0]
|
||||
sys.path.insert(0, temp_config_dir)
|
||||
Config._validate_py_syntax(filename)
|
||||
mod = import_module(temp_module_name)
|
||||
sys.path.pop(0)
|
||||
cfg_dict = {
|
||||
name: value
|
||||
for name, value in mod.__dict__.items()
|
||||
if not name.startswith("__")
|
||||
}
|
||||
# delete imported module
|
||||
del sys.modules[temp_module_name]
|
||||
elif filename.endswith((".yml", ".yaml", ".json")):
|
||||
raise NotImplementedError
|
||||
# close temp file
|
||||
temp_config_file.close()
|
||||
|
||||
# check deprecation information
|
||||
if DEPRECATION_KEY in cfg_dict:
|
||||
deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
|
||||
warning_msg = (
|
||||
f"The config file {filename} will be deprecated " "in the future."
|
||||
)
|
||||
if "expected" in deprecation_info:
|
||||
warning_msg += f' Please use {deprecation_info["expected"]} ' "instead."
|
||||
if "reference" in deprecation_info:
|
||||
warning_msg += (
|
||||
" More information can be found at "
|
||||
f'{deprecation_info["reference"]}'
|
||||
)
|
||||
warnings.warn(warning_msg)
|
||||
|
||||
cfg_text = filename + "\n"
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
cfg_text += f.read()
|
||||
|
||||
if BASE_KEY in cfg_dict:
|
||||
cfg_dir = osp.dirname(filename)
|
||||
base_filename = cfg_dict.pop(BASE_KEY)
|
||||
base_filename = (
|
||||
base_filename if isinstance(base_filename, list) else [base_filename]
|
||||
)
|
||||
|
||||
cfg_dict_list = list()
|
||||
cfg_text_list = list()
|
||||
for f in base_filename:
|
||||
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
|
||||
cfg_dict_list.append(_cfg_dict)
|
||||
cfg_text_list.append(_cfg_text)
|
||||
|
||||
base_cfg_dict = dict()
|
||||
for c in cfg_dict_list:
|
||||
duplicate_keys = base_cfg_dict.keys() & c.keys()
|
||||
if len(duplicate_keys) > 0:
|
||||
raise KeyError(
|
||||
"Duplicate key is not allowed among bases. "
|
||||
f"Duplicate keys: {duplicate_keys}"
|
||||
)
|
||||
base_cfg_dict.update(c)
|
||||
|
||||
# Substitute base variables from strings to their actual values
|
||||
cfg_dict = Config._substitute_base_vars(
|
||||
cfg_dict, base_var_dict, base_cfg_dict
|
||||
)
|
||||
|
||||
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
|
||||
cfg_dict = base_cfg_dict
|
||||
|
||||
# merge cfg_text
|
||||
cfg_text_list.append(cfg_text)
|
||||
cfg_text = "\n".join(cfg_text_list)
|
||||
|
||||
return cfg_dict, cfg_text
|
||||
|
||||
@staticmethod
|
||||
def _merge_a_into_b(a, b, allow_list_keys=False):
|
||||
"""merge dict ``a`` into dict ``b`` (non-inplace).
|
||||
|
||||
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
|
||||
in-place modifications.
|
||||
|
||||
Args:
|
||||
a (dict): The source dict to be merged into ``b``.
|
||||
b (dict): The origin dict to be fetch keys from ``a``.
|
||||
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
||||
are allowed in source ``a`` and will replace the element of the
|
||||
corresponding index in b if b is a list. Default: False.
|
||||
|
||||
Returns:
|
||||
dict: The modified dict of ``b`` using ``a``.
|
||||
|
||||
Examples:
|
||||
# Normally merge a into b.
|
||||
>>> Config._merge_a_into_b(
|
||||
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
|
||||
{'obj': {'a': 2}}
|
||||
|
||||
# Delete b first and merge a into b.
|
||||
>>> Config._merge_a_into_b(
|
||||
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
|
||||
{'obj': {'a': 2}}
|
||||
|
||||
# b is a list
|
||||
>>> Config._merge_a_into_b(
|
||||
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
|
||||
[{'a': 2}, {'b': 2}]
|
||||
"""
|
||||
b = b.copy()
|
||||
for k, v in a.items():
|
||||
if allow_list_keys and k.isdigit() and isinstance(b, list):
|
||||
k = int(k)
|
||||
if len(b) <= k:
|
||||
raise KeyError(f"Index {k} exceeds the length of list {b}")
|
||||
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
||||
elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
|
||||
allowed_types = (dict, list) if allow_list_keys else dict
|
||||
if not isinstance(b[k], allowed_types):
|
||||
raise TypeError(
|
||||
f"{k}={v} in child config cannot inherit from base "
|
||||
f"because {k} is a dict in the child config but is of "
|
||||
f"type {type(b[k])} in base config. You may set "
|
||||
f"`{DELETE_KEY}=True` to ignore the base config"
|
||||
)
|
||||
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
||||
else:
|
||||
b[k] = v
|
||||
return b
|
||||
|
||||
@staticmethod
|
||||
def fromfile(filename, use_predefined_variables=True, import_custom_modules=True):
|
||||
cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
|
||||
if import_custom_modules and cfg_dict.get("custom_imports", None):
|
||||
import_modules_from_strings(**cfg_dict["custom_imports"])
|
||||
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def fromstring(cfg_str, file_format):
|
||||
"""Generate config from config str.
|
||||
|
||||
Args:
|
||||
cfg_str (str): Config str.
|
||||
file_format (str): Config file format corresponding to the
|
||||
config str. Only py/yml/yaml/json type are supported now!
|
||||
|
||||
Returns:
|
||||
obj:`Config`: Config obj.
|
||||
"""
|
||||
if file_format not in [".py", ".json", ".yaml", ".yml"]:
|
||||
raise IOError("Only py/yml/yaml/json type are supported now!")
|
||||
if file_format != ".py" and "dict(" in cfg_str:
|
||||
# check if users specify a wrong suffix for python
|
||||
warnings.warn('Please check "file_format", the file format may be .py')
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w", encoding="utf-8", suffix=file_format, delete=False
|
||||
) as temp_file:
|
||||
temp_file.write(cfg_str)
|
||||
# on windows, previous implementation cause error
|
||||
# see PR 1077 for details
|
||||
cfg = Config.fromfile(temp_file.name)
|
||||
os.remove(temp_file.name)
|
||||
return cfg
|
||||
|
||||
@staticmethod
|
||||
def auto_argparser(description=None):
|
||||
"""Generate argparser from config file automatically (experimental)"""
|
||||
partial_parser = ArgumentParser(description=description)
|
||||
partial_parser.add_argument("config", help="config file path")
|
||||
cfg_file = partial_parser.parse_known_args()[0].config
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
parser = ArgumentParser(description=description)
|
||||
parser.add_argument("config", help="config file path")
|
||||
add_args(parser, cfg)
|
||||
return parser, cfg
|
||||
|
||||
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
||||
if cfg_dict is None:
|
||||
cfg_dict = dict()
|
||||
elif not isinstance(cfg_dict, dict):
|
||||
raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
|
||||
for key in cfg_dict:
|
||||
if key in RESERVED_KEYS:
|
||||
raise KeyError(f"{key} is reserved for config file")
|
||||
|
||||
super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
|
||||
super(Config, self).__setattr__("_filename", filename)
|
||||
if cfg_text:
|
||||
text = cfg_text
|
||||
elif filename:
|
||||
with open(filename, "r") as f:
|
||||
text = f.read()
|
||||
else:
|
||||
text = ""
|
||||
super(Config, self).__setattr__("_text", text)
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._filename
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return self._text
|
||||
|
||||
@property
|
||||
def pretty_text(self):
|
||||
indent = 4
|
||||
|
||||
def _indent(s_, num_spaces):
|
||||
s = s_.split("\n")
|
||||
if len(s) == 1:
|
||||
return s_
|
||||
first = s.pop(0)
|
||||
s = [(num_spaces * " ") + line for line in s]
|
||||
s = "\n".join(s)
|
||||
s = first + "\n" + s
|
||||
return s
|
||||
|
||||
def _format_basic_types(k, v, use_mapping=False):
|
||||
if isinstance(v, str):
|
||||
v_str = f"'{v}'"
|
||||
else:
|
||||
v_str = str(v)
|
||||
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f"{k_str}: {v_str}"
|
||||
else:
|
||||
attr_str = f"{str(k)}={v_str}"
|
||||
attr_str = _indent(attr_str, indent)
|
||||
|
||||
return attr_str
|
||||
|
||||
def _format_list(k, v, use_mapping=False):
|
||||
# check if all items in the list are dict
|
||||
if all(isinstance(_, dict) for _ in v):
|
||||
v_str = "[\n"
|
||||
v_str += "\n".join(
|
||||
f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
|
||||
).rstrip(",")
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f"{k_str}: {v_str}"
|
||||
else:
|
||||
attr_str = f"{str(k)}={v_str}"
|
||||
attr_str = _indent(attr_str, indent) + "]"
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v, use_mapping)
|
||||
return attr_str
|
||||
|
||||
def _contain_invalid_identifier(dict_str):
|
||||
contain_invalid_identifier = False
|
||||
for key_name in dict_str:
|
||||
contain_invalid_identifier |= not str(key_name).isidentifier()
|
||||
return contain_invalid_identifier
|
||||
|
||||
def _format_dict(input_dict, outest_level=False):
|
||||
r = ""
|
||||
s = []
|
||||
|
||||
use_mapping = _contain_invalid_identifier(input_dict)
|
||||
if use_mapping:
|
||||
r += "{"
|
||||
for idx, (k, v) in enumerate(input_dict.items()):
|
||||
is_last = idx >= len(input_dict) - 1
|
||||
end = "" if outest_level or is_last else ","
|
||||
if isinstance(v, dict):
|
||||
v_str = "\n" + _format_dict(v)
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f"{k_str}: dict({v_str}"
|
||||
else:
|
||||
attr_str = f"{str(k)}=dict({v_str}"
|
||||
attr_str = _indent(attr_str, indent) + ")" + end
|
||||
elif isinstance(v, list):
|
||||
attr_str = _format_list(k, v, use_mapping) + end
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v, use_mapping) + end
|
||||
|
||||
s.append(attr_str)
|
||||
r += "\n".join(s)
|
||||
if use_mapping:
|
||||
r += "}"
|
||||
return r
|
||||
|
||||
cfg_dict = self._cfg_dict.to_dict()
|
||||
text = _format_dict(cfg_dict, outest_level=True)
|
||||
# copied from setup.cfg
|
||||
yapf_style = dict(
|
||||
based_on_style="pep8",
|
||||
blank_line_before_nested_class_or_def=True,
|
||||
split_before_expression_after_opening_paren=True,
|
||||
)
|
||||
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
|
||||
|
||||
return text
|
||||
|
||||
def __repr__(self):
|
||||
return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cfg_dict)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._cfg_dict, name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self._cfg_dict.__getitem__(name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if isinstance(value, dict):
|
||||
value = ConfigDict(value)
|
||||
self._cfg_dict.__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
if isinstance(value, dict):
|
||||
value = ConfigDict(value)
|
||||
self._cfg_dict.__setitem__(name, value)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._cfg_dict)
|
||||
|
||||
def __getstate__(self):
|
||||
return (self._cfg_dict, self._filename, self._text)
|
||||
|
||||
def __setstate__(self, state):
|
||||
_cfg_dict, _filename, _text = state
|
||||
super(Config, self).__setattr__("_cfg_dict", _cfg_dict)
|
||||
super(Config, self).__setattr__("_filename", _filename)
|
||||
super(Config, self).__setattr__("_text", _text)
|
||||
|
||||
def dump(self, file=None):
|
||||
cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict()
|
||||
if self.filename.endswith(".py"):
|
||||
if file is None:
|
||||
return self.pretty_text
|
||||
else:
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
f.write(self.pretty_text)
|
||||
else:
|
||||
import mmcv
|
||||
|
||||
if file is None:
|
||||
file_format = self.filename.split(".")[-1]
|
||||
return mmcv.dump(cfg_dict, file_format=file_format)
|
||||
else:
|
||||
mmcv.dump(cfg_dict, file)
|
||||
|
||||
def merge_from_dict(self, options, allow_list_keys=True):
|
||||
"""Merge list into cfg_dict.
|
||||
|
||||
Merge the dict parsed by MultipleKVAction into this cfg.
|
||||
|
||||
Examples:
|
||||
>>> options = {'models.backbone.depth': 50,
|
||||
... 'models.backbone.with_cp':True}
|
||||
>>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet'))))
|
||||
>>> cfg.merge_from_dict(options)
|
||||
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
>>> assert cfg_dict == dict(
|
||||
... models=dict(backbone=dict(depth=50, with_cp=True)))
|
||||
|
||||
# Merge list element
|
||||
>>> cfg = Config(dict(pipeline=[
|
||||
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
|
||||
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
|
||||
>>> cfg.merge_from_dict(options, allow_list_keys=True)
|
||||
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
>>> assert cfg_dict == dict(pipeline=[
|
||||
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
|
||||
|
||||
Args:
|
||||
options (dict): dict of configs to merge from.
|
||||
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
||||
are allowed in ``options`` and will replace the element of the
|
||||
corresponding index in the config if the config is a list.
|
||||
Default: True.
|
||||
"""
|
||||
option_cfg_dict = {}
|
||||
for full_key, v in options.items():
|
||||
d = option_cfg_dict
|
||||
key_list = full_key.split(".")
|
||||
for subkey in key_list[:-1]:
|
||||
d.setdefault(subkey, ConfigDict())
|
||||
d = d[subkey]
|
||||
subkey = key_list[-1]
|
||||
d[subkey] = v
|
||||
|
||||
cfg_dict = super(Config, self).__getattribute__("_cfg_dict")
|
||||
super(Config, self).__setattr__(
|
||||
"_cfg_dict",
|
||||
Config._merge_a_into_b(
|
||||
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DictAction(Action):
|
||||
"""
|
||||
argparse action to split an argument into KEY=VALUE form
|
||||
on the first = and append to a dictionary. List options can
|
||||
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
|
||||
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
|
||||
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_int_float_bool(val):
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
pass
|
||||
if val.lower() in ["true", "false"]:
|
||||
return True if val.lower() == "true" else False
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def _parse_iterable(val):
|
||||
"""Parse iterable values in the string.
|
||||
|
||||
All elements inside '()' or '[]' are treated as iterable values.
|
||||
|
||||
Args:
|
||||
val (str): Value string.
|
||||
|
||||
Returns:
|
||||
list | tuple: The expanded list or tuple from the string.
|
||||
|
||||
Examples:
|
||||
>>> DictAction._parse_iterable('1,2,3')
|
||||
[1, 2, 3]
|
||||
>>> DictAction._parse_iterable('[a, b, c]')
|
||||
['a', 'b', 'c']
|
||||
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
|
||||
[(1, 2, 3), ['a', 'b'], 'c']
|
||||
"""
|
||||
|
||||
def find_next_comma(string):
|
||||
"""Find the position of next comma in the string.
|
||||
|
||||
If no ',' is found in the string, return the string length. All
|
||||
chars inside '()' and '[]' are treated as one element and thus ','
|
||||
inside these brackets are ignored.
|
||||
"""
|
||||
assert (string.count("(") == string.count(")")) and (
|
||||
string.count("[") == string.count("]")
|
||||
), f"Imbalanced brackets exist in {string}"
|
||||
end = len(string)
|
||||
for idx, char in enumerate(string):
|
||||
pre = string[:idx]
|
||||
# The string before this ',' is balanced
|
||||
if (
|
||||
(char == ",")
|
||||
and (pre.count("(") == pre.count(")"))
|
||||
and (pre.count("[") == pre.count("]"))
|
||||
):
|
||||
end = idx
|
||||
break
|
||||
return end
|
||||
|
||||
# Strip ' and " characters and replace whitespace.
|
||||
val = val.strip("'\"").replace(" ", "")
|
||||
is_tuple = False
|
||||
if val.startswith("(") and val.endswith(")"):
|
||||
is_tuple = True
|
||||
val = val[1:-1]
|
||||
elif val.startswith("[") and val.endswith("]"):
|
||||
val = val[1:-1]
|
||||
elif "," not in val:
|
||||
# val is a single value
|
||||
return DictAction._parse_int_float_bool(val)
|
||||
|
||||
values = []
|
||||
while len(val) > 0:
|
||||
comma_idx = find_next_comma(val)
|
||||
element = DictAction._parse_iterable(val[:comma_idx])
|
||||
values.append(element)
|
||||
val = val[comma_idx + 1 :]
|
||||
if is_tuple:
|
||||
values = tuple(values)
|
||||
return values
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
options = {}
|
||||
for kv in values:
|
||||
key, val = kv.split("=", maxsplit=1)
|
||||
options[key] = self._parse_iterable(val)
|
||||
setattr(namespace, self.dest, options)
|
||||
Reference in New Issue
Block a user