feat: Initial commit

This commit is contained in:
fdyuandong
2025-04-17 23:14:24 +08:00
commit ca93dd0572
51 changed files with 7904 additions and 0 deletions

18
.gitignore vendored Normal file
View File

@@ -0,0 +1,18 @@
image/
__pycache__
**/build/
**/*.egg-info/
**/dist/
*.so
exp
weights
data
log
outputs/
.vscode
.idea
*/.DS_Store
TEMP/
pretrained/
**/*.out
Dockerfile

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

103
README.md Normal file
View File

@@ -0,0 +1,103 @@
# LAM-A2E: Audio to Expression
[![Website](https://raw.githubusercontent.com/prs-eth/Marigold/main/doc/badges/badge-website.svg)](https://aigc3d.github.io/projects/LAM/)
[![Apache License](https://img.shields.io/badge/📃-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0)
#### This project leverages audio input to generate ARKit blendshapes-driven facial expressions in ⚡real-time⚡, powering ultra-realistic 3D avatars generated by [LAM](https://github.com/aigc3d/LAM).
## Demo
<div align="center">
<video controls src="https://github.com/user-attachments/assets/30ccbe82-7933-4031-8578-b5248435d317">
</video>
</div>
## 📢 News
### To do list
- [ ] Release Huggingface space.
- [ ] Release Modelscope space.
- [ ] Release the LAM-A2E model based on the Flame expression.
- [ ] Release Interactive Chatting Avatar SDK with [OpenAvatarChat](https://github.com/HumanAIGC-Engineering/OpenAvatarChat), including LLM, ASR, TTS, LAM-Avatars.
## 🚀 Get Started
### Environment Setup
```bash
git clone git@github.com:aigc3d/LAM_Audio2Expression.git
cd LAM_Audio2Expression
# Install with Cuda 12.1
sh ./scripts/install/install_cu121.sh
# Or Install with Cuda 11.8
sh ./scripts/install/install_cu118.sh
```
### Download
```
# HuggingFace download
# Download Assets and Model Weights
huggingface-cli download 3DAIGC/LAM_audio2exp --local-dir ./
tar -xzvf LAM_audio2exp_assets.tar && rm -f LAM_audio2exp_assets.tar
tar -xzvf LAM_audio2exp_streaming.tar && rm -f LAM_audio2exp_streaming.tar
# Or OSS Download (In case of HuggingFace download failing)
# Download Assets
wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/LAM_audio2exp_assets.tar
tar -xzvf LAM_audio2exp_assets.tar && rm -f LAM_audio2exp_assets.tar
# Download Model Weights
wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/LAM_audio2exp_streaming.tar
tar -xzvf LAM_audio2exp_streaming.tar && rm -f LAM_audio2exp_streaming.tar
```
### Quick Start Guide
#### Using <a href="https://github.com/gradio-app/gradio">Gradio</a> Interface:
We provide a simple Gradio demo with **WebGLGL Render**, and you can get rendering results by uploading audio in seconds.
<img src="./assets/images/snapshot.png" alt="teaser" width="1000"/>
```
python app_lam_audio2exp.py
```
### Inference
```bash
# example: python inference.py --config-file configs/lam_audio2exp_config_streaming.py --options save_path=exp/audio2exp weight=pretrained_models/lam_audio2exp_streaming.tar audio_input=./assets/sample_audio/BarackObama_english.wav
python inference.py --config-file ${CONFIG_PATH} --options save_path=${SAVE_PATH} weight=${CHECKPOINT_PATH} audio_input=${AUDIO_INPUT}
```
### Acknowledgement
This work is built on many amazing research works and open-source projects:
- [FLAME](https://flame.is.tue.mpg.de)
- [FaceFormer](https://github.com/EvelynFan/FaceFormer)
- [Meshtalk](https://github.com/facebookresearch/meshtalk)
- [Unitalker](https://github.com/X-niper/UniTalker)
- [Pointcept](https://github.com/Pointcept/Pointcept)
Thanks for their excellent works and great contribution.
### Related Works
Welcome to follow our other interesting works:
- [LAM](https://github.com/aigc3d/LAM)
- [LHM](https://github.com/aigc3d/LHM)
### Citation
```
@inproceedings{he2025LAM,
title={LAM: Large Avatar Model for One-shot Animatable Gaussian Head},
author={
Yisheng He and Xiaodong Gu and Xiaodan Ye and Chao Xu and Zhengyi Zhao and Yuan Dong and Weihao Yuan and Zilong Dong and Liefeng Bo
},
booktitle={arXiv preprint arXiv:2502.17796},
year={2025}
}
```

271
app_lam_audio2exp.py Normal file
View File

@@ -0,0 +1,271 @@
"""
Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import base64
import gradio as gr
import argparse
from omegaconf import OmegaConf
from gradio_gaussian_render import gaussian_render
from engines.defaults import (
default_argument_parser,
default_config_parser,
default_setup,
)
from engines.infer import INFER
from pathlib import Path
try:
import spaces
except:
pass
h5_rendering = True
def assert_input_image(input_image):
if input_image is None:
raise gr.Error('No image selected or uploaded!')
def prepare_working_dir():
import tempfile
working_dir = tempfile.TemporaryDirectory()
return working_dir
def get_image_base64(path):
with open(path, 'rb') as image_file:
encoded_string = base64.b64encode(image_file.read()).decode()
return f'data:image/png;base64,{encoded_string}'
def doRender():
print('H5 rendering ....')
def parse_configs():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--infer", type=str)
args, unknown = parser.parse_known_args()
cfg = OmegaConf.create()
cli_cfg = OmegaConf.from_cli(unknown)
# parse from ENV
if os.environ.get("APP_INFER") is not None:
args.infer = os.environ.get("APP_INFER")
if os.environ.get("APP_MODEL_NAME") is not None:
cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
args.config = args.infer if args.config is None else args.config
if args.config is not None:
cfg_train = OmegaConf.load(args.config)
cfg.source_size = cfg_train.dataset.source_image_res
try:
cfg.src_head_size = cfg_train.dataset.src_head_size
except:
cfg.src_head_size = 112
cfg.render_size = cfg_train.dataset.render_image.high
_relative_path = os.path.join(
cfg_train.experiment.parent,
cfg_train.experiment.child,
os.path.basename(cli_cfg.model_name).split("_")[-1],
)
cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
cfg.image_dump = os.path.join("exps", "images", _relative_path)
cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
if args.infer is not None:
cfg_infer = OmegaConf.load(args.infer)
cfg.merge_with(cfg_infer)
cfg.setdefault(
"save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
)
cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
cfg.setdefault(
"video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
)
cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
cfg.motion_video_read_fps = 30
cfg.merge_with(cli_cfg)
cfg.setdefault("logger", "INFO")
assert cfg.model_name is not None, "model_name is required"
return cfg, cfg_train
def create_zip_archive(output_zip='assets/arkitWithBSData.zip', base_dir=""):
import os
if (os.path.exists(output_zip)):
os.remove(output_zip)
print(f"Reomve previous file: {output_zip}")
run_command = 'zip -r '+output_zip+' '+base_dir
os.system(run_command)
# check file
if(os.path.exists(output_zip)):
print(f"Archive created successfully: {output_zip}")
else:
raise ValueError(f"Archive created failed: {output_zip}")
def demo_lam_audio2exp(infer, cfg):
def core_fn(image_path: str, audio_params, working_dir):
base_id = os.path.basename(image_path).split(".")[0]
# set input audio
cfg.audio_input = audio_params
cfg.save_json_path = os.path.join("./assets/sample_lam", base_id, 'arkitWithBSData', 'bsData.json')
infer.infer()
create_zip_archive(output_zip='./assets/arkitWithBSData.zip', base_dir=os.path.join("./assets/sample_lam", base_id))
return
with gr.Blocks(analytics_enabled=False) as demo:
logo_url = './assets/images/logo.jpeg'
logo_base64 = get_image_base64(logo_url)
gr.HTML(f"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1> LAM-A2E: Audio to Expression</h1>
</div>
</div>
""")
gr.HTML(
"""<p><h4 style="color: blue;"> Notes: This project leverages audio input to generate ARKit blendshapes-driven facial expressions in ⚡real-time⚡, powering ultra-realistic 3D avatars generated by LAM.</h4></p>"""
)
# DISPLAY
with gr.Row():
with gr.Column(variant='panel', scale=1):
with gr.Tabs(elem_id='lam_input_image'):
with gr.TabItem('Input Image'):
with gr.Row():
input_image = gr.Image(label='Input Image',
image_mode='RGB',
height=480,
width=270,
sources='upload',
type='filepath', # 'numpy',
elem_id='content_image')
# EXAMPLES
with gr.Row():
examples = [
['assets/sample_input/barbara.jpg'],
['assets/sample_input/status.png'],
['assets/sample_input/james.png'],
['assets/sample_input/vfhq_case1.png'],
]
gr.Examples(
examples=examples,
inputs=[input_image],
examples_per_page=20,
)
with gr.Column():
with gr.Tabs(elem_id='lam_input_audio'):
with gr.TabItem('Input Audio'):
with gr.Row():
audio_input = gr.Audio(label='Input Audio',
type='filepath',
waveform_options={
'sample_rate': 16000,
'waveform_progress_color': '#4682b4'
},
elem_id='content_audio')
examples = [
['assets/sample_audio/Nangyanwen_chinese.wav'],
['assets/sample_audio/LiBai_TTS_chinese.wav'],
['assets/sample_audio/LinJing_TTS_chinese.wav'],
['assets/sample_audio/BarackObama_english.wav'],
['assets/sample_audio/HillaryClinton_english.wav'],
['assets/sample_audio/XitongShi_japanese.wav'],
['assets/sample_audio/FangXiao_japanese.wav'],
]
gr.Examples(
examples=examples,
inputs=[audio_input],
examples_per_page=10,
)
# SETTING
with gr.Row():
with gr.Column(variant='panel', scale=1):
submit = gr.Button('Generate',
elem_id='lam_generate',
variant='primary')
if h5_rendering:
gr.set_static_paths(Path.cwd().absolute() / "assets/")
assetPrefix = 'gradio_api/file=assets/'
with gr.Row():
gs = gaussian_render(width=380, height=680, assets=assetPrefix + 'arkitWithBSData.zip')
working_dir = gr.State()
submit.click(
fn=assert_input_image,
inputs=[input_image],
queue=False,
).success(
fn=prepare_working_dir,
outputs=[working_dir],
queue=False,
).success(
fn=core_fn,
inputs=[input_image, audio_input,
working_dir], # video_params refer to smpl dir
outputs=[],
queue=False,
).success(
doRender, js='''() => window.start()'''
)
demo.queue()
demo.launch()
def launch_gradio_app():
os.environ.update({
'APP_ENABLED': '1',
'APP_MODEL_NAME':'',
'APP_INFER': 'configs/lam_audio2exp_streaming_config.py',
'APP_TYPE': 'infer.audio2exp',
'NUMBA_THREADING_LAYER': 'omp',
})
args = default_argument_parser().parse_args()
args.config_file = 'configs/lam_audio2exp_config_streaming.py'
cfg = default_config_parser(args.config_file, args.options)
cfg = default_setup(cfg)
cfg.ex_vol = True
infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg))
demo_lam_audio2exp(infer, cfg)
if __name__ == '__main__':
launch_gradio_app()

BIN
assets/images/logo.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

BIN
assets/images/snapshot.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB

BIN
assets/images/teaser.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 654 KiB

View File

@@ -0,0 +1,92 @@
weight = 'pretrained_models/lam_audio2exp.tar' # path to model weight
ex_vol = True # Isolates vocal track from audio file
audio_input = './assets/sample_audio/BarackObama.wav'
save_json_path = 'bsData.json'
audio_sr = 16000
fps = 30.0
movement_smooth = True
brow_movement = True
id_idx = 153
resume = False # whether to resume training process
evaluate = True # evaluate after each epoch training process
test_only = False # test process
seed = None # train process will init a random seed and record
save_path = "exp/audio2exp"
num_worker = 16 # total worker in all gpu
batch_size = 16 # total batch size in all gpu
batch_size_val = None # auto adapt to bs 1 for each gpu
batch_size_test = None # auto adapt to bs 1 for each gpu
epoch = 100 # total epoch, data loop = epoch // eval_epoch
eval_epoch = 100 # sche total eval & checkpoint epoch
sync_bn = False
enable_amp = False
empty_cache = False
find_unused_parameters = False
mix_prob = 0
param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)]
# model settings
model = dict(
type="DefaultEstimator",
backbone=dict(
type="Audio2Expression",
pretrained_encoder_type='wav2vec',
pretrained_encoder_path='facebook/wav2vec2-base-960h',
wav2vec2_config_path = 'configs/wav2vec2_config.json',
num_identity_classes=5016,
identity_feat_dim=64,
hidden_dim=512,
expression_dim=52,
norm_type='ln',
use_transformer=True,
num_attention_heads=8,
num_transformer_layers=6,
),
criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)],
)
dataset_type = 'audio2exp'
data_root = './'
data = dict(
train=dict(
type=dataset_type,
split="train",
data_root=data_root,
test_mode=False,
),
val=dict(
type=dataset_type,
split="val",
data_root=data_root,
test_mode=False,
),
test=dict(
type=dataset_type,
split="val",
data_root=data_root,
test_mode=True
),
)
# hook
hooks = [
dict(type="CheckpointLoader"),
dict(type="IterationTimer", warmup_iter=2),
dict(type="InformationWriter"),
dict(type="SemSegEvaluator"),
dict(type="CheckpointSaver", save_freq=None),
dict(type="PreciseEvaluator", test_last=False),
]
# Trainer
train = dict(type="DefaultTrainer")
# Tester
infer = dict(type="Audio2ExpressionInfer",
verbose=True)

View File

@@ -0,0 +1,92 @@
weight = 'pretrained_models/lam_audio2exp_streaming.tar' # path to model weight
ex_vol = True # extract
audio_input = './assets/sample_audio/BarackObama.wav'
save_json_path = 'bsData.json'
audio_sr = 16000
fps = 30.0
movement_smooth = False
brow_movement = False
id_idx = 0
resume = False # whether to resume training process
evaluate = True # evaluate after each epoch training process
test_only = False # test process
seed = None # train process will init a random seed and record
save_path = "exp/audio2exp"
num_worker = 16 # total worker in all gpu
batch_size = 16 # total batch size in all gpu
batch_size_val = None # auto adapt to bs 1 for each gpu
batch_size_test = None # auto adapt to bs 1 for each gpu
epoch = 100 # total epoch, data loop = epoch // eval_epoch
eval_epoch = 100 # sche total eval & checkpoint epoch
sync_bn = False
enable_amp = False
empty_cache = False
find_unused_parameters = False
mix_prob = 0
param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)]
# model settings
model = dict(
type="DefaultEstimator",
backbone=dict(
type="Audio2Expression",
pretrained_encoder_type='wav2vec',
pretrained_encoder_path='facebook/wav2vec2-base-960h',
wav2vec2_config_path = 'configs/wav2vec2_config.json',
num_identity_classes=12,
identity_feat_dim=64,
hidden_dim=512,
expression_dim=52,
norm_type='ln',
use_transformer=False,
num_attention_heads=8,
num_transformer_layers=6,
),
criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)],
)
dataset_type = 'audio2exp'
data_root = './'
data = dict(
train=dict(
type=dataset_type,
split="train",
data_root=data_root,
test_mode=False,
),
val=dict(
type=dataset_type,
split="val",
data_root=data_root,
test_mode=False,
),
test=dict(
type=dataset_type,
split="val",
data_root=data_root,
test_mode=True
),
)
# hook
hooks = [
dict(type="CheckpointLoader"),
dict(type="IterationTimer", warmup_iter=2),
dict(type="InformationWriter"),
dict(type="SemSegEvaluator"),
dict(type="CheckpointSaver", save_freq=None),
dict(type="PreciseEvaluator", test_last=False),
]
# Trainer
train = dict(type="DefaultTrainer")
# Tester
infer = dict(type="Audio2ExpressionInfer",
verbose=True)

View File

@@ -0,0 +1,77 @@
{
"_name_or_path": "facebook/wav2vec2-base-960h",
"activation_dropout": 0.1,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2ForCTC"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"codevector_dim": 256,
"contrastive_logits_temperature": 0.1,
"conv_bias": false,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"diversity_loss_weight": 0.1,
"do_stable_layer_norm": false,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "group",
"feat_proj_dropout": 0.1,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_prob": 0.05,
"model_type": "wav2vec2",
"num_attention_heads": 12,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 12,
"num_negatives": 100,
"pad_token_id": 0,
"proj_codevector_dim": 256,
"transformers_version": "4.7.0.dev0",
"vocab_size": 32
}

0
engines/__init__.py Normal file
View File

147
engines/defaults.py Normal file
View File

@@ -0,0 +1,147 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import sys
import argparse
import multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import utils.comm as comm
from utils.env import get_random_seed, set_seed
from utils.config import Config, DictAction
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
"""
Create a DistributedDataParallel model if there are >1 processes.
Args:
model: a torch.nn.Module
fp16_compression: add fp16 compression hooks to the ddp object.
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
"""
if comm.get_world_size() == 1:
return model
# kwargs['find_unused_parameters'] = True
if "device_ids" not in kwargs:
kwargs["device_ids"] = [comm.get_local_rank()]
if "output_device" not in kwargs:
kwargs["output_device"] = [comm.get_local_rank()]
ddp = DistributedDataParallel(model, **kwargs)
if fp16_compression:
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
return ddp
def worker_init_fn(worker_id, num_workers, rank, seed):
"""Worker init func for dataloader.
The seed of each worker equals to num_worker * rank + worker_id + user_seed
Args:
worker_id (int): Worker id.
num_workers (int): Number of workers.
rank (int): The rank of current process.
seed (int): The random seed to use.
"""
worker_seed = num_workers * rank + worker_id + seed
set_seed(worker_seed)
def default_argument_parser(epilog=None):
parser = argparse.ArgumentParser(
epilog=epilog
or f"""
Examples:
Run on single machine:
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
Change some config options:
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
Run on multiple machines:
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--config-file", default="", metavar="FILE", help="path to config file"
)
parser.add_argument(
"--num-gpus", type=int, default=1, help="number of gpus *per machine*"
)
parser.add_argument(
"--num-machines", type=int, default=1, help="total number of machines"
)
parser.add_argument(
"--machine-rank",
type=int,
default=0,
help="the rank of this machine (unique per machine)",
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument(
"--dist-url",
# default="tcp://127.0.0.1:{}".format(port),
default="auto",
help="initialization URL for pytorch distributed backend. See "
"https://pytorch.org/docs/stable/distributed.html for details.",
)
parser.add_argument(
"--options", nargs="+", action=DictAction, help="custom options"
)
return parser
def default_config_parser(file_path, options):
# config name protocol: dataset_name/model_name-exp_name
if os.path.isfile(file_path):
cfg = Config.fromfile(file_path)
else:
sep = file_path.find("-")
cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :]))
if options is not None:
cfg.merge_from_dict(options)
if cfg.seed is None:
cfg.seed = get_random_seed()
cfg.data.train.loop = cfg.epoch // cfg.eval_epoch
os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True)
if not cfg.resume:
cfg.dump(os.path.join(cfg.save_path, "config.py"))
return cfg
def default_setup(cfg):
# scalar by world size
world_size = comm.get_world_size()
cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count()
cfg.num_worker_per_gpu = cfg.num_worker // world_size
assert cfg.batch_size % world_size == 0
assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0
assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0
cfg.batch_size_per_gpu = cfg.batch_size // world_size
cfg.batch_size_val_per_gpu = (
cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1
)
cfg.batch_size_test_per_gpu = (
cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1
)
# update data loop
assert cfg.epoch % cfg.eval_epoch == 0
# settle random seed
rank = comm.get_rank()
seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank
set_seed(seed)
return cfg

View File

@@ -0,0 +1,5 @@
from .default import HookBase
from .misc import *
from .evaluator import *
from .builder import build_hooks

15
engines/hooks/builder.py Normal file
View File

@@ -0,0 +1,15 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
from utils.registry import Registry
HOOKS = Registry("hooks")
def build_hooks(cfg):
hooks = []
for hook_cfg in cfg:
hooks.append(HOOKS.build(hook_cfg))
return hooks

29
engines/hooks/default.py Normal file
View File

@@ -0,0 +1,29 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
class HookBase:
"""
Base class for hooks that can be registered with :class:`TrainerBase`.
"""
trainer = None # A weak reference to the trainer object.
def before_train(self):
pass
def before_epoch(self):
pass
def before_step(self):
pass
def after_step(self):
pass
def after_epoch(self):
pass
def after_train(self):
pass

577
engines/hooks/evaluator.py Normal file
View File

@@ -0,0 +1,577 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import numpy as np
import torch
import torch.distributed as dist
from uuid import uuid4
import utils.comm as comm
from utils.misc import intersection_and_union_gpu
from .default import HookBase
from .builder import HOOKS
@HOOKS.register_module()
class ClsEvaluator(HookBase):
def after_epoch(self):
if self.trainer.cfg.evaluate:
self.eval()
def eval(self):
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
self.trainer.model.eval()
for i, input_dict in enumerate(self.trainer.val_loader):
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.no_grad():
output_dict = self.trainer.model(input_dict)
output = output_dict["cls_logits"]
loss = output_dict["loss"]
pred = output.max(1)[1]
label = input_dict["category"]
intersection, union, target = intersection_and_union_gpu(
pred,
label,
self.trainer.cfg.data.num_classes,
self.trainer.cfg.data.ignore_index,
)
if comm.get_world_size() > 1:
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
target
)
intersection, union, target = (
intersection.cpu().numpy(),
union.cpu().numpy(),
target.cpu().numpy(),
)
# Here there is no need to sync since sync happened in dist.all_reduce
self.trainer.storage.put_scalar("val_intersection", intersection)
self.trainer.storage.put_scalar("val_union", union)
self.trainer.storage.put_scalar("val_target", target)
self.trainer.storage.put_scalar("val_loss", loss.item())
self.trainer.logger.info(
"Test: [{iter}/{max_iter}] "
"Loss {loss:.4f} ".format(
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
)
)
loss_avg = self.trainer.storage.history("val_loss").avg
intersection = self.trainer.storage.history("val_intersection").total
union = self.trainer.storage.history("val_union").total
target = self.trainer.storage.history("val_target").total
iou_class = intersection / (union + 1e-10)
acc_class = intersection / (target + 1e-10)
m_iou = np.mean(iou_class)
m_acc = np.mean(acc_class)
all_acc = sum(intersection) / (sum(target) + 1e-10)
self.trainer.logger.info(
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
m_iou, m_acc, all_acc
)
)
for i in range(self.trainer.cfg.data.num_classes):
self.trainer.logger.info(
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
idx=i,
name=self.trainer.cfg.data.names[i],
iou=iou_class[i],
accuracy=acc_class[i],
)
)
current_epoch = self.trainer.epoch + 1
if self.trainer.writer is not None:
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
self.trainer.comm_info["current_metric_value"] = all_acc # save for saver
self.trainer.comm_info["current_metric_name"] = "allAcc" # save for saver
def after_train(self):
self.trainer.logger.info(
"Best {}: {:.4f}".format("allAcc", self.trainer.best_metric_value)
)
@HOOKS.register_module()
class SemSegEvaluator(HookBase):
def after_epoch(self):
if self.trainer.cfg.evaluate:
self.eval()
def eval(self):
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
self.trainer.model.eval()
for i, input_dict in enumerate(self.trainer.val_loader):
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.no_grad():
output_dict = self.trainer.model(input_dict)
output = output_dict["seg_logits"]
loss = output_dict["loss"]
pred = output.max(1)[1]
segment = input_dict["segment"]
if "origin_coord" in input_dict.keys():
idx, _ = pointops.knn_query(
1,
input_dict["coord"].float(),
input_dict["offset"].int(),
input_dict["origin_coord"].float(),
input_dict["origin_offset"].int(),
)
pred = pred[idx.flatten().long()]
segment = input_dict["origin_segment"]
intersection, union, target = intersection_and_union_gpu(
pred,
segment,
self.trainer.cfg.data.num_classes,
self.trainer.cfg.data.ignore_index,
)
if comm.get_world_size() > 1:
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
target
)
intersection, union, target = (
intersection.cpu().numpy(),
union.cpu().numpy(),
target.cpu().numpy(),
)
# Here there is no need to sync since sync happened in dist.all_reduce
self.trainer.storage.put_scalar("val_intersection", intersection)
self.trainer.storage.put_scalar("val_union", union)
self.trainer.storage.put_scalar("val_target", target)
self.trainer.storage.put_scalar("val_loss", loss.item())
info = "Test: [{iter}/{max_iter}] ".format(
iter=i + 1, max_iter=len(self.trainer.val_loader)
)
if "origin_coord" in input_dict.keys():
info = "Interp. " + info
self.trainer.logger.info(
info
+ "Loss {loss:.4f} ".format(
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
)
)
loss_avg = self.trainer.storage.history("val_loss").avg
intersection = self.trainer.storage.history("val_intersection").total
union = self.trainer.storage.history("val_union").total
target = self.trainer.storage.history("val_target").total
iou_class = intersection / (union + 1e-10)
acc_class = intersection / (target + 1e-10)
m_iou = np.mean(iou_class)
m_acc = np.mean(acc_class)
all_acc = sum(intersection) / (sum(target) + 1e-10)
self.trainer.logger.info(
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
m_iou, m_acc, all_acc
)
)
for i in range(self.trainer.cfg.data.num_classes):
self.trainer.logger.info(
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
idx=i,
name=self.trainer.cfg.data.names[i],
iou=iou_class[i],
accuracy=acc_class[i],
)
)
current_epoch = self.trainer.epoch + 1
if self.trainer.writer is not None:
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
self.trainer.comm_info["current_metric_value"] = m_iou # save for saver
self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver
def after_train(self):
self.trainer.logger.info(
"Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value)
)
@HOOKS.register_module()
class InsSegEvaluator(HookBase):
def __init__(self, segment_ignore_index=(-1,), instance_ignore_index=-1):
self.segment_ignore_index = segment_ignore_index
self.instance_ignore_index = instance_ignore_index
self.valid_class_names = None # update in before train
self.overlaps = np.append(np.arange(0.5, 0.95, 0.05), 0.25)
self.min_region_sizes = 100
self.distance_threshes = float("inf")
self.distance_confs = -float("inf")
def before_train(self):
self.valid_class_names = [
self.trainer.cfg.data.names[i]
for i in range(self.trainer.cfg.data.num_classes)
if i not in self.segment_ignore_index
]
def after_epoch(self):
if self.trainer.cfg.evaluate:
self.eval()
def associate_instances(self, pred, segment, instance):
segment = segment.cpu().numpy()
instance = instance.cpu().numpy()
void_mask = np.in1d(segment, self.segment_ignore_index)
assert (
pred["pred_classes"].shape[0]
== pred["pred_scores"].shape[0]
== pred["pred_masks"].shape[0]
)
assert pred["pred_masks"].shape[1] == segment.shape[0] == instance.shape[0]
# get gt instances
gt_instances = dict()
for i in range(self.trainer.cfg.data.num_classes):
if i not in self.segment_ignore_index:
gt_instances[self.trainer.cfg.data.names[i]] = []
instance_ids, idx, counts = np.unique(
instance, return_index=True, return_counts=True
)
segment_ids = segment[idx]
for i in range(len(instance_ids)):
if instance_ids[i] == self.instance_ignore_index:
continue
if segment_ids[i] in self.segment_ignore_index:
continue
gt_inst = dict()
gt_inst["instance_id"] = instance_ids[i]
gt_inst["segment_id"] = segment_ids[i]
gt_inst["dist_conf"] = 0.0
gt_inst["med_dist"] = -1.0
gt_inst["vert_count"] = counts[i]
gt_inst["matched_pred"] = []
gt_instances[self.trainer.cfg.data.names[segment_ids[i]]].append(gt_inst)
# get pred instances and associate with gt
pred_instances = dict()
for i in range(self.trainer.cfg.data.num_classes):
if i not in self.segment_ignore_index:
pred_instances[self.trainer.cfg.data.names[i]] = []
instance_id = 0
for i in range(len(pred["pred_classes"])):
if pred["pred_classes"][i] in self.segment_ignore_index:
continue
pred_inst = dict()
pred_inst["uuid"] = uuid4()
pred_inst["instance_id"] = instance_id
pred_inst["segment_id"] = pred["pred_classes"][i]
pred_inst["confidence"] = pred["pred_scores"][i]
pred_inst["mask"] = np.not_equal(pred["pred_masks"][i], 0)
pred_inst["vert_count"] = np.count_nonzero(pred_inst["mask"])
pred_inst["void_intersection"] = np.count_nonzero(
np.logical_and(void_mask, pred_inst["mask"])
)
if pred_inst["vert_count"] < self.min_region_sizes:
continue # skip if empty
segment_name = self.trainer.cfg.data.names[pred_inst["segment_id"]]
matched_gt = []
for gt_idx, gt_inst in enumerate(gt_instances[segment_name]):
intersection = np.count_nonzero(
np.logical_and(
instance == gt_inst["instance_id"], pred_inst["mask"]
)
)
if intersection > 0:
gt_inst_ = gt_inst.copy()
pred_inst_ = pred_inst.copy()
gt_inst_["intersection"] = intersection
pred_inst_["intersection"] = intersection
matched_gt.append(gt_inst_)
gt_inst["matched_pred"].append(pred_inst_)
pred_inst["matched_gt"] = matched_gt
pred_instances[segment_name].append(pred_inst)
instance_id += 1
return gt_instances, pred_instances
def evaluate_matches(self, scenes):
overlaps = self.overlaps
min_region_sizes = [self.min_region_sizes]
dist_threshes = [self.distance_threshes]
dist_confs = [self.distance_confs]
# results: class x overlap
ap_table = np.zeros(
(len(dist_threshes), len(self.valid_class_names), len(overlaps)), float
)
for di, (min_region_size, distance_thresh, distance_conf) in enumerate(
zip(min_region_sizes, dist_threshes, dist_confs)
):
for oi, overlap_th in enumerate(overlaps):
pred_visited = {}
for scene in scenes:
for _ in scene["pred"]:
for label_name in self.valid_class_names:
for p in scene["pred"][label_name]:
if "uuid" in p:
pred_visited[p["uuid"]] = False
for li, label_name in enumerate(self.valid_class_names):
y_true = np.empty(0)
y_score = np.empty(0)
hard_false_negatives = 0
has_gt = False
has_pred = False
for scene in scenes:
pred_instances = scene["pred"][label_name]
gt_instances = scene["gt"][label_name]
# filter groups in ground truth
gt_instances = [
gt
for gt in gt_instances
if gt["vert_count"] >= min_region_size
and gt["med_dist"] <= distance_thresh
and gt["dist_conf"] >= distance_conf
]
if gt_instances:
has_gt = True
if pred_instances:
has_pred = True
cur_true = np.ones(len(gt_instances))
cur_score = np.ones(len(gt_instances)) * (-float("inf"))
cur_match = np.zeros(len(gt_instances), dtype=bool)
# collect matches
for gti, gt in enumerate(gt_instances):
found_match = False
for pred in gt["matched_pred"]:
# greedy assignments
if pred_visited[pred["uuid"]]:
continue
overlap = float(pred["intersection"]) / (
gt["vert_count"]
+ pred["vert_count"]
- pred["intersection"]
)
if overlap > overlap_th:
confidence = pred["confidence"]
# if already have a prediction for this gt,
# the prediction with the lower score is automatically a false positive
if cur_match[gti]:
max_score = max(cur_score[gti], confidence)
min_score = min(cur_score[gti], confidence)
cur_score[gti] = max_score
# append false positive
cur_true = np.append(cur_true, 0)
cur_score = np.append(cur_score, min_score)
cur_match = np.append(cur_match, True)
# otherwise set score
else:
found_match = True
cur_match[gti] = True
cur_score[gti] = confidence
pred_visited[pred["uuid"]] = True
if not found_match:
hard_false_negatives += 1
# remove non-matched ground truth instances
cur_true = cur_true[cur_match]
cur_score = cur_score[cur_match]
# collect non-matched predictions as false positive
for pred in pred_instances:
found_gt = False
for gt in pred["matched_gt"]:
overlap = float(gt["intersection"]) / (
gt["vert_count"]
+ pred["vert_count"]
- gt["intersection"]
)
if overlap > overlap_th:
found_gt = True
break
if not found_gt:
num_ignore = pred["void_intersection"]
for gt in pred["matched_gt"]:
if gt["segment_id"] in self.segment_ignore_index:
num_ignore += gt["intersection"]
# small ground truth instances
if (
gt["vert_count"] < min_region_size
or gt["med_dist"] > distance_thresh
or gt["dist_conf"] < distance_conf
):
num_ignore += gt["intersection"]
proportion_ignore = (
float(num_ignore) / pred["vert_count"]
)
# if not ignored append false positive
if proportion_ignore <= overlap_th:
cur_true = np.append(cur_true, 0)
confidence = pred["confidence"]
cur_score = np.append(cur_score, confidence)
# append to overall results
y_true = np.append(y_true, cur_true)
y_score = np.append(y_score, cur_score)
# compute average precision
if has_gt and has_pred:
# compute precision recall curve first
# sorting and cumsum
score_arg_sort = np.argsort(y_score)
y_score_sorted = y_score[score_arg_sort]
y_true_sorted = y_true[score_arg_sort]
y_true_sorted_cumsum = np.cumsum(y_true_sorted)
# unique thresholds
(thresholds, unique_indices) = np.unique(
y_score_sorted, return_index=True
)
num_prec_recall = len(unique_indices) + 1
# prepare precision recall
num_examples = len(y_score_sorted)
# https://github.com/ScanNet/ScanNet/pull/26
# all predictions are non-matched but also all of them are ignored and not counted as FP
# y_true_sorted_cumsum is empty
# num_true_examples = y_true_sorted_cumsum[-1]
num_true_examples = (
y_true_sorted_cumsum[-1]
if len(y_true_sorted_cumsum) > 0
else 0
)
precision = np.zeros(num_prec_recall)
recall = np.zeros(num_prec_recall)
# deal with the first point
y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
# deal with remaining
for idx_res, idx_scores in enumerate(unique_indices):
cumsum = y_true_sorted_cumsum[idx_scores - 1]
tp = num_true_examples - cumsum
fp = num_examples - idx_scores - tp
fn = cumsum + hard_false_negatives
p = float(tp) / (tp + fp)
r = float(tp) / (tp + fn)
precision[idx_res] = p
recall[idx_res] = r
# first point in curve is artificial
precision[-1] = 1.0
recall[-1] = 0.0
# compute average of precision-recall curve
recall_for_conv = np.copy(recall)
recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
recall_for_conv = np.append(recall_for_conv, 0.0)
stepWidths = np.convolve(
recall_for_conv, [-0.5, 0, 0.5], "valid"
)
# integrate is now simply a dot product
ap_current = np.dot(precision, stepWidths)
elif has_gt:
ap_current = 0.0
else:
ap_current = float("nan")
ap_table[di, li, oi] = ap_current
d_inf = 0
o50 = np.where(np.isclose(self.overlaps, 0.5))
o25 = np.where(np.isclose(self.overlaps, 0.25))
oAllBut25 = np.where(np.logical_not(np.isclose(self.overlaps, 0.25)))
ap_scores = dict()
ap_scores["all_ap"] = np.nanmean(ap_table[d_inf, :, oAllBut25])
ap_scores["all_ap_50%"] = np.nanmean(ap_table[d_inf, :, o50])
ap_scores["all_ap_25%"] = np.nanmean(ap_table[d_inf, :, o25])
ap_scores["classes"] = {}
for li, label_name in enumerate(self.valid_class_names):
ap_scores["classes"][label_name] = {}
ap_scores["classes"][label_name]["ap"] = np.average(
ap_table[d_inf, li, oAllBut25]
)
ap_scores["classes"][label_name]["ap50%"] = np.average(
ap_table[d_inf, li, o50]
)
ap_scores["classes"][label_name]["ap25%"] = np.average(
ap_table[d_inf, li, o25]
)
return ap_scores
def eval(self):
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
self.trainer.model.eval()
scenes = []
for i, input_dict in enumerate(self.trainer.val_loader):
assert (
len(input_dict["offset"]) == 1
) # currently only support bs 1 for each GPU
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.no_grad():
output_dict = self.trainer.model(input_dict)
loss = output_dict["loss"]
segment = input_dict["segment"]
instance = input_dict["instance"]
# map to origin
if "origin_coord" in input_dict.keys():
idx, _ = pointops.knn_query(
1,
input_dict["coord"].float(),
input_dict["offset"].int(),
input_dict["origin_coord"].float(),
input_dict["origin_offset"].int(),
)
idx = idx.cpu().flatten().long()
output_dict["pred_masks"] = output_dict["pred_masks"][:, idx]
segment = input_dict["origin_segment"]
instance = input_dict["origin_instance"]
gt_instances, pred_instance = self.associate_instances(
output_dict, segment, instance
)
scenes.append(dict(gt=gt_instances, pred=pred_instance))
self.trainer.storage.put_scalar("val_loss", loss.item())
self.trainer.logger.info(
"Test: [{iter}/{max_iter}] "
"Loss {loss:.4f} ".format(
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
)
)
loss_avg = self.trainer.storage.history("val_loss").avg
comm.synchronize()
scenes_sync = comm.gather(scenes, dst=0)
scenes = [scene for scenes_ in scenes_sync for scene in scenes_]
ap_scores = self.evaluate_matches(scenes)
all_ap = ap_scores["all_ap"]
all_ap_50 = ap_scores["all_ap_50%"]
all_ap_25 = ap_scores["all_ap_25%"]
self.trainer.logger.info(
"Val result: mAP/AP50/AP25 {:.4f}/{:.4f}/{:.4f}.".format(
all_ap, all_ap_50, all_ap_25
)
)
for i, label_name in enumerate(self.valid_class_names):
ap = ap_scores["classes"][label_name]["ap"]
ap_50 = ap_scores["classes"][label_name]["ap50%"]
ap_25 = ap_scores["classes"][label_name]["ap25%"]
self.trainer.logger.info(
"Class_{idx}-{name} Result: AP/AP50/AP25 {AP:.4f}/{AP50:.4f}/{AP25:.4f}".format(
idx=i, name=label_name, AP=ap, AP50=ap_50, AP25=ap_25
)
)
current_epoch = self.trainer.epoch + 1
if self.trainer.writer is not None:
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
self.trainer.writer.add_scalar("val/mAP", all_ap, current_epoch)
self.trainer.writer.add_scalar("val/AP50", all_ap_50, current_epoch)
self.trainer.writer.add_scalar("val/AP25", all_ap_25, current_epoch)
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
self.trainer.comm_info["current_metric_value"] = all_ap_50 # save for saver
self.trainer.comm_info["current_metric_name"] = "AP50" # save for saver

460
engines/hooks/misc.py Normal file
View File

@@ -0,0 +1,460 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import sys
import glob
import os
import shutil
import time
import torch
import torch.utils.data
from collections import OrderedDict
if sys.version_info >= (3, 10):
from collections.abc import Sequence
else:
from collections import Sequence
from utils.timer import Timer
from utils.comm import is_main_process, synchronize, get_world_size
from utils.cache import shared_dict
import utils.comm as comm
from engines.test import TESTERS
from .default import HookBase
from .builder import HOOKS
@HOOKS.register_module()
class IterationTimer(HookBase):
def __init__(self, warmup_iter=1):
self._warmup_iter = warmup_iter
self._start_time = time.perf_counter()
self._iter_timer = Timer()
self._remain_iter = 0
def before_train(self):
self._start_time = time.perf_counter()
self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader)
def before_epoch(self):
self._iter_timer.reset()
def before_step(self):
data_time = self._iter_timer.seconds()
self.trainer.storage.put_scalar("data_time", data_time)
def after_step(self):
batch_time = self._iter_timer.seconds()
self._iter_timer.reset()
self.trainer.storage.put_scalar("batch_time", batch_time)
self._remain_iter -= 1
remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg
t_m, t_s = divmod(remain_time, 60)
t_h, t_m = divmod(t_m, 60)
remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s))
if "iter_info" in self.trainer.comm_info.keys():
info = (
"Data {data_time_val:.3f} ({data_time_avg:.3f}) "
"Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) "
"Remain {remain_time} ".format(
data_time_val=self.trainer.storage.history("data_time").val,
data_time_avg=self.trainer.storage.history("data_time").avg,
batch_time_val=self.trainer.storage.history("batch_time").val,
batch_time_avg=self.trainer.storage.history("batch_time").avg,
remain_time=remain_time,
)
)
self.trainer.comm_info["iter_info"] += info
if self.trainer.comm_info["iter"] <= self._warmup_iter:
self.trainer.storage.history("data_time").reset()
self.trainer.storage.history("batch_time").reset()
@HOOKS.register_module()
class InformationWriter(HookBase):
def __init__(self):
self.curr_iter = 0
self.model_output_keys = []
def before_train(self):
self.trainer.comm_info["iter_info"] = ""
self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader)
def before_step(self):
self.curr_iter += 1
# MSC pretrain do not have offset information. Comment the code for support MSC
# info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \
# "Scan {batch_size} ({points_num}) ".format(
# epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch,
# iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader),
# batch_size=len(self.trainer.comm_info["input_dict"]["offset"]),
# points_num=self.trainer.comm_info["input_dict"]["offset"][-1]
# )
info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format(
epoch=self.trainer.epoch + 1,
max_epoch=self.trainer.max_epoch,
iter=self.trainer.comm_info["iter"] + 1,
max_iter=len(self.trainer.train_loader),
)
self.trainer.comm_info["iter_info"] += info
def after_step(self):
if "model_output_dict" in self.trainer.comm_info.keys():
model_output_dict = self.trainer.comm_info["model_output_dict"]
self.model_output_keys = model_output_dict.keys()
for key in self.model_output_keys:
self.trainer.storage.put_scalar(key, model_output_dict[key].item())
for key in self.model_output_keys:
self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format(
key=key, value=self.trainer.storage.history(key).val
)
lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"]
self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr)
self.trainer.logger.info(self.trainer.comm_info["iter_info"])
self.trainer.comm_info["iter_info"] = "" # reset iter info
if self.trainer.writer is not None:
self.trainer.writer.add_scalar("lr", lr, self.curr_iter)
for key in self.model_output_keys:
self.trainer.writer.add_scalar(
"train_batch/" + key,
self.trainer.storage.history(key).val,
self.curr_iter,
)
def after_epoch(self):
epoch_info = "Train result: "
for key in self.model_output_keys:
epoch_info += "{key}: {value:.4f} ".format(
key=key, value=self.trainer.storage.history(key).avg
)
self.trainer.logger.info(epoch_info)
if self.trainer.writer is not None:
for key in self.model_output_keys:
self.trainer.writer.add_scalar(
"train/" + key,
self.trainer.storage.history(key).avg,
self.trainer.epoch + 1,
)
@HOOKS.register_module()
class CheckpointSaver(HookBase):
def __init__(self, save_freq=None):
self.save_freq = save_freq # None or int, None indicate only save model last
def after_epoch(self):
if is_main_process():
is_best = False
if self.trainer.cfg.evaluate:
current_metric_value = self.trainer.comm_info["current_metric_value"]
current_metric_name = self.trainer.comm_info["current_metric_name"]
if current_metric_value > self.trainer.best_metric_value:
self.trainer.best_metric_value = current_metric_value
is_best = True
self.trainer.logger.info(
"Best validation {} updated to: {:.4f}".format(
current_metric_name, current_metric_value
)
)
self.trainer.logger.info(
"Currently Best {}: {:.4f}".format(
current_metric_name, self.trainer.best_metric_value
)
)
filename = os.path.join(
self.trainer.cfg.save_path, "model", "model_last.pth"
)
self.trainer.logger.info("Saving checkpoint to: " + filename)
torch.save(
{
"epoch": self.trainer.epoch + 1,
"state_dict": self.trainer.model.state_dict(),
"optimizer": self.trainer.optimizer.state_dict(),
"scheduler": self.trainer.scheduler.state_dict(),
"scaler": self.trainer.scaler.state_dict()
if self.trainer.cfg.enable_amp
else None,
"best_metric_value": self.trainer.best_metric_value,
},
filename + ".tmp",
)
os.replace(filename + ".tmp", filename)
if is_best:
shutil.copyfile(
filename,
os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"),
)
if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0:
shutil.copyfile(
filename,
os.path.join(
self.trainer.cfg.save_path,
"model",
f"epoch_{self.trainer.epoch + 1}.pth",
),
)
@HOOKS.register_module()
class CheckpointLoader(HookBase):
def __init__(self, keywords="", replacement=None, strict=False):
self.keywords = keywords
self.replacement = replacement if replacement is not None else keywords
self.strict = strict
def before_train(self):
self.trainer.logger.info("=> Loading checkpoint & weight ...")
if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight):
self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}")
checkpoint = torch.load(
self.trainer.cfg.weight,
map_location=lambda storage, loc: storage.cuda(),
)
self.trainer.logger.info(
f"Loading layer weights with keyword: {self.keywords}, "
f"replace keyword with: {self.replacement}"
)
weight = OrderedDict()
for key, value in checkpoint["state_dict"].items():
if not key.startswith("module."):
if comm.get_world_size() > 1:
key = "module." + key # xxx.xxx -> module.xxx.xxx
# Now all keys contain "module." no matter DDP or not.
if self.keywords in key:
key = key.replace(self.keywords, self.replacement)
if comm.get_world_size() == 1:
key = key[7:] # module.xxx.xxx -> xxx.xxx
weight[key] = value
load_state_info = self.trainer.model.load_state_dict(
weight, strict=self.strict
)
self.trainer.logger.info(f"Missing keys: {load_state_info[0]}")
if self.trainer.cfg.resume:
self.trainer.logger.info(
f"Resuming train at eval epoch: {checkpoint['epoch']}"
)
self.trainer.start_epoch = checkpoint["epoch"]
self.trainer.best_metric_value = checkpoint["best_metric_value"]
self.trainer.optimizer.load_state_dict(checkpoint["optimizer"])
self.trainer.scheduler.load_state_dict(checkpoint["scheduler"])
if self.trainer.cfg.enable_amp:
self.trainer.scaler.load_state_dict(checkpoint["scaler"])
else:
self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}")
@HOOKS.register_module()
class PreciseEvaluator(HookBase):
def __init__(self, test_last=False):
self.test_last = test_last
def after_train(self):
self.trainer.logger.info(
">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>"
)
torch.cuda.empty_cache()
cfg = self.trainer.cfg
tester = TESTERS.build(
dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model)
)
if self.test_last:
self.trainer.logger.info("=> Testing on model_last ...")
else:
self.trainer.logger.info("=> Testing on model_best ...")
best_path = os.path.join(
self.trainer.cfg.save_path, "model", "model_best.pth"
)
checkpoint = torch.load(best_path)
state_dict = checkpoint["state_dict"]
tester.model.load_state_dict(state_dict, strict=True)
tester.test()
@HOOKS.register_module()
class DataCacheOperator(HookBase):
def __init__(self, data_root, split):
self.data_root = data_root
self.split = split
self.data_list = self.get_data_list()
def get_data_list(self):
if isinstance(self.split, str):
data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth"))
elif isinstance(self.split, Sequence):
data_list = []
for split in self.split:
data_list += glob.glob(os.path.join(self.data_root, split, "*.pth"))
else:
raise NotImplementedError
return data_list
def get_cache_name(self, data_path):
data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0]
return "pointcept" + data_name.replace(os.path.sep, "-")
def before_train(self):
self.trainer.logger.info(
f"=> Caching dataset: {self.data_root}, split: {self.split} ..."
)
if is_main_process():
for data_path in self.data_list:
cache_name = self.get_cache_name(data_path)
data = torch.load(data_path)
shared_dict(cache_name, data)
synchronize()
@HOOKS.register_module()
class RuntimeProfiler(HookBase):
def __init__(
self,
forward=True,
backward=True,
interrupt=False,
warm_up=2,
sort_by="cuda_time_total",
row_limit=30,
):
self.forward = forward
self.backward = backward
self.interrupt = interrupt
self.warm_up = warm_up
self.sort_by = sort_by
self.row_limit = row_limit
def before_train(self):
self.trainer.logger.info("Profiling runtime ...")
from torch.profiler import profile, record_function, ProfilerActivity
for i, input_dict in enumerate(self.trainer.train_loader):
if i == self.warm_up + 1:
break
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
if self.forward:
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) as forward_prof:
with record_function("model_inference"):
output_dict = self.trainer.model(input_dict)
else:
output_dict = self.trainer.model(input_dict)
loss = output_dict["loss"]
if self.backward:
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) as backward_prof:
with record_function("model_inference"):
loss.backward()
self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]")
if self.forward:
self.trainer.logger.info(
"Forward profile: \n"
+ str(
forward_prof.key_averages().table(
sort_by=self.sort_by, row_limit=self.row_limit
)
)
)
forward_prof.export_chrome_trace(
os.path.join(self.trainer.cfg.save_path, "forward_trace.json")
)
if self.backward:
self.trainer.logger.info(
"Backward profile: \n"
+ str(
backward_prof.key_averages().table(
sort_by=self.sort_by, row_limit=self.row_limit
)
)
)
backward_prof.export_chrome_trace(
os.path.join(self.trainer.cfg.save_path, "backward_trace.json")
)
if self.interrupt:
sys.exit(0)
@HOOKS.register_module()
class RuntimeProfilerV2(HookBase):
def __init__(
self,
interrupt=False,
wait=1,
warmup=1,
active=10,
repeat=1,
sort_by="cuda_time_total",
row_limit=30,
):
self.interrupt = interrupt
self.wait = wait
self.warmup = warmup
self.active = active
self.repeat = repeat
self.sort_by = sort_by
self.row_limit = row_limit
def before_train(self):
self.trainer.logger.info("Profiling runtime ...")
from torch.profiler import (
profile,
record_function,
ProfilerActivity,
schedule,
tensorboard_trace_handler,
)
prof = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(
wait=self.wait,
warmup=self.warmup,
active=self.active,
repeat=self.repeat,
),
on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()
for i, input_dict in enumerate(self.trainer.train_loader):
if i >= (self.wait + self.warmup + self.active) * self.repeat:
break
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with record_function("model_forward"):
output_dict = self.trainer.model(input_dict)
loss = output_dict["loss"]
with record_function("model_backward"):
loss.backward()
prof.step()
self.trainer.logger.info(
f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]"
)
self.trainer.logger.info(
"Profile: \n"
+ str(
prof.key_averages().table(
sort_by=self.sort_by, row_limit=self.row_limit
)
)
)
prof.stop()
if self.interrupt:
sys.exit(0)

285
engines/infer.py Normal file
View File

@@ -0,0 +1,285 @@
"""
Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import math
import time
import librosa
import numpy as np
from collections import OrderedDict
import torch
import torch.utils.data
import torch.nn.functional as F
from .defaults import create_ddp_model
import utils.comm as comm
from models import build_model
from utils.logger import get_root_logger
from utils.registry import Registry
from utils.misc import (
AverageMeter,
)
from models.utils import smooth_mouth_movements, apply_frame_blending, apply_savitzky_golay_smoothing, apply_random_brow_movement, \
symmetrize_blendshapes, apply_random_eye_blinks, apply_random_eye_blinks_context, export_blendshape_animation, \
RETURN_CODE, DEFAULT_CONTEXT, ARKitBlendShape
INFER = Registry("infer")
class InferBase:
def __init__(self, cfg, model=None, verbose=False) -> None:
torch.multiprocessing.set_sharing_strategy("file_system")
self.logger = get_root_logger(
log_file=os.path.join(cfg.save_path, "infer.log"),
file_mode="a" if cfg.resume else "w",
)
self.logger.info("=> Loading config ...")
self.cfg = cfg
self.verbose = verbose
if self.verbose:
self.logger.info(f"Save path: {cfg.save_path}")
self.logger.info(f"Config:\n{cfg.pretty_text}")
if model is None:
self.logger.info("=> Building model ...")
self.model = self.build_model()
else:
self.model = model
def build_model(self):
model = build_model(self.cfg.model)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.logger.info(f"Num params: {n_parameters}")
model = create_ddp_model(
model.cuda(),
broadcast_buffers=False,
find_unused_parameters=self.cfg.find_unused_parameters,
)
if os.path.isfile(self.cfg.weight):
self.logger.info(f"Loading weight at: {self.cfg.weight}")
checkpoint = torch.load(self.cfg.weight)
weight = OrderedDict()
for key, value in checkpoint["state_dict"].items():
if key.startswith("module."):
if comm.get_world_size() == 1:
key = key[7:] # module.xxx.xxx -> xxx.xxx
else:
if comm.get_world_size() > 1:
key = "module." + key # xxx.xxx -> module.xxx.xxx
weight[key] = value
model.load_state_dict(weight, strict=True)
self.logger.info(
"=> Loaded weight '{}'".format(
self.cfg.weight
)
)
else:
raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight))
return model
def infer(self):
raise NotImplementedError
@INFER.register_module()
class Audio2ExpressionInfer(InferBase):
def infer(self):
logger = get_root_logger()
logger.info(">>>>>>>>>>>>>>>> Start Inference >>>>>>>>>>>>>>>>")
batch_time = AverageMeter()
self.model.eval()
# process audio-input
assert os.path.exists(self.cfg.audio_input)
if(self.cfg.ex_vol):
logger.info("Extract vocals ...")
vocal_path = self.extract_vocal_track(self.cfg.audio_input)
logger.info("=> Extract vocals at: {}".format(vocal_path if os.path.exists(vocal_path) else '... Failed'))
if(os.path.exists(vocal_path)):
self.cfg.audio_input = vocal_path
with torch.no_grad():
input_dict = {}
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[None,...]
speech_array, ssr = librosa.load(self.cfg.audio_input, sr=16000)
input_dict['input_audio_array'] = torch.FloatTensor(speech_array).cuda(non_blocking=True)[None,...]
end = time.time()
output_dict = self.model(input_dict)
batch_time.update(time.time() - end)
logger.info(
"Infer: [{}] "
"Running Time: {batch_time.avg:.3f} ".format(
self.cfg.audio_input,
batch_time=batch_time,
)
)
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()
frame_length = math.ceil(speech_array.shape[0] / ssr * 30)
volume = librosa.feature.rms(y=speech_array, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
if (volume.shape[0] > frame_length):
volume = volume[:frame_length]
if(self.cfg.movement_smooth):
out_exp = smooth_mouth_movements(out_exp, 0, volume)
if (self.cfg.brow_movement):
out_exp = apply_random_brow_movement(out_exp, volume)
pred_exp = self.blendshape_postprocess(out_exp)
if(self.cfg.save_json_path is not None):
export_blendshape_animation(pred_exp,
self.cfg.save_json_path,
ARKitBlendShape,
fps=self.cfg.fps)
logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
def infer_streaming_audio(self,
audio: np.ndarray,
ssr: float,
context: dict):
if (context is None):
context = DEFAULT_CONTEXT.copy()
max_frame_length = 64
frame_length = math.ceil(audio.shape[0] / ssr * 30)
output_context = DEFAULT_CONTEXT.copy()
volume = librosa.feature.rms(y=audio, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
if (volume.shape[0] > frame_length):
volume = volume[:frame_length]
# resample audio
if (ssr != self.cfg.audio_sr):
in_audio = librosa.resample(audio.astype(np.float32), orig_sr=ssr, target_sr=self.cfg.audio_sr)
else:
in_audio = audio.copy()
start_frame = int(max_frame_length - in_audio.shape[0] / self.cfg.audio_sr * 30)
if (context['is_initial_input'] or (context['previous_audio'] is None)):
blank_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
blank_audio = np.zeros(blank_audio_length, dtype=np.float32)
# pre-append
input_audio = np.concatenate([blank_audio, in_audio])
output_context['previous_audio'] = input_audio
else:
clip_pre_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
clip_pre_audio = context['previous_audio'][-clip_pre_audio_length:]
input_audio = np.concatenate([clip_pre_audio, in_audio])
output_context['previous_audio'] = input_audio
with torch.no_grad():
try:
input_dict = {}
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[
None, ...]
input_dict['input_audio_array'] = torch.FloatTensor(input_audio).cuda(non_blocking=True)[None, ...]
output_dict = self.model(input_dict)
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()[start_frame:, :]
except:
self.logger.error('Error: faided to predict expression.')
output_dict['pred_exp'] = torch.zeros((max_frame_length, 52)).float()
return
# post-process
if (context['previous_expression'] is None):
out_exp = self.apply_expression_postprocessing(out_exp, audio_volume=volume)
else:
previous_length = context['previous_expression'].shape[0]
out_exp = self.apply_expression_postprocessing(expression_params = np.concatenate([context['previous_expression'], out_exp], axis=0),
audio_volume=np.concatenate([context['previous_volume'], volume], axis=0),
processed_frames=previous_length)[previous_length:, :]
if (context['previous_expression'] is not None):
output_context['previous_expression'] = np.concatenate([context['previous_expression'], out_exp], axis=0)[
-max_frame_length:, :]
output_context['previous_volume'] = np.concatenate([context['previous_volume'], volume], axis=0)[-max_frame_length:]
else:
output_context['previous_expression'] = out_exp.copy()
output_context['previous_volume'] = volume.copy()
output_context['first_input_flag'] = False
return {"code": RETURN_CODE['SUCCESS'],
"expression": out_exp,
"headpose": None}, output_context
def apply_expression_postprocessing(
self,
expression_params: np.ndarray,
processed_frames: int = 0,
audio_volume: np.ndarray = None
) -> np.ndarray:
"""Applies full post-processing pipeline to facial expression parameters.
Args:
expression_params: Raw output from animation model [num_frames, num_parameters]
processed_frames: Number of frames already processed in previous batches
audio_volume: Optional volume array for audio-visual synchronization
Returns:
Processed expression parameters ready for animation synthesis
"""
# Pipeline execution order matters - maintain sequence
expression_params = smooth_mouth_movements(expression_params, processed_frames, audio_volume)
expression_params = apply_frame_blending(expression_params, processed_frames)
expression_params, _ = apply_savitzky_golay_smoothing(expression_params, window_length=5)
expression_params = symmetrize_blendshapes(expression_params)
expression_params = apply_random_eye_blinks_context(expression_params, processed_frames=processed_frames)
return expression_params
def extract_vocal_track(
self,
input_audio_path: str
) -> str:
"""Isolates vocal track from audio file using source separation.
Args:
input_audio_path: Path to input audio file containing vocals+accompaniment
Returns:
Path to isolated vocal track in WAV format
"""
separation_command = f'spleeter separate -p spleeter:2stems -o {self.cfg.save_path} {input_audio_path}'
os.system(separation_command)
base_name = os.path.splitext(os.path.basename(input_audio_path))[0]
return os.path.join(self.cfg.save_path, base_name, 'vocals.wav')
def blendshape_postprocess(self,
bs_array: np.ndarray
)->np.array:
bs_array, _ = apply_savitzky_golay_smoothing(bs_array, window_length=5)
bs_array = symmetrize_blendshapes(bs_array)
bs_array = apply_random_eye_blinks(bs_array)
return bs_array

135
engines/launch.py Normal file
View File

@@ -0,0 +1,135 @@
"""
Launcher
modified from detectron2(https://github.com/facebookresearch/detectron2)
"""
import os
import logging
from datetime import timedelta
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from utils import comm
__all__ = ["DEFAULT_TIMEOUT", "launch"]
DEFAULT_TIMEOUT = timedelta(minutes=30)
def _find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def launch(
main_func,
num_gpus_per_machine,
num_machines=1,
machine_rank=0,
dist_url=None,
cfg=(),
timeout=DEFAULT_TIMEOUT,
):
"""
Launch multi-gpu or distributed training.
This function must be called on all machines involved in the training.
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
Args:
main_func: a function that will be called by `main_func(*args)`
num_gpus_per_machine (int): number of GPUs per machine
num_machines (int): the total number of machines
machine_rank (int): the rank of this machine
dist_url (str): url to connect to for distributed jobs, including protocol
e.g. "tcp://127.0.0.1:8686".
Can be set to "auto" to automatically select a free port on localhost
timeout (timedelta): timeout of the distributed workers
args (tuple): arguments passed to main_func
"""
world_size = num_machines * num_gpus_per_machine
if world_size > 1:
if dist_url == "auto":
assert (
num_machines == 1
), "dist_url=auto not supported in multi-machine jobs."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if num_machines > 1 and dist_url.startswith("file://"):
logger = logging.getLogger(__name__)
logger.warning(
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
)
mp.spawn(
_distributed_worker,
nprocs=num_gpus_per_machine,
args=(
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
cfg,
timeout,
),
daemon=False,
)
else:
main_func(*cfg)
def _distributed_worker(
local_rank,
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
cfg,
timeout=DEFAULT_TIMEOUT,
):
assert (
torch.cuda.is_available()
), "cuda is not available. Please check your installation."
global_rank = machine_rank * num_gpus_per_machine + local_rank
try:
dist.init_process_group(
backend="NCCL",
init_method=dist_url,
world_size=world_size,
rank=global_rank,
timeout=timeout,
)
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Process group URL: {}".format(dist_url))
raise e
# Setup the local process group (which contains ranks within the same machine)
assert comm._LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
)
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comm._LOCAL_PROCESS_GROUP = pg
assert num_gpus_per_machine <= torch.cuda.device_count()
torch.cuda.set_device(local_rank)
# synchronize is needed here to prevent a possible timeout after calling init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
main_func(*cfg)

299
engines/train.py Normal file
View File

@@ -0,0 +1,299 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import sys
import weakref
import torch
import torch.nn as nn
import torch.utils.data
from functools import partial
if sys.version_info >= (3, 10):
from collections.abc import Iterator
else:
from collections import Iterator
from tensorboardX import SummaryWriter
from .defaults import create_ddp_model, worker_init_fn
from .hooks import HookBase, build_hooks
import utils.comm as comm
from datasets import build_dataset, point_collate_fn, collate_fn
from models import build_model
from utils.logger import get_root_logger
from utils.optimizer import build_optimizer
from utils.scheduler import build_scheduler
from utils.events import EventStorage
from utils.registry import Registry
TRAINERS = Registry("trainers")
class TrainerBase:
def __init__(self) -> None:
self.hooks = []
self.epoch = 0
self.start_epoch = 0
self.max_epoch = 0
self.max_iter = 0
self.comm_info = dict()
self.data_iterator: Iterator = enumerate([])
self.storage: EventStorage
self.writer: SummaryWriter
def register_hooks(self, hooks) -> None:
hooks = build_hooks(hooks)
for h in hooks:
assert isinstance(h, HookBase)
# To avoid circular reference, hooks and trainer cannot own each other.
# This normally does not matter, but will cause memory leak if the
# involved objects contain __del__:
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
h.trainer = weakref.proxy(self)
self.hooks.extend(hooks)
def train(self):
with EventStorage() as self.storage:
# => before train
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
# => before epoch
self.before_epoch()
# => run_epoch
for (
self.comm_info["iter"],
self.comm_info["input_dict"],
) in self.data_iterator:
# => before_step
self.before_step()
# => run_step
self.run_step()
# => after_step
self.after_step()
# => after epoch
self.after_epoch()
# => after train
self.after_train()
def before_train(self):
for h in self.hooks:
h.before_train()
def before_epoch(self):
for h in self.hooks:
h.before_epoch()
def before_step(self):
for h in self.hooks:
h.before_step()
def run_step(self):
raise NotImplementedError
def after_step(self):
for h in self.hooks:
h.after_step()
def after_epoch(self):
for h in self.hooks:
h.after_epoch()
self.storage.reset_histories()
def after_train(self):
# Sync GPU before running train hooks
comm.synchronize()
for h in self.hooks:
h.after_train()
if comm.is_main_process():
self.writer.close()
@TRAINERS.register_module("DefaultTrainer")
class Trainer(TrainerBase):
def __init__(self, cfg):
super(Trainer, self).__init__()
self.epoch = 0
self.start_epoch = 0
self.max_epoch = cfg.eval_epoch
self.best_metric_value = -torch.inf
self.logger = get_root_logger(
log_file=os.path.join(cfg.save_path, "train.log"),
file_mode="a" if cfg.resume else "w",
)
self.logger.info("=> Loading config ...")
self.cfg = cfg
self.logger.info(f"Save path: {cfg.save_path}")
self.logger.info(f"Config:\n{cfg.pretty_text}")
self.logger.info("=> Building model ...")
self.model = self.build_model()
self.logger.info("=> Building writer ...")
self.writer = self.build_writer()
self.logger.info("=> Building train dataset & dataloader ...")
self.train_loader = self.build_train_loader()
self.logger.info("=> Building val dataset & dataloader ...")
self.val_loader = self.build_val_loader()
self.logger.info("=> Building optimize, scheduler, scaler(amp) ...")
self.optimizer = self.build_optimizer()
self.scheduler = self.build_scheduler()
self.scaler = self.build_scaler()
self.logger.info("=> Building hooks ...")
self.register_hooks(self.cfg.hooks)
def train(self):
with EventStorage() as self.storage:
# => before train
self.before_train()
self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
for self.epoch in range(self.start_epoch, self.max_epoch):
# => before epoch
# TODO: optimize to iteration based
if comm.get_world_size() > 1:
self.train_loader.sampler.set_epoch(self.epoch)
self.model.train()
self.data_iterator = enumerate(self.train_loader)
self.before_epoch()
# => run_epoch
for (
self.comm_info["iter"],
self.comm_info["input_dict"],
) in self.data_iterator:
# => before_step
self.before_step()
# => run_step
self.run_step()
# => after_step
self.after_step()
# => after epoch
self.after_epoch()
# => after train
self.after_train()
def run_step(self):
input_dict = self.comm_info["input_dict"]
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp):
output_dict = self.model(input_dict)
loss = output_dict["loss"]
self.optimizer.zero_grad()
if self.cfg.enable_amp:
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
# When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
# Fix torch warning scheduler step before optimizer step.
scaler = self.scaler.get_scale()
self.scaler.update()
if scaler <= self.scaler.get_scale():
self.scheduler.step()
else:
loss.backward()
self.optimizer.step()
self.scheduler.step()
if self.cfg.empty_cache:
torch.cuda.empty_cache()
self.comm_info["model_output_dict"] = output_dict
def build_model(self):
model = build_model(self.cfg.model)
if self.cfg.sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logger.info(f"Model: \n{self.model}")
self.logger.info(f"Num params: {n_parameters}")
model = create_ddp_model(
model.cuda(),
broadcast_buffers=False,
find_unused_parameters=self.cfg.find_unused_parameters,
)
return model
def build_writer(self):
writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None
self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}")
return writer
def build_train_loader(self):
train_data = build_dataset(self.cfg.data.train)
if comm.get_world_size() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
else:
train_sampler = None
init_fn = (
partial(
worker_init_fn,
num_workers=self.cfg.num_worker_per_gpu,
rank=comm.get_rank(),
seed=self.cfg.seed,
)
if self.cfg.seed is not None
else None
)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=self.cfg.batch_size_per_gpu,
shuffle=(train_sampler is None),
num_workers=0,
sampler=train_sampler,
collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob),
pin_memory=True,
worker_init_fn=init_fn,
drop_last=True,
# persistent_workers=True,
)
return train_loader
def build_val_loader(self):
val_loader = None
if self.cfg.evaluate:
val_data = build_dataset(self.cfg.data.val)
if comm.get_world_size() > 1:
val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
else:
val_sampler = None
val_loader = torch.utils.data.DataLoader(
val_data,
batch_size=self.cfg.batch_size_val_per_gpu,
shuffle=False,
num_workers=self.cfg.num_worker_per_gpu,
pin_memory=True,
sampler=val_sampler,
collate_fn=collate_fn,
)
return val_loader
def build_optimizer(self):
return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts)
def build_scheduler(self):
assert hasattr(self, "optimizer")
assert hasattr(self, "train_loader")
self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch
return build_scheduler(self.cfg.scheduler, self.optimizer)
def build_scaler(self):
scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None
return scaler
@TRAINERS.register_module("MultiDatasetTrainer")
class MultiDatasetTrainer(Trainer):
def build_train_loader(self):
from datasets import MultiDatasetDataloader
train_data = build_dataset(self.cfg.data.train)
train_loader = MultiDatasetDataloader(
train_data,
self.cfg.batch_size_per_gpu,
self.cfg.num_worker_per_gpu,
self.cfg.mix_prob,
self.cfg.seed,
)
self.comm_info["iter_per_epoch"] = len(train_loader)
return train_loader

48
inference.py Normal file
View File

@@ -0,0 +1,48 @@
"""
# Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from engines.defaults import (
default_argument_parser,
default_config_parser,
default_setup,
)
from engines.infer import INFER
from engines.launch import launch
def main_worker(cfg):
cfg = default_setup(cfg)
infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg))
infer.infer()
def main():
args = default_argument_parser().parse_args()
cfg = default_config_parser(args.config_file, args.options)
launch(
main_worker,
num_gpus_per_machine=args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
cfg=(cfg,),
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,60 @@
"""
# Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from engines.defaults import (
default_argument_parser,
default_config_parser,
default_setup,
)
from engines.infer import INFER
import librosa
from tqdm import tqdm
import time
def export_json(bs_array, json_path):
from models.utils import export_blendshape_animation, ARKitBlendShape
export_blendshape_animation(bs_array, json_path, ARKitBlendShape, fps=30.0)
if __name__ == '__main__':
args = default_argument_parser().parse_args()
args.config_file = 'configs/lam_audio2exp_config_streaming.py'
cfg = default_config_parser(args.config_file, args.options)
cfg = default_setup(cfg)
infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg))
infer.model.eval()
audio, sample_rate = librosa.load(cfg.audio_input, sr=16000)
context = None
input_num = audio.shape[0]//16000+1
gap = 16000
all_exp = []
for i in tqdm(range(input_num)):
start = time.time()
output, context = infer.infer_streaming_audio(audio[i*gap:(i+1)*gap], sample_rate, context)
end = time.time()
print('Inference time {}'.format(end - start))
all_exp.append(output['expression'])
all_exp = np.concatenate(all_exp,axis=0)
export_json(all_exp, cfg.save_json_path)

7
models/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
from .builder import build_model
from .default import DefaultEstimator
# Backbones
from .network import Audio2Expression

13
models/builder.py Normal file
View File

@@ -0,0 +1,13 @@
"""
Modified by https://github.com/Pointcept/Pointcept
"""
from utils.registry import Registry
MODELS = Registry("models")
MODULES = Registry("modules")
def build_model(cfg):
"""Build models."""
return MODELS.build(cfg)

25
models/default.py Normal file
View File

@@ -0,0 +1,25 @@
import torch.nn as nn
from models.losses import build_criteria
from .builder import MODELS, build_model
@MODELS.register_module()
class DefaultEstimator(nn.Module):
def __init__(self, backbone=None, criteria=None):
super().__init__()
self.backbone = build_model(backbone)
self.criteria = build_criteria(criteria)
def forward(self, input_dict):
pred_exp = self.backbone(input_dict)
# train
if self.training:
loss = self.criteria(pred_exp, input_dict["gt_exp"])
return dict(loss=loss)
# eval
elif "gt_exp" in input_dict.keys():
loss = self.criteria(pred_exp, input_dict["gt_exp"])
return dict(loss=loss, pred_exp=pred_exp)
# infer
else:
return dict(pred_exp=pred_exp)

248
models/encoder/wav2vec.py Normal file
View File

@@ -0,0 +1,248 @@
import numpy as np
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from dataclasses import dataclass
from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from transformers.file_utils import ModelOutput
_CONFIG_FOR_DOC = "Wav2Vec2Config"
_HIDDEN_STATES_START_POSITION = 2
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.Tensor] = None,
min_masks: int = 0,
) -> np.ndarray:
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
lengths = np.full(num_mask, mask_length)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
# linear interpolation layer
def linear_interpolation(features, input_fps, output_fps, output_len=None):
features = features.transpose(1, 2)
seq_len = features.shape[2] / float(input_fps)
if output_len is None:
output_len = int(seq_len * output_fps)
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
class Wav2Vec2Model(Wav2Vec2Model):
def __init__(self, config):
super().__init__(config)
self.lm_head = nn.Linear(1024, 32)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
frame_num=None
):
self.config.output_attentions = True
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.feature_extractor(input_values)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
if attention_mask is not None:
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
attention_mask = torch.zeros(
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
)
attention_mask[
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
hidden_states = self.feature_projection(hidden_states)[0]
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states,) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@dataclass
class SpeechClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class Wav2Vec2ClassificationHead(nn.Module):
"""Head for wav2vec classification task."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pooling_mode = config.pooling_mode
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = Wav2Vec2ClassificationHead(config)
self.init_weights()
def freeze_feature_extractor(self):
self.wav2vec2.feature_extractor._freeze_parameters()
def merged_strategy(
self,
hidden_states,
mode="mean"
):
if mode == "mean":
outputs = torch.mean(hidden_states, dim=1)
elif mode == "sum":
outputs = torch.sum(hidden_states, dim=1)
elif mode == "max":
outputs = torch.max(hidden_states, dim=1)[0]
else:
raise Exception(
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
return outputs
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
frame_num=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states1 = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
hidden_states = self.merged_strategy(hidden_states1, mode=self.pooling_mode)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SpeechClassifierOutput(
loss=loss,
logits=logits,
hidden_states=hidden_states1,
attentions=outputs.attentions,
)

87
models/encoder/wavlm.py Normal file
View File

@@ -0,0 +1,87 @@
import numpy as np
import torch
from transformers import WavLMModel
from transformers.modeling_outputs import Wav2Vec2BaseModelOutput
from typing import Optional, Tuple, Union
import torch.nn.functional as F
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(
features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
class WavLMModel(WavLMModel):
def __init__(self, config):
super().__init__(config)
def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
frame_num=None,
interpolate_pos: int = 0,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if interpolate_pos == 0:
extract_features = linear_interpolation(
extract_features, output_len=frame_num)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if interpolate_pos == 1:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)

View File

@@ -0,0 +1,4 @@
from .builder import build_criteria
from .misc import CrossEntropyLoss, SmoothCELoss, DiceLoss, FocalLoss, BinaryFocalLoss, L1Loss
from .lovasz import LovaszLoss

28
models/losses/builder.py Normal file
View File

@@ -0,0 +1,28 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
from utils.registry import Registry
LOSSES = Registry("losses")
class Criteria(object):
def __init__(self, cfg=None):
self.cfg = cfg if cfg is not None else []
self.criteria = []
for loss_cfg in self.cfg:
self.criteria.append(LOSSES.build(cfg=loss_cfg))
def __call__(self, pred, target):
if len(self.criteria) == 0:
# loss computation occur in model
return pred
loss = 0
for c in self.criteria:
loss += c(pred, target)
return loss
def build_criteria(cfg):
return Criteria(cfg)

253
models/losses/lovasz.py Normal file
View File

@@ -0,0 +1,253 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
from typing import Optional
from itertools import filterfalse
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from .builder import LOSSES
BINARY_MODE: str = "binary"
MULTICLASS_MODE: str = "multiclass"
MULTILABEL_MODE: str = "multilabel"
def _lovasz_grad(gt_sorted):
"""Compute gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1.0 - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def _lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Logits at each pixel (between -infinity and +infinity)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(
_lovasz_hinge_flat(
*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)
)
for log, lab in zip(logits, labels)
)
else:
loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore))
return loss
def _lovasz_hinge_flat(logits, labels):
"""Binary Lovasz hinge loss
Args:
logits: [P] Logits at each prediction (between -infinity and +infinity)
labels: [P] Tensor, binary ground truth labels (0 or 1)
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.0
signs = 2.0 * labels.float() - 1.0
errors = 1.0 - logits * signs
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = _lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), grad)
return loss
def _flatten_binary_scores(scores, labels, ignore=None):
"""Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)
if ignore is None:
return scores, labels
valid = labels != ignore
vscores = scores[valid]
vlabels = labels[valid]
return vscores, vlabels
def _lovasz_softmax(
probas, labels, classes="present", class_seen=None, per_image=False, ignore=None
):
"""Multi-class Lovasz-Softmax loss
Args:
@param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
@param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
@param per_image: compute the loss per image instead of per batch
@param ignore: void class labels
"""
if per_image:
loss = mean(
_lovasz_softmax_flat(
*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore),
classes=classes
)
for prob, lab in zip(probas, labels)
)
else:
loss = _lovasz_softmax_flat(
*_flatten_probas(probas, labels, ignore),
classes=classes,
class_seen=class_seen
)
return loss
def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None):
"""Multi-class Lovasz-Softmax loss
Args:
@param probas: [P, C] Class probabilities at each prediction (between 0 and 1)
@param labels: [P] Tensor, ground truth labels (between 0 and C - 1)
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if probas.numel() == 0:
# only void pixels, the gradients should be 0
return probas * 0.0
C = probas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ["all", "present"] else classes
# for c in class_to_sum:
for c in labels.unique():
if class_seen is None:
fg = (labels == c).type_as(probas) # foreground for class c
if classes == "present" and fg.sum() == 0:
continue
if C == 1:
if len(classes) > 1:
raise ValueError("Sigmoid output possible only with 1 class")
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (fg - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
else:
if c in class_seen:
fg = (labels == c).type_as(probas) # foreground for class c
if classes == "present" and fg.sum() == 0:
continue
if C == 1:
if len(classes) > 1:
raise ValueError("Sigmoid output possible only with 1 class")
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (fg - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
return mean(losses)
def _flatten_probas(probas, labels, ignore=None):
"""Flattens predictions in the batch"""
if probas.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probas.size()
probas = probas.view(B, 1, H, W)
C = probas.size(1)
probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C]
probas = probas.contiguous().view(-1, C) # [P, C]
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = labels != ignore
vprobas = probas[valid]
vlabels = labels[valid]
return vprobas, vlabels
def isnan(x):
return x != x
def mean(values, ignore_nan=False, empty=0):
"""Nan-mean compatible with generators."""
values = iter(values)
if ignore_nan:
values = filterfalse(isnan, values)
try:
n = 1
acc = next(values)
except StopIteration:
if empty == "raise":
raise ValueError("Empty mean")
return empty
for n, v in enumerate(values, 2):
acc += v
if n == 1:
return acc
return acc / n
@LOSSES.register_module()
class LovaszLoss(_Loss):
def __init__(
self,
mode: str,
class_seen: Optional[int] = None,
per_image: bool = False,
ignore_index: Optional[int] = None,
loss_weight: float = 1.0,
):
"""Lovasz loss for segmentation task.
It supports binary, multiclass and multilabel cases
Args:
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
per_image: If True loss computed per each image and then averaged, else computed per whole batch
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
Reference
https://github.com/BloodAxe/pytorch-toolbelt
"""
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__()
self.mode = mode
self.ignore_index = ignore_index
self.per_image = per_image
self.class_seen = class_seen
self.loss_weight = loss_weight
def forward(self, y_pred, y_true):
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
loss = _lovasz_hinge(
y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index
)
elif self.mode == MULTICLASS_MODE:
y_pred = y_pred.softmax(dim=1)
loss = _lovasz_softmax(
y_pred,
y_true,
class_seen=self.class_seen,
per_image=self.per_image,
ignore=self.ignore_index,
)
else:
raise ValueError("Wrong mode {}.".format(self.mode))
return loss * self.loss_weight

241
models/losses/misc.py Normal file
View File

@@ -0,0 +1,241 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .builder import LOSSES
@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
def __init__(
self,
weight=None,
size_average=None,
reduce=None,
reduction="mean",
label_smoothing=0.0,
loss_weight=1.0,
ignore_index=-1,
):
super(CrossEntropyLoss, self).__init__()
weight = torch.tensor(weight).cuda() if weight is not None else None
self.loss_weight = loss_weight
self.loss = nn.CrossEntropyLoss(
weight=weight,
size_average=size_average,
ignore_index=ignore_index,
reduce=reduce,
reduction=reduction,
label_smoothing=label_smoothing,
)
def forward(self, pred, target):
return self.loss(pred, target) * self.loss_weight
@LOSSES.register_module()
class L1Loss(nn.Module):
def __init__(
self,
weight=None,
size_average=None,
reduce=None,
reduction="mean",
label_smoothing=0.0,
loss_weight=1.0,
ignore_index=-1,
):
super(L1Loss, self).__init__()
weight = torch.tensor(weight).cuda() if weight is not None else None
self.loss_weight = loss_weight
self.loss = nn.L1Loss(reduction='mean')
def forward(self, pred, target):
return self.loss(pred, target[:,None]) * self.loss_weight
@LOSSES.register_module()
class SmoothCELoss(nn.Module):
def __init__(self, smoothing_ratio=0.1):
super(SmoothCELoss, self).__init__()
self.smoothing_ratio = smoothing_ratio
def forward(self, pred, target):
eps = self.smoothing_ratio
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).total(dim=1)
loss = loss[torch.isfinite(loss)].mean()
return loss
@LOSSES.register_module()
class BinaryFocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0):
"""Binary Focal Loss
<https://arxiv.org/abs/1708.02002>`
"""
super(BinaryFocalLoss, self).__init__()
assert 0 < alpha < 1
self.gamma = gamma
self.alpha = alpha
self.logits = logits
self.reduce = reduce
self.loss_weight = loss_weight
def forward(self, pred, target, **kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction with shape (N)
target (torch.Tensor): The ground truth. If containing class
indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities,
same shape as the input.
Returns:
torch.Tensor: The calculated loss
"""
if self.logits:
bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
else:
bce = F.binary_cross_entropy(pred, target, reduction="none")
pt = torch.exp(-bce)
alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
focal_loss = alpha * (1 - pt) ** self.gamma * bce
if self.reduce:
focal_loss = torch.mean(focal_loss)
return focal_loss * self.loss_weight
@LOSSES.register_module()
class FocalLoss(nn.Module):
def __init__(
self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1
):
"""Focal Loss
<https://arxiv.org/abs/1708.02002>`
"""
super(FocalLoss, self).__init__()
assert reduction in (
"mean",
"sum",
), "AssertionError: reduction should be 'mean' or 'sum'"
assert isinstance(
alpha, (float, list)
), "AssertionError: alpha should be of type float"
assert isinstance(gamma, float), "AssertionError: gamma should be of type float"
assert isinstance(
loss_weight, float
), "AssertionError: loss_weight should be of type float"
assert isinstance(ignore_index, int), "ignore_index must be of type int"
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
self.ignore_index = ignore_index
def forward(self, pred, target, **kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes.
target (torch.Tensor): The ground truth. If containing class
indices, shape (N) where each value is 0≤targets[i]≤C1, If containing class probabilities,
same shape as the input.
Returns:
torch.Tensor: The calculated loss
"""
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
pred = pred.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
pred = pred.reshape(pred.size(0), -1)
# [C, N] -> [N, C]
pred = pred.transpose(0, 1).contiguous()
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
target = target.view(-1).contiguous()
assert pred.size(0) == target.size(
0
), "The shape of pred doesn't match the shape of target"
valid_mask = target != self.ignore_index
target = target[valid_mask]
pred = pred[valid_mask]
if len(target) == 0:
return 0.0
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes)
alpha = self.alpha
if isinstance(alpha, list):
alpha = pred.new_tensor(alpha)
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow(
self.gamma
)
loss = (
F.binary_cross_entropy_with_logits(pred, target, reduction="none")
* focal_weight
)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.total()
return self.loss_weight * loss
@LOSSES.register_module()
class DiceLoss(nn.Module):
def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1):
"""DiceLoss.
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
"""
super(DiceLoss, self).__init__()
self.smooth = smooth
self.exponent = exponent
self.loss_weight = loss_weight
self.ignore_index = ignore_index
def forward(self, pred, target, **kwargs):
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
pred = pred.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
pred = pred.reshape(pred.size(0), -1)
# [C, N] -> [N, C]
pred = pred.transpose(0, 1).contiguous()
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
target = target.view(-1).contiguous()
assert pred.size(0) == target.size(
0
), "The shape of pred doesn't match the shape of target"
valid_mask = target != self.ignore_index
target = target[valid_mask]
pred = pred[valid_mask]
pred = F.softmax(pred, dim=1)
num_classes = pred.shape[1]
target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes
)
total_loss = 0
for i in range(num_classes):
if i != self.ignore_index:
num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth
den = (
torch.sum(
pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent)
)
+ self.smooth
)
dice_loss = 1 - num / den
total_loss += dice_loss
loss = total_loss / num_classes
return self.loss_weight * loss

646
models/network.py Normal file
View File

@@ -0,0 +1,646 @@
import math
import os.path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio as ta
from models.encoder.wav2vec import Wav2Vec2Model
from models.encoder.wavlm import WavLMModel
from models.builder import MODELS
from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
@MODELS.register_module("Audio2Expression")
class Audio2Expression(nn.Module):
def __init__(self,
device: torch.device = None,
pretrained_encoder_type: str = 'wav2vec',
pretrained_encoder_path: str = '',
wav2vec2_config_path: str = '',
num_identity_classes: int = 0,
identity_feat_dim: int = 64,
hidden_dim: int = 512,
expression_dim: int = 52,
norm_type: str = 'ln',
decoder_depth: int = 3,
use_transformer: bool = False,
num_attention_heads: int = 8,
num_transformer_layers: int = 6,
):
super().__init__()
self.device = device
# Initialize audio feature encoder
if pretrained_encoder_type == 'wav2vec':
if os.path.exists(pretrained_encoder_path):
self.audio_encoder = Wav2Vec2Model.from_pretrained(pretrained_encoder_path)
else:
config = Wav2Vec2Config.from_pretrained(wav2vec2_config_path)
self.audio_encoder = Wav2Vec2Model(config)
encoder_output_dim = 768
elif pretrained_encoder_type == 'wavlm':
self.audio_encoder = WavLMModel.from_pretrained(pretrained_encoder_path)
encoder_output_dim = 768
else:
raise NotImplementedError(f"Encoder type {pretrained_encoder_type} not supported")
self.audio_encoder.feature_extractor._freeze_parameters()
self.feature_projection = nn.Linear(encoder_output_dim, hidden_dim)
self.identity_encoder = AudioIdentityEncoder(
hidden_dim,
num_identity_classes,
identity_feat_dim,
use_transformer,
num_attention_heads,
num_transformer_layers
)
self.decoder = nn.ModuleList([
nn.Sequential(*[
ConvNormRelu(hidden_dim, hidden_dim, norm=norm_type)
for _ in range(decoder_depth)
])
])
self.output_proj = nn.Linear(hidden_dim, expression_dim)
def freeze_encoder_parameters(self, do_freeze=False):
for name, param in self.audio_encoder.named_parameters():
if('feature_extractor' in name):
param.requires_grad = False
else:
param.requires_grad = (not do_freeze)
def forward(self, input_dict):
if 'time_steps' not in input_dict:
audio_length = input_dict['input_audio_array'].shape[1]
time_steps = math.ceil(audio_length / 16000 * 30)
else:
time_steps = input_dict['time_steps']
# Process audio through encoder
audio_input = input_dict['input_audio_array'].flatten(start_dim=1)
hidden_states = self.audio_encoder(audio_input, frame_num=time_steps).last_hidden_state
# Project features to hidden dimension
audio_features = self.feature_projection(hidden_states).transpose(1, 2)
# Process identity-conditioned features
audio_features = self.identity_encoder(audio_features, identity=input_dict['id_idx'])
# Refine features through decoder
audio_features = self.decoder[0](audio_features)
# Generate output parameters
audio_features = audio_features.permute(0, 2, 1)
expression_params = self.output_proj(audio_features)
return torch.sigmoid(expression_params)
class AudioIdentityEncoder(nn.Module):
def __init__(self,
hidden_dim,
num_identity_classes=0,
identity_feat_dim=64,
use_transformer=False,
num_attention_heads = 8,
num_transformer_layers = 6,
dropout_ratio=0.1,
):
super().__init__()
in_dim = hidden_dim + identity_feat_dim
self.id_mlp = nn.Conv1d(num_identity_classes, identity_feat_dim, 1, 1)
self.first_net = SeqTranslator1D(in_dim, hidden_dim,
min_layers_num=3,
residual=True,
norm='ln'
)
self.grus = nn.GRU(hidden_dim, hidden_dim, 1, batch_first=True)
self.dropout = nn.Dropout(dropout_ratio)
self.use_transformer = use_transformer
if(self.use_transformer):
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_attention_heads, dim_feedforward= 2 * hidden_dim, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
def forward(self,
audio_features: torch.Tensor,
identity: torch.Tensor = None,
time_steps: int = None) -> tuple:
audio_features = self.dropout(audio_features)
identity = identity.reshape(identity.shape[0], -1, 1).repeat(1, 1, audio_features.shape[2]).to(torch.float32)
identity = self.id_mlp(identity)
audio_features = torch.cat([audio_features, identity], dim=1)
x = self.first_net(audio_features)
if time_steps is not None:
x = F.interpolate(x, size=time_steps, align_corners=False, mode='linear')
if(self.use_transformer):
x = x.permute(0, 2, 1)
x = self.transformer_encoder(x)
x = x.permute(0, 2, 1)
return x
class ConvNormRelu(nn.Module):
'''
(B,C_in,H,W) -> (B, C_out, H, W)
there exist some kernel size that makes the result is not H/s
'''
def __init__(self,
in_channels,
out_channels,
type='1d',
leaky=False,
downsample=False,
kernel_size=None,
stride=None,
padding=None,
p=0,
groups=1,
residual=False,
norm='bn'):
'''
conv-bn-relu
'''
super(ConvNormRelu, self).__init__()
self.residual = residual
self.norm_type = norm
# kernel_size = k
# stride = s
if kernel_size is None and stride is None:
if not downsample:
kernel_size = 3
stride = 1
else:
kernel_size = 4
stride = 2
if padding is None:
if isinstance(kernel_size, int) and isinstance(stride, tuple):
padding = tuple(int((kernel_size - st) / 2) for st in stride)
elif isinstance(kernel_size, tuple) and isinstance(stride, int):
padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
else:
padding = int((kernel_size - stride) / 2)
if self.residual:
if downsample:
if type == '1d':
self.residual_layer = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
)
elif type == '2d':
self.residual_layer = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
)
else:
if in_channels == out_channels:
self.residual_layer = nn.Identity()
else:
if type == '1d':
self.residual_layer = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
)
elif type == '2d':
self.residual_layer = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
)
in_channels = in_channels * groups
out_channels = out_channels * groups
if type == '1d':
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=groups)
self.norm = nn.BatchNorm1d(out_channels)
self.dropout = nn.Dropout(p=p)
elif type == '2d':
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=groups)
self.norm = nn.BatchNorm2d(out_channels)
self.dropout = nn.Dropout2d(p=p)
if norm == 'gn':
self.norm = nn.GroupNorm(2, out_channels)
elif norm == 'ln':
self.norm = nn.LayerNorm(out_channels)
if leaky:
self.relu = nn.LeakyReLU(negative_slope=0.2)
else:
self.relu = nn.ReLU()
def forward(self, x, **kwargs):
if self.norm_type == 'ln':
out = self.dropout(self.conv(x))
out = self.norm(out.transpose(1,2)).transpose(1,2)
else:
out = self.norm(self.dropout(self.conv(x)))
if self.residual:
residual = self.residual_layer(x)
out += residual
return self.relu(out)
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
class SeqTranslator1D(nn.Module):
'''
(B, C, T)->(B, C_out, T)
'''
def __init__(self,
C_in,
C_out,
kernel_size=None,
stride=None,
min_layers_num=None,
residual=True,
norm='bn'
):
super(SeqTranslator1D, self).__init__()
conv_layers = nn.ModuleList([])
conv_layers.append(ConvNormRelu(
in_channels=C_in,
out_channels=C_out,
type='1d',
kernel_size=kernel_size,
stride=stride,
residual=residual,
norm=norm
))
self.num_layers = 1
if min_layers_num is not None and self.num_layers < min_layers_num:
while self.num_layers < min_layers_num:
conv_layers.append(ConvNormRelu(
in_channels=C_out,
out_channels=C_out,
type='1d',
kernel_size=kernel_size,
stride=stride,
residual=residual,
norm=norm
))
self.num_layers += 1
self.conv_layers = nn.Sequential(*conv_layers)
def forward(self, x):
return self.conv_layers(x)
def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
"""
:param audio: 1 x T tensor containing a 16kHz audio signal
:param frame_rate: frame rate for video (we need one audio chunk per video frame)
:param chunk_size: number of audio samples per chunk
:return: num_chunks x chunk_size tensor containing sliced audio
"""
samples_per_frame = 16000 // frame_rate
padding = (chunk_size - samples_per_frame) // 2
audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
return audio
""" https://github.com/facebookresearch/meshtalk """
class MeshtalkEncoder(nn.Module):
def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
"""
:param latent_dim: size of the latent audio embedding
:param model_name: name of the model, used to load and save the model
"""
super().__init__()
self.melspec = ta.transforms.MelSpectrogram(
sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
)
conv_len = 5
self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
self.weights_init(self.convert_dimensions)
self.receptive_field = conv_len
convs = []
for i in range(6):
dilation = 2 * (i % 3 + 1)
self.receptive_field += (conv_len - 1) * dilation
convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
self.weights_init(convs[-1])
self.convs = torch.nn.ModuleList(convs)
self.code = torch.nn.Linear(128, latent_dim)
self.apply(lambda x: self.weights_init(x))
def weights_init(self, m):
if isinstance(m, torch.nn.Conv1d):
torch.nn.init.xavier_uniform_(m.weight)
try:
torch.nn.init.constant_(m.bias, .01)
except:
pass
def forward(self, audio: torch.Tensor):
"""
:param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
:return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
"""
B, T = audio.shape[0], audio.shape[1]
x = self.melspec(audio).squeeze(1)
x = torch.log(x.clamp(min=1e-10, max=None))
if T == 1:
x = x.unsqueeze(1)
# Convert to the right dimensionality
x = x.view(-1, x.shape[2], x.shape[3])
x = F.leaky_relu(self.convert_dimensions(x), .2)
# Process stacks
for conv in self.convs:
x_ = F.leaky_relu(conv(x), .2)
if self.training:
x_ = F.dropout(x_, .2)
l = (x.shape[2] - x_.shape[2]) // 2
x = (x[:, :, l:-l] + x_) / 2
x = torch.mean(x, dim=-1)
x = x.view(B, T, x.shape[-1])
x = self.code(x)
return {"code": x}
class PeriodicPositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=64):
super(PeriodicPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(period, d_model)
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, period, d_model)
repeat_num = (max_seq_len//period) + 1
pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# print(self.pe.shape, x.shape)
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class GeneratorTransformer(nn.Module):
def __init__(self,
n_poses,
each_dim: list,
dim_list: list,
training=True,
device=None,
identity=False,
num_classes=0,
):
super().__init__()
self.training = training
self.device = device
self.gen_length = n_poses
norm = 'ln'
in_dim = 256
out_dim = 256
self.encoder_choice = 'faceformer'
self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
self.audio_encoder.feature_extractor._freeze_parameters()
self.audio_feature_map = nn.Linear(768, in_dim)
self.audio_middle = AudioEncoder(in_dim, out_dim, False, num_classes)
self.dim_list = dim_list
self.decoder = nn.ModuleList()
self.final_out = nn.ModuleList()
self.hidden_size = 768
self.transformer_de_layer = nn.TransformerDecoderLayer(
d_model=self.hidden_size,
nhead=4,
dim_feedforward=self.hidden_size*2,
batch_first=True
)
self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4)
self.feature2face = nn.Linear(256, self.hidden_size)
self.position_embeddings = PeriodicPositionalEncoding(self.hidden_size, period=64, max_seq_len=64)
self.id_maping = nn.Linear(12,self.hidden_size)
self.decoder.append(self.face_decoder)
self.final_out.append(nn.Linear(self.hidden_size, 32))
def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
if gt_poses is None:
time_steps = 64
else:
time_steps = gt_poses.shape[1]
# vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
if self.encoder_choice == 'meshtalk':
in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
elif self.encoder_choice == 'faceformer':
hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
feature = self.audio_feature_map(hidden_states).transpose(1, 2)
else:
feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
feature, _ = self.audio_middle(feature, id=None)
feature = self.feature2face(feature.permute(0,2,1))
id = id.unsqueeze(1).repeat(1,64,1).to(torch.float32)
id_feature = self.id_maping(id)
id_feature = self.position_embeddings(id_feature)
for i in range(self.decoder.__len__()):
mid = self.decoder[i](tgt=id_feature, memory=feature)
out = self.final_out[i](mid)
return out, None
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(
features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
def init_biased_mask(n_head, max_seq_len, period):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
slopes = torch.Tensor(get_slopes(n_head))
bias = torch.div(
torch.arange(start=0, end=max_seq_len,
step=period).unsqueeze(1).repeat(1, period).view(-1),
period,
rounding_mode='floor')
bias = -torch.flip(bias, dims=[0])
alibi = torch.zeros(max_seq_len, max_seq_len)
for i in range(max_seq_len):
alibi[i, :i + 1] = bias[-(i + 1):]
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
mask = (torch.triu(torch.ones(max_seq_len,
max_seq_len)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
mask == 1, float(0.0))
mask = mask.unsqueeze(0) + alibi
return mask
# Alignment Bias
def enc_dec_mask(device, T, S):
mask = torch.ones(T, S)
for i in range(T):
mask[i, i] = 0
return (mask == 1).to(device=device)
# Periodic Positional Encoding
class PeriodicPositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
super(PeriodicPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(period, d_model)
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, period, d_model)
repeat_num = (max_seq_len // period) + 1
pe = pe.repeat(1, repeat_num, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class BaseModel(nn.Module):
"""Base class for all models."""
def __init__(self):
super(BaseModel, self).__init__()
# self.logger = logging.getLogger(self.__class__.__name__)
def forward(self, *x):
"""Forward pass logic.
:return: Model output
"""
raise NotImplementedError
def freeze_model(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def summary(self, logger, writer=None):
"""Model summary."""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size())
for p in model_parameters]) / 1e6 # Unit is Mega
logger.info('===>Trainable parameters: %.3f M' % params)
if writer is not None:
writer.add_text('Model Summary',
'Trainable parameters: %.3f M' % params)
"""https://github.com/X-niper/UniTalker"""
class UniTalkerDecoderTransformer(BaseModel):
def __init__(self, out_dim, identity_num, period=30, interpolate_pos=1) -> None:
super().__init__()
self.learnable_style_emb = nn.Embedding(identity_num, out_dim)
self.PPE = PeriodicPositionalEncoding(
out_dim, period=period, max_seq_len=3000)
self.biased_mask = init_biased_mask(
n_head=4, max_seq_len=3000, period=period)
decoder_layer = nn.TransformerDecoderLayer(
d_model=out_dim,
nhead=4,
dim_feedforward=2 * out_dim,
batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer, num_layers=1)
self.interpolate_pos = interpolate_pos
def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,
frame_num: int):
style_idx = torch.argmax(style_idx, dim=1)
obj_embedding = self.learnable_style_emb(style_idx)
obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1)
style_input = self.PPE(obj_embedding)
tgt_mask = self.biased_mask.repeat(style_idx.shape[0], 1, 1)[:, :style_input.shape[1], :style_input.
shape[1]].clone().detach().to(
device=style_input.device)
memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1],
frame_num)
feat_out = self.transformer_decoder(
style_input,
hidden_states,
tgt_mask=tgt_mask,
memory_mask=memory_mask)
if self.interpolate_pos == 2:
feat_out = linear_interpolation(feat_out, output_len=frame_num)
return feat_out

752
models/utils.py Normal file
View File

@@ -0,0 +1,752 @@
import json
import time
import warnings
import numpy as np
from typing import List, Optional,Tuple
from scipy.signal import savgol_filter
ARKitLeftRightPair = [
("jawLeft", "jawRight"),
("mouthLeft", "mouthRight"),
("mouthSmileLeft", "mouthSmileRight"),
("mouthFrownLeft", "mouthFrownRight"),
("mouthDimpleLeft", "mouthDimpleRight"),
("mouthStretchLeft", "mouthStretchRight"),
("mouthPressLeft", "mouthPressRight"),
("mouthLowerDownLeft", "mouthLowerDownRight"),
("mouthUpperUpLeft", "mouthUpperUpRight"),
("cheekSquintLeft", "cheekSquintRight"),
("noseSneerLeft", "noseSneerRight"),
("browDownLeft", "browDownRight"),
("browOuterUpLeft", "browOuterUpRight"),
("eyeBlinkLeft","eyeBlinkRight"),
("eyeLookDownLeft","eyeLookDownRight"),
("eyeLookInLeft", "eyeLookInRight"),
("eyeLookOutLeft","eyeLookOutRight"),
("eyeLookUpLeft","eyeLookUpRight"),
("eyeSquintLeft","eyeSquintRight"),
("eyeWideLeft","eyeWideRight")
]
ARKitBlendShape =[
"browDownLeft",
"browDownRight",
"browInnerUp",
"browOuterUpLeft",
"browOuterUpRight",
"cheekPuff",
"cheekSquintLeft",
"cheekSquintRight",
"eyeBlinkLeft",
"eyeBlinkRight",
"eyeLookDownLeft",
"eyeLookDownRight",
"eyeLookInLeft",
"eyeLookInRight",
"eyeLookOutLeft",
"eyeLookOutRight",
"eyeLookUpLeft",
"eyeLookUpRight",
"eyeSquintLeft",
"eyeSquintRight",
"eyeWideLeft",
"eyeWideRight",
"jawForward",
"jawLeft",
"jawOpen",
"jawRight",
"mouthClose",
"mouthDimpleLeft",
"mouthDimpleRight",
"mouthFrownLeft",
"mouthFrownRight",
"mouthFunnel",
"mouthLeft",
"mouthLowerDownLeft",
"mouthLowerDownRight",
"mouthPressLeft",
"mouthPressRight",
"mouthPucker",
"mouthRight",
"mouthRollLower",
"mouthRollUpper",
"mouthShrugLower",
"mouthShrugUpper",
"mouthSmileLeft",
"mouthSmileRight",
"mouthStretchLeft",
"mouthStretchRight",
"mouthUpperUpLeft",
"mouthUpperUpRight",
"noseSneerLeft",
"noseSneerRight",
"tongueOut"
]
MOUTH_BLENDSHAPES = [ "mouthDimpleLeft",
"mouthDimpleRight",
"mouthFrownLeft",
"mouthFrownRight",
"mouthFunnel",
"mouthLeft",
"mouthLowerDownLeft",
"mouthLowerDownRight",
"mouthPressLeft",
"mouthPressRight",
"mouthPucker",
"mouthRight",
"mouthRollLower",
"mouthRollUpper",
"mouthShrugLower",
"mouthShrugUpper",
"mouthSmileLeft",
"mouthSmileRight",
"mouthStretchLeft",
"mouthStretchRight",
"mouthUpperUpLeft",
"mouthUpperUpRight",
"jawForward",
"jawLeft",
"jawOpen",
"jawRight",
"noseSneerLeft",
"noseSneerRight",
"cheekPuff",
]
DEFAULT_CONTEXT ={
'is_initial_input': True,
'previous_audio': None,
'previous_expression': None,
'previous_volume': None,
'previous_headpose': None,
}
RETURN_CODE = {
"SUCCESS": 0,
"AUDIO_LENGTH_ERROR": 1,
"CHECKPOINT_PATH_ERROR":2,
"MODEL_INFERENCE_ERROR":3,
}
DEFAULT_CONTEXTRETURN = {
"code": RETURN_CODE['SUCCESS'],
"expression": None,
"headpose": None,
}
BLINK_PATTERNS = [
np.array([0.365, 0.950, 0.956, 0.917, 0.367, 0.119, 0.025]),
np.array([0.235, 0.910, 0.945, 0.778, 0.191, 0.235, 0.089]),
np.array([0.870, 0.950, 0.949, 0.696, 0.191, 0.073, 0.007]),
np.array([0.000, 0.557, 0.953, 0.942, 0.426, 0.148, 0.018])
]
# Postprocess
def symmetrize_blendshapes(
bs_params: np.ndarray,
mode: str = "average",
symmetric_pairs: list = ARKitLeftRightPair
) -> np.ndarray:
"""
Apply symmetrization to ARKit blendshape parameters (batched version)
Args:
bs_params: numpy array of shape (N, 52), batch of ARKit parameters
mode: symmetrization mode ["average", "max", "min", "left_dominant", "right_dominant"]
symmetric_pairs: list of left-right parameter pairs
Returns:
Symmetrized parameters with same shape (N, 52)
"""
name_to_idx = {name: i for i, name in enumerate(ARKitBlendShape)}
# Input validation
if bs_params.ndim != 2 or bs_params.shape[1] != 52:
raise ValueError("Input must be of shape (N, 52)")
symmetric_bs = bs_params.copy() # Shape (N, 52)
# Precompute valid index pairs
valid_pairs = []
for left, right in symmetric_pairs:
left_idx = name_to_idx.get(left)
right_idx = name_to_idx.get(right)
if None not in (left_idx, right_idx):
valid_pairs.append((left_idx, right_idx))
# Vectorized processing
for l_idx, r_idx in valid_pairs:
left_col = symmetric_bs[:, l_idx]
right_col = symmetric_bs[:, r_idx]
if mode == "average":
new_vals = (left_col + right_col) / 2
elif mode == "max":
new_vals = np.maximum(left_col, right_col)
elif mode == "min":
new_vals = np.minimum(left_col, right_col)
elif mode == "left_dominant":
new_vals = left_col
elif mode == "right_dominant":
new_vals = right_col
else:
raise ValueError(f"Invalid mode: {mode}")
# Update both columns simultaneously
symmetric_bs[:, l_idx] = new_vals
symmetric_bs[:, r_idx] = new_vals
return symmetric_bs
def apply_random_eye_blinks(
input: np.ndarray,
blink_scale: tuple = (0.8, 1.0),
blink_interval: tuple = (60, 120),
blink_duration: int = 7
) -> np.ndarray:
"""
Apply randomized eye blinks to blendshape parameters
Args:
output: Input array of shape (N, 52) containing blendshape parameters
blink_scale: Tuple (min, max) for random blink intensity scaling
blink_interval: Tuple (min, max) for random blink spacing in frames
blink_duration: Number of frames for blink animation (fixed)
Returns:
None (modifies output array in-place)
"""
# Define eye blink patterns (normalized 0-1)
# Initialize parameters
n_frames = input.shape[0]
input[:,8:10] = np.zeros((n_frames,2))
current_frame = 0
# Main blink application loop
while current_frame < n_frames - blink_duration:
# Randomize blink parameters
scale = np.random.uniform(*blink_scale)
pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
# Apply blink animation
blink_values = pattern * scale
input[current_frame:current_frame + blink_duration, 8] = blink_values
input[current_frame:current_frame + blink_duration, 9] = blink_values
# Advance to next blink position
current_frame += blink_duration + np.random.randint(*blink_interval)
return input
def apply_random_eye_blinks_context(
animation_params: np.ndarray,
processed_frames: int = 0,
intensity_range: tuple = (0.8, 1.0)
) -> np.ndarray:
"""Applies random eye blink patterns to facial animation parameters.
Args:
animation_params: Input facial animation parameters array with shape [num_frames, num_features].
Columns 8 and 9 typically represent left/right eye blink parameters.
processed_frames: Number of already processed frames that shouldn't be modified
intensity_range: Tuple defining (min, max) scaling for blink intensity
Returns:
Modified animation parameters array with random eye blinks added to unprocessed frames
"""
remaining_frames = animation_params.shape[0] - processed_frames
# Only apply blinks if there's enough remaining frames (blink pattern requires 7 frames)
if remaining_frames <= 7:
return animation_params
# Configure blink timing parameters
min_blink_interval = 40 # Minimum frames between blinks
max_blink_interval = 100 # Maximum frames between blinks
# Find last blink in previously processed frames (column 8 > 0.5 indicates blink)
previous_blink_indices = np.where(animation_params[:processed_frames, 8] > 0.5)[0]
last_processed_blink = previous_blink_indices[-1] - 7 if previous_blink_indices.size > 0 else processed_frames
# Calculate first new blink position
blink_interval = np.random.randint(min_blink_interval, max_blink_interval)
first_blink_start = max(0, blink_interval - last_processed_blink)
# Apply first blink if there's enough space
if first_blink_start <= (remaining_frames - 7):
# Randomly select blink pattern and intensity
blink_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
intensity = np.random.uniform(*intensity_range)
# Calculate blink frame range
blink_start = processed_frames + first_blink_start
blink_end = blink_start + 7
# Apply pattern to both eyes
animation_params[blink_start:blink_end, 8] = blink_pattern * intensity
animation_params[blink_start:blink_end, 9] = blink_pattern * intensity
# Check space for additional blink
remaining_after_blink = animation_params.shape[0] - blink_end
if remaining_after_blink > min_blink_interval:
# Calculate second blink position
second_intensity = np.random.uniform(*intensity_range)
second_interval = np.random.randint(min_blink_interval, max_blink_interval)
if (remaining_after_blink - 7) > second_interval:
second_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
second_blink_start = blink_end + second_interval
second_blink_end = second_blink_start + 7
# Apply second blink
animation_params[second_blink_start:second_blink_end, 8] = second_pattern * second_intensity
animation_params[second_blink_start:second_blink_end, 9] = second_pattern * second_intensity
return animation_params
def export_blendshape_animation(
blendshape_weights: np.ndarray,
output_path: str,
blendshape_names: List[str],
fps: float,
rotation_data: Optional[np.ndarray] = None
) -> None:
"""
Export blendshape animation data to JSON format compatible with ARKit.
Args:
blendshape_weights: 2D numpy array of shape (N, 52) containing animation frames
output_path: Full path for output JSON file (including .json extension)
blendshape_names: Ordered list of 52 ARKit-standard blendshape names
fps: Frame rate for timing calculations (frames per second)
rotation_data: Optional 3D rotation data array of shape (N, 3)
Raises:
ValueError: If input dimensions are incompatible
IOError: If file writing fails
"""
# Validate input dimensions
if blendshape_weights.shape[1] != 52:
raise ValueError(f"Expected 52 blendshapes, got {blendshape_weights.shape[1]}")
if len(blendshape_names) != 52:
raise ValueError(f"Requires 52 blendshape names, got {len(blendshape_names)}")
if rotation_data is not None and len(rotation_data) != len(blendshape_weights):
raise ValueError("Rotation data length must match animation frames")
# Build animation data structure
animation_data = {
"names":blendshape_names,
"metadata": {
"fps": fps,
"frame_count": len(blendshape_weights),
"blendshape_names": blendshape_names
},
"frames": []
}
# Convert numpy array to serializable format
for frame_idx in range(blendshape_weights.shape[0]):
frame_data = {
"weights": blendshape_weights[frame_idx].tolist(),
"time": frame_idx / fps,
"rotation": rotation_data[frame_idx].tolist() if rotation_data else []
}
animation_data["frames"].append(frame_data)
# Safeguard against data loss
if not output_path.endswith('.json'):
output_path += '.json'
# Write to file with error handling
try:
with open(output_path, 'w', encoding='utf-8') as json_file:
json.dump(animation_data, json_file, indent=2, ensure_ascii=False)
except Exception as e:
raise IOError(f"Failed to write animation data: {str(e)}") from e
def apply_savitzky_golay_smoothing(
input_data: np.ndarray,
window_length: int = 5,
polyorder: int = 2,
axis: int = 0,
validate: bool = True
) -> Tuple[np.ndarray, Optional[float]]:
"""
Apply Savitzky-Golay filter smoothing along specified axis of input data.
Args:
input_data: 2D numpy array of shape (n_samples, n_features)
window_length: Length of the filter window (must be odd and > polyorder)
polyorder: Order of the polynomial fit
axis: Axis along which to filter (0: column-wise, 1: row-wise)
validate: Enable input validation checks when True
Returns:
tuple: (smoothed_data, processing_time)
- smoothed_data: Smoothed output array
- processing_time: Execution time in seconds (None in validation mode)
Raises:
ValueError: For invalid input dimensions or filter parameters
"""
# Validation mode timing bypass
processing_time = None
if validate:
# Input integrity checks
if input_data.ndim != 2:
raise ValueError(f"Expected 2D input, got {input_data.ndim}D array")
if window_length % 2 == 0 or window_length < 3:
raise ValueError("Window length must be odd integer ≥ 3")
if polyorder >= window_length:
raise ValueError("Polynomial order must be < window length")
# Store original dtype and convert to float64 for numerical stability
original_dtype = input_data.dtype
working_data = input_data.astype(np.float64)
# Start performance timer
timer_start = time.perf_counter()
try:
# Vectorized Savitzky-Golay application
smoothed_data = savgol_filter(working_data,
window_length=window_length,
polyorder=polyorder,
axis=axis,
mode='mirror')
except Exception as e:
raise RuntimeError(f"Filtering failed: {str(e)}") from e
# Stop timer and calculate duration
processing_time = time.perf_counter() - timer_start
# Restore original data type with overflow protection
return (
np.clip(smoothed_data,
0.0,
1.0
).astype(original_dtype),
processing_time
)
def _blend_region_start(
array: np.ndarray,
region: np.ndarray,
processed_boundary: int,
blend_frames: int
) -> None:
"""Applies linear blend between last active frame and silent region start."""
blend_length = min(blend_frames, region[0] - processed_boundary)
if blend_length <= 0:
return
pre_frame = array[region[0] - 1]
for i in range(blend_length):
weight = (i + 1) / (blend_length + 1)
array[region[0] + i] = pre_frame * (1 - weight) + array[region[0] + i] * weight
def _blend_region_end(
array: np.ndarray,
region: np.ndarray,
blend_frames: int
) -> None:
"""Applies linear blend between silent region end and next active frame."""
blend_length = min(blend_frames, array.shape[0] - region[-1] - 1)
if blend_length <= 0:
return
post_frame = array[region[-1] + 1]
for i in range(blend_length):
weight = (i + 1) / (blend_length + 1)
array[region[-1] - i] = post_frame * (1 - weight) + array[region[-1] - i] * weight
def find_low_value_regions(
signal: np.ndarray,
threshold: float,
min_region_length: int = 5
) -> list:
"""Identifies contiguous regions in a signal where values fall below a threshold.
Args:
signal: Input 1D array of numerical values
threshold: Value threshold for identifying low regions
min_region_length: Minimum consecutive samples required to qualify as a region
Returns:
List of numpy arrays, each containing indices for a qualifying low-value region
"""
low_value_indices = np.where(signal < threshold)[0]
contiguous_regions = []
current_region_length = 0
region_start_idx = 0
for i in range(1, len(low_value_indices)):
# Check if current index continues a consecutive sequence
if low_value_indices[i] != low_value_indices[i - 1] + 1:
# Finalize previous region if it meets length requirement
if current_region_length >= min_region_length:
contiguous_regions.append(low_value_indices[region_start_idx:i])
# Reset tracking for new potential region
region_start_idx = i
current_region_length = 0
current_region_length += 1
# Add the final region if it qualifies
if current_region_length >= min_region_length:
contiguous_regions.append(low_value_indices[region_start_idx:])
return contiguous_regions
def smooth_mouth_movements(
blend_shapes: np.ndarray,
processed_frames: int,
volume: np.ndarray = None,
silence_threshold: float = 0.001,
min_silence_duration: int = 7,
blend_window: int = 3
) -> np.ndarray:
"""Reduces jaw movement artifacts during silent periods in audio-driven animation.
Args:
blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
processed_frames: Number of already processed frames that shouldn't be modified
volume: Audio volume array used to detect silent periods
silence_threshold: Volume threshold for considering a frame silent
min_silence_duration: Minimum consecutive silent frames to qualify for processing
blend_window: Number of frames to smooth at region boundaries
Returns:
Modified blend shape array with reduced mouth movements during silence
"""
if volume is None:
return blend_shapes
# Detect silence periods using volume data
silent_regions = find_low_value_regions(
volume,
threshold=silence_threshold,
min_region_length=min_silence_duration
)
for region_indices in silent_regions:
# Reduce mouth blend shapes in silent region
mouth_blend_indices = [ARKitBlendShape.index(name) for name in MOUTH_BLENDSHAPES]
for region_indice in region_indices.tolist():
blend_shapes[region_indice, mouth_blend_indices] *= 0.1
try:
# Smooth transition into silent region
_blend_region_start(
blend_shapes,
region_indices,
processed_frames,
blend_window
)
# Smooth transition out of silent region
_blend_region_end(
blend_shapes,
region_indices,
blend_window
)
except IndexError as e:
warnings.warn(f"Edge blending skipped at region {region_indices}: {str(e)}")
return blend_shapes
def apply_frame_blending(
blend_shapes: np.ndarray,
processed_frames: int,
initial_blend_window: int = 3,
subsequent_blend_window: int = 5
) -> np.ndarray:
"""Smooths transitions between processed and unprocessed animation frames using linear blending.
Args:
blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
processed_frames: Number of already processed frames (0 means no previous processing)
initial_blend_window: Max frames to blend at sequence start
subsequent_blend_window: Max frames to blend between processed and new frames
Returns:
Modified blend shape array with smoothed transitions
"""
if processed_frames > 0:
# Blend transition between existing and new animation
_blend_animation_segment(
blend_shapes,
transition_start=processed_frames,
blend_window=subsequent_blend_window,
reference_frame=blend_shapes[processed_frames - 1]
)
else:
# Smooth initial frames from neutral expression (zeros)
_blend_animation_segment(
blend_shapes,
transition_start=0,
blend_window=initial_blend_window,
reference_frame=np.zeros_like(blend_shapes[0])
)
return blend_shapes
def _blend_animation_segment(
array: np.ndarray,
transition_start: int,
blend_window: int,
reference_frame: np.ndarray
) -> None:
"""Applies linear interpolation between reference frame and target frames.
Args:
array: Blend shape array to modify
transition_start: Starting index for blending
blend_window: Maximum number of frames to blend
reference_frame: The reference frame to blend from
"""
actual_blend_length = min(blend_window, array.shape[0] - transition_start)
for frame_offset in range(actual_blend_length):
current_idx = transition_start + frame_offset
blend_weight = (frame_offset + 1) / (actual_blend_length + 1)
# Linear interpolation: ref_frame * (1 - weight) + current_frame * weight
array[current_idx] = (reference_frame * (1 - blend_weight)
+ array[current_idx] * blend_weight)
BROW1 = np.array([[0.05597309, 0.05727929, 0.07995935, 0. , 0. ],
[0.00757574, 0.00936678, 0.12242376, 0. , 0. ],
[0. , 0. , 0.14943372, 0.04535687, 0.04264118],
[0. , 0. , 0.18015374, 0.09019445, 0.08736137],
[0. , 0. , 0.20549579, 0.12802747, 0.12450772],
[0. , 0. , 0.21098022, 0.1369939 , 0.13343132],
[0. , 0. , 0.20904602, 0.13903855, 0.13562402],
[0. , 0. , 0.20365039, 0.13977394, 0.13653506],
[0. , 0. , 0.19714841, 0.14096624, 0.13805152],
[0. , 0. , 0.20325482, 0.17303431, 0.17028868],
[0. , 0. , 0.21990852, 0.20164253, 0.19818163],
[0. , 0. , 0.23858181, 0.21908803, 0.21540019],
[0. , 0. , 0.2567876 , 0.23762083, 0.23396946],
[0. , 0. , 0.34093422, 0.27898848, 0.27651772],
[0. , 0. , 0.45288125, 0.35008961, 0.34887788],
[0. , 0. , 0.48076251, 0.36878952, 0.36778417],
[0. , 0. , 0.47798249, 0.36362219, 0.36145973],
[0. , 0. , 0.46186113, 0.33865979, 0.33597934],
[0. , 0. , 0.45264384, 0.33152157, 0.32891783],
[0. , 0. , 0.40986338, 0.29646468, 0.2945672 ],
[0. , 0. , 0.35628179, 0.23356403, 0.23155804],
[0. , 0. , 0.30870566, 0.1780673 , 0.17637439],
[0. , 0. , 0.25293985, 0.10710219, 0.10622486],
[0. , 0. , 0.18743332, 0.03252602, 0.03244236],
[0.02340254, 0.02364671, 0.15736724, 0. , 0. ]])
BROW2 = np.array([
[0. , 0. , 0.09799323, 0.05944436, 0.05002545],
[0. , 0. , 0.09780276, 0.07674237, 0.01636653],
[0. , 0. , 0.11136199, 0.1027964 , 0.04249811],
[0. , 0. , 0.26883412, 0.15861984, 0.15832305],
[0. , 0. , 0.42191629, 0.27038204, 0.27007768],
[0. , 0. , 0.3404977 , 0.21633868, 0.21597538],
[0. , 0. , 0.27301185, 0.17176409, 0.17134669],
[0. , 0. , 0.25960442, 0.15670464, 0.15622253],
[0. , 0. , 0.22877269, 0.11805892, 0.11754539],
[0. , 0. , 0.1451605 , 0.06389034, 0.0636282 ]])
BROW3 = np.array([
[0. , 0. , 0.124 , 0.0295, 0.0295],
[0. , 0. , 0.267 , 0.184 , 0.184 ],
[0. , 0. , 0.359 , 0.2765, 0.2765],
[0. , 0. , 0.3945, 0.3125, 0.3125],
[0. , 0. , 0.4125, 0.331 , 0.331 ],
[0. , 0. , 0.4235, 0.3445, 0.3445],
[0. , 0. , 0.4085, 0.3305, 0.3305],
[0. , 0. , 0.3695, 0.294 , 0.294 ],
[0. , 0. , 0.2835, 0.213 , 0.213 ],
[0. , 0. , 0.1795, 0.1005, 0.1005],
[0. , 0. , 0.108 , 0.014 , 0.014 ]])
import numpy as np
from scipy.ndimage import label
def apply_random_brow_movement(input_exp, volume):
FRAME_SEGMENT = 150
HOLD_THRESHOLD = 10
VOLUME_THRESHOLD = 0.08
MIN_REGION_LENGTH = 6
STRENGTH_RANGE = (0.7, 1.3)
BROW_PEAKS = {
0: np.argmax(BROW1[:, 2]),
1: np.argmax(BROW2[:, 2])
}
for seg_start in range(0, len(volume), FRAME_SEGMENT):
seg_end = min(seg_start + FRAME_SEGMENT, len(volume))
seg_volume = volume[seg_start:seg_end]
candidate_regions = []
high_vol_mask = seg_volume > VOLUME_THRESHOLD
labeled_array, num_features = label(high_vol_mask)
for i in range(1, num_features + 1):
region = (labeled_array == i)
region_indices = np.where(region)[0]
if len(region_indices) >= MIN_REGION_LENGTH:
candidate_regions.append(region_indices)
if candidate_regions:
selected_region = candidate_regions[np.random.choice(len(candidate_regions))]
region_start = selected_region[0]
region_end = selected_region[-1]
region_length = region_end - region_start + 1
brow_idx = np.random.randint(0, 2)
base_brow = BROW1 if brow_idx == 0 else BROW2
peak_idx = BROW_PEAKS[brow_idx]
if region_length > HOLD_THRESHOLD:
local_max_pos = seg_volume[selected_region].argmax()
global_peak_frame = seg_start + selected_region[local_max_pos]
rise_anim = base_brow[:peak_idx + 1]
hold_frame = base_brow[peak_idx:peak_idx + 1]
insert_start = max(global_peak_frame - peak_idx, seg_start)
insert_end = min(global_peak_frame + (region_length - local_max_pos), seg_end)
strength = np.random.uniform(*STRENGTH_RANGE)
if insert_start + len(rise_anim) <= seg_end:
input_exp[insert_start:insert_start + len(rise_anim), :5] += rise_anim * strength
hold_duration = insert_end - (insert_start + len(rise_anim))
if hold_duration > 0:
input_exp[insert_start + len(rise_anim):insert_end, :5] += np.tile(hold_frame * strength,
(hold_duration, 1))
else:
anim_length = base_brow.shape[0]
insert_pos = seg_start + region_start + (region_length - anim_length) // 2
insert_pos = max(seg_start, min(insert_pos, seg_end - anim_length))
if insert_pos + anim_length <= seg_end:
strength = np.random.uniform(*STRENGTH_RANGE)
input_exp[insert_pos:insert_pos + anim_length, :5] += base_brow * strength
return np.clip(input_exp, 0, 1)

10
requirements.txt Normal file
View File

@@ -0,0 +1,10 @@
spleeter==2.4.2
opencv_python_headless==4.11.0.86
gradio==5.25.2
omegaconf==2.3.0
addict==2.4.0
yapf==0.40.1
librosa==0.11.0
transformers==4.36.2
termcolor==3.0.1
numpy==1.26.3

View File

@@ -0,0 +1,9 @@
# install torch 2.1.2
# or conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
# install dependencies
pip install -r requirements.txt
# install H5-render
pip install wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl

View File

@@ -0,0 +1,9 @@
# install torch 2.1.2
# or conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
# install dependencies
pip install -r requirements.txt
# install H5-render
pip install wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl

0
utils/__init__.py Normal file
View File

53
utils/cache.py Normal file
View File

@@ -0,0 +1,53 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import SharedArray
try:
from multiprocessing.shared_memory import ShareableList
except ImportError:
import warnings
warnings.warn("Please update python version >= 3.8 to enable shared_memory")
import numpy as np
def shared_array(name, var=None):
if var is not None:
# check exist
if os.path.exists(f"/dev/shm/{name}"):
return SharedArray.attach(f"shm://{name}")
# create shared_array
data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype)
data[...] = var[...]
data.flags.writeable = False
else:
data = SharedArray.attach(f"shm://{name}").copy()
return data
def shared_dict(name, var=None):
name = str(name)
assert "." not in name # '.' is used as sep flag
data = {}
if var is not None:
assert isinstance(var, dict)
keys = var.keys()
# current version only cache np.array
keys_valid = []
for key in keys:
if isinstance(var[key], np.ndarray):
keys_valid.append(key)
keys = keys_valid
ShareableList(sequence=keys, name=name + ".keys")
for key in keys:
if isinstance(var[key], np.ndarray):
data[key] = shared_array(name=f"{name}.{key}", var=var[key])
else:
keys = list(ShareableList(name=name + ".keys"))
for key in keys:
data[key] = shared_array(name=f"{name}.{key}")
return data

192
utils/comm.py Normal file
View File

@@ -0,0 +1,192 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import functools
import numpy as np
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
"""
A torch process group which only includes processes that on the same machine as the current process.
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
"""
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert (
_LOCAL_PROCESS_GROUP is not None
), "Local process group is not created! Please use launch() to spawn processes!"
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def is_main_process() -> bool:
return get_rank() == 0
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
if dist.get_backend() == dist.Backend.NCCL:
# This argument is needed to avoid warnings.
# It's valid only for NCCL backend.
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
else:
return dist.group.WORLD
def all_gather(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if get_world_size() == 1:
return [data]
if group is None:
group = (
_get_global_gloo_group()
) # use CPU group by default, to reduce GPU RAM usage.
world_size = dist.get_world_size(group)
if world_size == 1:
return [data]
output = [None for _ in range(world_size)]
dist.all_gather_object(output, data, group=group)
return output
def gather(data, dst=0, group=None):
"""
Run gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
dst (int): destination rank
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: on dst, a list of data gathered from each rank. Otherwise,
an empty list.
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
world_size = dist.get_world_size(group=group)
if world_size == 1:
return [data]
rank = dist.get_rank(group=group)
if rank == dst:
output = [None for _ in range(world_size)]
dist.gather_object(data, output, dst=dst, group=group)
return output
else:
dist.gather_object(data, None, dst=dst, group=group)
return []
def shared_random_seed():
"""
Returns:
int: a random number that is the same across all workers.
If workers need a shared RNG, they can use this shared seed to
create one.
All workers must call this function, otherwise it will deadlock.
"""
ints = np.random.randint(2**31)
all_ints = all_gather(ints)
return all_ints[0]
def reduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict

696
utils/config.py Normal file
View File

@@ -0,0 +1,696 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import ast
import copy
import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
from .misc import import_modules_from_strings
from .path import check_file_exist
if platform.system() == "Windows":
import regex as re
else:
import re
BASE_KEY = "_base_"
DELETE_KEY = "_delete_"
DEPRECATION_KEY = "_deprecation_"
RESERVED_KEYS = ["filename", "text", "pretty_text"]
class ConfigDict(Dict):
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError(
f"'{self.__class__.__name__}' object has no " f"attribute '{name}'"
)
except Exception as e:
ex = e
else:
return value
raise ex
def add_args(parser, cfg, prefix=""):
for k, v in cfg.items():
if isinstance(v, str):
parser.add_argument("--" + prefix + k)
elif isinstance(v, int):
parser.add_argument("--" + prefix + k, type=int)
elif isinstance(v, float):
parser.add_argument("--" + prefix + k, type=float)
elif isinstance(v, bool):
parser.add_argument("--" + prefix + k, action="store_true")
elif isinstance(v, dict):
add_args(parser, v, prefix + k + ".")
elif isinstance(v, abc.Iterable):
parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+")
else:
print(f"cannot parse key {prefix + k} of type {type(v)}")
return parser
class Config:
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
def _validate_py_syntax(filename):
with open(filename, "r", encoding="utf-8") as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError(
"There are syntax errors in config " f"file {filename}: {e}"
)
@staticmethod
def _substitute_predefined_vars(filename, temp_config_name):
file_dirname = osp.dirname(filename)
file_basename = osp.basename(filename)
file_basename_no_extension = osp.splitext(file_basename)[0]
file_extname = osp.splitext(filename)[1]
support_templates = dict(
fileDirname=file_dirname,
fileBasename=file_basename,
fileBasenameNoExtension=file_basename_no_extension,
fileExtname=file_extname,
)
with open(filename, "r", encoding="utf-8") as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
for key, value in support_templates.items():
regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
value = value.replace("\\", "/")
config_file = re.sub(regexp, value, config_file)
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
tmp_config_file.write(config_file)
@staticmethod
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, "r", encoding="utf-8") as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
base_vars = set(re.findall(regexp, config_file))
for base_var in base_vars:
randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
base_var_dict[randstr] = base_var
regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
config_file = re.sub(regexp, f'"{randstr}"', config_file)
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
tmp_config_file.write(config_file)
return base_var_dict
@staticmethod
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
"""Substitute variable strings to their actual values."""
cfg = copy.deepcopy(cfg)
if isinstance(cfg, dict):
for k, v in cfg.items():
if isinstance(v, str) and v in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[v].split("."):
new_v = new_v[new_k]
cfg[k] = new_v
elif isinstance(v, (list, tuple, dict)):
cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
elif isinstance(cfg, tuple):
cfg = tuple(
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
)
elif isinstance(cfg, list):
cfg = [
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
]
elif isinstance(cfg, str) and cfg in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[cfg].split("."):
new_v = new_v[new_k]
cfg = new_v
return cfg
@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
fileExtname = osp.splitext(filename)[1]
if fileExtname not in [".py", ".json", ".yaml", ".yml"]:
raise IOError("Only py/yml/yaml/json type are supported now!")
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix=fileExtname
)
if platform.system() == "Windows":
temp_config_file.close()
temp_config_name = osp.basename(temp_config_file.name)
# Substitute predefined variables
if use_predefined_variables:
Config._substitute_predefined_vars(filename, temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
# Substitute base variables from placeholders to strings
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name
)
if filename.endswith(".py"):
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
Config._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith("__")
}
# delete imported module
del sys.modules[temp_module_name]
elif filename.endswith((".yml", ".yaml", ".json")):
raise NotImplementedError
# close temp file
temp_config_file.close()
# check deprecation information
if DEPRECATION_KEY in cfg_dict:
deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
warning_msg = (
f"The config file {filename} will be deprecated " "in the future."
)
if "expected" in deprecation_info:
warning_msg += f' Please use {deprecation_info["expected"]} ' "instead."
if "reference" in deprecation_info:
warning_msg += (
" More information can be found at "
f'{deprecation_info["reference"]}'
)
warnings.warn(warning_msg)
cfg_text = filename + "\n"
with open(filename, "r", encoding="utf-8") as f:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = (
base_filename if isinstance(base_filename, list) else [base_filename]
)
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
base_cfg_dict = dict()
for c in cfg_dict_list:
duplicate_keys = base_cfg_dict.keys() & c.keys()
if len(duplicate_keys) > 0:
raise KeyError(
"Duplicate key is not allowed among bases. "
f"Duplicate keys: {duplicate_keys}"
)
base_cfg_dict.update(c)
# Substitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(
cfg_dict, base_var_dict, base_cfg_dict
)
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = "\n".join(cfg_text_list)
return cfg_dict, cfg_text
@staticmethod
def _merge_a_into_b(a, b, allow_list_keys=False):
"""merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
Args:
a (dict): The source dict to be merged into ``b``.
b (dict): The origin dict to be fetch keys from ``a``.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in source ``a`` and will replace the element of the
corresponding index in b if b is a list. Default: False.
Returns:
dict: The modified dict of ``b`` using ``a``.
Examples:
# Normally merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# Delete b first and merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# b is a list
>>> Config._merge_a_into_b(
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
[{'a': 2}, {'b': 2}]
"""
b = b.copy()
for k, v in a.items():
if allow_list_keys and k.isdigit() and isinstance(b, list):
k = int(k)
if len(b) <= k:
raise KeyError(f"Index {k} exceeds the length of list {b}")
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
allowed_types = (dict, list) if allow_list_keys else dict
if not isinstance(b[k], allowed_types):
raise TypeError(
f"{k}={v} in child config cannot inherit from base "
f"because {k} is a dict in the child config but is of "
f"type {type(b[k])} in base config. You may set "
f"`{DELETE_KEY}=True` to ignore the base config"
)
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
else:
b[k] = v
return b
@staticmethod
def fromfile(filename, use_predefined_variables=True, import_custom_modules=True):
cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
if import_custom_modules and cfg_dict.get("custom_imports", None):
import_modules_from_strings(**cfg_dict["custom_imports"])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
@staticmethod
def fromstring(cfg_str, file_format):
"""Generate config from config str.
Args:
cfg_str (str): Config str.
file_format (str): Config file format corresponding to the
config str. Only py/yml/yaml/json type are supported now!
Returns:
obj:`Config`: Config obj.
"""
if file_format not in [".py", ".json", ".yaml", ".yml"]:
raise IOError("Only py/yml/yaml/json type are supported now!")
if file_format != ".py" and "dict(" in cfg_str:
# check if users specify a wrong suffix for python
warnings.warn('Please check "file_format", the file format may be .py')
with tempfile.NamedTemporaryFile(
"w", encoding="utf-8", suffix=file_format, delete=False
) as temp_file:
temp_file.write(cfg_str)
# on windows, previous implementation cause error
# see PR 1077 for details
cfg = Config.fromfile(temp_file.name)
os.remove(temp_file.name)
return cfg
@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)"""
partial_parser = ArgumentParser(description=description)
partial_parser.add_argument("config", help="config file path")
cfg_file = partial_parser.parse_known_args()[0].config
cfg = Config.fromfile(cfg_file)
parser = ArgumentParser(description=description)
parser.add_argument("config", help="config file path")
add_args(parser, cfg)
return parser, cfg
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
for key in cfg_dict:
if key in RESERVED_KEYS:
raise KeyError(f"{key} is reserved for config file")
super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
super(Config, self).__setattr__("_filename", filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, "r") as f:
text = f.read()
else:
text = ""
super(Config, self).__setattr__("_text", text)
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
@property
def pretty_text(self):
indent = 4
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = f"'{v}'"
else:
v_str = str(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: {v_str}"
else:
attr_str = f"{str(k)}={v_str}"
attr_str = _indent(attr_str, indent)
return attr_str
def _format_list(k, v, use_mapping=False):
# check if all items in the list are dict
if all(isinstance(_, dict) for _ in v):
v_str = "[\n"
v_str += "\n".join(
f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
).rstrip(",")
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: {v_str}"
else:
attr_str = f"{str(k)}={v_str}"
attr_str = _indent(attr_str, indent) + "]"
else:
attr_str = _format_basic_types(k, v, use_mapping)
return attr_str
def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= not str(key_name).isidentifier()
return contain_invalid_identifier
def _format_dict(input_dict, outest_level=False):
r = ""
s = []
use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += "{"
for idx, (k, v) in enumerate(input_dict.items()):
is_last = idx >= len(input_dict) - 1
end = "" if outest_level or is_last else ","
if isinstance(v, dict):
v_str = "\n" + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: dict({v_str}"
else:
attr_str = f"{str(k)}=dict({v_str}"
attr_str = _indent(attr_str, indent) + ")" + end
elif isinstance(v, list):
attr_str = _format_list(k, v, use_mapping) + end
else:
attr_str = _format_basic_types(k, v, use_mapping) + end
s.append(attr_str)
r += "\n".join(s)
if use_mapping:
r += "}"
return r
cfg_dict = self._cfg_dict.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
# copied from setup.cfg
yapf_style = dict(
based_on_style="pep8",
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True,
)
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
return text
def __repr__(self):
return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
def __getstate__(self):
return (self._cfg_dict, self._filename, self._text)
def __setstate__(self, state):
_cfg_dict, _filename, _text = state
super(Config, self).__setattr__("_cfg_dict", _cfg_dict)
super(Config, self).__setattr__("_filename", _filename)
super(Config, self).__setattr__("_text", _text)
def dump(self, file=None):
cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict()
if self.filename.endswith(".py"):
if file is None:
return self.pretty_text
else:
with open(file, "w", encoding="utf-8") as f:
f.write(self.pretty_text)
else:
import mmcv
if file is None:
file_format = self.filename.split(".")[-1]
return mmcv.dump(cfg_dict, file_format=file_format)
else:
mmcv.dump(cfg_dict, file)
def merge_from_dict(self, options, allow_list_keys=True):
"""Merge list into cfg_dict.
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'models.backbone.depth': 50,
... 'models.backbone.with_cp':True}
>>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... models=dict(backbone=dict(depth=50, with_cp=True)))
# Merge list element
>>> cfg = Config(dict(pipeline=[
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
>>> cfg.merge_from_dict(options, allow_list_keys=True)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(pipeline=[
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
Args:
options (dict): dict of configs to merge from.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in ``options`` and will replace the element of the
corresponding index in the config if the config is a list.
Default: True.
"""
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split(".")
for subkey in key_list[:-1]:
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__("_cfg_dict")
super(Config, self).__setattr__(
"_cfg_dict",
Config._merge_a_into_b(
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys
),
)
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ["true", "false"]:
return True if val.lower() == "true" else False
return val
@staticmethod
def _parse_iterable(val):
"""Parse iterable values in the string.
All elements inside '()' or '[]' are treated as iterable values.
Args:
val (str): Value string.
Returns:
list | tuple: The expanded list or tuple from the string.
Examples:
>>> DictAction._parse_iterable('1,2,3')
[1, 2, 3]
>>> DictAction._parse_iterable('[a, b, c]')
['a', 'b', 'c']
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
[(1, 2, 3), ['a', 'b'], 'c']
"""
def find_next_comma(string):
"""Find the position of next comma in the string.
If no ',' is found in the string, return the string length. All
chars inside '()' and '[]' are treated as one element and thus ','
inside these brackets are ignored.
"""
assert (string.count("(") == string.count(")")) and (
string.count("[") == string.count("]")
), f"Imbalanced brackets exist in {string}"
end = len(string)
for idx, char in enumerate(string):
pre = string[:idx]
# The string before this ',' is balanced
if (
(char == ",")
and (pre.count("(") == pre.count(")"))
and (pre.count("[") == pre.count("]"))
):
end = idx
break
return end
# Strip ' and " characters and replace whitespace.
val = val.strip("'\"").replace(" ", "")
is_tuple = False
if val.startswith("(") and val.endswith(")"):
is_tuple = True
val = val[1:-1]
elif val.startswith("[") and val.endswith("]"):
val = val[1:-1]
elif "," not in val:
# val is a single value
return DictAction._parse_int_float_bool(val)
values = []
while len(val) > 0:
comma_idx = find_next_comma(val)
element = DictAction._parse_iterable(val[:comma_idx])
values.append(element)
val = val[comma_idx + 1 :]
if is_tuple:
values = tuple(values)
return values
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split("=", maxsplit=1)
options[key] = self._parse_iterable(val)
setattr(namespace, self.dest, options)

33
utils/env.py Normal file
View File

@@ -0,0 +1,33 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from datetime import datetime
def get_random_seed():
seed = (
os.getpid()
+ int(datetime.now().strftime("%S%f"))
+ int.from_bytes(os.urandom(2), "big")
)
return seed
def set_seed(seed=None):
if seed is None:
seed = get_random_seed()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.benchmark = False
cudnn.deterministic = True
os.environ["PYTHONHASHSEED"] = str(seed)

585
utils/events.py Normal file
View File

@@ -0,0 +1,585 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import datetime
import json
import logging
import os
import time
import torch
import numpy as np
from typing import List, Optional, Tuple
from collections import defaultdict
from contextlib import contextmanager
__all__ = [
"get_event_storage",
"JSONWriter",
"TensorboardXWriter",
"CommonMetricPrinter",
"EventStorage",
]
_CURRENT_STORAGE_STACK = []
def get_event_storage():
"""
Returns:
The :class:`EventStorage` object that's currently being used.
Throws an error if no :class:`EventStorage` is currently enabled.
"""
assert len(
_CURRENT_STORAGE_STACK
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
return _CURRENT_STORAGE_STACK[-1]
class EventWriter:
"""
Base class for writers that obtain events from :class:`EventStorage` and process them.
"""
def write(self):
raise NotImplementedError
def close(self):
pass
class JSONWriter(EventWriter):
"""
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Examples parsing such a json file:
::
$ cat metrics.json | jq -s '.[0:2]'
[
{
"data_time": 0.008433341979980469,
"iteration": 19,
"loss": 1.9228371381759644,
"loss_box_reg": 0.050025828182697296,
"loss_classifier": 0.5316952466964722,
"loss_mask": 0.7236229181289673,
"loss_rpn_box": 0.0856662318110466,
"loss_rpn_cls": 0.48198649287223816,
"lr": 0.007173333333333333,
"time": 0.25401854515075684
},
{
"data_time": 0.007216215133666992,
"iteration": 39,
"loss": 1.282649278640747,
"loss_box_reg": 0.06222952902317047,
"loss_classifier": 0.30682939291000366,
"loss_mask": 0.6970193982124329,
"loss_rpn_box": 0.038663312792778015,
"loss_rpn_cls": 0.1471673548221588,
"lr": 0.007706666666666667,
"time": 0.2490077018737793
}
]
$ cat metrics.json | jq '.loss_mask'
0.7126231789588928
0.689423680305481
0.6776131987571716
...
"""
def __init__(self, json_file, window_size=20):
"""
Args:
json_file (str): path to the json file. New data will be appended if the file exists.
window_size (int): the window size of median smoothing for the scalars whose
`smoothing_hint` are True.
"""
self._file_handle = open(json_file, "a")
self._window_size = window_size
self._last_write = -1
def write(self):
storage = get_event_storage()
to_save = defaultdict(dict)
for k, (v, iter) in storage.latest_with_smoothing_hint(
self._window_size
).items():
# keep scalars that have not been written
if iter <= self._last_write:
continue
to_save[iter][k] = v
if len(to_save):
all_iters = sorted(to_save.keys())
self._last_write = max(all_iters)
for itr, scalars_per_iter in to_save.items():
scalars_per_iter["iteration"] = itr
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
self._file_handle.flush()
try:
os.fsync(self._file_handle.fileno())
except AttributeError:
pass
def close(self):
self._file_handle.close()
class TensorboardXWriter(EventWriter):
"""
Write all scalars to a tensorboard file.
"""
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
"""
Args:
log_dir (str): the directory to save the output events
window_size (int): the scalars will be median-smoothed by this window size
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
"""
self._window_size = window_size
from torch.utils.tensorboard import SummaryWriter
self._writer = SummaryWriter(log_dir, **kwargs)
self._last_write = -1
def write(self):
storage = get_event_storage()
new_last_write = self._last_write
for k, (v, iter) in storage.latest_with_smoothing_hint(
self._window_size
).items():
if iter > self._last_write:
self._writer.add_scalar(k, v, iter)
new_last_write = max(new_last_write, iter)
self._last_write = new_last_write
# storage.put_{image,histogram} is only meant to be used by
# tensorboard writer. So we access its internal fields directly from here.
if len(storage._vis_data) >= 1:
for img_name, img, step_num in storage._vis_data:
self._writer.add_image(img_name, img, step_num)
# Storage stores all image data and rely on this writer to clear them.
# As a result it assumes only one writer will use its image data.
# An alternative design is to let storage store limited recent
# data (e.g. only the most recent image) that all writers can access.
# In that case a writer may not see all image data if its period is long.
storage.clear_images()
if len(storage._histograms) >= 1:
for params in storage._histograms:
self._writer.add_histogram_raw(**params)
storage.clear_histograms()
def close(self):
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
self._writer.close()
class CommonMetricPrinter(EventWriter):
"""
Print **common** metrics to the terminal, including
iteration time, ETA, memory, all losses, and the learning rate.
It also applies smoothing using a window of 20 elements.
It's meant to print common metrics in common ways.
To print something in more customized ways, please implement a similar printer by yourself.
"""
def __init__(self, max_iter: Optional[int] = None, window_size: int = 20):
"""
Args:
max_iter: the maximum number of iterations to train.
Used to compute ETA. If not given, ETA will not be printed.
window_size (int): the losses will be median-smoothed by this window size
"""
self.logger = logging.getLogger(__name__)
self._max_iter = max_iter
self._window_size = window_size
self._last_write = (
None # (step, time) of last call to write(). Used to compute ETA
)
def _get_eta(self, storage) -> Optional[str]:
if self._max_iter is None:
return ""
iteration = storage.iter
try:
eta_seconds = storage.history("time").median(1000) * (
self._max_iter - iteration - 1
)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
return str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError:
# estimate eta on our own - more noisy
eta_string = None
if self._last_write is not None:
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
iteration - self._last_write[0]
)
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
self._last_write = (iteration, time.perf_counter())
return eta_string
def write(self):
storage = get_event_storage()
iteration = storage.iter
if iteration == self._max_iter:
# This hook only reports training progress (loss, ETA, etc) but not other data,
# therefore do not write anything after training succeeds, even if this method
# is called.
return
try:
data_time = storage.history("data_time").avg(20)
except KeyError:
# they may not exist in the first few iterations (due to warmup)
# or when SimpleTrainer is not used
data_time = None
try:
iter_time = storage.history("time").global_avg()
except KeyError:
iter_time = None
try:
lr = "{:.5g}".format(storage.history("lr").latest())
except KeyError:
lr = "N/A"
eta_string = self._get_eta(storage)
if torch.cuda.is_available():
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
else:
max_mem_mb = None
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
" {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
eta=f"eta: {eta_string} " if eta_string else "",
iter=iteration,
losses=" ".join(
[
"{}: {:.4g}".format(k, v.median(self._window_size))
for k, v in storage.histories().items()
if "loss" in k
]
),
time="time: {:.4f} ".format(iter_time)
if iter_time is not None
else "",
data_time="data_time: {:.4f} ".format(data_time)
if data_time is not None
else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb)
if max_mem_mb is not None
else "",
)
)
class EventStorage:
"""
The user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
"""
def __init__(self, start_iter=0):
"""
Args:
start_iter (int): the iteration number to start with
"""
self._history = defaultdict(AverageMeter)
self._smoothing_hints = {}
self._latest_scalars = {}
self._iter = start_iter
self._current_prefix = ""
self._vis_data = []
self._histograms = []
# def put_image(self, img_name, img_tensor):
# """
# Add an `img_tensor` associated with `img_name`, to be shown on
# tensorboard.
# Args:
# img_name (str): The name of the image to put into tensorboard.
# img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
# Tensor of shape `[channel, height, width]` where `channel` is
# 3. The image format should be RGB. The elements in img_tensor
# can either have values in [0, 1] (float32) or [0, 255] (uint8).
# The `img_tensor` will be visualized in tensorboard.
# """
# self._vis_data.append((img_name, img_tensor, self._iter))
def put_scalar(self, name, value, n=1, smoothing_hint=False):
"""
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
Args:
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
smoothed when logged. The hint will be accessible through
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
and apply custom smoothing rule.
It defaults to True because most scalars we save need to be smoothed to
provide any useful signal.
"""
name = self._current_prefix + name
history = self._history[name]
history.update(value, n)
self._latest_scalars[name] = (value, self._iter)
existing_hint = self._smoothing_hints.get(name)
if existing_hint is not None:
assert (
existing_hint == smoothing_hint
), "Scalar {} was put with a different smoothing_hint!".format(name)
else:
self._smoothing_hints[name] = smoothing_hint
# def put_scalars(self, *, smoothing_hint=True, **kwargs):
# """
# Put multiple scalars from keyword arguments.
# Examples:
# storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
# """
# for k, v in kwargs.items():
# self.put_scalar(k, v, smoothing_hint=smoothing_hint)
#
# def put_histogram(self, hist_name, hist_tensor, bins=1000):
# """
# Create a histogram from a tensor.
# Args:
# hist_name (str): The name of the histogram to put into tensorboard.
# hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
# into a histogram.
# bins (int): Number of histogram bins.
# """
# ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
#
# # Create a histogram with PyTorch
# hist_counts = torch.histc(hist_tensor, bins=bins)
# hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
#
# # Parameter for the add_histogram_raw function of SummaryWriter
# hist_params = dict(
# tag=hist_name,
# min=ht_min,
# max=ht_max,
# num=len(hist_tensor),
# sum=float(hist_tensor.sum()),
# sum_squares=float(torch.sum(hist_tensor**2)),
# bucket_limits=hist_edges[1:].tolist(),
# bucket_counts=hist_counts.tolist(),
# global_step=self._iter,
# )
# self._histograms.append(hist_params)
def history(self, name):
"""
Returns:
AverageMeter: the history for name
"""
ret = self._history.get(name, None)
if ret is None:
raise KeyError("No history metric available for {}!".format(name))
return ret
def histories(self):
"""
Returns:
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
"""
return self._history
def latest(self):
"""
Returns:
dict[str -> (float, int)]: mapping from the name of each scalar to the most
recent value and the iteration number its added.
"""
return self._latest_scalars
def latest_with_smoothing_hint(self, window_size=20):
"""
Similar to :meth:`latest`, but the returned values
are either the un-smoothed original latest value,
or a median of the given window_size,
depend on whether the smoothing_hint is True.
This provides a default behavior that other writers can use.
"""
result = {}
for k, (v, itr) in self._latest_scalars.items():
result[k] = (
self._history[k].median(window_size) if self._smoothing_hints[k] else v,
itr,
)
return result
def smoothing_hints(self):
"""
Returns:
dict[name -> bool]: the user-provided hint on whether the scalar
is noisy and needs smoothing.
"""
return self._smoothing_hints
def step(self):
"""
User should either: (1) Call this function to increment storage.iter when needed. Or
(2) Set `storage.iter` to the correct iteration number before each iteration.
The storage will then be able to associate the new data with an iteration number.
"""
self._iter += 1
@property
def iter(self):
"""
Returns:
int: The current iteration number. When used together with a trainer,
this is ensured to be the same as trainer.iter.
"""
return self._iter
@iter.setter
def iter(self, val):
self._iter = int(val)
@property
def iteration(self):
# for backward compatibility
return self._iter
def __enter__(self):
_CURRENT_STORAGE_STACK.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
assert _CURRENT_STORAGE_STACK[-1] == self
_CURRENT_STORAGE_STACK.pop()
@contextmanager
def name_scope(self, name):
"""
Yields:
A context within which all the events added to this storage
will be prefixed by the name scope.
"""
old_prefix = self._current_prefix
self._current_prefix = name.rstrip("/") + "/"
yield
self._current_prefix = old_prefix
def clear_images(self):
"""
Delete all the stored images for visualization. This should be called
after images are written to tensorboard.
"""
self._vis_data = []
def clear_histograms(self):
"""
Delete all the stored histograms for visualization.
This should be called after histograms are written to tensorboard.
"""
self._histograms = []
def reset_history(self, name):
ret = self._history.get(name, None)
if ret is None:
raise KeyError("No history metric available for {}!".format(name))
ret.reset()
def reset_histories(self):
for name in self._history.keys():
self._history[name].reset()
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.total = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.total = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.total += val * n
self.count += n
self.avg = self.total / self.count
class HistoryBuffer:
"""
Track a series of scalar values and provide access to smoothed values over a
window or the global average of the series.
"""
def __init__(self, max_length: int = 1000000) -> None:
"""
Args:
max_length: maximal number of values that can be stored in the
buffer. When the capacity of the buffer is exhausted, old
values will be removed.
"""
self._max_length: int = max_length
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
self._count: int = 0
self._global_avg: float = 0
def update(self, value: float, iteration: Optional[float] = None) -> None:
"""
Add a new scalar value produced at certain iteration. If the length
of the buffer exceeds self._max_length, the oldest element will be
removed from the buffer.
"""
if iteration is None:
iteration = self._count
if len(self._data) == self._max_length:
self._data.pop(0)
self._data.append((value, iteration))
self._count += 1
self._global_avg += (value - self._global_avg) / self._count
def latest(self) -> float:
"""
Return the latest scalar value added to the buffer.
"""
return self._data[-1][0]
def median(self, window_size: int) -> float:
"""
Return the median of the latest `window_size` values in the buffer.
"""
return np.median([x[0] for x in self._data[-window_size:]])
def avg(self, window_size: int) -> float:
"""
Return the mean of the latest `window_size` values in the buffer.
"""
return np.mean([x[0] for x in self._data[-window_size:]])
def global_avg(self) -> float:
"""
Return the mean of all the elements in the buffer. Note that this
includes those getting removed due to limited buffer storage.
"""
return self._global_avg
def values(self) -> List[Tuple[float, float]]:
"""
Returns:
list[(number, iteration)]: content of the current buffer.
"""
return self._data

167
utils/logger.py Normal file
View File

@@ -0,0 +1,167 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import logging
import torch
import torch.distributed as dist
from termcolor import colored
logger_initialized = {}
root_status = 0
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified and the process rank is 0, a FileHandler
will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
file_mode (str): The file mode used in opening log file.
Defaults to 'a'.
color (bool): Colorful log output. Defaults to True
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
logger.propagate = False
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# only rank 0 will add a FileHandler
if rank == 0 and log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
plain_formatter = logging.Formatter(
"[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
)
else:
formatter = plain_formatter
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
if rank == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
def print_log(msg, logger=None, level=logging.INFO):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (logging.Logger | str | None): The logger to be used.
Some special loggers are:
- "silent": no message will be printed.
- other str: the logger obtained with `get_root_logger(logger)`.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if logger is None:
print(msg)
elif isinstance(logger, logging.Logger):
logger.log(level, msg)
elif logger == "silent":
pass
elif isinstance(logger, str):
_logger = get_logger(logger)
_logger.log(level, msg)
else:
raise TypeError(
"logger should be either a logging.Logger object, str, "
f'"silent" or None, but got {type(logger)}'
)
def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name.
Args:
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
file_mode (str): File Mode of logger. (w or a)
Returns:
logging.Logger: The root logger.
"""
logger = get_logger(
name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode
)
return logger
def _log_api_usage(identifier: str):
"""
Internal function used to log the usage of different detectron2 components
inside facebook's infra.
"""
torch._C._log_api_usage_once("pointcept." + identifier)

156
utils/misc.py Normal file
View File

@@ -0,0 +1,156 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import warnings
from collections import abc
import numpy as np
import torch
from importlib import import_module
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def intersection_and_union(output, target, K, ignore_index=-1):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert output.ndim in [1, 2, 3]
assert output.shape == target.shape
output = output.reshape(output.size).copy()
target = target.reshape(target.size)
output[np.where(target == ignore_index)[0]] = ignore_index
intersection = output[np.where(output == target)[0]]
area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
area_output, _ = np.histogram(output, bins=np.arange(K + 1))
area_target, _ = np.histogram(target, bins=np.arange(K + 1))
area_union = area_output + area_target - area_intersection
return area_intersection, area_union, area_target
def intersection_and_union_gpu(output, target, k, ignore_index=-1):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert output.dim() in [1, 2, 3]
assert output.shape == target.shape
output = output.view(-1)
target = target.view(-1)
output[target == ignore_index] = ignore_index
intersection = output[output == target]
area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1)
area_output = torch.histc(output, bins=k, min=0, max=k - 1)
area_target = torch.histc(target, bins=k, min=0, max=k - 1)
area_union = area_output + area_target - area_intersection
return area_intersection, area_union, area_target
def make_dirs(dir_name):
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
def find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def is_seq_of(seq, expected_type, seq_type=None):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if seq_type is None:
exp_seq_type = abc.Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def is_str(x):
"""Whether the input is an string instance.
Note: This method is deprecated since python 2 is no longer supported.
"""
return isinstance(x, str)
def import_modules_from_strings(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.
Returns:
list[module] | module | None: The imported modules.
Examples:
>>> osp, sys = import_modules_from_strings(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(f"custom_imports must be a list but got type {type(imports)}")
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.")
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
warnings.warn(f"{imp} failed to import and is ignored.", UserWarning)
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported

52
utils/optimizer.py Normal file
View File

@@ -0,0 +1,52 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import torch
from utils.logger import get_root_logger
from utils.registry import Registry
OPTIMIZERS = Registry("optimizers")
OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD")
OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam")
OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW")
def build_optimizer(cfg, model, param_dicts=None):
if param_dicts is None:
cfg.params = model.parameters()
else:
cfg.params = [dict(names=[], params=[], lr=cfg.lr)]
for i in range(len(param_dicts)):
param_group = dict(names=[], params=[])
if "lr" in param_dicts[i].keys():
param_group["lr"] = param_dicts[i].lr
if "momentum" in param_dicts[i].keys():
param_group["momentum"] = param_dicts[i].momentum
if "weight_decay" in param_dicts[i].keys():
param_group["weight_decay"] = param_dicts[i].weight_decay
cfg.params.append(param_group)
for n, p in model.named_parameters():
flag = False
for i in range(len(param_dicts)):
if param_dicts[i].keyword in n:
cfg.params[i + 1]["names"].append(n)
cfg.params[i + 1]["params"].append(p)
flag = True
break
if not flag:
cfg.params[0]["names"].append(n)
cfg.params[0]["params"].append(p)
logger = get_root_logger()
for i in range(len(cfg.params)):
param_names = cfg.params[i].pop("names")
message = ""
for key in cfg.params[i].keys():
if key != "params":
message += f" {key}: {cfg.params[i][key]};"
logger.info(f"Params Group {i+1} -{message} Params: {param_names}.")
return OPTIMIZERS.build(cfg=cfg)

105
utils/path.py Normal file
View File

@@ -0,0 +1,105 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import os.path as osp
from pathlib import Path
from .misc import is_str
def is_filepath(x):
return is_str(x) or isinstance(x, Path)
def fopen(filepath, *args, **kwargs):
if is_str(filepath):
return open(filepath, *args, **kwargs)
elif isinstance(filepath, Path):
return filepath.open(*args, **kwargs)
raise ValueError("`filepath` should be a string or a Path")
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
if not osp.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))
def mkdir_or_exist(dir_name, mode=0o777):
if dir_name == "":
return
dir_name = osp.expanduser(dir_name)
os.makedirs(dir_name, mode=mode, exist_ok=True)
def symlink(src, dst, overwrite=True, **kwargs):
if os.path.lexists(dst) and overwrite:
os.remove(dst)
os.symlink(src, dst, **kwargs)
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
"""Scan a directory to find the interested files.
Args:
dir_path (str | obj:`Path`): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
case_sensitive (bool, optional) : If set to False, ignore the case of
suffix. Default: True.
Returns:
A generator for all the interested files with relative paths.
"""
if isinstance(dir_path, (str, Path)):
dir_path = str(dir_path)
else:
raise TypeError('"dir_path" must be a string or Path object')
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
if suffix is not None and not case_sensitive:
suffix = (
suffix.lower()
if isinstance(suffix, str)
else tuple(item.lower() for item in suffix)
)
root = dir_path
def _scandir(dir_path, suffix, recursive, case_sensitive):
for entry in os.scandir(dir_path):
if not entry.name.startswith(".") and entry.is_file():
rel_path = osp.relpath(entry.path, root)
_rel_path = rel_path if case_sensitive else rel_path.lower()
if suffix is None or _rel_path.endswith(suffix):
yield rel_path
elif recursive and os.path.isdir(entry.path):
# scan recursively if entry.path is a directory
yield from _scandir(entry.path, suffix, recursive, case_sensitive)
return _scandir(dir_path, suffix, recursive, case_sensitive)
def find_vcs_root(path, markers=(".git",)):
"""Finds the root directory (including itself) of specified markers.
Args:
path (str): Path of directory or file.
markers (list[str], optional): List of file or directory names.
Returns:
The directory contained one of the markers or None if not found.
"""
if osp.isfile(path):
path = osp.dirname(path)
prev, cur = None, osp.abspath(osp.expanduser(path))
while cur != prev:
if any(osp.exists(osp.join(cur, marker)) for marker in markers):
return cur
prev, cur = cur, osp.split(cur)[0]
return None

318
utils/registry.py Normal file
View File

@@ -0,0 +1,318 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import inspect
import warnings
from functools import partial
from .misc import is_seq_of
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from configs dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
if "type" not in cfg:
if default_args is None or "type" not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f"but got {cfg}\n{default_args}"
)
if not isinstance(registry, Registry):
raise TypeError(
"registry must be an mmcv.Registry object, " f"but got {type(registry)}"
)
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError(
"default_args must be a dict or None, " f"but got {type(default_args)}"
)
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop("type")
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(f"{obj_type} is not in the {registry.name} registry")
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f"{obj_cls.__name__}: {e}")
class Registry:
"""A registry to map strings to classes.
Registered object could be built from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(type='ResNet'))
Please refer to
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
advanced usage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = (
self.__class__.__name__ + f"(name={self._name}, "
f"items={self._module_dict})"
)
return format_str
@staticmethod
def infer_scope():
"""Infer the scope of registry.
The name of the package where registry is defined will be returned.
Example:
# in mmdet/models/backbone/resnet.py
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
The scope of ``ResNet`` will be ``mmdet``.
Returns:
scope (str): The inferred scope name.
"""
# inspect.stack() trace where this function is called, the index-2
# indicates the frame where `infer_scope()` is called
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
split_filename = filename.split(".")
return split_filename[0]
@staticmethod
def split_scope_key(key):
"""Split scope and key.
The first scope will be split from key.
Examples:
>>> Registry.split_scope_key('mmdet.ResNet')
'mmdet', 'ResNet'
>>> Registry.split_scope_key('ResNet')
None, 'ResNet'
Return:
scope (str, None): The first scope.
key (str): The remaining key.
"""
split_index = key.find(".")
if split_index != -1:
return key[:split_index], key[split_index + 1 :]
else:
return None, key
@property
def name(self):
return self._name
@property
def scope(self):
return self._scope
@property
def module_dict(self):
return self._module_dict
@property
def children(self):
return self._children
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
scope, real_key = self.split_scope_key(key)
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
# get from self._children
if scope in self._children:
return self._children[scope].get(real_key)
else:
# goto root
parent = self.parent
while parent.parent is not None:
parent = parent.parent
return parent.get(key)
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def _add_children(self, registry):
"""Add children for a registry.
The ``registry`` will be added as children based on its scope.
The parent registry could build objects from children registry.
Example:
>>> models = Registry('models')
>>> mmdet_models = Registry('models', parent=models)
>>> @mmdet_models.register_module()
>>> class ResNet:
>>> pass
>>> resnet = models.build(dict(type='mmdet.ResNet'))
"""
assert isinstance(registry, Registry)
assert registry.scope is not None
assert (
registry.scope not in self.children
), f"scope {registry.scope} exists in {self.name} registry"
self.children[registry.scope] = registry
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError("module must be a class, " f"but got {type(module_class)}")
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f"{name} is already registered " f"in {self.name}")
self._module_dict[name] = module_class
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
"The old API of register_module(module, force=False) "
"is deprecated and will be removed, please use the new API "
"register_module(name=None, force=False, module=None) instead."
)
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f"force must be a boolean, but got {type(force)}")
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
"name must be either of None, an instance of str or a sequence"
f" of str, but got {type(name)}"
)
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module_class=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(module_class=cls, module_name=name, force=force)
return cls
return _register

144
utils/scheduler.py Normal file
View File

@@ -0,0 +1,144 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import torch.optim.lr_scheduler as lr_scheduler
from .registry import Registry
SCHEDULERS = Registry("schedulers")
@SCHEDULERS.register_module()
class MultiStepLR(lr_scheduler.MultiStepLR):
def __init__(
self,
optimizer,
milestones,
total_steps,
gamma=0.1,
last_epoch=-1,
verbose=False,
):
super().__init__(
optimizer=optimizer,
milestones=[rate * total_steps for rate in milestones],
gamma=gamma,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
def __init__(
self,
optimizer,
milestones,
total_steps,
gamma=0.1,
warmup_rate=0.05,
warmup_scale=1e-6,
last_epoch=-1,
verbose=False,
):
milestones = [rate * total_steps for rate in milestones]
def multi_step_with_warmup(s):
factor = 1.0
for i in range(len(milestones)):
if s < milestones[i]:
break
factor *= gamma
if s <= warmup_rate * total_steps:
warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (
1 - warmup_scale
)
else:
warmup_coefficient = 1.0
return warmup_coefficient * factor
super().__init__(
optimizer=optimizer,
lr_lambda=multi_step_with_warmup,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class PolyLR(lr_scheduler.LambdaLR):
def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
super().__init__(
optimizer=optimizer,
lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class ExpLR(lr_scheduler.LambdaLR):
def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
super().__init__(
optimizer=optimizer,
lr_lambda=lambda s: gamma ** (s / total_steps),
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
super().__init__(
optimizer=optimizer,
T_max=total_steps,
eta_min=eta_min,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class OneCycleLR(lr_scheduler.OneCycleLR):
r"""
torch.optim.lr_scheduler.OneCycleLR, Block total_steps
"""
def __init__(
self,
optimizer,
max_lr,
total_steps=None,
pct_start=0.3,
anneal_strategy="cos",
cycle_momentum=True,
base_momentum=0.85,
max_momentum=0.95,
div_factor=25.0,
final_div_factor=1e4,
three_phase=False,
last_epoch=-1,
verbose=False,
):
super().__init__(
optimizer=optimizer,
max_lr=max_lr,
total_steps=total_steps,
pct_start=pct_start,
anneal_strategy=anneal_strategy,
cycle_momentum=cycle_momentum,
base_momentum=base_momentum,
max_momentum=max_momentum,
div_factor=div_factor,
final_div_factor=final_div_factor,
three_phase=three_phase,
last_epoch=last_epoch,
verbose=verbose,
)
def build_scheduler(cfg, optimizer):
cfg.optimizer = optimizer
return SCHEDULERS.build(cfg=cfg)

71
utils/timer.py Normal file
View File

@@ -0,0 +1,71 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
from time import perf_counter
from typing import Optional
class Timer:
"""
A timer which computes the time elapsed since the start/reset of the timer.
"""
def __init__(self) -> None:
self.reset()
def reset(self) -> None:
"""
Reset the timer.
"""
self._start = perf_counter()
self._paused: Optional[float] = None
self._total_paused = 0
self._count_start = 1
def pause(self) -> None:
"""
Pause the timer.
"""
if self._paused is not None:
raise ValueError("Trying to pause a Timer that is already paused!")
self._paused = perf_counter()
def is_paused(self) -> bool:
"""
Returns:
bool: whether the timer is currently paused
"""
return self._paused is not None
def resume(self) -> None:
"""
Resume the timer.
"""
if self._paused is None:
raise ValueError("Trying to resume a Timer that is not paused!")
# pyre-fixme[58]: `-` is not supported for operand types `float` and
# `Optional[float]`.
self._total_paused += perf_counter() - self._paused
self._paused = None
self._count_start += 1
def seconds(self) -> float:
"""
Returns:
(float): the total number of seconds since the start/reset of the
timer, excluding the time when the timer is paused.
"""
if self._paused is not None:
end_time: float = self._paused # type: ignore
else:
end_time = perf_counter()
return end_time - self._start - self._total_paused
def avg_seconds(self) -> float:
"""
Returns:
(float): the average number of seconds between every start/reset and
pause.
"""
return self.seconds() / self._count_start

86
utils/visualization.py Normal file
View File

@@ -0,0 +1,86 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import open3d as o3d
import numpy as np
import torch
def to_numpy(x):
if isinstance(x, torch.Tensor):
x = x.clone().detach().cpu().numpy()
assert isinstance(x, np.ndarray)
return x
def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
coord = to_numpy(coord)
if color is not None:
color = to_numpy(color)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(coord)
pcd.colors = o3d.utility.Vector3dVector(
np.ones_like(coord) if color is None else color
)
o3d.io.write_point_cloud(file_path, pcd)
if logger is not None:
logger.info(f"Save Point Cloud to: {file_path}")
def save_bounding_boxes(
bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None
):
bboxes_corners = to_numpy(bboxes_corners)
# point list
points = bboxes_corners.reshape(-1, 3)
# line list
box_lines = np.array(
[
[0, 1],
[1, 2],
[2, 3],
[3, 0],
[4, 5],
[5, 6],
[6, 7],
[7, 0],
[0, 4],
[1, 5],
[2, 6],
[3, 7],
]
)
lines = []
for i, _ in enumerate(bboxes_corners):
lines.append(box_lines + i * 8)
lines = np.concatenate(lines)
# color list
color = np.array([color for _ in range(len(lines))])
# generate line set
line_set = o3d.geometry.LineSet()
line_set.points = o3d.utility.Vector3dVector(points)
line_set.lines = o3d.utility.Vector2iVector(lines)
line_set.colors = o3d.utility.Vector3dVector(color)
o3d.io.write_line_set(file_path, line_set)
if logger is not None:
logger.info(f"Save Boxes to: {file_path}")
def save_lines(
points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None
):
points = to_numpy(points)
lines = to_numpy(lines)
colors = np.array([color for _ in range(len(lines))])
line_set = o3d.geometry.LineSet()
line_set.points = o3d.utility.Vector3dVector(points)
line_set.lines = o3d.utility.Vector2iVector(lines)
line_set.colors = o3d.utility.Vector3dVector(colors)
o3d.io.write_line_set(file_path, line_set)
if logger is not None:
logger.info(f"Save Lines to: {file_path}")

Binary file not shown.