mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 09:29:24 +08:00
feat: Initial commit
This commit is contained in:
18
.gitignore
vendored
Normal file
18
.gitignore
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
image/
|
||||||
|
__pycache__
|
||||||
|
**/build/
|
||||||
|
**/*.egg-info/
|
||||||
|
**/dist/
|
||||||
|
*.so
|
||||||
|
exp
|
||||||
|
weights
|
||||||
|
data
|
||||||
|
log
|
||||||
|
outputs/
|
||||||
|
.vscode
|
||||||
|
.idea
|
||||||
|
*/.DS_Store
|
||||||
|
TEMP/
|
||||||
|
pretrained/
|
||||||
|
**/*.out
|
||||||
|
Dockerfile
|
||||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
103
README.md
Normal file
103
README.md
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# LAM-A2E: Audio to Expression
|
||||||
|
|
||||||
|
[](https://aigc3d.github.io/projects/LAM/)
|
||||||
|
[](https://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
#### This project leverages audio input to generate ARKit blendshapes-driven facial expressions in ⚡real-time⚡, powering ultra-realistic 3D avatars generated by [LAM](https://github.com/aigc3d/LAM).
|
||||||
|
|
||||||
|
## Demo
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<video controls src="https://github.com/user-attachments/assets/30ccbe82-7933-4031-8578-b5248435d317">
|
||||||
|
</video>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## 📢 News
|
||||||
|
|
||||||
|
|
||||||
|
### To do list
|
||||||
|
- [ ] Release Huggingface space.
|
||||||
|
- [ ] Release Modelscope space.
|
||||||
|
- [ ] Release the LAM-A2E model based on the Flame expression.
|
||||||
|
- [ ] Release Interactive Chatting Avatar SDK with [OpenAvatarChat](https://github.com/HumanAIGC-Engineering/OpenAvatarChat), including LLM, ASR, TTS, LAM-Avatars.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 🚀 Get Started
|
||||||
|
### Environment Setup
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:aigc3d/LAM_Audio2Expression.git
|
||||||
|
cd LAM_Audio2Expression
|
||||||
|
# Install with Cuda 12.1
|
||||||
|
sh ./scripts/install/install_cu121.sh
|
||||||
|
# Or Install with Cuda 11.8
|
||||||
|
sh ./scripts/install/install_cu118.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Download
|
||||||
|
|
||||||
|
```
|
||||||
|
# HuggingFace download
|
||||||
|
# Download Assets and Model Weights
|
||||||
|
huggingface-cli download 3DAIGC/LAM_audio2exp --local-dir ./
|
||||||
|
tar -xzvf LAM_audio2exp_assets.tar && rm -f LAM_audio2exp_assets.tar
|
||||||
|
tar -xzvf LAM_audio2exp_streaming.tar && rm -f LAM_audio2exp_streaming.tar
|
||||||
|
|
||||||
|
|
||||||
|
# Or OSS Download (In case of HuggingFace download failing)
|
||||||
|
# Download Assets
|
||||||
|
wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/LAM_audio2exp_assets.tar
|
||||||
|
tar -xzvf LAM_audio2exp_assets.tar && rm -f LAM_audio2exp_assets.tar
|
||||||
|
# Download Model Weights
|
||||||
|
wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/LAM_audio2exp_streaming.tar
|
||||||
|
tar -xzvf LAM_audio2exp_streaming.tar && rm -f LAM_audio2exp_streaming.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Quick Start Guide
|
||||||
|
#### Using <a href="https://github.com/gradio-app/gradio">Gradio</a> Interface:
|
||||||
|
We provide a simple Gradio demo with **WebGLGL Render**, and you can get rendering results by uploading audio in seconds.
|
||||||
|
|
||||||
|
<img src="./assets/images/snapshot.png" alt="teaser" width="1000"/>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
python app_lam_audio2exp.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Inference
|
||||||
|
```bash
|
||||||
|
# example: python inference.py --config-file configs/lam_audio2exp_config_streaming.py --options save_path=exp/audio2exp weight=pretrained_models/lam_audio2exp_streaming.tar audio_input=./assets/sample_audio/BarackObama_english.wav
|
||||||
|
python inference.py --config-file ${CONFIG_PATH} --options save_path=${SAVE_PATH} weight=${CHECKPOINT_PATH} audio_input=${AUDIO_INPUT}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Acknowledgement
|
||||||
|
This work is built on many amazing research works and open-source projects:
|
||||||
|
- [FLAME](https://flame.is.tue.mpg.de)
|
||||||
|
- [FaceFormer](https://github.com/EvelynFan/FaceFormer)
|
||||||
|
- [Meshtalk](https://github.com/facebookresearch/meshtalk)
|
||||||
|
- [Unitalker](https://github.com/X-niper/UniTalker)
|
||||||
|
- [Pointcept](https://github.com/Pointcept/Pointcept)
|
||||||
|
|
||||||
|
Thanks for their excellent works and great contribution.
|
||||||
|
|
||||||
|
|
||||||
|
### Related Works
|
||||||
|
Welcome to follow our other interesting works:
|
||||||
|
- [LAM](https://github.com/aigc3d/LAM)
|
||||||
|
- [LHM](https://github.com/aigc3d/LHM)
|
||||||
|
|
||||||
|
|
||||||
|
### Citation
|
||||||
|
```
|
||||||
|
@inproceedings{he2025LAM,
|
||||||
|
title={LAM: Large Avatar Model for One-shot Animatable Gaussian Head},
|
||||||
|
author={
|
||||||
|
Yisheng He and Xiaodong Gu and Xiaodan Ye and Chao Xu and Zhengyi Zhao and Yuan Dong and Weihao Yuan and Zilong Dong and Liefeng Bo
|
||||||
|
},
|
||||||
|
booktitle={arXiv preprint arXiv:2502.17796},
|
||||||
|
year={2025}
|
||||||
|
}
|
||||||
|
```
|
||||||
271
app_lam_audio2exp.py
Normal file
271
app_lam_audio2exp.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import argparse
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from gradio_gaussian_render import gaussian_render
|
||||||
|
|
||||||
|
from engines.defaults import (
|
||||||
|
default_argument_parser,
|
||||||
|
default_config_parser,
|
||||||
|
default_setup,
|
||||||
|
)
|
||||||
|
from engines.infer import INFER
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import spaces
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
h5_rendering = True
|
||||||
|
|
||||||
|
|
||||||
|
def assert_input_image(input_image):
|
||||||
|
if input_image is None:
|
||||||
|
raise gr.Error('No image selected or uploaded!')
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_working_dir():
|
||||||
|
import tempfile
|
||||||
|
working_dir = tempfile.TemporaryDirectory()
|
||||||
|
return working_dir
|
||||||
|
|
||||||
|
def get_image_base64(path):
|
||||||
|
with open(path, 'rb') as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read()).decode()
|
||||||
|
return f'data:image/png;base64,{encoded_string}'
|
||||||
|
|
||||||
|
|
||||||
|
def doRender():
|
||||||
|
print('H5 rendering ....')
|
||||||
|
|
||||||
|
def parse_configs():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str)
|
||||||
|
parser.add_argument("--infer", type=str)
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
cfg = OmegaConf.create()
|
||||||
|
cli_cfg = OmegaConf.from_cli(unknown)
|
||||||
|
|
||||||
|
# parse from ENV
|
||||||
|
if os.environ.get("APP_INFER") is not None:
|
||||||
|
args.infer = os.environ.get("APP_INFER")
|
||||||
|
if os.environ.get("APP_MODEL_NAME") is not None:
|
||||||
|
cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
|
||||||
|
|
||||||
|
args.config = args.infer if args.config is None else args.config
|
||||||
|
|
||||||
|
if args.config is not None:
|
||||||
|
cfg_train = OmegaConf.load(args.config)
|
||||||
|
cfg.source_size = cfg_train.dataset.source_image_res
|
||||||
|
try:
|
||||||
|
cfg.src_head_size = cfg_train.dataset.src_head_size
|
||||||
|
except:
|
||||||
|
cfg.src_head_size = 112
|
||||||
|
cfg.render_size = cfg_train.dataset.render_image.high
|
||||||
|
_relative_path = os.path.join(
|
||||||
|
cfg_train.experiment.parent,
|
||||||
|
cfg_train.experiment.child,
|
||||||
|
os.path.basename(cli_cfg.model_name).split("_")[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
|
||||||
|
cfg.image_dump = os.path.join("exps", "images", _relative_path)
|
||||||
|
cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
|
||||||
|
|
||||||
|
if args.infer is not None:
|
||||||
|
cfg_infer = OmegaConf.load(args.infer)
|
||||||
|
cfg.merge_with(cfg_infer)
|
||||||
|
cfg.setdefault(
|
||||||
|
"save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
|
||||||
|
)
|
||||||
|
cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
|
||||||
|
cfg.setdefault(
|
||||||
|
"video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
|
||||||
|
)
|
||||||
|
cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
|
||||||
|
|
||||||
|
cfg.motion_video_read_fps = 30
|
||||||
|
cfg.merge_with(cli_cfg)
|
||||||
|
|
||||||
|
cfg.setdefault("logger", "INFO")
|
||||||
|
|
||||||
|
assert cfg.model_name is not None, "model_name is required"
|
||||||
|
|
||||||
|
return cfg, cfg_train
|
||||||
|
|
||||||
|
|
||||||
|
def create_zip_archive(output_zip='assets/arkitWithBSData.zip', base_dir=""):
|
||||||
|
import os
|
||||||
|
if (os.path.exists(output_zip)):
|
||||||
|
os.remove(output_zip)
|
||||||
|
print(f"Reomve previous file: {output_zip}")
|
||||||
|
run_command = 'zip -r '+output_zip+' '+base_dir
|
||||||
|
os.system(run_command)
|
||||||
|
# check file
|
||||||
|
if(os.path.exists(output_zip)):
|
||||||
|
print(f"Archive created successfully: {output_zip}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Archive created failed: {output_zip}")
|
||||||
|
|
||||||
|
|
||||||
|
def demo_lam_audio2exp(infer, cfg):
|
||||||
|
def core_fn(image_path: str, audio_params, working_dir):
|
||||||
|
|
||||||
|
base_id = os.path.basename(image_path).split(".")[0]
|
||||||
|
|
||||||
|
# set input audio
|
||||||
|
cfg.audio_input = audio_params
|
||||||
|
cfg.save_json_path = os.path.join("./assets/sample_lam", base_id, 'arkitWithBSData', 'bsData.json')
|
||||||
|
infer.infer()
|
||||||
|
|
||||||
|
create_zip_archive(output_zip='./assets/arkitWithBSData.zip', base_dir=os.path.join("./assets/sample_lam", base_id))
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
with gr.Blocks(analytics_enabled=False) as demo:
|
||||||
|
logo_url = './assets/images/logo.jpeg'
|
||||||
|
logo_base64 = get_image_base64(logo_url)
|
||||||
|
gr.HTML(f"""
|
||||||
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
||||||
|
<div>
|
||||||
|
<h1> LAM-A2E: Audio to Expression</h1>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
""")
|
||||||
|
|
||||||
|
gr.HTML(
|
||||||
|
"""<p><h4 style="color: blue;"> Notes: This project leverages audio input to generate ARKit blendshapes-driven facial expressions in ⚡real-time⚡, powering ultra-realistic 3D avatars generated by LAM.</h4></p>"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# DISPLAY
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(variant='panel', scale=1):
|
||||||
|
with gr.Tabs(elem_id='lam_input_image'):
|
||||||
|
with gr.TabItem('Input Image'):
|
||||||
|
with gr.Row():
|
||||||
|
input_image = gr.Image(label='Input Image',
|
||||||
|
image_mode='RGB',
|
||||||
|
height=480,
|
||||||
|
width=270,
|
||||||
|
sources='upload',
|
||||||
|
type='filepath', # 'numpy',
|
||||||
|
elem_id='content_image')
|
||||||
|
# EXAMPLES
|
||||||
|
with gr.Row():
|
||||||
|
examples = [
|
||||||
|
['assets/sample_input/barbara.jpg'],
|
||||||
|
['assets/sample_input/status.png'],
|
||||||
|
['assets/sample_input/james.png'],
|
||||||
|
['assets/sample_input/vfhq_case1.png'],
|
||||||
|
]
|
||||||
|
gr.Examples(
|
||||||
|
examples=examples,
|
||||||
|
inputs=[input_image],
|
||||||
|
examples_per_page=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Tabs(elem_id='lam_input_audio'):
|
||||||
|
with gr.TabItem('Input Audio'):
|
||||||
|
with gr.Row():
|
||||||
|
audio_input = gr.Audio(label='Input Audio',
|
||||||
|
type='filepath',
|
||||||
|
waveform_options={
|
||||||
|
'sample_rate': 16000,
|
||||||
|
'waveform_progress_color': '#4682b4'
|
||||||
|
},
|
||||||
|
elem_id='content_audio')
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
['assets/sample_audio/Nangyanwen_chinese.wav'],
|
||||||
|
['assets/sample_audio/LiBai_TTS_chinese.wav'],
|
||||||
|
['assets/sample_audio/LinJing_TTS_chinese.wav'],
|
||||||
|
['assets/sample_audio/BarackObama_english.wav'],
|
||||||
|
['assets/sample_audio/HillaryClinton_english.wav'],
|
||||||
|
['assets/sample_audio/XitongShi_japanese.wav'],
|
||||||
|
['assets/sample_audio/FangXiao_japanese.wav'],
|
||||||
|
]
|
||||||
|
gr.Examples(
|
||||||
|
examples=examples,
|
||||||
|
inputs=[audio_input],
|
||||||
|
examples_per_page=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SETTING
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(variant='panel', scale=1):
|
||||||
|
submit = gr.Button('Generate',
|
||||||
|
elem_id='lam_generate',
|
||||||
|
variant='primary')
|
||||||
|
|
||||||
|
if h5_rendering:
|
||||||
|
gr.set_static_paths(Path.cwd().absolute() / "assets/")
|
||||||
|
assetPrefix = 'gradio_api/file=assets/'
|
||||||
|
with gr.Row():
|
||||||
|
gs = gaussian_render(width=380, height=680, assets=assetPrefix + 'arkitWithBSData.zip')
|
||||||
|
|
||||||
|
working_dir = gr.State()
|
||||||
|
submit.click(
|
||||||
|
fn=assert_input_image,
|
||||||
|
inputs=[input_image],
|
||||||
|
queue=False,
|
||||||
|
).success(
|
||||||
|
fn=prepare_working_dir,
|
||||||
|
outputs=[working_dir],
|
||||||
|
queue=False,
|
||||||
|
).success(
|
||||||
|
fn=core_fn,
|
||||||
|
inputs=[input_image, audio_input,
|
||||||
|
working_dir], # video_params refer to smpl dir
|
||||||
|
outputs=[],
|
||||||
|
queue=False,
|
||||||
|
).success(
|
||||||
|
doRender, js='''() => window.start()'''
|
||||||
|
)
|
||||||
|
|
||||||
|
demo.queue()
|
||||||
|
demo.launch()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def launch_gradio_app():
|
||||||
|
os.environ.update({
|
||||||
|
'APP_ENABLED': '1',
|
||||||
|
'APP_MODEL_NAME':'',
|
||||||
|
'APP_INFER': 'configs/lam_audio2exp_streaming_config.py',
|
||||||
|
'APP_TYPE': 'infer.audio2exp',
|
||||||
|
'NUMBA_THREADING_LAYER': 'omp',
|
||||||
|
})
|
||||||
|
|
||||||
|
args = default_argument_parser().parse_args()
|
||||||
|
args.config_file = 'configs/lam_audio2exp_config_streaming.py'
|
||||||
|
cfg = default_config_parser(args.config_file, args.options)
|
||||||
|
cfg = default_setup(cfg)
|
||||||
|
|
||||||
|
cfg.ex_vol = True
|
||||||
|
infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg))
|
||||||
|
|
||||||
|
demo_lam_audio2exp(infer, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
launch_gradio_app()
|
||||||
BIN
assets/images/logo.jpeg
Normal file
BIN
assets/images/logo.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 36 KiB |
BIN
assets/images/snapshot.png
Normal file
BIN
assets/images/snapshot.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.0 MiB |
BIN
assets/images/teaser.jpg
Normal file
BIN
assets/images/teaser.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 654 KiB |
92
configs/lam_audio2exp_config.py
Normal file
92
configs/lam_audio2exp_config.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
weight = 'pretrained_models/lam_audio2exp.tar' # path to model weight
|
||||||
|
ex_vol = True # Isolates vocal track from audio file
|
||||||
|
audio_input = './assets/sample_audio/BarackObama.wav'
|
||||||
|
save_json_path = 'bsData.json'
|
||||||
|
|
||||||
|
audio_sr = 16000
|
||||||
|
fps = 30.0
|
||||||
|
|
||||||
|
movement_smooth = True
|
||||||
|
brow_movement = True
|
||||||
|
id_idx = 153
|
||||||
|
|
||||||
|
resume = False # whether to resume training process
|
||||||
|
evaluate = True # evaluate after each epoch training process
|
||||||
|
test_only = False # test process
|
||||||
|
|
||||||
|
seed = None # train process will init a random seed and record
|
||||||
|
save_path = "exp/audio2exp"
|
||||||
|
num_worker = 16 # total worker in all gpu
|
||||||
|
batch_size = 16 # total batch size in all gpu
|
||||||
|
batch_size_val = None # auto adapt to bs 1 for each gpu
|
||||||
|
batch_size_test = None # auto adapt to bs 1 for each gpu
|
||||||
|
epoch = 100 # total epoch, data loop = epoch // eval_epoch
|
||||||
|
eval_epoch = 100 # sche total eval & checkpoint epoch
|
||||||
|
|
||||||
|
sync_bn = False
|
||||||
|
enable_amp = False
|
||||||
|
empty_cache = False
|
||||||
|
find_unused_parameters = False
|
||||||
|
|
||||||
|
mix_prob = 0
|
||||||
|
param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)]
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type="DefaultEstimator",
|
||||||
|
backbone=dict(
|
||||||
|
type="Audio2Expression",
|
||||||
|
pretrained_encoder_type='wav2vec',
|
||||||
|
pretrained_encoder_path='facebook/wav2vec2-base-960h',
|
||||||
|
wav2vec2_config_path = 'configs/wav2vec2_config.json',
|
||||||
|
num_identity_classes=5016,
|
||||||
|
identity_feat_dim=64,
|
||||||
|
hidden_dim=512,
|
||||||
|
expression_dim=52,
|
||||||
|
norm_type='ln',
|
||||||
|
use_transformer=True,
|
||||||
|
num_attention_heads=8,
|
||||||
|
num_transformer_layers=6,
|
||||||
|
),
|
||||||
|
criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)],
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_type = 'audio2exp'
|
||||||
|
data_root = './'
|
||||||
|
data = dict(
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
split="train",
|
||||||
|
data_root=data_root,
|
||||||
|
test_mode=False,
|
||||||
|
),
|
||||||
|
val=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
split="val",
|
||||||
|
data_root=data_root,
|
||||||
|
test_mode=False,
|
||||||
|
),
|
||||||
|
test=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
split="val",
|
||||||
|
data_root=data_root,
|
||||||
|
test_mode=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# hook
|
||||||
|
hooks = [
|
||||||
|
dict(type="CheckpointLoader"),
|
||||||
|
dict(type="IterationTimer", warmup_iter=2),
|
||||||
|
dict(type="InformationWriter"),
|
||||||
|
dict(type="SemSegEvaluator"),
|
||||||
|
dict(type="CheckpointSaver", save_freq=None),
|
||||||
|
dict(type="PreciseEvaluator", test_last=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Trainer
|
||||||
|
train = dict(type="DefaultTrainer")
|
||||||
|
|
||||||
|
# Tester
|
||||||
|
infer = dict(type="Audio2ExpressionInfer",
|
||||||
|
verbose=True)
|
||||||
92
configs/lam_audio2exp_config_streaming.py
Normal file
92
configs/lam_audio2exp_config_streaming.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
weight = 'pretrained_models/lam_audio2exp_streaming.tar' # path to model weight
|
||||||
|
ex_vol = True # extract
|
||||||
|
audio_input = './assets/sample_audio/BarackObama.wav'
|
||||||
|
save_json_path = 'bsData.json'
|
||||||
|
|
||||||
|
audio_sr = 16000
|
||||||
|
fps = 30.0
|
||||||
|
|
||||||
|
movement_smooth = False
|
||||||
|
brow_movement = False
|
||||||
|
id_idx = 0
|
||||||
|
|
||||||
|
resume = False # whether to resume training process
|
||||||
|
evaluate = True # evaluate after each epoch training process
|
||||||
|
test_only = False # test process
|
||||||
|
|
||||||
|
seed = None # train process will init a random seed and record
|
||||||
|
save_path = "exp/audio2exp"
|
||||||
|
num_worker = 16 # total worker in all gpu
|
||||||
|
batch_size = 16 # total batch size in all gpu
|
||||||
|
batch_size_val = None # auto adapt to bs 1 for each gpu
|
||||||
|
batch_size_test = None # auto adapt to bs 1 for each gpu
|
||||||
|
epoch = 100 # total epoch, data loop = epoch // eval_epoch
|
||||||
|
eval_epoch = 100 # sche total eval & checkpoint epoch
|
||||||
|
|
||||||
|
sync_bn = False
|
||||||
|
enable_amp = False
|
||||||
|
empty_cache = False
|
||||||
|
find_unused_parameters = False
|
||||||
|
|
||||||
|
mix_prob = 0
|
||||||
|
param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)]
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type="DefaultEstimator",
|
||||||
|
backbone=dict(
|
||||||
|
type="Audio2Expression",
|
||||||
|
pretrained_encoder_type='wav2vec',
|
||||||
|
pretrained_encoder_path='facebook/wav2vec2-base-960h',
|
||||||
|
wav2vec2_config_path = 'configs/wav2vec2_config.json',
|
||||||
|
num_identity_classes=12,
|
||||||
|
identity_feat_dim=64,
|
||||||
|
hidden_dim=512,
|
||||||
|
expression_dim=52,
|
||||||
|
norm_type='ln',
|
||||||
|
use_transformer=False,
|
||||||
|
num_attention_heads=8,
|
||||||
|
num_transformer_layers=6,
|
||||||
|
),
|
||||||
|
criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)],
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_type = 'audio2exp'
|
||||||
|
data_root = './'
|
||||||
|
data = dict(
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
split="train",
|
||||||
|
data_root=data_root,
|
||||||
|
test_mode=False,
|
||||||
|
),
|
||||||
|
val=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
split="val",
|
||||||
|
data_root=data_root,
|
||||||
|
test_mode=False,
|
||||||
|
),
|
||||||
|
test=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
split="val",
|
||||||
|
data_root=data_root,
|
||||||
|
test_mode=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# hook
|
||||||
|
hooks = [
|
||||||
|
dict(type="CheckpointLoader"),
|
||||||
|
dict(type="IterationTimer", warmup_iter=2),
|
||||||
|
dict(type="InformationWriter"),
|
||||||
|
dict(type="SemSegEvaluator"),
|
||||||
|
dict(type="CheckpointSaver", save_freq=None),
|
||||||
|
dict(type="PreciseEvaluator", test_last=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Trainer
|
||||||
|
train = dict(type="DefaultTrainer")
|
||||||
|
|
||||||
|
# Tester
|
||||||
|
infer = dict(type="Audio2ExpressionInfer",
|
||||||
|
verbose=True)
|
||||||
77
configs/wav2vec2_config.json
Normal file
77
configs/wav2vec2_config.json
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "facebook/wav2vec2-base-960h",
|
||||||
|
"activation_dropout": 0.1,
|
||||||
|
"apply_spec_augment": true,
|
||||||
|
"architectures": [
|
||||||
|
"Wav2Vec2ForCTC"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"codevector_dim": 256,
|
||||||
|
"contrastive_logits_temperature": 0.1,
|
||||||
|
"conv_bias": false,
|
||||||
|
"conv_dim": [
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512
|
||||||
|
],
|
||||||
|
"conv_kernel": [
|
||||||
|
10,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"conv_stride": [
|
||||||
|
5,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"ctc_loss_reduction": "sum",
|
||||||
|
"ctc_zero_infinity": false,
|
||||||
|
"diversity_loss_weight": 0.1,
|
||||||
|
"do_stable_layer_norm": false,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"feat_extract_activation": "gelu",
|
||||||
|
"feat_extract_dropout": 0.0,
|
||||||
|
"feat_extract_norm": "group",
|
||||||
|
"feat_proj_dropout": 0.1,
|
||||||
|
"feat_quantizer_dropout": 0.0,
|
||||||
|
"final_dropout": 0.1,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout": 0.1,
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"layerdrop": 0.1,
|
||||||
|
"mask_feature_length": 10,
|
||||||
|
"mask_feature_prob": 0.0,
|
||||||
|
"mask_time_length": 10,
|
||||||
|
"mask_time_prob": 0.05,
|
||||||
|
"model_type": "wav2vec2",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_codevector_groups": 2,
|
||||||
|
"num_codevectors_per_group": 320,
|
||||||
|
"num_conv_pos_embedding_groups": 16,
|
||||||
|
"num_conv_pos_embeddings": 128,
|
||||||
|
"num_feat_extract_layers": 7,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"num_negatives": 100,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"proj_codevector_dim": 256,
|
||||||
|
"transformers_version": "4.7.0.dev0",
|
||||||
|
"vocab_size": 32
|
||||||
|
}
|
||||||
0
engines/__init__.py
Normal file
0
engines/__init__.py
Normal file
147
engines/defaults.py
Normal file
147
engines/defaults.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import multiprocessing as mp
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
|
||||||
|
|
||||||
|
import utils.comm as comm
|
||||||
|
from utils.env import get_random_seed, set_seed
|
||||||
|
from utils.config import Config, DictAction
|
||||||
|
|
||||||
|
|
||||||
|
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a DistributedDataParallel model if there are >1 processes.
|
||||||
|
Args:
|
||||||
|
model: a torch.nn.Module
|
||||||
|
fp16_compression: add fp16 compression hooks to the ddp object.
|
||||||
|
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
|
||||||
|
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
|
||||||
|
"""
|
||||||
|
if comm.get_world_size() == 1:
|
||||||
|
return model
|
||||||
|
# kwargs['find_unused_parameters'] = True
|
||||||
|
if "device_ids" not in kwargs:
|
||||||
|
kwargs["device_ids"] = [comm.get_local_rank()]
|
||||||
|
if "output_device" not in kwargs:
|
||||||
|
kwargs["output_device"] = [comm.get_local_rank()]
|
||||||
|
ddp = DistributedDataParallel(model, **kwargs)
|
||||||
|
if fp16_compression:
|
||||||
|
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
||||||
|
|
||||||
|
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
|
||||||
|
return ddp
|
||||||
|
|
||||||
|
|
||||||
|
def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||||
|
"""Worker init func for dataloader.
|
||||||
|
|
||||||
|
The seed of each worker equals to num_worker * rank + worker_id + user_seed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_id (int): Worker id.
|
||||||
|
num_workers (int): Number of workers.
|
||||||
|
rank (int): The rank of current process.
|
||||||
|
seed (int): The random seed to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
worker_seed = num_workers * rank + worker_id + seed
|
||||||
|
set_seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def default_argument_parser(epilog=None):
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
epilog=epilog
|
||||||
|
or f"""
|
||||||
|
Examples:
|
||||||
|
Run on single machine:
|
||||||
|
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
|
||||||
|
Change some config options:
|
||||||
|
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
|
||||||
|
Run on multiple machines:
|
||||||
|
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
|
||||||
|
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
|
||||||
|
""",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config-file", default="", metavar="FILE", help="path to config file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-gpus", type=int, default=1, help="number of gpus *per machine*"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-machines", type=int, default=1, help="total number of machines"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--machine-rank",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="the rank of this machine (unique per machine)",
|
||||||
|
)
|
||||||
|
# PyTorch still may leave orphan processes in multi-gpu training.
|
||||||
|
# Therefore we use a deterministic way to obtain port,
|
||||||
|
# so that users are aware of orphan processes by seeing the port occupied.
|
||||||
|
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
|
||||||
|
parser.add_argument(
|
||||||
|
"--dist-url",
|
||||||
|
# default="tcp://127.0.0.1:{}".format(port),
|
||||||
|
default="auto",
|
||||||
|
help="initialization URL for pytorch distributed backend. See "
|
||||||
|
"https://pytorch.org/docs/stable/distributed.html for details.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--options", nargs="+", action=DictAction, help="custom options"
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def default_config_parser(file_path, options):
|
||||||
|
# config name protocol: dataset_name/model_name-exp_name
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
cfg = Config.fromfile(file_path)
|
||||||
|
else:
|
||||||
|
sep = file_path.find("-")
|
||||||
|
cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :]))
|
||||||
|
|
||||||
|
if options is not None:
|
||||||
|
cfg.merge_from_dict(options)
|
||||||
|
|
||||||
|
if cfg.seed is None:
|
||||||
|
cfg.seed = get_random_seed()
|
||||||
|
|
||||||
|
cfg.data.train.loop = cfg.epoch // cfg.eval_epoch
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True)
|
||||||
|
if not cfg.resume:
|
||||||
|
cfg.dump(os.path.join(cfg.save_path, "config.py"))
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def default_setup(cfg):
|
||||||
|
# scalar by world size
|
||||||
|
world_size = comm.get_world_size()
|
||||||
|
cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count()
|
||||||
|
cfg.num_worker_per_gpu = cfg.num_worker // world_size
|
||||||
|
assert cfg.batch_size % world_size == 0
|
||||||
|
assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0
|
||||||
|
assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0
|
||||||
|
cfg.batch_size_per_gpu = cfg.batch_size // world_size
|
||||||
|
cfg.batch_size_val_per_gpu = (
|
||||||
|
cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1
|
||||||
|
)
|
||||||
|
cfg.batch_size_test_per_gpu = (
|
||||||
|
cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1
|
||||||
|
)
|
||||||
|
# update data loop
|
||||||
|
assert cfg.epoch % cfg.eval_epoch == 0
|
||||||
|
# settle random seed
|
||||||
|
rank = comm.get_rank()
|
||||||
|
seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank
|
||||||
|
set_seed(seed)
|
||||||
|
return cfg
|
||||||
5
engines/hooks/__init__.py
Normal file
5
engines/hooks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .default import HookBase
|
||||||
|
from .misc import *
|
||||||
|
from .evaluator import *
|
||||||
|
|
||||||
|
from .builder import build_hooks
|
||||||
15
engines/hooks/builder.py
Normal file
15
engines/hooks/builder.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
from utils.registry import Registry
|
||||||
|
|
||||||
|
|
||||||
|
HOOKS = Registry("hooks")
|
||||||
|
|
||||||
|
|
||||||
|
def build_hooks(cfg):
|
||||||
|
hooks = []
|
||||||
|
for hook_cfg in cfg:
|
||||||
|
hooks.append(HOOKS.build(hook_cfg))
|
||||||
|
return hooks
|
||||||
29
engines/hooks/default.py
Normal file
29
engines/hooks/default.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class HookBase:
|
||||||
|
"""
|
||||||
|
Base class for hooks that can be registered with :class:`TrainerBase`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
trainer = None # A weak reference to the trainer object.
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_epoch(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_step(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_step(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_epoch(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_train(self):
|
||||||
|
pass
|
||||||
577
engines/hooks/evaluator.py
Normal file
577
engines/hooks/evaluator.py
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import utils.comm as comm
|
||||||
|
from utils.misc import intersection_and_union_gpu
|
||||||
|
|
||||||
|
from .default import HookBase
|
||||||
|
from .builder import HOOKS
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class ClsEvaluator(HookBase):
|
||||||
|
def after_epoch(self):
|
||||||
|
if self.trainer.cfg.evaluate:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
|
||||||
|
self.trainer.model.eval()
|
||||||
|
for i, input_dict in enumerate(self.trainer.val_loader):
|
||||||
|
for key in input_dict.keys():
|
||||||
|
if isinstance(input_dict[key], torch.Tensor):
|
||||||
|
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
output_dict = self.trainer.model(input_dict)
|
||||||
|
output = output_dict["cls_logits"]
|
||||||
|
loss = output_dict["loss"]
|
||||||
|
pred = output.max(1)[1]
|
||||||
|
label = input_dict["category"]
|
||||||
|
intersection, union, target = intersection_and_union_gpu(
|
||||||
|
pred,
|
||||||
|
label,
|
||||||
|
self.trainer.cfg.data.num_classes,
|
||||||
|
self.trainer.cfg.data.ignore_index,
|
||||||
|
)
|
||||||
|
if comm.get_world_size() > 1:
|
||||||
|
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
|
||||||
|
target
|
||||||
|
)
|
||||||
|
intersection, union, target = (
|
||||||
|
intersection.cpu().numpy(),
|
||||||
|
union.cpu().numpy(),
|
||||||
|
target.cpu().numpy(),
|
||||||
|
)
|
||||||
|
# Here there is no need to sync since sync happened in dist.all_reduce
|
||||||
|
self.trainer.storage.put_scalar("val_intersection", intersection)
|
||||||
|
self.trainer.storage.put_scalar("val_union", union)
|
||||||
|
self.trainer.storage.put_scalar("val_target", target)
|
||||||
|
self.trainer.storage.put_scalar("val_loss", loss.item())
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Test: [{iter}/{max_iter}] "
|
||||||
|
"Loss {loss:.4f} ".format(
|
||||||
|
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
loss_avg = self.trainer.storage.history("val_loss").avg
|
||||||
|
intersection = self.trainer.storage.history("val_intersection").total
|
||||||
|
union = self.trainer.storage.history("val_union").total
|
||||||
|
target = self.trainer.storage.history("val_target").total
|
||||||
|
iou_class = intersection / (union + 1e-10)
|
||||||
|
acc_class = intersection / (target + 1e-10)
|
||||||
|
m_iou = np.mean(iou_class)
|
||||||
|
m_acc = np.mean(acc_class)
|
||||||
|
all_acc = sum(intersection) / (sum(target) + 1e-10)
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
|
||||||
|
m_iou, m_acc, all_acc
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(self.trainer.cfg.data.num_classes):
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
|
||||||
|
idx=i,
|
||||||
|
name=self.trainer.cfg.data.names[i],
|
||||||
|
iou=iou_class[i],
|
||||||
|
accuracy=acc_class[i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_epoch = self.trainer.epoch + 1
|
||||||
|
if self.trainer.writer is not None:
|
||||||
|
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
|
||||||
|
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
|
||||||
|
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
|
||||||
|
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
|
||||||
|
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||||
|
self.trainer.comm_info["current_metric_value"] = all_acc # save for saver
|
||||||
|
self.trainer.comm_info["current_metric_name"] = "allAcc" # save for saver
|
||||||
|
|
||||||
|
def after_train(self):
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Best {}: {:.4f}".format("allAcc", self.trainer.best_metric_value)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class SemSegEvaluator(HookBase):
|
||||||
|
def after_epoch(self):
|
||||||
|
if self.trainer.cfg.evaluate:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
|
||||||
|
self.trainer.model.eval()
|
||||||
|
for i, input_dict in enumerate(self.trainer.val_loader):
|
||||||
|
for key in input_dict.keys():
|
||||||
|
if isinstance(input_dict[key], torch.Tensor):
|
||||||
|
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
output_dict = self.trainer.model(input_dict)
|
||||||
|
output = output_dict["seg_logits"]
|
||||||
|
loss = output_dict["loss"]
|
||||||
|
pred = output.max(1)[1]
|
||||||
|
segment = input_dict["segment"]
|
||||||
|
if "origin_coord" in input_dict.keys():
|
||||||
|
idx, _ = pointops.knn_query(
|
||||||
|
1,
|
||||||
|
input_dict["coord"].float(),
|
||||||
|
input_dict["offset"].int(),
|
||||||
|
input_dict["origin_coord"].float(),
|
||||||
|
input_dict["origin_offset"].int(),
|
||||||
|
)
|
||||||
|
pred = pred[idx.flatten().long()]
|
||||||
|
segment = input_dict["origin_segment"]
|
||||||
|
intersection, union, target = intersection_and_union_gpu(
|
||||||
|
pred,
|
||||||
|
segment,
|
||||||
|
self.trainer.cfg.data.num_classes,
|
||||||
|
self.trainer.cfg.data.ignore_index,
|
||||||
|
)
|
||||||
|
if comm.get_world_size() > 1:
|
||||||
|
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
|
||||||
|
target
|
||||||
|
)
|
||||||
|
intersection, union, target = (
|
||||||
|
intersection.cpu().numpy(),
|
||||||
|
union.cpu().numpy(),
|
||||||
|
target.cpu().numpy(),
|
||||||
|
)
|
||||||
|
# Here there is no need to sync since sync happened in dist.all_reduce
|
||||||
|
self.trainer.storage.put_scalar("val_intersection", intersection)
|
||||||
|
self.trainer.storage.put_scalar("val_union", union)
|
||||||
|
self.trainer.storage.put_scalar("val_target", target)
|
||||||
|
self.trainer.storage.put_scalar("val_loss", loss.item())
|
||||||
|
info = "Test: [{iter}/{max_iter}] ".format(
|
||||||
|
iter=i + 1, max_iter=len(self.trainer.val_loader)
|
||||||
|
)
|
||||||
|
if "origin_coord" in input_dict.keys():
|
||||||
|
info = "Interp. " + info
|
||||||
|
self.trainer.logger.info(
|
||||||
|
info
|
||||||
|
+ "Loss {loss:.4f} ".format(
|
||||||
|
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
loss_avg = self.trainer.storage.history("val_loss").avg
|
||||||
|
intersection = self.trainer.storage.history("val_intersection").total
|
||||||
|
union = self.trainer.storage.history("val_union").total
|
||||||
|
target = self.trainer.storage.history("val_target").total
|
||||||
|
iou_class = intersection / (union + 1e-10)
|
||||||
|
acc_class = intersection / (target + 1e-10)
|
||||||
|
m_iou = np.mean(iou_class)
|
||||||
|
m_acc = np.mean(acc_class)
|
||||||
|
all_acc = sum(intersection) / (sum(target) + 1e-10)
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
|
||||||
|
m_iou, m_acc, all_acc
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(self.trainer.cfg.data.num_classes):
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
|
||||||
|
idx=i,
|
||||||
|
name=self.trainer.cfg.data.names[i],
|
||||||
|
iou=iou_class[i],
|
||||||
|
accuracy=acc_class[i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_epoch = self.trainer.epoch + 1
|
||||||
|
if self.trainer.writer is not None:
|
||||||
|
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
|
||||||
|
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
|
||||||
|
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
|
||||||
|
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
|
||||||
|
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||||
|
self.trainer.comm_info["current_metric_value"] = m_iou # save for saver
|
||||||
|
self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver
|
||||||
|
|
||||||
|
def after_train(self):
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class InsSegEvaluator(HookBase):
|
||||||
|
def __init__(self, segment_ignore_index=(-1,), instance_ignore_index=-1):
|
||||||
|
self.segment_ignore_index = segment_ignore_index
|
||||||
|
self.instance_ignore_index = instance_ignore_index
|
||||||
|
|
||||||
|
self.valid_class_names = None # update in before train
|
||||||
|
self.overlaps = np.append(np.arange(0.5, 0.95, 0.05), 0.25)
|
||||||
|
self.min_region_sizes = 100
|
||||||
|
self.distance_threshes = float("inf")
|
||||||
|
self.distance_confs = -float("inf")
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
self.valid_class_names = [
|
||||||
|
self.trainer.cfg.data.names[i]
|
||||||
|
for i in range(self.trainer.cfg.data.num_classes)
|
||||||
|
if i not in self.segment_ignore_index
|
||||||
|
]
|
||||||
|
|
||||||
|
def after_epoch(self):
|
||||||
|
if self.trainer.cfg.evaluate:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def associate_instances(self, pred, segment, instance):
|
||||||
|
segment = segment.cpu().numpy()
|
||||||
|
instance = instance.cpu().numpy()
|
||||||
|
void_mask = np.in1d(segment, self.segment_ignore_index)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
pred["pred_classes"].shape[0]
|
||||||
|
== pred["pred_scores"].shape[0]
|
||||||
|
== pred["pred_masks"].shape[0]
|
||||||
|
)
|
||||||
|
assert pred["pred_masks"].shape[1] == segment.shape[0] == instance.shape[0]
|
||||||
|
# get gt instances
|
||||||
|
gt_instances = dict()
|
||||||
|
for i in range(self.trainer.cfg.data.num_classes):
|
||||||
|
if i not in self.segment_ignore_index:
|
||||||
|
gt_instances[self.trainer.cfg.data.names[i]] = []
|
||||||
|
instance_ids, idx, counts = np.unique(
|
||||||
|
instance, return_index=True, return_counts=True
|
||||||
|
)
|
||||||
|
segment_ids = segment[idx]
|
||||||
|
for i in range(len(instance_ids)):
|
||||||
|
if instance_ids[i] == self.instance_ignore_index:
|
||||||
|
continue
|
||||||
|
if segment_ids[i] in self.segment_ignore_index:
|
||||||
|
continue
|
||||||
|
gt_inst = dict()
|
||||||
|
gt_inst["instance_id"] = instance_ids[i]
|
||||||
|
gt_inst["segment_id"] = segment_ids[i]
|
||||||
|
gt_inst["dist_conf"] = 0.0
|
||||||
|
gt_inst["med_dist"] = -1.0
|
||||||
|
gt_inst["vert_count"] = counts[i]
|
||||||
|
gt_inst["matched_pred"] = []
|
||||||
|
gt_instances[self.trainer.cfg.data.names[segment_ids[i]]].append(gt_inst)
|
||||||
|
|
||||||
|
# get pred instances and associate with gt
|
||||||
|
pred_instances = dict()
|
||||||
|
for i in range(self.trainer.cfg.data.num_classes):
|
||||||
|
if i not in self.segment_ignore_index:
|
||||||
|
pred_instances[self.trainer.cfg.data.names[i]] = []
|
||||||
|
instance_id = 0
|
||||||
|
for i in range(len(pred["pred_classes"])):
|
||||||
|
if pred["pred_classes"][i] in self.segment_ignore_index:
|
||||||
|
continue
|
||||||
|
pred_inst = dict()
|
||||||
|
pred_inst["uuid"] = uuid4()
|
||||||
|
pred_inst["instance_id"] = instance_id
|
||||||
|
pred_inst["segment_id"] = pred["pred_classes"][i]
|
||||||
|
pred_inst["confidence"] = pred["pred_scores"][i]
|
||||||
|
pred_inst["mask"] = np.not_equal(pred["pred_masks"][i], 0)
|
||||||
|
pred_inst["vert_count"] = np.count_nonzero(pred_inst["mask"])
|
||||||
|
pred_inst["void_intersection"] = np.count_nonzero(
|
||||||
|
np.logical_and(void_mask, pred_inst["mask"])
|
||||||
|
)
|
||||||
|
if pred_inst["vert_count"] < self.min_region_sizes:
|
||||||
|
continue # skip if empty
|
||||||
|
segment_name = self.trainer.cfg.data.names[pred_inst["segment_id"]]
|
||||||
|
matched_gt = []
|
||||||
|
for gt_idx, gt_inst in enumerate(gt_instances[segment_name]):
|
||||||
|
intersection = np.count_nonzero(
|
||||||
|
np.logical_and(
|
||||||
|
instance == gt_inst["instance_id"], pred_inst["mask"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if intersection > 0:
|
||||||
|
gt_inst_ = gt_inst.copy()
|
||||||
|
pred_inst_ = pred_inst.copy()
|
||||||
|
gt_inst_["intersection"] = intersection
|
||||||
|
pred_inst_["intersection"] = intersection
|
||||||
|
matched_gt.append(gt_inst_)
|
||||||
|
gt_inst["matched_pred"].append(pred_inst_)
|
||||||
|
pred_inst["matched_gt"] = matched_gt
|
||||||
|
pred_instances[segment_name].append(pred_inst)
|
||||||
|
instance_id += 1
|
||||||
|
return gt_instances, pred_instances
|
||||||
|
|
||||||
|
def evaluate_matches(self, scenes):
|
||||||
|
overlaps = self.overlaps
|
||||||
|
min_region_sizes = [self.min_region_sizes]
|
||||||
|
dist_threshes = [self.distance_threshes]
|
||||||
|
dist_confs = [self.distance_confs]
|
||||||
|
|
||||||
|
# results: class x overlap
|
||||||
|
ap_table = np.zeros(
|
||||||
|
(len(dist_threshes), len(self.valid_class_names), len(overlaps)), float
|
||||||
|
)
|
||||||
|
for di, (min_region_size, distance_thresh, distance_conf) in enumerate(
|
||||||
|
zip(min_region_sizes, dist_threshes, dist_confs)
|
||||||
|
):
|
||||||
|
for oi, overlap_th in enumerate(overlaps):
|
||||||
|
pred_visited = {}
|
||||||
|
for scene in scenes:
|
||||||
|
for _ in scene["pred"]:
|
||||||
|
for label_name in self.valid_class_names:
|
||||||
|
for p in scene["pred"][label_name]:
|
||||||
|
if "uuid" in p:
|
||||||
|
pred_visited[p["uuid"]] = False
|
||||||
|
for li, label_name in enumerate(self.valid_class_names):
|
||||||
|
y_true = np.empty(0)
|
||||||
|
y_score = np.empty(0)
|
||||||
|
hard_false_negatives = 0
|
||||||
|
has_gt = False
|
||||||
|
has_pred = False
|
||||||
|
for scene in scenes:
|
||||||
|
pred_instances = scene["pred"][label_name]
|
||||||
|
gt_instances = scene["gt"][label_name]
|
||||||
|
# filter groups in ground truth
|
||||||
|
gt_instances = [
|
||||||
|
gt
|
||||||
|
for gt in gt_instances
|
||||||
|
if gt["vert_count"] >= min_region_size
|
||||||
|
and gt["med_dist"] <= distance_thresh
|
||||||
|
and gt["dist_conf"] >= distance_conf
|
||||||
|
]
|
||||||
|
if gt_instances:
|
||||||
|
has_gt = True
|
||||||
|
if pred_instances:
|
||||||
|
has_pred = True
|
||||||
|
|
||||||
|
cur_true = np.ones(len(gt_instances))
|
||||||
|
cur_score = np.ones(len(gt_instances)) * (-float("inf"))
|
||||||
|
cur_match = np.zeros(len(gt_instances), dtype=bool)
|
||||||
|
# collect matches
|
||||||
|
for gti, gt in enumerate(gt_instances):
|
||||||
|
found_match = False
|
||||||
|
for pred in gt["matched_pred"]:
|
||||||
|
# greedy assignments
|
||||||
|
if pred_visited[pred["uuid"]]:
|
||||||
|
continue
|
||||||
|
overlap = float(pred["intersection"]) / (
|
||||||
|
gt["vert_count"]
|
||||||
|
+ pred["vert_count"]
|
||||||
|
- pred["intersection"]
|
||||||
|
)
|
||||||
|
if overlap > overlap_th:
|
||||||
|
confidence = pred["confidence"]
|
||||||
|
# if already have a prediction for this gt,
|
||||||
|
# the prediction with the lower score is automatically a false positive
|
||||||
|
if cur_match[gti]:
|
||||||
|
max_score = max(cur_score[gti], confidence)
|
||||||
|
min_score = min(cur_score[gti], confidence)
|
||||||
|
cur_score[gti] = max_score
|
||||||
|
# append false positive
|
||||||
|
cur_true = np.append(cur_true, 0)
|
||||||
|
cur_score = np.append(cur_score, min_score)
|
||||||
|
cur_match = np.append(cur_match, True)
|
||||||
|
# otherwise set score
|
||||||
|
else:
|
||||||
|
found_match = True
|
||||||
|
cur_match[gti] = True
|
||||||
|
cur_score[gti] = confidence
|
||||||
|
pred_visited[pred["uuid"]] = True
|
||||||
|
if not found_match:
|
||||||
|
hard_false_negatives += 1
|
||||||
|
# remove non-matched ground truth instances
|
||||||
|
cur_true = cur_true[cur_match]
|
||||||
|
cur_score = cur_score[cur_match]
|
||||||
|
|
||||||
|
# collect non-matched predictions as false positive
|
||||||
|
for pred in pred_instances:
|
||||||
|
found_gt = False
|
||||||
|
for gt in pred["matched_gt"]:
|
||||||
|
overlap = float(gt["intersection"]) / (
|
||||||
|
gt["vert_count"]
|
||||||
|
+ pred["vert_count"]
|
||||||
|
- gt["intersection"]
|
||||||
|
)
|
||||||
|
if overlap > overlap_th:
|
||||||
|
found_gt = True
|
||||||
|
break
|
||||||
|
if not found_gt:
|
||||||
|
num_ignore = pred["void_intersection"]
|
||||||
|
for gt in pred["matched_gt"]:
|
||||||
|
if gt["segment_id"] in self.segment_ignore_index:
|
||||||
|
num_ignore += gt["intersection"]
|
||||||
|
# small ground truth instances
|
||||||
|
if (
|
||||||
|
gt["vert_count"] < min_region_size
|
||||||
|
or gt["med_dist"] > distance_thresh
|
||||||
|
or gt["dist_conf"] < distance_conf
|
||||||
|
):
|
||||||
|
num_ignore += gt["intersection"]
|
||||||
|
proportion_ignore = (
|
||||||
|
float(num_ignore) / pred["vert_count"]
|
||||||
|
)
|
||||||
|
# if not ignored append false positive
|
||||||
|
if proportion_ignore <= overlap_th:
|
||||||
|
cur_true = np.append(cur_true, 0)
|
||||||
|
confidence = pred["confidence"]
|
||||||
|
cur_score = np.append(cur_score, confidence)
|
||||||
|
|
||||||
|
# append to overall results
|
||||||
|
y_true = np.append(y_true, cur_true)
|
||||||
|
y_score = np.append(y_score, cur_score)
|
||||||
|
|
||||||
|
# compute average precision
|
||||||
|
if has_gt and has_pred:
|
||||||
|
# compute precision recall curve first
|
||||||
|
|
||||||
|
# sorting and cumsum
|
||||||
|
score_arg_sort = np.argsort(y_score)
|
||||||
|
y_score_sorted = y_score[score_arg_sort]
|
||||||
|
y_true_sorted = y_true[score_arg_sort]
|
||||||
|
y_true_sorted_cumsum = np.cumsum(y_true_sorted)
|
||||||
|
|
||||||
|
# unique thresholds
|
||||||
|
(thresholds, unique_indices) = np.unique(
|
||||||
|
y_score_sorted, return_index=True
|
||||||
|
)
|
||||||
|
num_prec_recall = len(unique_indices) + 1
|
||||||
|
|
||||||
|
# prepare precision recall
|
||||||
|
num_examples = len(y_score_sorted)
|
||||||
|
# https://github.com/ScanNet/ScanNet/pull/26
|
||||||
|
# all predictions are non-matched but also all of them are ignored and not counted as FP
|
||||||
|
# y_true_sorted_cumsum is empty
|
||||||
|
# num_true_examples = y_true_sorted_cumsum[-1]
|
||||||
|
num_true_examples = (
|
||||||
|
y_true_sorted_cumsum[-1]
|
||||||
|
if len(y_true_sorted_cumsum) > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
precision = np.zeros(num_prec_recall)
|
||||||
|
recall = np.zeros(num_prec_recall)
|
||||||
|
|
||||||
|
# deal with the first point
|
||||||
|
y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
|
||||||
|
# deal with remaining
|
||||||
|
for idx_res, idx_scores in enumerate(unique_indices):
|
||||||
|
cumsum = y_true_sorted_cumsum[idx_scores - 1]
|
||||||
|
tp = num_true_examples - cumsum
|
||||||
|
fp = num_examples - idx_scores - tp
|
||||||
|
fn = cumsum + hard_false_negatives
|
||||||
|
p = float(tp) / (tp + fp)
|
||||||
|
r = float(tp) / (tp + fn)
|
||||||
|
precision[idx_res] = p
|
||||||
|
recall[idx_res] = r
|
||||||
|
|
||||||
|
# first point in curve is artificial
|
||||||
|
precision[-1] = 1.0
|
||||||
|
recall[-1] = 0.0
|
||||||
|
|
||||||
|
# compute average of precision-recall curve
|
||||||
|
recall_for_conv = np.copy(recall)
|
||||||
|
recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
|
||||||
|
recall_for_conv = np.append(recall_for_conv, 0.0)
|
||||||
|
|
||||||
|
stepWidths = np.convolve(
|
||||||
|
recall_for_conv, [-0.5, 0, 0.5], "valid"
|
||||||
|
)
|
||||||
|
# integrate is now simply a dot product
|
||||||
|
ap_current = np.dot(precision, stepWidths)
|
||||||
|
|
||||||
|
elif has_gt:
|
||||||
|
ap_current = 0.0
|
||||||
|
else:
|
||||||
|
ap_current = float("nan")
|
||||||
|
ap_table[di, li, oi] = ap_current
|
||||||
|
d_inf = 0
|
||||||
|
o50 = np.where(np.isclose(self.overlaps, 0.5))
|
||||||
|
o25 = np.where(np.isclose(self.overlaps, 0.25))
|
||||||
|
oAllBut25 = np.where(np.logical_not(np.isclose(self.overlaps, 0.25)))
|
||||||
|
ap_scores = dict()
|
||||||
|
ap_scores["all_ap"] = np.nanmean(ap_table[d_inf, :, oAllBut25])
|
||||||
|
ap_scores["all_ap_50%"] = np.nanmean(ap_table[d_inf, :, o50])
|
||||||
|
ap_scores["all_ap_25%"] = np.nanmean(ap_table[d_inf, :, o25])
|
||||||
|
ap_scores["classes"] = {}
|
||||||
|
for li, label_name in enumerate(self.valid_class_names):
|
||||||
|
ap_scores["classes"][label_name] = {}
|
||||||
|
ap_scores["classes"][label_name]["ap"] = np.average(
|
||||||
|
ap_table[d_inf, li, oAllBut25]
|
||||||
|
)
|
||||||
|
ap_scores["classes"][label_name]["ap50%"] = np.average(
|
||||||
|
ap_table[d_inf, li, o50]
|
||||||
|
)
|
||||||
|
ap_scores["classes"][label_name]["ap25%"] = np.average(
|
||||||
|
ap_table[d_inf, li, o25]
|
||||||
|
)
|
||||||
|
return ap_scores
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
|
||||||
|
self.trainer.model.eval()
|
||||||
|
scenes = []
|
||||||
|
for i, input_dict in enumerate(self.trainer.val_loader):
|
||||||
|
assert (
|
||||||
|
len(input_dict["offset"]) == 1
|
||||||
|
) # currently only support bs 1 for each GPU
|
||||||
|
for key in input_dict.keys():
|
||||||
|
if isinstance(input_dict[key], torch.Tensor):
|
||||||
|
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
output_dict = self.trainer.model(input_dict)
|
||||||
|
|
||||||
|
loss = output_dict["loss"]
|
||||||
|
|
||||||
|
segment = input_dict["segment"]
|
||||||
|
instance = input_dict["instance"]
|
||||||
|
# map to origin
|
||||||
|
if "origin_coord" in input_dict.keys():
|
||||||
|
idx, _ = pointops.knn_query(
|
||||||
|
1,
|
||||||
|
input_dict["coord"].float(),
|
||||||
|
input_dict["offset"].int(),
|
||||||
|
input_dict["origin_coord"].float(),
|
||||||
|
input_dict["origin_offset"].int(),
|
||||||
|
)
|
||||||
|
idx = idx.cpu().flatten().long()
|
||||||
|
output_dict["pred_masks"] = output_dict["pred_masks"][:, idx]
|
||||||
|
segment = input_dict["origin_segment"]
|
||||||
|
instance = input_dict["origin_instance"]
|
||||||
|
|
||||||
|
gt_instances, pred_instance = self.associate_instances(
|
||||||
|
output_dict, segment, instance
|
||||||
|
)
|
||||||
|
scenes.append(dict(gt=gt_instances, pred=pred_instance))
|
||||||
|
|
||||||
|
self.trainer.storage.put_scalar("val_loss", loss.item())
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Test: [{iter}/{max_iter}] "
|
||||||
|
"Loss {loss:.4f} ".format(
|
||||||
|
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_avg = self.trainer.storage.history("val_loss").avg
|
||||||
|
comm.synchronize()
|
||||||
|
scenes_sync = comm.gather(scenes, dst=0)
|
||||||
|
scenes = [scene for scenes_ in scenes_sync for scene in scenes_]
|
||||||
|
ap_scores = self.evaluate_matches(scenes)
|
||||||
|
all_ap = ap_scores["all_ap"]
|
||||||
|
all_ap_50 = ap_scores["all_ap_50%"]
|
||||||
|
all_ap_25 = ap_scores["all_ap_25%"]
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Val result: mAP/AP50/AP25 {:.4f}/{:.4f}/{:.4f}.".format(
|
||||||
|
all_ap, all_ap_50, all_ap_25
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i, label_name in enumerate(self.valid_class_names):
|
||||||
|
ap = ap_scores["classes"][label_name]["ap"]
|
||||||
|
ap_50 = ap_scores["classes"][label_name]["ap50%"]
|
||||||
|
ap_25 = ap_scores["classes"][label_name]["ap25%"]
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Class_{idx}-{name} Result: AP/AP50/AP25 {AP:.4f}/{AP50:.4f}/{AP25:.4f}".format(
|
||||||
|
idx=i, name=label_name, AP=ap, AP50=ap_50, AP25=ap_25
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_epoch = self.trainer.epoch + 1
|
||||||
|
if self.trainer.writer is not None:
|
||||||
|
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
|
||||||
|
self.trainer.writer.add_scalar("val/mAP", all_ap, current_epoch)
|
||||||
|
self.trainer.writer.add_scalar("val/AP50", all_ap_50, current_epoch)
|
||||||
|
self.trainer.writer.add_scalar("val/AP25", all_ap_25, current_epoch)
|
||||||
|
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||||
|
self.trainer.comm_info["current_metric_value"] = all_ap_50 # save for saver
|
||||||
|
self.trainer.comm_info["current_metric_name"] = "AP50" # save for saver
|
||||||
460
engines/hooks/misc.py
Normal file
460
engines/hooks/misc.py
Normal file
@@ -0,0 +1,460 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from collections.abc import Sequence
|
||||||
|
else:
|
||||||
|
from collections import Sequence
|
||||||
|
from utils.timer import Timer
|
||||||
|
from utils.comm import is_main_process, synchronize, get_world_size
|
||||||
|
from utils.cache import shared_dict
|
||||||
|
|
||||||
|
import utils.comm as comm
|
||||||
|
from engines.test import TESTERS
|
||||||
|
|
||||||
|
from .default import HookBase
|
||||||
|
from .builder import HOOKS
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class IterationTimer(HookBase):
|
||||||
|
def __init__(self, warmup_iter=1):
|
||||||
|
self._warmup_iter = warmup_iter
|
||||||
|
self._start_time = time.perf_counter()
|
||||||
|
self._iter_timer = Timer()
|
||||||
|
self._remain_iter = 0
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
self._start_time = time.perf_counter()
|
||||||
|
self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader)
|
||||||
|
|
||||||
|
def before_epoch(self):
|
||||||
|
self._iter_timer.reset()
|
||||||
|
|
||||||
|
def before_step(self):
|
||||||
|
data_time = self._iter_timer.seconds()
|
||||||
|
self.trainer.storage.put_scalar("data_time", data_time)
|
||||||
|
|
||||||
|
def after_step(self):
|
||||||
|
batch_time = self._iter_timer.seconds()
|
||||||
|
self._iter_timer.reset()
|
||||||
|
self.trainer.storage.put_scalar("batch_time", batch_time)
|
||||||
|
self._remain_iter -= 1
|
||||||
|
remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg
|
||||||
|
t_m, t_s = divmod(remain_time, 60)
|
||||||
|
t_h, t_m = divmod(t_m, 60)
|
||||||
|
remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s))
|
||||||
|
if "iter_info" in self.trainer.comm_info.keys():
|
||||||
|
info = (
|
||||||
|
"Data {data_time_val:.3f} ({data_time_avg:.3f}) "
|
||||||
|
"Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) "
|
||||||
|
"Remain {remain_time} ".format(
|
||||||
|
data_time_val=self.trainer.storage.history("data_time").val,
|
||||||
|
data_time_avg=self.trainer.storage.history("data_time").avg,
|
||||||
|
batch_time_val=self.trainer.storage.history("batch_time").val,
|
||||||
|
batch_time_avg=self.trainer.storage.history("batch_time").avg,
|
||||||
|
remain_time=remain_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.trainer.comm_info["iter_info"] += info
|
||||||
|
if self.trainer.comm_info["iter"] <= self._warmup_iter:
|
||||||
|
self.trainer.storage.history("data_time").reset()
|
||||||
|
self.trainer.storage.history("batch_time").reset()
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class InformationWriter(HookBase):
|
||||||
|
def __init__(self):
|
||||||
|
self.curr_iter = 0
|
||||||
|
self.model_output_keys = []
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
self.trainer.comm_info["iter_info"] = ""
|
||||||
|
self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader)
|
||||||
|
|
||||||
|
def before_step(self):
|
||||||
|
self.curr_iter += 1
|
||||||
|
# MSC pretrain do not have offset information. Comment the code for support MSC
|
||||||
|
# info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \
|
||||||
|
# "Scan {batch_size} ({points_num}) ".format(
|
||||||
|
# epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch,
|
||||||
|
# iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader),
|
||||||
|
# batch_size=len(self.trainer.comm_info["input_dict"]["offset"]),
|
||||||
|
# points_num=self.trainer.comm_info["input_dict"]["offset"][-1]
|
||||||
|
# )
|
||||||
|
info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format(
|
||||||
|
epoch=self.trainer.epoch + 1,
|
||||||
|
max_epoch=self.trainer.max_epoch,
|
||||||
|
iter=self.trainer.comm_info["iter"] + 1,
|
||||||
|
max_iter=len(self.trainer.train_loader),
|
||||||
|
)
|
||||||
|
self.trainer.comm_info["iter_info"] += info
|
||||||
|
|
||||||
|
def after_step(self):
|
||||||
|
if "model_output_dict" in self.trainer.comm_info.keys():
|
||||||
|
model_output_dict = self.trainer.comm_info["model_output_dict"]
|
||||||
|
self.model_output_keys = model_output_dict.keys()
|
||||||
|
for key in self.model_output_keys:
|
||||||
|
self.trainer.storage.put_scalar(key, model_output_dict[key].item())
|
||||||
|
|
||||||
|
for key in self.model_output_keys:
|
||||||
|
self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format(
|
||||||
|
key=key, value=self.trainer.storage.history(key).val
|
||||||
|
)
|
||||||
|
lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"]
|
||||||
|
self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr)
|
||||||
|
self.trainer.logger.info(self.trainer.comm_info["iter_info"])
|
||||||
|
self.trainer.comm_info["iter_info"] = "" # reset iter info
|
||||||
|
if self.trainer.writer is not None:
|
||||||
|
self.trainer.writer.add_scalar("lr", lr, self.curr_iter)
|
||||||
|
for key in self.model_output_keys:
|
||||||
|
self.trainer.writer.add_scalar(
|
||||||
|
"train_batch/" + key,
|
||||||
|
self.trainer.storage.history(key).val,
|
||||||
|
self.curr_iter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def after_epoch(self):
|
||||||
|
epoch_info = "Train result: "
|
||||||
|
for key in self.model_output_keys:
|
||||||
|
epoch_info += "{key}: {value:.4f} ".format(
|
||||||
|
key=key, value=self.trainer.storage.history(key).avg
|
||||||
|
)
|
||||||
|
self.trainer.logger.info(epoch_info)
|
||||||
|
if self.trainer.writer is not None:
|
||||||
|
for key in self.model_output_keys:
|
||||||
|
self.trainer.writer.add_scalar(
|
||||||
|
"train/" + key,
|
||||||
|
self.trainer.storage.history(key).avg,
|
||||||
|
self.trainer.epoch + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class CheckpointSaver(HookBase):
|
||||||
|
def __init__(self, save_freq=None):
|
||||||
|
self.save_freq = save_freq # None or int, None indicate only save model last
|
||||||
|
|
||||||
|
def after_epoch(self):
|
||||||
|
if is_main_process():
|
||||||
|
is_best = False
|
||||||
|
if self.trainer.cfg.evaluate:
|
||||||
|
current_metric_value = self.trainer.comm_info["current_metric_value"]
|
||||||
|
current_metric_name = self.trainer.comm_info["current_metric_name"]
|
||||||
|
if current_metric_value > self.trainer.best_metric_value:
|
||||||
|
self.trainer.best_metric_value = current_metric_value
|
||||||
|
is_best = True
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Best validation {} updated to: {:.4f}".format(
|
||||||
|
current_metric_name, current_metric_value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Currently Best {}: {:.4f}".format(
|
||||||
|
current_metric_name, self.trainer.best_metric_value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = os.path.join(
|
||||||
|
self.trainer.cfg.save_path, "model", "model_last.pth"
|
||||||
|
)
|
||||||
|
self.trainer.logger.info("Saving checkpoint to: " + filename)
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"epoch": self.trainer.epoch + 1,
|
||||||
|
"state_dict": self.trainer.model.state_dict(),
|
||||||
|
"optimizer": self.trainer.optimizer.state_dict(),
|
||||||
|
"scheduler": self.trainer.scheduler.state_dict(),
|
||||||
|
"scaler": self.trainer.scaler.state_dict()
|
||||||
|
if self.trainer.cfg.enable_amp
|
||||||
|
else None,
|
||||||
|
"best_metric_value": self.trainer.best_metric_value,
|
||||||
|
},
|
||||||
|
filename + ".tmp",
|
||||||
|
)
|
||||||
|
os.replace(filename + ".tmp", filename)
|
||||||
|
if is_best:
|
||||||
|
shutil.copyfile(
|
||||||
|
filename,
|
||||||
|
os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"),
|
||||||
|
)
|
||||||
|
if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0:
|
||||||
|
shutil.copyfile(
|
||||||
|
filename,
|
||||||
|
os.path.join(
|
||||||
|
self.trainer.cfg.save_path,
|
||||||
|
"model",
|
||||||
|
f"epoch_{self.trainer.epoch + 1}.pth",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class CheckpointLoader(HookBase):
|
||||||
|
def __init__(self, keywords="", replacement=None, strict=False):
|
||||||
|
self.keywords = keywords
|
||||||
|
self.replacement = replacement if replacement is not None else keywords
|
||||||
|
self.strict = strict
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
self.trainer.logger.info("=> Loading checkpoint & weight ...")
|
||||||
|
if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight):
|
||||||
|
self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}")
|
||||||
|
checkpoint = torch.load(
|
||||||
|
self.trainer.cfg.weight,
|
||||||
|
map_location=lambda storage, loc: storage.cuda(),
|
||||||
|
)
|
||||||
|
self.trainer.logger.info(
|
||||||
|
f"Loading layer weights with keyword: {self.keywords}, "
|
||||||
|
f"replace keyword with: {self.replacement}"
|
||||||
|
)
|
||||||
|
weight = OrderedDict()
|
||||||
|
for key, value in checkpoint["state_dict"].items():
|
||||||
|
if not key.startswith("module."):
|
||||||
|
if comm.get_world_size() > 1:
|
||||||
|
key = "module." + key # xxx.xxx -> module.xxx.xxx
|
||||||
|
# Now all keys contain "module." no matter DDP or not.
|
||||||
|
if self.keywords in key:
|
||||||
|
key = key.replace(self.keywords, self.replacement)
|
||||||
|
if comm.get_world_size() == 1:
|
||||||
|
key = key[7:] # module.xxx.xxx -> xxx.xxx
|
||||||
|
weight[key] = value
|
||||||
|
load_state_info = self.trainer.model.load_state_dict(
|
||||||
|
weight, strict=self.strict
|
||||||
|
)
|
||||||
|
self.trainer.logger.info(f"Missing keys: {load_state_info[0]}")
|
||||||
|
if self.trainer.cfg.resume:
|
||||||
|
self.trainer.logger.info(
|
||||||
|
f"Resuming train at eval epoch: {checkpoint['epoch']}"
|
||||||
|
)
|
||||||
|
self.trainer.start_epoch = checkpoint["epoch"]
|
||||||
|
self.trainer.best_metric_value = checkpoint["best_metric_value"]
|
||||||
|
self.trainer.optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
|
self.trainer.scheduler.load_state_dict(checkpoint["scheduler"])
|
||||||
|
if self.trainer.cfg.enable_amp:
|
||||||
|
self.trainer.scaler.load_state_dict(checkpoint["scaler"])
|
||||||
|
else:
|
||||||
|
self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}")
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class PreciseEvaluator(HookBase):
|
||||||
|
def __init__(self, test_last=False):
|
||||||
|
self.test_last = test_last
|
||||||
|
|
||||||
|
def after_train(self):
|
||||||
|
self.trainer.logger.info(
|
||||||
|
">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>"
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cfg = self.trainer.cfg
|
||||||
|
tester = TESTERS.build(
|
||||||
|
dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model)
|
||||||
|
)
|
||||||
|
if self.test_last:
|
||||||
|
self.trainer.logger.info("=> Testing on model_last ...")
|
||||||
|
else:
|
||||||
|
self.trainer.logger.info("=> Testing on model_best ...")
|
||||||
|
best_path = os.path.join(
|
||||||
|
self.trainer.cfg.save_path, "model", "model_best.pth"
|
||||||
|
)
|
||||||
|
checkpoint = torch.load(best_path)
|
||||||
|
state_dict = checkpoint["state_dict"]
|
||||||
|
tester.model.load_state_dict(state_dict, strict=True)
|
||||||
|
tester.test()
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class DataCacheOperator(HookBase):
|
||||||
|
def __init__(self, data_root, split):
|
||||||
|
self.data_root = data_root
|
||||||
|
self.split = split
|
||||||
|
self.data_list = self.get_data_list()
|
||||||
|
|
||||||
|
def get_data_list(self):
|
||||||
|
if isinstance(self.split, str):
|
||||||
|
data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth"))
|
||||||
|
elif isinstance(self.split, Sequence):
|
||||||
|
data_list = []
|
||||||
|
for split in self.split:
|
||||||
|
data_list += glob.glob(os.path.join(self.data_root, split, "*.pth"))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
def get_cache_name(self, data_path):
|
||||||
|
data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0]
|
||||||
|
return "pointcept" + data_name.replace(os.path.sep, "-")
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
self.trainer.logger.info(
|
||||||
|
f"=> Caching dataset: {self.data_root}, split: {self.split} ..."
|
||||||
|
)
|
||||||
|
if is_main_process():
|
||||||
|
for data_path in self.data_list:
|
||||||
|
cache_name = self.get_cache_name(data_path)
|
||||||
|
data = torch.load(data_path)
|
||||||
|
shared_dict(cache_name, data)
|
||||||
|
synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class RuntimeProfiler(HookBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
forward=True,
|
||||||
|
backward=True,
|
||||||
|
interrupt=False,
|
||||||
|
warm_up=2,
|
||||||
|
sort_by="cuda_time_total",
|
||||||
|
row_limit=30,
|
||||||
|
):
|
||||||
|
self.forward = forward
|
||||||
|
self.backward = backward
|
||||||
|
self.interrupt = interrupt
|
||||||
|
self.warm_up = warm_up
|
||||||
|
self.sort_by = sort_by
|
||||||
|
self.row_limit = row_limit
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
self.trainer.logger.info("Profiling runtime ...")
|
||||||
|
from torch.profiler import profile, record_function, ProfilerActivity
|
||||||
|
|
||||||
|
for i, input_dict in enumerate(self.trainer.train_loader):
|
||||||
|
if i == self.warm_up + 1:
|
||||||
|
break
|
||||||
|
for key in input_dict.keys():
|
||||||
|
if isinstance(input_dict[key], torch.Tensor):
|
||||||
|
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||||
|
if self.forward:
|
||||||
|
with profile(
|
||||||
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
|
record_shapes=True,
|
||||||
|
profile_memory=True,
|
||||||
|
with_stack=True,
|
||||||
|
) as forward_prof:
|
||||||
|
with record_function("model_inference"):
|
||||||
|
output_dict = self.trainer.model(input_dict)
|
||||||
|
else:
|
||||||
|
output_dict = self.trainer.model(input_dict)
|
||||||
|
loss = output_dict["loss"]
|
||||||
|
if self.backward:
|
||||||
|
with profile(
|
||||||
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
|
record_shapes=True,
|
||||||
|
profile_memory=True,
|
||||||
|
with_stack=True,
|
||||||
|
) as backward_prof:
|
||||||
|
with record_function("model_inference"):
|
||||||
|
loss.backward()
|
||||||
|
self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]")
|
||||||
|
if self.forward:
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Forward profile: \n"
|
||||||
|
+ str(
|
||||||
|
forward_prof.key_averages().table(
|
||||||
|
sort_by=self.sort_by, row_limit=self.row_limit
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
forward_prof.export_chrome_trace(
|
||||||
|
os.path.join(self.trainer.cfg.save_path, "forward_trace.json")
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.backward:
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Backward profile: \n"
|
||||||
|
+ str(
|
||||||
|
backward_prof.key_averages().table(
|
||||||
|
sort_by=self.sort_by, row_limit=self.row_limit
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
backward_prof.export_chrome_trace(
|
||||||
|
os.path.join(self.trainer.cfg.save_path, "backward_trace.json")
|
||||||
|
)
|
||||||
|
if self.interrupt:
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class RuntimeProfilerV2(HookBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
interrupt=False,
|
||||||
|
wait=1,
|
||||||
|
warmup=1,
|
||||||
|
active=10,
|
||||||
|
repeat=1,
|
||||||
|
sort_by="cuda_time_total",
|
||||||
|
row_limit=30,
|
||||||
|
):
|
||||||
|
self.interrupt = interrupt
|
||||||
|
self.wait = wait
|
||||||
|
self.warmup = warmup
|
||||||
|
self.active = active
|
||||||
|
self.repeat = repeat
|
||||||
|
self.sort_by = sort_by
|
||||||
|
self.row_limit = row_limit
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
self.trainer.logger.info("Profiling runtime ...")
|
||||||
|
from torch.profiler import (
|
||||||
|
profile,
|
||||||
|
record_function,
|
||||||
|
ProfilerActivity,
|
||||||
|
schedule,
|
||||||
|
tensorboard_trace_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
prof = profile(
|
||||||
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
|
schedule=schedule(
|
||||||
|
wait=self.wait,
|
||||||
|
warmup=self.warmup,
|
||||||
|
active=self.active,
|
||||||
|
repeat=self.repeat,
|
||||||
|
),
|
||||||
|
on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path),
|
||||||
|
record_shapes=True,
|
||||||
|
profile_memory=True,
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
prof.start()
|
||||||
|
for i, input_dict in enumerate(self.trainer.train_loader):
|
||||||
|
if i >= (self.wait + self.warmup + self.active) * self.repeat:
|
||||||
|
break
|
||||||
|
for key in input_dict.keys():
|
||||||
|
if isinstance(input_dict[key], torch.Tensor):
|
||||||
|
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||||
|
with record_function("model_forward"):
|
||||||
|
output_dict = self.trainer.model(input_dict)
|
||||||
|
loss = output_dict["loss"]
|
||||||
|
with record_function("model_backward"):
|
||||||
|
loss.backward()
|
||||||
|
prof.step()
|
||||||
|
self.trainer.logger.info(
|
||||||
|
f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]"
|
||||||
|
)
|
||||||
|
self.trainer.logger.info(
|
||||||
|
"Profile: \n"
|
||||||
|
+ str(
|
||||||
|
prof.key_averages().table(
|
||||||
|
sort_by=self.sort_by, row_limit=self.row_limit
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prof.stop()
|
||||||
|
|
||||||
|
if self.interrupt:
|
||||||
|
sys.exit(0)
|
||||||
285
engines/infer.py
Normal file
285
engines/infer.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .defaults import create_ddp_model
|
||||||
|
import utils.comm as comm
|
||||||
|
from models import build_model
|
||||||
|
from utils.logger import get_root_logger
|
||||||
|
from utils.registry import Registry
|
||||||
|
from utils.misc import (
|
||||||
|
AverageMeter,
|
||||||
|
)
|
||||||
|
|
||||||
|
from models.utils import smooth_mouth_movements, apply_frame_blending, apply_savitzky_golay_smoothing, apply_random_brow_movement, \
|
||||||
|
symmetrize_blendshapes, apply_random_eye_blinks, apply_random_eye_blinks_context, export_blendshape_animation, \
|
||||||
|
RETURN_CODE, DEFAULT_CONTEXT, ARKitBlendShape
|
||||||
|
|
||||||
|
INFER = Registry("infer")
|
||||||
|
|
||||||
|
class InferBase:
|
||||||
|
def __init__(self, cfg, model=None, verbose=False) -> None:
|
||||||
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
self.logger = get_root_logger(
|
||||||
|
log_file=os.path.join(cfg.save_path, "infer.log"),
|
||||||
|
file_mode="a" if cfg.resume else "w",
|
||||||
|
)
|
||||||
|
self.logger.info("=> Loading config ...")
|
||||||
|
self.cfg = cfg
|
||||||
|
self.verbose = verbose
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(f"Save path: {cfg.save_path}")
|
||||||
|
self.logger.info(f"Config:\n{cfg.pretty_text}")
|
||||||
|
if model is None:
|
||||||
|
self.logger.info("=> Building model ...")
|
||||||
|
self.model = self.build_model()
|
||||||
|
else:
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
model = build_model(self.cfg.model)
|
||||||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Num params: {n_parameters}")
|
||||||
|
model = create_ddp_model(
|
||||||
|
model.cuda(),
|
||||||
|
broadcast_buffers=False,
|
||||||
|
find_unused_parameters=self.cfg.find_unused_parameters,
|
||||||
|
)
|
||||||
|
if os.path.isfile(self.cfg.weight):
|
||||||
|
self.logger.info(f"Loading weight at: {self.cfg.weight}")
|
||||||
|
checkpoint = torch.load(self.cfg.weight)
|
||||||
|
weight = OrderedDict()
|
||||||
|
for key, value in checkpoint["state_dict"].items():
|
||||||
|
if key.startswith("module."):
|
||||||
|
if comm.get_world_size() == 1:
|
||||||
|
key = key[7:] # module.xxx.xxx -> xxx.xxx
|
||||||
|
else:
|
||||||
|
if comm.get_world_size() > 1:
|
||||||
|
key = "module." + key # xxx.xxx -> module.xxx.xxx
|
||||||
|
weight[key] = value
|
||||||
|
model.load_state_dict(weight, strict=True)
|
||||||
|
self.logger.info(
|
||||||
|
"=> Loaded weight '{}'".format(
|
||||||
|
self.cfg.weight
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def infer(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@INFER.register_module()
|
||||||
|
class Audio2ExpressionInfer(InferBase):
|
||||||
|
def infer(self):
|
||||||
|
logger = get_root_logger()
|
||||||
|
logger.info(">>>>>>>>>>>>>>>> Start Inference >>>>>>>>>>>>>>>>")
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# process audio-input
|
||||||
|
assert os.path.exists(self.cfg.audio_input)
|
||||||
|
if(self.cfg.ex_vol):
|
||||||
|
logger.info("Extract vocals ...")
|
||||||
|
vocal_path = self.extract_vocal_track(self.cfg.audio_input)
|
||||||
|
logger.info("=> Extract vocals at: {}".format(vocal_path if os.path.exists(vocal_path) else '... Failed'))
|
||||||
|
if(os.path.exists(vocal_path)):
|
||||||
|
self.cfg.audio_input = vocal_path
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
input_dict = {}
|
||||||
|
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
||||||
|
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[None,...]
|
||||||
|
speech_array, ssr = librosa.load(self.cfg.audio_input, sr=16000)
|
||||||
|
input_dict['input_audio_array'] = torch.FloatTensor(speech_array).cuda(non_blocking=True)[None,...]
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
output_dict = self.model(input_dict)
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Infer: [{}] "
|
||||||
|
"Running Time: {batch_time.avg:.3f} ".format(
|
||||||
|
self.cfg.audio_input,
|
||||||
|
batch_time=batch_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
frame_length = math.ceil(speech_array.shape[0] / ssr * 30)
|
||||||
|
volume = librosa.feature.rms(y=speech_array, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
|
||||||
|
if (volume.shape[0] > frame_length):
|
||||||
|
volume = volume[:frame_length]
|
||||||
|
|
||||||
|
if(self.cfg.movement_smooth):
|
||||||
|
out_exp = smooth_mouth_movements(out_exp, 0, volume)
|
||||||
|
|
||||||
|
if (self.cfg.brow_movement):
|
||||||
|
out_exp = apply_random_brow_movement(out_exp, volume)
|
||||||
|
|
||||||
|
pred_exp = self.blendshape_postprocess(out_exp)
|
||||||
|
|
||||||
|
if(self.cfg.save_json_path is not None):
|
||||||
|
export_blendshape_animation(pred_exp,
|
||||||
|
self.cfg.save_json_path,
|
||||||
|
ARKitBlendShape,
|
||||||
|
fps=self.cfg.fps)
|
||||||
|
|
||||||
|
logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||||
|
|
||||||
|
def infer_streaming_audio(self,
|
||||||
|
audio: np.ndarray,
|
||||||
|
ssr: float,
|
||||||
|
context: dict):
|
||||||
|
|
||||||
|
if (context is None):
|
||||||
|
context = DEFAULT_CONTEXT.copy()
|
||||||
|
max_frame_length = 64
|
||||||
|
|
||||||
|
frame_length = math.ceil(audio.shape[0] / ssr * 30)
|
||||||
|
output_context = DEFAULT_CONTEXT.copy()
|
||||||
|
|
||||||
|
volume = librosa.feature.rms(y=audio, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
|
||||||
|
if (volume.shape[0] > frame_length):
|
||||||
|
volume = volume[:frame_length]
|
||||||
|
|
||||||
|
# resample audio
|
||||||
|
if (ssr != self.cfg.audio_sr):
|
||||||
|
in_audio = librosa.resample(audio.astype(np.float32), orig_sr=ssr, target_sr=self.cfg.audio_sr)
|
||||||
|
else:
|
||||||
|
in_audio = audio.copy()
|
||||||
|
|
||||||
|
start_frame = int(max_frame_length - in_audio.shape[0] / self.cfg.audio_sr * 30)
|
||||||
|
|
||||||
|
if (context['is_initial_input'] or (context['previous_audio'] is None)):
|
||||||
|
blank_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
|
||||||
|
blank_audio = np.zeros(blank_audio_length, dtype=np.float32)
|
||||||
|
|
||||||
|
# pre-append
|
||||||
|
input_audio = np.concatenate([blank_audio, in_audio])
|
||||||
|
output_context['previous_audio'] = input_audio
|
||||||
|
|
||||||
|
else:
|
||||||
|
clip_pre_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
|
||||||
|
clip_pre_audio = context['previous_audio'][-clip_pre_audio_length:]
|
||||||
|
input_audio = np.concatenate([clip_pre_audio, in_audio])
|
||||||
|
output_context['previous_audio'] = input_audio
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
try:
|
||||||
|
input_dict = {}
|
||||||
|
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
||||||
|
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[
|
||||||
|
None, ...]
|
||||||
|
input_dict['input_audio_array'] = torch.FloatTensor(input_audio).cuda(non_blocking=True)[None, ...]
|
||||||
|
output_dict = self.model(input_dict)
|
||||||
|
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()[start_frame:, :]
|
||||||
|
except:
|
||||||
|
self.logger.error('Error: faided to predict expression.')
|
||||||
|
output_dict['pred_exp'] = torch.zeros((max_frame_length, 52)).float()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# post-process
|
||||||
|
if (context['previous_expression'] is None):
|
||||||
|
out_exp = self.apply_expression_postprocessing(out_exp, audio_volume=volume)
|
||||||
|
else:
|
||||||
|
previous_length = context['previous_expression'].shape[0]
|
||||||
|
out_exp = self.apply_expression_postprocessing(expression_params = np.concatenate([context['previous_expression'], out_exp], axis=0),
|
||||||
|
audio_volume=np.concatenate([context['previous_volume'], volume], axis=0),
|
||||||
|
processed_frames=previous_length)[previous_length:, :]
|
||||||
|
|
||||||
|
if (context['previous_expression'] is not None):
|
||||||
|
output_context['previous_expression'] = np.concatenate([context['previous_expression'], out_exp], axis=0)[
|
||||||
|
-max_frame_length:, :]
|
||||||
|
output_context['previous_volume'] = np.concatenate([context['previous_volume'], volume], axis=0)[-max_frame_length:]
|
||||||
|
else:
|
||||||
|
output_context['previous_expression'] = out_exp.copy()
|
||||||
|
output_context['previous_volume'] = volume.copy()
|
||||||
|
|
||||||
|
output_context['first_input_flag'] = False
|
||||||
|
|
||||||
|
return {"code": RETURN_CODE['SUCCESS'],
|
||||||
|
"expression": out_exp,
|
||||||
|
"headpose": None}, output_context
|
||||||
|
def apply_expression_postprocessing(
|
||||||
|
self,
|
||||||
|
expression_params: np.ndarray,
|
||||||
|
processed_frames: int = 0,
|
||||||
|
audio_volume: np.ndarray = None
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Applies full post-processing pipeline to facial expression parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression_params: Raw output from animation model [num_frames, num_parameters]
|
||||||
|
processed_frames: Number of frames already processed in previous batches
|
||||||
|
audio_volume: Optional volume array for audio-visual synchronization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed expression parameters ready for animation synthesis
|
||||||
|
"""
|
||||||
|
# Pipeline execution order matters - maintain sequence
|
||||||
|
expression_params = smooth_mouth_movements(expression_params, processed_frames, audio_volume)
|
||||||
|
expression_params = apply_frame_blending(expression_params, processed_frames)
|
||||||
|
expression_params, _ = apply_savitzky_golay_smoothing(expression_params, window_length=5)
|
||||||
|
expression_params = symmetrize_blendshapes(expression_params)
|
||||||
|
expression_params = apply_random_eye_blinks_context(expression_params, processed_frames=processed_frames)
|
||||||
|
|
||||||
|
return expression_params
|
||||||
|
|
||||||
|
def extract_vocal_track(
|
||||||
|
self,
|
||||||
|
input_audio_path: str
|
||||||
|
) -> str:
|
||||||
|
"""Isolates vocal track from audio file using source separation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_audio_path: Path to input audio file containing vocals+accompaniment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to isolated vocal track in WAV format
|
||||||
|
"""
|
||||||
|
separation_command = f'spleeter separate -p spleeter:2stems -o {self.cfg.save_path} {input_audio_path}'
|
||||||
|
os.system(separation_command)
|
||||||
|
|
||||||
|
base_name = os.path.splitext(os.path.basename(input_audio_path))[0]
|
||||||
|
return os.path.join(self.cfg.save_path, base_name, 'vocals.wav')
|
||||||
|
|
||||||
|
def blendshape_postprocess(self,
|
||||||
|
bs_array: np.ndarray
|
||||||
|
)->np.array:
|
||||||
|
|
||||||
|
bs_array, _ = apply_savitzky_golay_smoothing(bs_array, window_length=5)
|
||||||
|
bs_array = symmetrize_blendshapes(bs_array)
|
||||||
|
bs_array = apply_random_eye_blinks(bs_array)
|
||||||
|
|
||||||
|
return bs_array
|
||||||
135
engines/launch.py
Normal file
135
engines/launch.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
Launcher
|
||||||
|
|
||||||
|
modified from detectron2(https://github.com/facebookresearch/detectron2)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from datetime import timedelta
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from utils import comm
|
||||||
|
|
||||||
|
__all__ = ["DEFAULT_TIMEOUT", "launch"]
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = timedelta(minutes=30)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_free_port():
|
||||||
|
import socket
|
||||||
|
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
# Binding to port 0 will cause the OS to find an available port for us
|
||||||
|
sock.bind(("", 0))
|
||||||
|
port = sock.getsockname()[1]
|
||||||
|
sock.close()
|
||||||
|
# NOTE: there is still a chance the port could be taken by other processes.
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
def launch(
|
||||||
|
main_func,
|
||||||
|
num_gpus_per_machine,
|
||||||
|
num_machines=1,
|
||||||
|
machine_rank=0,
|
||||||
|
dist_url=None,
|
||||||
|
cfg=(),
|
||||||
|
timeout=DEFAULT_TIMEOUT,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Launch multi-gpu or distributed training.
|
||||||
|
This function must be called on all machines involved in the training.
|
||||||
|
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
|
||||||
|
Args:
|
||||||
|
main_func: a function that will be called by `main_func(*args)`
|
||||||
|
num_gpus_per_machine (int): number of GPUs per machine
|
||||||
|
num_machines (int): the total number of machines
|
||||||
|
machine_rank (int): the rank of this machine
|
||||||
|
dist_url (str): url to connect to for distributed jobs, including protocol
|
||||||
|
e.g. "tcp://127.0.0.1:8686".
|
||||||
|
Can be set to "auto" to automatically select a free port on localhost
|
||||||
|
timeout (timedelta): timeout of the distributed workers
|
||||||
|
args (tuple): arguments passed to main_func
|
||||||
|
"""
|
||||||
|
world_size = num_machines * num_gpus_per_machine
|
||||||
|
if world_size > 1:
|
||||||
|
if dist_url == "auto":
|
||||||
|
assert (
|
||||||
|
num_machines == 1
|
||||||
|
), "dist_url=auto not supported in multi-machine jobs."
|
||||||
|
port = _find_free_port()
|
||||||
|
dist_url = f"tcp://127.0.0.1:{port}"
|
||||||
|
if num_machines > 1 and dist_url.startswith("file://"):
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning(
|
||||||
|
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
|
||||||
|
)
|
||||||
|
|
||||||
|
mp.spawn(
|
||||||
|
_distributed_worker,
|
||||||
|
nprocs=num_gpus_per_machine,
|
||||||
|
args=(
|
||||||
|
main_func,
|
||||||
|
world_size,
|
||||||
|
num_gpus_per_machine,
|
||||||
|
machine_rank,
|
||||||
|
dist_url,
|
||||||
|
cfg,
|
||||||
|
timeout,
|
||||||
|
),
|
||||||
|
daemon=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
main_func(*cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _distributed_worker(
|
||||||
|
local_rank,
|
||||||
|
main_func,
|
||||||
|
world_size,
|
||||||
|
num_gpus_per_machine,
|
||||||
|
machine_rank,
|
||||||
|
dist_url,
|
||||||
|
cfg,
|
||||||
|
timeout=DEFAULT_TIMEOUT,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
), "cuda is not available. Please check your installation."
|
||||||
|
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
||||||
|
try:
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="NCCL",
|
||||||
|
init_method=dist_url,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=global_rank,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.error("Process group URL: {}".format(dist_url))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Setup the local process group (which contains ranks within the same machine)
|
||||||
|
assert comm._LOCAL_PROCESS_GROUP is None
|
||||||
|
num_machines = world_size // num_gpus_per_machine
|
||||||
|
for i in range(num_machines):
|
||||||
|
ranks_on_i = list(
|
||||||
|
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
|
||||||
|
)
|
||||||
|
pg = dist.new_group(ranks_on_i)
|
||||||
|
if i == machine_rank:
|
||||||
|
comm._LOCAL_PROCESS_GROUP = pg
|
||||||
|
|
||||||
|
assert num_gpus_per_machine <= torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
|
||||||
|
# synchronize is needed here to prevent a possible timeout after calling init_process_group
|
||||||
|
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
|
||||||
|
comm.synchronize()
|
||||||
|
|
||||||
|
main_func(*cfg)
|
||||||
299
engines/train.py
Normal file
299
engines/train.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import weakref
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.data
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from collections.abc import Iterator
|
||||||
|
else:
|
||||||
|
from collections import Iterator
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
from .defaults import create_ddp_model, worker_init_fn
|
||||||
|
from .hooks import HookBase, build_hooks
|
||||||
|
import utils.comm as comm
|
||||||
|
from datasets import build_dataset, point_collate_fn, collate_fn
|
||||||
|
from models import build_model
|
||||||
|
from utils.logger import get_root_logger
|
||||||
|
from utils.optimizer import build_optimizer
|
||||||
|
from utils.scheduler import build_scheduler
|
||||||
|
from utils.events import EventStorage
|
||||||
|
from utils.registry import Registry
|
||||||
|
|
||||||
|
|
||||||
|
TRAINERS = Registry("trainers")
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerBase:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.hooks = []
|
||||||
|
self.epoch = 0
|
||||||
|
self.start_epoch = 0
|
||||||
|
self.max_epoch = 0
|
||||||
|
self.max_iter = 0
|
||||||
|
self.comm_info = dict()
|
||||||
|
self.data_iterator: Iterator = enumerate([])
|
||||||
|
self.storage: EventStorage
|
||||||
|
self.writer: SummaryWriter
|
||||||
|
|
||||||
|
def register_hooks(self, hooks) -> None:
|
||||||
|
hooks = build_hooks(hooks)
|
||||||
|
for h in hooks:
|
||||||
|
assert isinstance(h, HookBase)
|
||||||
|
# To avoid circular reference, hooks and trainer cannot own each other.
|
||||||
|
# This normally does not matter, but will cause memory leak if the
|
||||||
|
# involved objects contain __del__:
|
||||||
|
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
|
||||||
|
h.trainer = weakref.proxy(self)
|
||||||
|
self.hooks.extend(hooks)
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
with EventStorage() as self.storage:
|
||||||
|
# => before train
|
||||||
|
self.before_train()
|
||||||
|
for self.epoch in range(self.start_epoch, self.max_epoch):
|
||||||
|
# => before epoch
|
||||||
|
self.before_epoch()
|
||||||
|
# => run_epoch
|
||||||
|
for (
|
||||||
|
self.comm_info["iter"],
|
||||||
|
self.comm_info["input_dict"],
|
||||||
|
) in self.data_iterator:
|
||||||
|
# => before_step
|
||||||
|
self.before_step()
|
||||||
|
# => run_step
|
||||||
|
self.run_step()
|
||||||
|
# => after_step
|
||||||
|
self.after_step()
|
||||||
|
# => after epoch
|
||||||
|
self.after_epoch()
|
||||||
|
# => after train
|
||||||
|
self.after_train()
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
for h in self.hooks:
|
||||||
|
h.before_train()
|
||||||
|
|
||||||
|
def before_epoch(self):
|
||||||
|
for h in self.hooks:
|
||||||
|
h.before_epoch()
|
||||||
|
|
||||||
|
def before_step(self):
|
||||||
|
for h in self.hooks:
|
||||||
|
h.before_step()
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def after_step(self):
|
||||||
|
for h in self.hooks:
|
||||||
|
h.after_step()
|
||||||
|
|
||||||
|
def after_epoch(self):
|
||||||
|
for h in self.hooks:
|
||||||
|
h.after_epoch()
|
||||||
|
self.storage.reset_histories()
|
||||||
|
|
||||||
|
def after_train(self):
|
||||||
|
# Sync GPU before running train hooks
|
||||||
|
comm.synchronize()
|
||||||
|
for h in self.hooks:
|
||||||
|
h.after_train()
|
||||||
|
if comm.is_main_process():
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINERS.register_module("DefaultTrainer")
|
||||||
|
class Trainer(TrainerBase):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(Trainer, self).__init__()
|
||||||
|
self.epoch = 0
|
||||||
|
self.start_epoch = 0
|
||||||
|
self.max_epoch = cfg.eval_epoch
|
||||||
|
self.best_metric_value = -torch.inf
|
||||||
|
self.logger = get_root_logger(
|
||||||
|
log_file=os.path.join(cfg.save_path, "train.log"),
|
||||||
|
file_mode="a" if cfg.resume else "w",
|
||||||
|
)
|
||||||
|
self.logger.info("=> Loading config ...")
|
||||||
|
self.cfg = cfg
|
||||||
|
self.logger.info(f"Save path: {cfg.save_path}")
|
||||||
|
self.logger.info(f"Config:\n{cfg.pretty_text}")
|
||||||
|
self.logger.info("=> Building model ...")
|
||||||
|
self.model = self.build_model()
|
||||||
|
self.logger.info("=> Building writer ...")
|
||||||
|
self.writer = self.build_writer()
|
||||||
|
self.logger.info("=> Building train dataset & dataloader ...")
|
||||||
|
self.train_loader = self.build_train_loader()
|
||||||
|
self.logger.info("=> Building val dataset & dataloader ...")
|
||||||
|
self.val_loader = self.build_val_loader()
|
||||||
|
self.logger.info("=> Building optimize, scheduler, scaler(amp) ...")
|
||||||
|
self.optimizer = self.build_optimizer()
|
||||||
|
self.scheduler = self.build_scheduler()
|
||||||
|
self.scaler = self.build_scaler()
|
||||||
|
self.logger.info("=> Building hooks ...")
|
||||||
|
self.register_hooks(self.cfg.hooks)
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
with EventStorage() as self.storage:
|
||||||
|
# => before train
|
||||||
|
self.before_train()
|
||||||
|
self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
|
||||||
|
for self.epoch in range(self.start_epoch, self.max_epoch):
|
||||||
|
# => before epoch
|
||||||
|
# TODO: optimize to iteration based
|
||||||
|
if comm.get_world_size() > 1:
|
||||||
|
self.train_loader.sampler.set_epoch(self.epoch)
|
||||||
|
self.model.train()
|
||||||
|
self.data_iterator = enumerate(self.train_loader)
|
||||||
|
self.before_epoch()
|
||||||
|
# => run_epoch
|
||||||
|
for (
|
||||||
|
self.comm_info["iter"],
|
||||||
|
self.comm_info["input_dict"],
|
||||||
|
) in self.data_iterator:
|
||||||
|
# => before_step
|
||||||
|
self.before_step()
|
||||||
|
# => run_step
|
||||||
|
self.run_step()
|
||||||
|
# => after_step
|
||||||
|
self.after_step()
|
||||||
|
# => after epoch
|
||||||
|
self.after_epoch()
|
||||||
|
# => after train
|
||||||
|
self.after_train()
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
input_dict = self.comm_info["input_dict"]
|
||||||
|
for key in input_dict.keys():
|
||||||
|
if isinstance(input_dict[key], torch.Tensor):
|
||||||
|
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||||
|
with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp):
|
||||||
|
output_dict = self.model(input_dict)
|
||||||
|
loss = output_dict["loss"]
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
if self.cfg.enable_amp:
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
|
||||||
|
# When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
|
||||||
|
# Fix torch warning scheduler step before optimizer step.
|
||||||
|
scaler = self.scaler.get_scale()
|
||||||
|
self.scaler.update()
|
||||||
|
if scaler <= self.scaler.get_scale():
|
||||||
|
self.scheduler.step()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
self.scheduler.step()
|
||||||
|
if self.cfg.empty_cache:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.comm_info["model_output_dict"] = output_dict
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
model = build_model(self.cfg.model)
|
||||||
|
if self.cfg.sync_bn:
|
||||||
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
# logger.info(f"Model: \n{self.model}")
|
||||||
|
self.logger.info(f"Num params: {n_parameters}")
|
||||||
|
model = create_ddp_model(
|
||||||
|
model.cuda(),
|
||||||
|
broadcast_buffers=False,
|
||||||
|
find_unused_parameters=self.cfg.find_unused_parameters,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def build_writer(self):
|
||||||
|
writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None
|
||||||
|
self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}")
|
||||||
|
return writer
|
||||||
|
|
||||||
|
def build_train_loader(self):
|
||||||
|
train_data = build_dataset(self.cfg.data.train)
|
||||||
|
|
||||||
|
if comm.get_world_size() > 1:
|
||||||
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
|
||||||
|
else:
|
||||||
|
train_sampler = None
|
||||||
|
|
||||||
|
init_fn = (
|
||||||
|
partial(
|
||||||
|
worker_init_fn,
|
||||||
|
num_workers=self.cfg.num_worker_per_gpu,
|
||||||
|
rank=comm.get_rank(),
|
||||||
|
seed=self.cfg.seed,
|
||||||
|
)
|
||||||
|
if self.cfg.seed is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_data,
|
||||||
|
batch_size=self.cfg.batch_size_per_gpu,
|
||||||
|
shuffle=(train_sampler is None),
|
||||||
|
num_workers=0,
|
||||||
|
sampler=train_sampler,
|
||||||
|
collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob),
|
||||||
|
pin_memory=True,
|
||||||
|
worker_init_fn=init_fn,
|
||||||
|
drop_last=True,
|
||||||
|
# persistent_workers=True,
|
||||||
|
)
|
||||||
|
return train_loader
|
||||||
|
|
||||||
|
def build_val_loader(self):
|
||||||
|
val_loader = None
|
||||||
|
if self.cfg.evaluate:
|
||||||
|
val_data = build_dataset(self.cfg.data.val)
|
||||||
|
if comm.get_world_size() > 1:
|
||||||
|
val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
|
||||||
|
else:
|
||||||
|
val_sampler = None
|
||||||
|
val_loader = torch.utils.data.DataLoader(
|
||||||
|
val_data,
|
||||||
|
batch_size=self.cfg.batch_size_val_per_gpu,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.cfg.num_worker_per_gpu,
|
||||||
|
pin_memory=True,
|
||||||
|
sampler=val_sampler,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
return val_loader
|
||||||
|
|
||||||
|
def build_optimizer(self):
|
||||||
|
return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts)
|
||||||
|
|
||||||
|
def build_scheduler(self):
|
||||||
|
assert hasattr(self, "optimizer")
|
||||||
|
assert hasattr(self, "train_loader")
|
||||||
|
self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch
|
||||||
|
return build_scheduler(self.cfg.scheduler, self.optimizer)
|
||||||
|
|
||||||
|
def build_scaler(self):
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None
|
||||||
|
return scaler
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINERS.register_module("MultiDatasetTrainer")
|
||||||
|
class MultiDatasetTrainer(Trainer):
|
||||||
|
def build_train_loader(self):
|
||||||
|
from datasets import MultiDatasetDataloader
|
||||||
|
|
||||||
|
train_data = build_dataset(self.cfg.data.train)
|
||||||
|
train_loader = MultiDatasetDataloader(
|
||||||
|
train_data,
|
||||||
|
self.cfg.batch_size_per_gpu,
|
||||||
|
self.cfg.num_worker_per_gpu,
|
||||||
|
self.cfg.mix_prob,
|
||||||
|
self.cfg.seed,
|
||||||
|
)
|
||||||
|
self.comm_info["iter_per_epoch"] = len(train_loader)
|
||||||
|
return train_loader
|
||||||
48
inference.py
Normal file
48
inference.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""
|
||||||
|
# Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from engines.defaults import (
|
||||||
|
default_argument_parser,
|
||||||
|
default_config_parser,
|
||||||
|
default_setup,
|
||||||
|
)
|
||||||
|
from engines.infer import INFER
|
||||||
|
from engines.launch import launch
|
||||||
|
|
||||||
|
|
||||||
|
def main_worker(cfg):
|
||||||
|
cfg = default_setup(cfg)
|
||||||
|
infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg))
|
||||||
|
infer.infer()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = default_argument_parser().parse_args()
|
||||||
|
cfg = default_config_parser(args.config_file, args.options)
|
||||||
|
|
||||||
|
launch(
|
||||||
|
main_worker,
|
||||||
|
num_gpus_per_machine=args.num_gpus,
|
||||||
|
num_machines=args.num_machines,
|
||||||
|
machine_rank=args.machine_rank,
|
||||||
|
dist_url=args.dist_url,
|
||||||
|
cfg=(cfg,),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
60
inference_streaming_audio.py
Normal file
60
inference_streaming_audio.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""
|
||||||
|
# Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from engines.defaults import (
|
||||||
|
default_argument_parser,
|
||||||
|
default_config_parser,
|
||||||
|
default_setup,
|
||||||
|
)
|
||||||
|
from engines.infer import INFER
|
||||||
|
import librosa
|
||||||
|
from tqdm import tqdm
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def export_json(bs_array, json_path):
|
||||||
|
from models.utils import export_blendshape_animation, ARKitBlendShape
|
||||||
|
export_blendshape_animation(bs_array, json_path, ARKitBlendShape, fps=30.0)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = default_argument_parser().parse_args()
|
||||||
|
args.config_file = 'configs/lam_audio2exp_config_streaming.py'
|
||||||
|
cfg = default_config_parser(args.config_file, args.options)
|
||||||
|
|
||||||
|
|
||||||
|
cfg = default_setup(cfg)
|
||||||
|
infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg))
|
||||||
|
infer.model.eval()
|
||||||
|
|
||||||
|
audio, sample_rate = librosa.load(cfg.audio_input, sr=16000)
|
||||||
|
context = None
|
||||||
|
input_num = audio.shape[0]//16000+1
|
||||||
|
gap = 16000
|
||||||
|
all_exp = []
|
||||||
|
for i in tqdm(range(input_num)):
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
output, context = infer.infer_streaming_audio(audio[i*gap:(i+1)*gap], sample_rate, context)
|
||||||
|
end = time.time()
|
||||||
|
print('Inference time {}'.format(end - start))
|
||||||
|
all_exp.append(output['expression'])
|
||||||
|
|
||||||
|
all_exp = np.concatenate(all_exp,axis=0)
|
||||||
|
|
||||||
|
export_json(all_exp, cfg.save_json_path)
|
||||||
7
models/__init__.py
Normal file
7
models/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from .builder import build_model
|
||||||
|
|
||||||
|
from .default import DefaultEstimator
|
||||||
|
|
||||||
|
# Backbones
|
||||||
|
from .network import Audio2Expression
|
||||||
|
|
||||||
13
models/builder.py
Normal file
13
models/builder.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
Modified by https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
from utils.registry import Registry
|
||||||
|
|
||||||
|
MODELS = Registry("models")
|
||||||
|
MODULES = Registry("modules")
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(cfg):
|
||||||
|
"""Build models."""
|
||||||
|
return MODELS.build(cfg)
|
||||||
25
models/default.py
Normal file
25
models/default.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from models.losses import build_criteria
|
||||||
|
from .builder import MODELS, build_model
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class DefaultEstimator(nn.Module):
|
||||||
|
def __init__(self, backbone=None, criteria=None):
|
||||||
|
super().__init__()
|
||||||
|
self.backbone = build_model(backbone)
|
||||||
|
self.criteria = build_criteria(criteria)
|
||||||
|
|
||||||
|
def forward(self, input_dict):
|
||||||
|
pred_exp = self.backbone(input_dict)
|
||||||
|
# train
|
||||||
|
if self.training:
|
||||||
|
loss = self.criteria(pred_exp, input_dict["gt_exp"])
|
||||||
|
return dict(loss=loss)
|
||||||
|
# eval
|
||||||
|
elif "gt_exp" in input_dict.keys():
|
||||||
|
loss = self.criteria(pred_exp, input_dict["gt_exp"])
|
||||||
|
return dict(loss=loss, pred_exp=pred_exp)
|
||||||
|
# infer
|
||||||
|
else:
|
||||||
|
return dict(pred_exp=pred_exp)
|
||||||
248
models/encoder/wav2vec.py
Normal file
248
models/encoder/wav2vec.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
import numpy as np
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel
|
||||||
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
|
from transformers.file_utils import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||||
|
_HIDDEN_STATES_START_POSITION = 2
|
||||||
|
|
||||||
|
|
||||||
|
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
|
||||||
|
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
||||||
|
def _compute_mask_indices(
|
||||||
|
shape: Tuple[int, int],
|
||||||
|
mask_prob: float,
|
||||||
|
mask_length: int,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
min_masks: int = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
bsz, all_sz = shape
|
||||||
|
mask = np.full((bsz, all_sz), False)
|
||||||
|
|
||||||
|
all_num_mask = int(
|
||||||
|
mask_prob * all_sz / float(mask_length)
|
||||||
|
+ np.random.rand()
|
||||||
|
)
|
||||||
|
all_num_mask = max(min_masks, all_num_mask)
|
||||||
|
mask_idcs = []
|
||||||
|
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
||||||
|
for i in range(bsz):
|
||||||
|
if padding_mask is not None:
|
||||||
|
sz = all_sz - padding_mask[i].long().sum().item()
|
||||||
|
num_mask = int(
|
||||||
|
mask_prob * sz / float(mask_length)
|
||||||
|
+ np.random.rand()
|
||||||
|
)
|
||||||
|
num_mask = max(min_masks, num_mask)
|
||||||
|
else:
|
||||||
|
sz = all_sz
|
||||||
|
num_mask = all_num_mask
|
||||||
|
|
||||||
|
lengths = np.full(num_mask, mask_length)
|
||||||
|
|
||||||
|
if sum(lengths) == 0:
|
||||||
|
lengths[0] = min(mask_length, sz - 1)
|
||||||
|
|
||||||
|
min_len = min(lengths)
|
||||||
|
if sz - min_len <= num_mask:
|
||||||
|
min_len = sz - num_mask - 1
|
||||||
|
|
||||||
|
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||||
|
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
||||||
|
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||||
|
|
||||||
|
min_len = min([len(m) for m in mask_idcs])
|
||||||
|
for i, mask_idc in enumerate(mask_idcs):
|
||||||
|
if len(mask_idc) > min_len:
|
||||||
|
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||||
|
mask[i, mask_idc] = True
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
# linear interpolation layer
|
||||||
|
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
seq_len = features.shape[2] / float(input_fps)
|
||||||
|
if output_len is None:
|
||||||
|
output_len = int(seq_len * output_fps)
|
||||||
|
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
|
||||||
|
return output_features.transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2Model(Wav2Vec2Model):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.lm_head = nn.Linear(1024, 32)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
frame_num=None
|
||||||
|
):
|
||||||
|
self.config.output_attentions = True
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
hidden_states = self.feature_extractor(input_values)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2)
|
||||||
|
|
||||||
|
hidden_states = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
||||||
|
attention_mask = torch.zeros(
|
||||||
|
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
|
attention_mask[
|
||||||
|
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
|
||||||
|
] = 1
|
||||||
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||||
|
|
||||||
|
hidden_states = self.feature_projection(hidden_states)[0]
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = encoder_outputs[0]
|
||||||
|
if not return_dict:
|
||||||
|
return (hidden_states,) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeechClassifierOutput(ModelOutput):
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2ClassificationHead(nn.Module):
|
||||||
|
"""Head for wav2vec classification task."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.dropout = nn.Dropout(config.final_dropout)
|
||||||
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
def forward(self, features, **kwargs):
|
||||||
|
x = features
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.dense(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.out_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.pooling_mode = config.pooling_mode
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.wav2vec2 = Wav2Vec2Model(config)
|
||||||
|
self.classifier = Wav2Vec2ClassificationHead(config)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def freeze_feature_extractor(self):
|
||||||
|
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def merged_strategy(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
mode="mean"
|
||||||
|
):
|
||||||
|
if mode == "mean":
|
||||||
|
outputs = torch.mean(hidden_states, dim=1)
|
||||||
|
elif mode == "sum":
|
||||||
|
outputs = torch.sum(hidden_states, dim=1)
|
||||||
|
elif mode == "max":
|
||||||
|
outputs = torch.max(hidden_states, dim=1)[0]
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
labels=None,
|
||||||
|
frame_num=None,
|
||||||
|
):
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
outputs = self.wav2vec2(
|
||||||
|
input_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
hidden_states1 = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
|
||||||
|
hidden_states = self.merged_strategy(hidden_states1, mode=self.pooling_mode)
|
||||||
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
if self.config.problem_type is None:
|
||||||
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||||
|
elif self.config.problem_type == "single_label_classification":
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SpeechClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=hidden_states1,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
87
models/encoder/wavlm.py
Normal file
87
models/encoder/wavlm.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import WavLMModel
|
||||||
|
from transformers.modeling_outputs import Wav2Vec2BaseModelOutput
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def linear_interpolation(features, output_len: int):
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
output_features = F.interpolate(
|
||||||
|
features, size=output_len, align_corners=True, mode='linear')
|
||||||
|
return output_features.transpose(1, 2)
|
||||||
|
|
||||||
|
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
|
||||||
|
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
||||||
|
|
||||||
|
|
||||||
|
class WavLMModel(WavLMModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = (not do_freeze)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values: Optional[torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
frame_num=None,
|
||||||
|
interpolate_pos: int = 0,
|
||||||
|
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
extract_features = self.feature_extractor(input_values)
|
||||||
|
extract_features = extract_features.transpose(1, 2)
|
||||||
|
|
||||||
|
if interpolate_pos == 0:
|
||||||
|
extract_features = linear_interpolation(
|
||||||
|
extract_features, output_len=frame_num)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# compute reduced attention_mask corresponding to feature vectors
|
||||||
|
attention_mask = self._get_feature_vector_attention_mask(
|
||||||
|
extract_features.shape[1], attention_mask, add_adapter=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||||
|
hidden_states = self._mask_hidden_states(
|
||||||
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
if interpolate_pos == 1:
|
||||||
|
hidden_states = linear_interpolation(
|
||||||
|
hidden_states, output_len=frame_num)
|
||||||
|
|
||||||
|
if self.adapter is not None:
|
||||||
|
hidden_states = self.adapter(hidden_states)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (hidden_states, extract_features) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return Wav2Vec2BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
extract_features=extract_features,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
4
models/losses/__init__.py
Normal file
4
models/losses/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .builder import build_criteria
|
||||||
|
|
||||||
|
from .misc import CrossEntropyLoss, SmoothCELoss, DiceLoss, FocalLoss, BinaryFocalLoss, L1Loss
|
||||||
|
from .lovasz import LovaszLoss
|
||||||
28
models/losses/builder.py
Normal file
28
models/losses/builder.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
from utils.registry import Registry
|
||||||
|
|
||||||
|
LOSSES = Registry("losses")
|
||||||
|
|
||||||
|
|
||||||
|
class Criteria(object):
|
||||||
|
def __init__(self, cfg=None):
|
||||||
|
self.cfg = cfg if cfg is not None else []
|
||||||
|
self.criteria = []
|
||||||
|
for loss_cfg in self.cfg:
|
||||||
|
self.criteria.append(LOSSES.build(cfg=loss_cfg))
|
||||||
|
|
||||||
|
def __call__(self, pred, target):
|
||||||
|
if len(self.criteria) == 0:
|
||||||
|
# loss computation occur in model
|
||||||
|
return pred
|
||||||
|
loss = 0
|
||||||
|
for c in self.criteria:
|
||||||
|
loss += c(pred, target)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def build_criteria(cfg):
|
||||||
|
return Criteria(cfg)
|
||||||
253
models/losses/lovasz.py
Normal file
253
models/losses/lovasz.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from itertools import filterfalse
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
|
from .builder import LOSSES
|
||||||
|
|
||||||
|
BINARY_MODE: str = "binary"
|
||||||
|
MULTICLASS_MODE: str = "multiclass"
|
||||||
|
MULTILABEL_MODE: str = "multilabel"
|
||||||
|
|
||||||
|
|
||||||
|
def _lovasz_grad(gt_sorted):
|
||||||
|
"""Compute gradient of the Lovasz extension w.r.t sorted errors
|
||||||
|
See Alg. 1 in paper
|
||||||
|
"""
|
||||||
|
p = len(gt_sorted)
|
||||||
|
gts = gt_sorted.sum()
|
||||||
|
intersection = gts - gt_sorted.float().cumsum(0)
|
||||||
|
union = gts + (1 - gt_sorted).float().cumsum(0)
|
||||||
|
jaccard = 1.0 - intersection / union
|
||||||
|
if p > 1: # cover 1-pixel case
|
||||||
|
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
||||||
|
return jaccard
|
||||||
|
|
||||||
|
|
||||||
|
def _lovasz_hinge(logits, labels, per_image=True, ignore=None):
|
||||||
|
"""
|
||||||
|
Binary Lovasz hinge loss
|
||||||
|
logits: [B, H, W] Logits at each pixel (between -infinity and +infinity)
|
||||||
|
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
|
||||||
|
per_image: compute the loss per image instead of per batch
|
||||||
|
ignore: void class id
|
||||||
|
"""
|
||||||
|
if per_image:
|
||||||
|
loss = mean(
|
||||||
|
_lovasz_hinge_flat(
|
||||||
|
*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)
|
||||||
|
)
|
||||||
|
for log, lab in zip(logits, labels)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _lovasz_hinge_flat(logits, labels):
|
||||||
|
"""Binary Lovasz hinge loss
|
||||||
|
Args:
|
||||||
|
logits: [P] Logits at each prediction (between -infinity and +infinity)
|
||||||
|
labels: [P] Tensor, binary ground truth labels (0 or 1)
|
||||||
|
"""
|
||||||
|
if len(labels) == 0:
|
||||||
|
# only void pixels, the gradients should be 0
|
||||||
|
return logits.sum() * 0.0
|
||||||
|
signs = 2.0 * labels.float() - 1.0
|
||||||
|
errors = 1.0 - logits * signs
|
||||||
|
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
||||||
|
perm = perm.data
|
||||||
|
gt_sorted = labels[perm]
|
||||||
|
grad = _lovasz_grad(gt_sorted)
|
||||||
|
loss = torch.dot(F.relu(errors_sorted), grad)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_binary_scores(scores, labels, ignore=None):
|
||||||
|
"""Flattens predictions in the batch (binary case)
|
||||||
|
Remove labels equal to 'ignore'
|
||||||
|
"""
|
||||||
|
scores = scores.view(-1)
|
||||||
|
labels = labels.view(-1)
|
||||||
|
if ignore is None:
|
||||||
|
return scores, labels
|
||||||
|
valid = labels != ignore
|
||||||
|
vscores = scores[valid]
|
||||||
|
vlabels = labels[valid]
|
||||||
|
return vscores, vlabels
|
||||||
|
|
||||||
|
|
||||||
|
def _lovasz_softmax(
|
||||||
|
probas, labels, classes="present", class_seen=None, per_image=False, ignore=None
|
||||||
|
):
|
||||||
|
"""Multi-class Lovasz-Softmax loss
|
||||||
|
Args:
|
||||||
|
@param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1).
|
||||||
|
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
|
||||||
|
@param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
|
||||||
|
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
||||||
|
@param per_image: compute the loss per image instead of per batch
|
||||||
|
@param ignore: void class labels
|
||||||
|
"""
|
||||||
|
if per_image:
|
||||||
|
loss = mean(
|
||||||
|
_lovasz_softmax_flat(
|
||||||
|
*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore),
|
||||||
|
classes=classes
|
||||||
|
)
|
||||||
|
for prob, lab in zip(probas, labels)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = _lovasz_softmax_flat(
|
||||||
|
*_flatten_probas(probas, labels, ignore),
|
||||||
|
classes=classes,
|
||||||
|
class_seen=class_seen
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None):
|
||||||
|
"""Multi-class Lovasz-Softmax loss
|
||||||
|
Args:
|
||||||
|
@param probas: [P, C] Class probabilities at each prediction (between 0 and 1)
|
||||||
|
@param labels: [P] Tensor, ground truth labels (between 0 and C - 1)
|
||||||
|
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
||||||
|
"""
|
||||||
|
if probas.numel() == 0:
|
||||||
|
# only void pixels, the gradients should be 0
|
||||||
|
return probas * 0.0
|
||||||
|
C = probas.size(1)
|
||||||
|
losses = []
|
||||||
|
class_to_sum = list(range(C)) if classes in ["all", "present"] else classes
|
||||||
|
# for c in class_to_sum:
|
||||||
|
for c in labels.unique():
|
||||||
|
if class_seen is None:
|
||||||
|
fg = (labels == c).type_as(probas) # foreground for class c
|
||||||
|
if classes == "present" and fg.sum() == 0:
|
||||||
|
continue
|
||||||
|
if C == 1:
|
||||||
|
if len(classes) > 1:
|
||||||
|
raise ValueError("Sigmoid output possible only with 1 class")
|
||||||
|
class_pred = probas[:, 0]
|
||||||
|
else:
|
||||||
|
class_pred = probas[:, c]
|
||||||
|
errors = (fg - class_pred).abs()
|
||||||
|
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||||
|
perm = perm.data
|
||||||
|
fg_sorted = fg[perm]
|
||||||
|
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
|
||||||
|
else:
|
||||||
|
if c in class_seen:
|
||||||
|
fg = (labels == c).type_as(probas) # foreground for class c
|
||||||
|
if classes == "present" and fg.sum() == 0:
|
||||||
|
continue
|
||||||
|
if C == 1:
|
||||||
|
if len(classes) > 1:
|
||||||
|
raise ValueError("Sigmoid output possible only with 1 class")
|
||||||
|
class_pred = probas[:, 0]
|
||||||
|
else:
|
||||||
|
class_pred = probas[:, c]
|
||||||
|
errors = (fg - class_pred).abs()
|
||||||
|
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||||
|
perm = perm.data
|
||||||
|
fg_sorted = fg[perm]
|
||||||
|
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
|
||||||
|
return mean(losses)
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_probas(probas, labels, ignore=None):
|
||||||
|
"""Flattens predictions in the batch"""
|
||||||
|
if probas.dim() == 3:
|
||||||
|
# assumes output of a sigmoid layer
|
||||||
|
B, H, W = probas.size()
|
||||||
|
probas = probas.view(B, 1, H, W)
|
||||||
|
|
||||||
|
C = probas.size(1)
|
||||||
|
probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C]
|
||||||
|
probas = probas.contiguous().view(-1, C) # [P, C]
|
||||||
|
|
||||||
|
labels = labels.view(-1)
|
||||||
|
if ignore is None:
|
||||||
|
return probas, labels
|
||||||
|
valid = labels != ignore
|
||||||
|
vprobas = probas[valid]
|
||||||
|
vlabels = labels[valid]
|
||||||
|
return vprobas, vlabels
|
||||||
|
|
||||||
|
|
||||||
|
def isnan(x):
|
||||||
|
return x != x
|
||||||
|
|
||||||
|
|
||||||
|
def mean(values, ignore_nan=False, empty=0):
|
||||||
|
"""Nan-mean compatible with generators."""
|
||||||
|
values = iter(values)
|
||||||
|
if ignore_nan:
|
||||||
|
values = filterfalse(isnan, values)
|
||||||
|
try:
|
||||||
|
n = 1
|
||||||
|
acc = next(values)
|
||||||
|
except StopIteration:
|
||||||
|
if empty == "raise":
|
||||||
|
raise ValueError("Empty mean")
|
||||||
|
return empty
|
||||||
|
for n, v in enumerate(values, 2):
|
||||||
|
acc += v
|
||||||
|
if n == 1:
|
||||||
|
return acc
|
||||||
|
return acc / n
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class LovaszLoss(_Loss):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mode: str,
|
||||||
|
class_seen: Optional[int] = None,
|
||||||
|
per_image: bool = False,
|
||||||
|
ignore_index: Optional[int] = None,
|
||||||
|
loss_weight: float = 1.0,
|
||||||
|
):
|
||||||
|
"""Lovasz loss for segmentation task.
|
||||||
|
It supports binary, multiclass and multilabel cases
|
||||||
|
Args:
|
||||||
|
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
|
||||||
|
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
|
||||||
|
per_image: If True loss computed per each image and then averaged, else computed per whole batch
|
||||||
|
Shape
|
||||||
|
- **y_pred** - torch.Tensor of shape (N, C, H, W)
|
||||||
|
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
|
||||||
|
Reference
|
||||||
|
https://github.com/BloodAxe/pytorch-toolbelt
|
||||||
|
"""
|
||||||
|
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mode = mode
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.per_image = per_image
|
||||||
|
self.class_seen = class_seen
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
|
||||||
|
def forward(self, y_pred, y_true):
|
||||||
|
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
|
||||||
|
loss = _lovasz_hinge(
|
||||||
|
y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index
|
||||||
|
)
|
||||||
|
elif self.mode == MULTICLASS_MODE:
|
||||||
|
y_pred = y_pred.softmax(dim=1)
|
||||||
|
loss = _lovasz_softmax(
|
||||||
|
y_pred,
|
||||||
|
y_true,
|
||||||
|
class_seen=self.class_seen,
|
||||||
|
per_image=self.per_image,
|
||||||
|
ignore=self.ignore_index,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Wrong mode {}.".format(self.mode))
|
||||||
|
return loss * self.loss_weight
|
||||||
241
models/losses/misc.py
Normal file
241
models/losses/misc.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .builder import LOSSES
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class CrossEntropyLoss(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight=None,
|
||||||
|
size_average=None,
|
||||||
|
reduce=None,
|
||||||
|
reduction="mean",
|
||||||
|
label_smoothing=0.0,
|
||||||
|
loss_weight=1.0,
|
||||||
|
ignore_index=-1,
|
||||||
|
):
|
||||||
|
super(CrossEntropyLoss, self).__init__()
|
||||||
|
weight = torch.tensor(weight).cuda() if weight is not None else None
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.loss = nn.CrossEntropyLoss(
|
||||||
|
weight=weight,
|
||||||
|
size_average=size_average,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
reduce=reduce,
|
||||||
|
reduction=reduction,
|
||||||
|
label_smoothing=label_smoothing,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pred, target):
|
||||||
|
return self.loss(pred, target) * self.loss_weight
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class L1Loss(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight=None,
|
||||||
|
size_average=None,
|
||||||
|
reduce=None,
|
||||||
|
reduction="mean",
|
||||||
|
label_smoothing=0.0,
|
||||||
|
loss_weight=1.0,
|
||||||
|
ignore_index=-1,
|
||||||
|
):
|
||||||
|
super(L1Loss, self).__init__()
|
||||||
|
weight = torch.tensor(weight).cuda() if weight is not None else None
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.loss = nn.L1Loss(reduction='mean')
|
||||||
|
|
||||||
|
def forward(self, pred, target):
|
||||||
|
return self.loss(pred, target[:,None]) * self.loss_weight
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class SmoothCELoss(nn.Module):
|
||||||
|
def __init__(self, smoothing_ratio=0.1):
|
||||||
|
super(SmoothCELoss, self).__init__()
|
||||||
|
self.smoothing_ratio = smoothing_ratio
|
||||||
|
|
||||||
|
def forward(self, pred, target):
|
||||||
|
eps = self.smoothing_ratio
|
||||||
|
n_class = pred.size(1)
|
||||||
|
one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
|
||||||
|
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||||
|
log_prb = F.log_softmax(pred, dim=1)
|
||||||
|
loss = -(one_hot * log_prb).total(dim=1)
|
||||||
|
loss = loss[torch.isfinite(loss)].mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class BinaryFocalLoss(nn.Module):
|
||||||
|
def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0):
|
||||||
|
"""Binary Focal Loss
|
||||||
|
<https://arxiv.org/abs/1708.02002>`
|
||||||
|
"""
|
||||||
|
super(BinaryFocalLoss, self).__init__()
|
||||||
|
assert 0 < alpha < 1
|
||||||
|
self.gamma = gamma
|
||||||
|
self.alpha = alpha
|
||||||
|
self.logits = logits
|
||||||
|
self.reduce = reduce
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
|
||||||
|
def forward(self, pred, target, **kwargs):
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction with shape (N)
|
||||||
|
target (torch.Tensor): The ground truth. If containing class
|
||||||
|
indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities,
|
||||||
|
same shape as the input.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The calculated loss
|
||||||
|
"""
|
||||||
|
if self.logits:
|
||||||
|
bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
|
||||||
|
else:
|
||||||
|
bce = F.binary_cross_entropy(pred, target, reduction="none")
|
||||||
|
pt = torch.exp(-bce)
|
||||||
|
alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
|
||||||
|
focal_loss = alpha * (1 - pt) ** self.gamma * bce
|
||||||
|
|
||||||
|
if self.reduce:
|
||||||
|
focal_loss = torch.mean(focal_loss)
|
||||||
|
return focal_loss * self.loss_weight
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class FocalLoss(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1
|
||||||
|
):
|
||||||
|
"""Focal Loss
|
||||||
|
<https://arxiv.org/abs/1708.02002>`
|
||||||
|
"""
|
||||||
|
super(FocalLoss, self).__init__()
|
||||||
|
assert reduction in (
|
||||||
|
"mean",
|
||||||
|
"sum",
|
||||||
|
), "AssertionError: reduction should be 'mean' or 'sum'"
|
||||||
|
assert isinstance(
|
||||||
|
alpha, (float, list)
|
||||||
|
), "AssertionError: alpha should be of type float"
|
||||||
|
assert isinstance(gamma, float), "AssertionError: gamma should be of type float"
|
||||||
|
assert isinstance(
|
||||||
|
loss_weight, float
|
||||||
|
), "AssertionError: loss_weight should be of type float"
|
||||||
|
assert isinstance(ignore_index, int), "ignore_index must be of type int"
|
||||||
|
self.gamma = gamma
|
||||||
|
self.alpha = alpha
|
||||||
|
self.reduction = reduction
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
|
def forward(self, pred, target, **kwargs):
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes.
|
||||||
|
target (torch.Tensor): The ground truth. If containing class
|
||||||
|
indices, shape (N) where each value is 0≤targets[i]≤C−1, If containing class probabilities,
|
||||||
|
same shape as the input.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The calculated loss
|
||||||
|
"""
|
||||||
|
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
||||||
|
pred = pred.transpose(0, 1)
|
||||||
|
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||||
|
pred = pred.reshape(pred.size(0), -1)
|
||||||
|
# [C, N] -> [N, C]
|
||||||
|
pred = pred.transpose(0, 1).contiguous()
|
||||||
|
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
|
||||||
|
target = target.view(-1).contiguous()
|
||||||
|
assert pred.size(0) == target.size(
|
||||||
|
0
|
||||||
|
), "The shape of pred doesn't match the shape of target"
|
||||||
|
valid_mask = target != self.ignore_index
|
||||||
|
target = target[valid_mask]
|
||||||
|
pred = pred[valid_mask]
|
||||||
|
|
||||||
|
if len(target) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
num_classes = pred.size(1)
|
||||||
|
target = F.one_hot(target, num_classes=num_classes)
|
||||||
|
|
||||||
|
alpha = self.alpha
|
||||||
|
if isinstance(alpha, list):
|
||||||
|
alpha = pred.new_tensor(alpha)
|
||||||
|
pred_sigmoid = pred.sigmoid()
|
||||||
|
target = target.type_as(pred)
|
||||||
|
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
|
||||||
|
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow(
|
||||||
|
self.gamma
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = (
|
||||||
|
F.binary_cross_entropy_with_logits(pred, target, reduction="none")
|
||||||
|
* focal_weight
|
||||||
|
)
|
||||||
|
if self.reduction == "mean":
|
||||||
|
loss = loss.mean()
|
||||||
|
elif self.reduction == "sum":
|
||||||
|
loss = loss.total()
|
||||||
|
return self.loss_weight * loss
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class DiceLoss(nn.Module):
|
||||||
|
def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1):
|
||||||
|
"""DiceLoss.
|
||||||
|
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
|
||||||
|
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
||||||
|
"""
|
||||||
|
super(DiceLoss, self).__init__()
|
||||||
|
self.smooth = smooth
|
||||||
|
self.exponent = exponent
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
|
def forward(self, pred, target, **kwargs):
|
||||||
|
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
||||||
|
pred = pred.transpose(0, 1)
|
||||||
|
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||||
|
pred = pred.reshape(pred.size(0), -1)
|
||||||
|
# [C, N] -> [N, C]
|
||||||
|
pred = pred.transpose(0, 1).contiguous()
|
||||||
|
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
|
||||||
|
target = target.view(-1).contiguous()
|
||||||
|
assert pred.size(0) == target.size(
|
||||||
|
0
|
||||||
|
), "The shape of pred doesn't match the shape of target"
|
||||||
|
valid_mask = target != self.ignore_index
|
||||||
|
target = target[valid_mask]
|
||||||
|
pred = pred[valid_mask]
|
||||||
|
|
||||||
|
pred = F.softmax(pred, dim=1)
|
||||||
|
num_classes = pred.shape[1]
|
||||||
|
target = F.one_hot(
|
||||||
|
torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes
|
||||||
|
)
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
for i in range(num_classes):
|
||||||
|
if i != self.ignore_index:
|
||||||
|
num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth
|
||||||
|
den = (
|
||||||
|
torch.sum(
|
||||||
|
pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent)
|
||||||
|
)
|
||||||
|
+ self.smooth
|
||||||
|
)
|
||||||
|
dice_loss = 1 - num / den
|
||||||
|
total_loss += dice_loss
|
||||||
|
loss = total_loss / num_classes
|
||||||
|
return self.loss_weight * loss
|
||||||
646
models/network.py
Normal file
646
models/network.py
Normal file
@@ -0,0 +1,646 @@
|
|||||||
|
import math
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio as ta
|
||||||
|
|
||||||
|
from models.encoder.wav2vec import Wav2Vec2Model
|
||||||
|
from models.encoder.wavlm import WavLMModel
|
||||||
|
|
||||||
|
from models.builder import MODELS
|
||||||
|
|
||||||
|
from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
|
||||||
|
|
||||||
|
@MODELS.register_module("Audio2Expression")
|
||||||
|
class Audio2Expression(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
device: torch.device = None,
|
||||||
|
pretrained_encoder_type: str = 'wav2vec',
|
||||||
|
pretrained_encoder_path: str = '',
|
||||||
|
wav2vec2_config_path: str = '',
|
||||||
|
num_identity_classes: int = 0,
|
||||||
|
identity_feat_dim: int = 64,
|
||||||
|
hidden_dim: int = 512,
|
||||||
|
expression_dim: int = 52,
|
||||||
|
norm_type: str = 'ln',
|
||||||
|
decoder_depth: int = 3,
|
||||||
|
use_transformer: bool = False,
|
||||||
|
num_attention_heads: int = 8,
|
||||||
|
num_transformer_layers: int = 6,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Initialize audio feature encoder
|
||||||
|
if pretrained_encoder_type == 'wav2vec':
|
||||||
|
if os.path.exists(pretrained_encoder_path):
|
||||||
|
self.audio_encoder = Wav2Vec2Model.from_pretrained(pretrained_encoder_path)
|
||||||
|
else:
|
||||||
|
config = Wav2Vec2Config.from_pretrained(wav2vec2_config_path)
|
||||||
|
self.audio_encoder = Wav2Vec2Model(config)
|
||||||
|
encoder_output_dim = 768
|
||||||
|
elif pretrained_encoder_type == 'wavlm':
|
||||||
|
self.audio_encoder = WavLMModel.from_pretrained(pretrained_encoder_path)
|
||||||
|
encoder_output_dim = 768
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Encoder type {pretrained_encoder_type} not supported")
|
||||||
|
|
||||||
|
self.audio_encoder.feature_extractor._freeze_parameters()
|
||||||
|
self.feature_projection = nn.Linear(encoder_output_dim, hidden_dim)
|
||||||
|
|
||||||
|
self.identity_encoder = AudioIdentityEncoder(
|
||||||
|
hidden_dim,
|
||||||
|
num_identity_classes,
|
||||||
|
identity_feat_dim,
|
||||||
|
use_transformer,
|
||||||
|
num_attention_heads,
|
||||||
|
num_transformer_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = nn.ModuleList([
|
||||||
|
nn.Sequential(*[
|
||||||
|
ConvNormRelu(hidden_dim, hidden_dim, norm=norm_type)
|
||||||
|
for _ in range(decoder_depth)
|
||||||
|
])
|
||||||
|
])
|
||||||
|
|
||||||
|
self.output_proj = nn.Linear(hidden_dim, expression_dim)
|
||||||
|
|
||||||
|
def freeze_encoder_parameters(self, do_freeze=False):
|
||||||
|
|
||||||
|
for name, param in self.audio_encoder.named_parameters():
|
||||||
|
if('feature_extractor' in name):
|
||||||
|
param.requires_grad = False
|
||||||
|
else:
|
||||||
|
param.requires_grad = (not do_freeze)
|
||||||
|
|
||||||
|
def forward(self, input_dict):
|
||||||
|
|
||||||
|
if 'time_steps' not in input_dict:
|
||||||
|
audio_length = input_dict['input_audio_array'].shape[1]
|
||||||
|
time_steps = math.ceil(audio_length / 16000 * 30)
|
||||||
|
else:
|
||||||
|
time_steps = input_dict['time_steps']
|
||||||
|
|
||||||
|
# Process audio through encoder
|
||||||
|
audio_input = input_dict['input_audio_array'].flatten(start_dim=1)
|
||||||
|
hidden_states = self.audio_encoder(audio_input, frame_num=time_steps).last_hidden_state
|
||||||
|
|
||||||
|
# Project features to hidden dimension
|
||||||
|
audio_features = self.feature_projection(hidden_states).transpose(1, 2)
|
||||||
|
|
||||||
|
# Process identity-conditioned features
|
||||||
|
audio_features = self.identity_encoder(audio_features, identity=input_dict['id_idx'])
|
||||||
|
|
||||||
|
# Refine features through decoder
|
||||||
|
audio_features = self.decoder[0](audio_features)
|
||||||
|
|
||||||
|
# Generate output parameters
|
||||||
|
audio_features = audio_features.permute(0, 2, 1)
|
||||||
|
expression_params = self.output_proj(audio_features)
|
||||||
|
|
||||||
|
return torch.sigmoid(expression_params)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioIdentityEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
hidden_dim,
|
||||||
|
num_identity_classes=0,
|
||||||
|
identity_feat_dim=64,
|
||||||
|
use_transformer=False,
|
||||||
|
num_attention_heads = 8,
|
||||||
|
num_transformer_layers = 6,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
in_dim = hidden_dim + identity_feat_dim
|
||||||
|
self.id_mlp = nn.Conv1d(num_identity_classes, identity_feat_dim, 1, 1)
|
||||||
|
self.first_net = SeqTranslator1D(in_dim, hidden_dim,
|
||||||
|
min_layers_num=3,
|
||||||
|
residual=True,
|
||||||
|
norm='ln'
|
||||||
|
)
|
||||||
|
self.grus = nn.GRU(hidden_dim, hidden_dim, 1, batch_first=True)
|
||||||
|
self.dropout = nn.Dropout(dropout_ratio)
|
||||||
|
|
||||||
|
self.use_transformer = use_transformer
|
||||||
|
if(self.use_transformer):
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_attention_heads, dim_feedforward= 2 * hidden_dim, batch_first=True)
|
||||||
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
audio_features: torch.Tensor,
|
||||||
|
identity: torch.Tensor = None,
|
||||||
|
time_steps: int = None) -> tuple:
|
||||||
|
|
||||||
|
audio_features = self.dropout(audio_features)
|
||||||
|
identity = identity.reshape(identity.shape[0], -1, 1).repeat(1, 1, audio_features.shape[2]).to(torch.float32)
|
||||||
|
identity = self.id_mlp(identity)
|
||||||
|
audio_features = torch.cat([audio_features, identity], dim=1)
|
||||||
|
|
||||||
|
x = self.first_net(audio_features)
|
||||||
|
|
||||||
|
if time_steps is not None:
|
||||||
|
x = F.interpolate(x, size=time_steps, align_corners=False, mode='linear')
|
||||||
|
|
||||||
|
if(self.use_transformer):
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.transformer_encoder(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ConvNormRelu(nn.Module):
|
||||||
|
'''
|
||||||
|
(B,C_in,H,W) -> (B, C_out, H, W)
|
||||||
|
there exist some kernel size that makes the result is not H/s
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
type='1d',
|
||||||
|
leaky=False,
|
||||||
|
downsample=False,
|
||||||
|
kernel_size=None,
|
||||||
|
stride=None,
|
||||||
|
padding=None,
|
||||||
|
p=0,
|
||||||
|
groups=1,
|
||||||
|
residual=False,
|
||||||
|
norm='bn'):
|
||||||
|
'''
|
||||||
|
conv-bn-relu
|
||||||
|
'''
|
||||||
|
super(ConvNormRelu, self).__init__()
|
||||||
|
self.residual = residual
|
||||||
|
self.norm_type = norm
|
||||||
|
# kernel_size = k
|
||||||
|
# stride = s
|
||||||
|
|
||||||
|
if kernel_size is None and stride is None:
|
||||||
|
if not downsample:
|
||||||
|
kernel_size = 3
|
||||||
|
stride = 1
|
||||||
|
else:
|
||||||
|
kernel_size = 4
|
||||||
|
stride = 2
|
||||||
|
|
||||||
|
if padding is None:
|
||||||
|
if isinstance(kernel_size, int) and isinstance(stride, tuple):
|
||||||
|
padding = tuple(int((kernel_size - st) / 2) for st in stride)
|
||||||
|
elif isinstance(kernel_size, tuple) and isinstance(stride, int):
|
||||||
|
padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
|
||||||
|
elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
|
||||||
|
padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
|
||||||
|
else:
|
||||||
|
padding = int((kernel_size - stride) / 2)
|
||||||
|
|
||||||
|
if self.residual:
|
||||||
|
if downsample:
|
||||||
|
if type == '1d':
|
||||||
|
self.residual_layer = nn.Sequential(
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif type == '2d':
|
||||||
|
self.residual_layer = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if in_channels == out_channels:
|
||||||
|
self.residual_layer = nn.Identity()
|
||||||
|
else:
|
||||||
|
if type == '1d':
|
||||||
|
self.residual_layer = nn.Sequential(
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif type == '2d':
|
||||||
|
self.residual_layer = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
in_channels = in_channels * groups
|
||||||
|
out_channels = out_channels * groups
|
||||||
|
if type == '1d':
|
||||||
|
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
groups=groups)
|
||||||
|
self.norm = nn.BatchNorm1d(out_channels)
|
||||||
|
self.dropout = nn.Dropout(p=p)
|
||||||
|
elif type == '2d':
|
||||||
|
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
groups=groups)
|
||||||
|
self.norm = nn.BatchNorm2d(out_channels)
|
||||||
|
self.dropout = nn.Dropout2d(p=p)
|
||||||
|
if norm == 'gn':
|
||||||
|
self.norm = nn.GroupNorm(2, out_channels)
|
||||||
|
elif norm == 'ln':
|
||||||
|
self.norm = nn.LayerNorm(out_channels)
|
||||||
|
if leaky:
|
||||||
|
self.relu = nn.LeakyReLU(negative_slope=0.2)
|
||||||
|
else:
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
if self.norm_type == 'ln':
|
||||||
|
out = self.dropout(self.conv(x))
|
||||||
|
out = self.norm(out.transpose(1,2)).transpose(1,2)
|
||||||
|
else:
|
||||||
|
out = self.norm(self.dropout(self.conv(x)))
|
||||||
|
if self.residual:
|
||||||
|
residual = self.residual_layer(x)
|
||||||
|
out += residual
|
||||||
|
return self.relu(out)
|
||||||
|
|
||||||
|
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
|
||||||
|
class SeqTranslator1D(nn.Module):
|
||||||
|
'''
|
||||||
|
(B, C, T)->(B, C_out, T)
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
C_in,
|
||||||
|
C_out,
|
||||||
|
kernel_size=None,
|
||||||
|
stride=None,
|
||||||
|
min_layers_num=None,
|
||||||
|
residual=True,
|
||||||
|
norm='bn'
|
||||||
|
):
|
||||||
|
super(SeqTranslator1D, self).__init__()
|
||||||
|
|
||||||
|
conv_layers = nn.ModuleList([])
|
||||||
|
conv_layers.append(ConvNormRelu(
|
||||||
|
in_channels=C_in,
|
||||||
|
out_channels=C_out,
|
||||||
|
type='1d',
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
residual=residual,
|
||||||
|
norm=norm
|
||||||
|
))
|
||||||
|
self.num_layers = 1
|
||||||
|
if min_layers_num is not None and self.num_layers < min_layers_num:
|
||||||
|
while self.num_layers < min_layers_num:
|
||||||
|
conv_layers.append(ConvNormRelu(
|
||||||
|
in_channels=C_out,
|
||||||
|
out_channels=C_out,
|
||||||
|
type='1d',
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
residual=residual,
|
||||||
|
norm=norm
|
||||||
|
))
|
||||||
|
self.num_layers += 1
|
||||||
|
self.conv_layers = nn.Sequential(*conv_layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv_layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
|
||||||
|
"""
|
||||||
|
:param audio: 1 x T tensor containing a 16kHz audio signal
|
||||||
|
:param frame_rate: frame rate for video (we need one audio chunk per video frame)
|
||||||
|
:param chunk_size: number of audio samples per chunk
|
||||||
|
:return: num_chunks x chunk_size tensor containing sliced audio
|
||||||
|
"""
|
||||||
|
samples_per_frame = 16000 // frame_rate
|
||||||
|
padding = (chunk_size - samples_per_frame) // 2
|
||||||
|
audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
|
||||||
|
anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
|
||||||
|
audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
""" https://github.com/facebookresearch/meshtalk """
|
||||||
|
class MeshtalkEncoder(nn.Module):
|
||||||
|
def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
|
||||||
|
"""
|
||||||
|
:param latent_dim: size of the latent audio embedding
|
||||||
|
:param model_name: name of the model, used to load and save the model
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.melspec = ta.transforms.MelSpectrogram(
|
||||||
|
sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_len = 5
|
||||||
|
self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
|
||||||
|
self.weights_init(self.convert_dimensions)
|
||||||
|
self.receptive_field = conv_len
|
||||||
|
|
||||||
|
convs = []
|
||||||
|
for i in range(6):
|
||||||
|
dilation = 2 * (i % 3 + 1)
|
||||||
|
self.receptive_field += (conv_len - 1) * dilation
|
||||||
|
convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
|
||||||
|
self.weights_init(convs[-1])
|
||||||
|
self.convs = torch.nn.ModuleList(convs)
|
||||||
|
self.code = torch.nn.Linear(128, latent_dim)
|
||||||
|
|
||||||
|
self.apply(lambda x: self.weights_init(x))
|
||||||
|
|
||||||
|
def weights_init(self, m):
|
||||||
|
if isinstance(m, torch.nn.Conv1d):
|
||||||
|
torch.nn.init.xavier_uniform_(m.weight)
|
||||||
|
try:
|
||||||
|
torch.nn.init.constant_(m.bias, .01)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, audio: torch.Tensor):
|
||||||
|
"""
|
||||||
|
:param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
|
||||||
|
:return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
|
||||||
|
"""
|
||||||
|
B, T = audio.shape[0], audio.shape[1]
|
||||||
|
x = self.melspec(audio).squeeze(1)
|
||||||
|
x = torch.log(x.clamp(min=1e-10, max=None))
|
||||||
|
if T == 1:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
|
# Convert to the right dimensionality
|
||||||
|
x = x.view(-1, x.shape[2], x.shape[3])
|
||||||
|
x = F.leaky_relu(self.convert_dimensions(x), .2)
|
||||||
|
|
||||||
|
# Process stacks
|
||||||
|
for conv in self.convs:
|
||||||
|
x_ = F.leaky_relu(conv(x), .2)
|
||||||
|
if self.training:
|
||||||
|
x_ = F.dropout(x_, .2)
|
||||||
|
l = (x.shape[2] - x_.shape[2]) // 2
|
||||||
|
x = (x[:, :, l:-l] + x_) / 2
|
||||||
|
|
||||||
|
x = torch.mean(x, dim=-1)
|
||||||
|
x = x.view(B, T, x.shape[-1])
|
||||||
|
x = self.code(x)
|
||||||
|
|
||||||
|
return {"code": x}
|
||||||
|
|
||||||
|
class PeriodicPositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=64):
|
||||||
|
super(PeriodicPositionalEncoding, self).__init__()
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
pe = torch.zeros(period, d_model)
|
||||||
|
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0) # (1, period, d_model)
|
||||||
|
repeat_num = (max_seq_len//period) + 1
|
||||||
|
pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model)
|
||||||
|
self.register_buffer('pe', pe)
|
||||||
|
def forward(self, x):
|
||||||
|
# print(self.pe.shape, x.shape)
|
||||||
|
x = x + self.pe[:, :x.size(1), :]
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorTransformer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
n_poses,
|
||||||
|
each_dim: list,
|
||||||
|
dim_list: list,
|
||||||
|
training=True,
|
||||||
|
device=None,
|
||||||
|
identity=False,
|
||||||
|
num_classes=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.training = training
|
||||||
|
self.device = device
|
||||||
|
self.gen_length = n_poses
|
||||||
|
|
||||||
|
norm = 'ln'
|
||||||
|
in_dim = 256
|
||||||
|
out_dim = 256
|
||||||
|
|
||||||
|
self.encoder_choice = 'faceformer'
|
||||||
|
|
||||||
|
self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
|
||||||
|
self.audio_encoder.feature_extractor._freeze_parameters()
|
||||||
|
self.audio_feature_map = nn.Linear(768, in_dim)
|
||||||
|
|
||||||
|
self.audio_middle = AudioEncoder(in_dim, out_dim, False, num_classes)
|
||||||
|
|
||||||
|
self.dim_list = dim_list
|
||||||
|
|
||||||
|
self.decoder = nn.ModuleList()
|
||||||
|
self.final_out = nn.ModuleList()
|
||||||
|
|
||||||
|
self.hidden_size = 768
|
||||||
|
self.transformer_de_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
nhead=4,
|
||||||
|
dim_feedforward=self.hidden_size*2,
|
||||||
|
batch_first=True
|
||||||
|
)
|
||||||
|
self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4)
|
||||||
|
self.feature2face = nn.Linear(256, self.hidden_size)
|
||||||
|
|
||||||
|
self.position_embeddings = PeriodicPositionalEncoding(self.hidden_size, period=64, max_seq_len=64)
|
||||||
|
self.id_maping = nn.Linear(12,self.hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
self.decoder.append(self.face_decoder)
|
||||||
|
self.final_out.append(nn.Linear(self.hidden_size, 32))
|
||||||
|
|
||||||
|
def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
|
||||||
|
if gt_poses is None:
|
||||||
|
time_steps = 64
|
||||||
|
else:
|
||||||
|
time_steps = gt_poses.shape[1]
|
||||||
|
|
||||||
|
# vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
|
||||||
|
if self.encoder_choice == 'meshtalk':
|
||||||
|
in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
|
||||||
|
feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
|
||||||
|
elif self.encoder_choice == 'faceformer':
|
||||||
|
hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
|
||||||
|
feature = self.audio_feature_map(hidden_states).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
|
||||||
|
|
||||||
|
feature, _ = self.audio_middle(feature, id=None)
|
||||||
|
feature = self.feature2face(feature.permute(0,2,1))
|
||||||
|
|
||||||
|
id = id.unsqueeze(1).repeat(1,64,1).to(torch.float32)
|
||||||
|
id_feature = self.id_maping(id)
|
||||||
|
id_feature = self.position_embeddings(id_feature)
|
||||||
|
|
||||||
|
for i in range(self.decoder.__len__()):
|
||||||
|
mid = self.decoder[i](tgt=id_feature, memory=feature)
|
||||||
|
out = self.final_out[i](mid)
|
||||||
|
|
||||||
|
return out, None
|
||||||
|
|
||||||
|
def linear_interpolation(features, output_len: int):
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
output_features = F.interpolate(
|
||||||
|
features, size=output_len, align_corners=True, mode='linear')
|
||||||
|
return output_features.transpose(1, 2)
|
||||||
|
|
||||||
|
def init_biased_mask(n_head, max_seq_len, period):
|
||||||
|
|
||||||
|
def get_slopes(n):
|
||||||
|
|
||||||
|
def get_slopes_power_of_2(n):
|
||||||
|
start = (2**(-2**-(math.log2(n) - 3)))
|
||||||
|
ratio = start
|
||||||
|
return [start * ratio**i for i in range(n)]
|
||||||
|
|
||||||
|
if math.log2(n).is_integer():
|
||||||
|
return get_slopes_power_of_2(n)
|
||||||
|
else:
|
||||||
|
closest_power_of_2 = 2**math.floor(math.log2(n))
|
||||||
|
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(
|
||||||
|
2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
|
||||||
|
|
||||||
|
slopes = torch.Tensor(get_slopes(n_head))
|
||||||
|
bias = torch.div(
|
||||||
|
torch.arange(start=0, end=max_seq_len,
|
||||||
|
step=period).unsqueeze(1).repeat(1, period).view(-1),
|
||||||
|
period,
|
||||||
|
rounding_mode='floor')
|
||||||
|
bias = -torch.flip(bias, dims=[0])
|
||||||
|
alibi = torch.zeros(max_seq_len, max_seq_len)
|
||||||
|
for i in range(max_seq_len):
|
||||||
|
alibi[i, :i + 1] = bias[-(i + 1):]
|
||||||
|
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
|
||||||
|
mask = (torch.triu(torch.ones(max_seq_len,
|
||||||
|
max_seq_len)) == 1).transpose(0, 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
|
||||||
|
mask == 1, float(0.0))
|
||||||
|
mask = mask.unsqueeze(0) + alibi
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
# Alignment Bias
|
||||||
|
def enc_dec_mask(device, T, S):
|
||||||
|
mask = torch.ones(T, S)
|
||||||
|
for i in range(T):
|
||||||
|
mask[i, i] = 0
|
||||||
|
return (mask == 1).to(device=device)
|
||||||
|
|
||||||
|
|
||||||
|
# Periodic Positional Encoding
|
||||||
|
class PeriodicPositionalEncoding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
|
||||||
|
super(PeriodicPositionalEncoding, self).__init__()
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
pe = torch.zeros(period, d_model)
|
||||||
|
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
|
||||||
|
div_term = torch.exp(
|
||||||
|
torch.arange(0, d_model, 2).float() *
|
||||||
|
(-math.log(10000.0) / d_model))
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0) # (1, period, d_model)
|
||||||
|
repeat_num = (max_seq_len // period) + 1
|
||||||
|
pe = pe.repeat(1, repeat_num, 1)
|
||||||
|
self.register_buffer('pe', pe)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.pe[:, :x.size(1), :]
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(nn.Module):
|
||||||
|
"""Base class for all models."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(BaseModel, self).__init__()
|
||||||
|
# self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
|
def forward(self, *x):
|
||||||
|
"""Forward pass logic.
|
||||||
|
|
||||||
|
:return: Model output
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def freeze_model(self, do_freeze: bool = True):
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = (not do_freeze)
|
||||||
|
|
||||||
|
def summary(self, logger, writer=None):
|
||||||
|
"""Model summary."""
|
||||||
|
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
params = sum([np.prod(p.size())
|
||||||
|
for p in model_parameters]) / 1e6 # Unit is Mega
|
||||||
|
logger.info('===>Trainable parameters: %.3f M' % params)
|
||||||
|
if writer is not None:
|
||||||
|
writer.add_text('Model Summary',
|
||||||
|
'Trainable parameters: %.3f M' % params)
|
||||||
|
|
||||||
|
|
||||||
|
"""https://github.com/X-niper/UniTalker"""
|
||||||
|
class UniTalkerDecoderTransformer(BaseModel):
|
||||||
|
|
||||||
|
def __init__(self, out_dim, identity_num, period=30, interpolate_pos=1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.learnable_style_emb = nn.Embedding(identity_num, out_dim)
|
||||||
|
self.PPE = PeriodicPositionalEncoding(
|
||||||
|
out_dim, period=period, max_seq_len=3000)
|
||||||
|
self.biased_mask = init_biased_mask(
|
||||||
|
n_head=4, max_seq_len=3000, period=period)
|
||||||
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=out_dim,
|
||||||
|
nhead=4,
|
||||||
|
dim_feedforward=2 * out_dim,
|
||||||
|
batch_first=True)
|
||||||
|
self.transformer_decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer, num_layers=1)
|
||||||
|
self.interpolate_pos = interpolate_pos
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,
|
||||||
|
frame_num: int):
|
||||||
|
style_idx = torch.argmax(style_idx, dim=1)
|
||||||
|
obj_embedding = self.learnable_style_emb(style_idx)
|
||||||
|
obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1)
|
||||||
|
style_input = self.PPE(obj_embedding)
|
||||||
|
tgt_mask = self.biased_mask.repeat(style_idx.shape[0], 1, 1)[:, :style_input.shape[1], :style_input.
|
||||||
|
shape[1]].clone().detach().to(
|
||||||
|
device=style_input.device)
|
||||||
|
memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1],
|
||||||
|
frame_num)
|
||||||
|
feat_out = self.transformer_decoder(
|
||||||
|
style_input,
|
||||||
|
hidden_states,
|
||||||
|
tgt_mask=tgt_mask,
|
||||||
|
memory_mask=memory_mask)
|
||||||
|
if self.interpolate_pos == 2:
|
||||||
|
feat_out = linear_interpolation(feat_out, output_len=frame_num)
|
||||||
|
return feat_out
|
||||||
752
models/utils.py
Normal file
752
models/utils.py
Normal file
@@ -0,0 +1,752 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Optional,Tuple
|
||||||
|
from scipy.signal import savgol_filter
|
||||||
|
|
||||||
|
|
||||||
|
ARKitLeftRightPair = [
|
||||||
|
("jawLeft", "jawRight"),
|
||||||
|
("mouthLeft", "mouthRight"),
|
||||||
|
("mouthSmileLeft", "mouthSmileRight"),
|
||||||
|
("mouthFrownLeft", "mouthFrownRight"),
|
||||||
|
("mouthDimpleLeft", "mouthDimpleRight"),
|
||||||
|
("mouthStretchLeft", "mouthStretchRight"),
|
||||||
|
("mouthPressLeft", "mouthPressRight"),
|
||||||
|
("mouthLowerDownLeft", "mouthLowerDownRight"),
|
||||||
|
("mouthUpperUpLeft", "mouthUpperUpRight"),
|
||||||
|
("cheekSquintLeft", "cheekSquintRight"),
|
||||||
|
("noseSneerLeft", "noseSneerRight"),
|
||||||
|
("browDownLeft", "browDownRight"),
|
||||||
|
("browOuterUpLeft", "browOuterUpRight"),
|
||||||
|
("eyeBlinkLeft","eyeBlinkRight"),
|
||||||
|
("eyeLookDownLeft","eyeLookDownRight"),
|
||||||
|
("eyeLookInLeft", "eyeLookInRight"),
|
||||||
|
("eyeLookOutLeft","eyeLookOutRight"),
|
||||||
|
("eyeLookUpLeft","eyeLookUpRight"),
|
||||||
|
("eyeSquintLeft","eyeSquintRight"),
|
||||||
|
("eyeWideLeft","eyeWideRight")
|
||||||
|
]
|
||||||
|
|
||||||
|
ARKitBlendShape =[
|
||||||
|
"browDownLeft",
|
||||||
|
"browDownRight",
|
||||||
|
"browInnerUp",
|
||||||
|
"browOuterUpLeft",
|
||||||
|
"browOuterUpRight",
|
||||||
|
"cheekPuff",
|
||||||
|
"cheekSquintLeft",
|
||||||
|
"cheekSquintRight",
|
||||||
|
"eyeBlinkLeft",
|
||||||
|
"eyeBlinkRight",
|
||||||
|
"eyeLookDownLeft",
|
||||||
|
"eyeLookDownRight",
|
||||||
|
"eyeLookInLeft",
|
||||||
|
"eyeLookInRight",
|
||||||
|
"eyeLookOutLeft",
|
||||||
|
"eyeLookOutRight",
|
||||||
|
"eyeLookUpLeft",
|
||||||
|
"eyeLookUpRight",
|
||||||
|
"eyeSquintLeft",
|
||||||
|
"eyeSquintRight",
|
||||||
|
"eyeWideLeft",
|
||||||
|
"eyeWideRight",
|
||||||
|
"jawForward",
|
||||||
|
"jawLeft",
|
||||||
|
"jawOpen",
|
||||||
|
"jawRight",
|
||||||
|
"mouthClose",
|
||||||
|
"mouthDimpleLeft",
|
||||||
|
"mouthDimpleRight",
|
||||||
|
"mouthFrownLeft",
|
||||||
|
"mouthFrownRight",
|
||||||
|
"mouthFunnel",
|
||||||
|
"mouthLeft",
|
||||||
|
"mouthLowerDownLeft",
|
||||||
|
"mouthLowerDownRight",
|
||||||
|
"mouthPressLeft",
|
||||||
|
"mouthPressRight",
|
||||||
|
"mouthPucker",
|
||||||
|
"mouthRight",
|
||||||
|
"mouthRollLower",
|
||||||
|
"mouthRollUpper",
|
||||||
|
"mouthShrugLower",
|
||||||
|
"mouthShrugUpper",
|
||||||
|
"mouthSmileLeft",
|
||||||
|
"mouthSmileRight",
|
||||||
|
"mouthStretchLeft",
|
||||||
|
"mouthStretchRight",
|
||||||
|
"mouthUpperUpLeft",
|
||||||
|
"mouthUpperUpRight",
|
||||||
|
"noseSneerLeft",
|
||||||
|
"noseSneerRight",
|
||||||
|
"tongueOut"
|
||||||
|
]
|
||||||
|
|
||||||
|
MOUTH_BLENDSHAPES = [ "mouthDimpleLeft",
|
||||||
|
"mouthDimpleRight",
|
||||||
|
"mouthFrownLeft",
|
||||||
|
"mouthFrownRight",
|
||||||
|
"mouthFunnel",
|
||||||
|
"mouthLeft",
|
||||||
|
"mouthLowerDownLeft",
|
||||||
|
"mouthLowerDownRight",
|
||||||
|
"mouthPressLeft",
|
||||||
|
"mouthPressRight",
|
||||||
|
"mouthPucker",
|
||||||
|
"mouthRight",
|
||||||
|
"mouthRollLower",
|
||||||
|
"mouthRollUpper",
|
||||||
|
"mouthShrugLower",
|
||||||
|
"mouthShrugUpper",
|
||||||
|
"mouthSmileLeft",
|
||||||
|
"mouthSmileRight",
|
||||||
|
"mouthStretchLeft",
|
||||||
|
"mouthStretchRight",
|
||||||
|
"mouthUpperUpLeft",
|
||||||
|
"mouthUpperUpRight",
|
||||||
|
"jawForward",
|
||||||
|
"jawLeft",
|
||||||
|
"jawOpen",
|
||||||
|
"jawRight",
|
||||||
|
"noseSneerLeft",
|
||||||
|
"noseSneerRight",
|
||||||
|
"cheekPuff",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_CONTEXT ={
|
||||||
|
'is_initial_input': True,
|
||||||
|
'previous_audio': None,
|
||||||
|
'previous_expression': None,
|
||||||
|
'previous_volume': None,
|
||||||
|
'previous_headpose': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_CODE = {
|
||||||
|
"SUCCESS": 0,
|
||||||
|
"AUDIO_LENGTH_ERROR": 1,
|
||||||
|
"CHECKPOINT_PATH_ERROR":2,
|
||||||
|
"MODEL_INFERENCE_ERROR":3,
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_CONTEXTRETURN = {
|
||||||
|
"code": RETURN_CODE['SUCCESS'],
|
||||||
|
"expression": None,
|
||||||
|
"headpose": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
BLINK_PATTERNS = [
|
||||||
|
np.array([0.365, 0.950, 0.956, 0.917, 0.367, 0.119, 0.025]),
|
||||||
|
np.array([0.235, 0.910, 0.945, 0.778, 0.191, 0.235, 0.089]),
|
||||||
|
np.array([0.870, 0.950, 0.949, 0.696, 0.191, 0.073, 0.007]),
|
||||||
|
np.array([0.000, 0.557, 0.953, 0.942, 0.426, 0.148, 0.018])
|
||||||
|
]
|
||||||
|
|
||||||
|
# Postprocess
|
||||||
|
def symmetrize_blendshapes(
|
||||||
|
bs_params: np.ndarray,
|
||||||
|
mode: str = "average",
|
||||||
|
symmetric_pairs: list = ARKitLeftRightPair
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Apply symmetrization to ARKit blendshape parameters (batched version)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bs_params: numpy array of shape (N, 52), batch of ARKit parameters
|
||||||
|
mode: symmetrization mode ["average", "max", "min", "left_dominant", "right_dominant"]
|
||||||
|
symmetric_pairs: list of left-right parameter pairs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Symmetrized parameters with same shape (N, 52)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name_to_idx = {name: i for i, name in enumerate(ARKitBlendShape)}
|
||||||
|
|
||||||
|
# Input validation
|
||||||
|
if bs_params.ndim != 2 or bs_params.shape[1] != 52:
|
||||||
|
raise ValueError("Input must be of shape (N, 52)")
|
||||||
|
|
||||||
|
symmetric_bs = bs_params.copy() # Shape (N, 52)
|
||||||
|
|
||||||
|
# Precompute valid index pairs
|
||||||
|
valid_pairs = []
|
||||||
|
for left, right in symmetric_pairs:
|
||||||
|
left_idx = name_to_idx.get(left)
|
||||||
|
right_idx = name_to_idx.get(right)
|
||||||
|
if None not in (left_idx, right_idx):
|
||||||
|
valid_pairs.append((left_idx, right_idx))
|
||||||
|
|
||||||
|
# Vectorized processing
|
||||||
|
for l_idx, r_idx in valid_pairs:
|
||||||
|
left_col = symmetric_bs[:, l_idx]
|
||||||
|
right_col = symmetric_bs[:, r_idx]
|
||||||
|
|
||||||
|
if mode == "average":
|
||||||
|
new_vals = (left_col + right_col) / 2
|
||||||
|
elif mode == "max":
|
||||||
|
new_vals = np.maximum(left_col, right_col)
|
||||||
|
elif mode == "min":
|
||||||
|
new_vals = np.minimum(left_col, right_col)
|
||||||
|
elif mode == "left_dominant":
|
||||||
|
new_vals = left_col
|
||||||
|
elif mode == "right_dominant":
|
||||||
|
new_vals = right_col
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mode: {mode}")
|
||||||
|
|
||||||
|
# Update both columns simultaneously
|
||||||
|
symmetric_bs[:, l_idx] = new_vals
|
||||||
|
symmetric_bs[:, r_idx] = new_vals
|
||||||
|
|
||||||
|
return symmetric_bs
|
||||||
|
|
||||||
|
|
||||||
|
def apply_random_eye_blinks(
|
||||||
|
input: np.ndarray,
|
||||||
|
blink_scale: tuple = (0.8, 1.0),
|
||||||
|
blink_interval: tuple = (60, 120),
|
||||||
|
blink_duration: int = 7
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Apply randomized eye blinks to blendshape parameters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: Input array of shape (N, 52) containing blendshape parameters
|
||||||
|
blink_scale: Tuple (min, max) for random blink intensity scaling
|
||||||
|
blink_interval: Tuple (min, max) for random blink spacing in frames
|
||||||
|
blink_duration: Number of frames for blink animation (fixed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None (modifies output array in-place)
|
||||||
|
"""
|
||||||
|
# Define eye blink patterns (normalized 0-1)
|
||||||
|
|
||||||
|
# Initialize parameters
|
||||||
|
n_frames = input.shape[0]
|
||||||
|
input[:,8:10] = np.zeros((n_frames,2))
|
||||||
|
current_frame = 0
|
||||||
|
|
||||||
|
# Main blink application loop
|
||||||
|
while current_frame < n_frames - blink_duration:
|
||||||
|
# Randomize blink parameters
|
||||||
|
scale = np.random.uniform(*blink_scale)
|
||||||
|
pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
|
||||||
|
|
||||||
|
# Apply blink animation
|
||||||
|
blink_values = pattern * scale
|
||||||
|
input[current_frame:current_frame + blink_duration, 8] = blink_values
|
||||||
|
input[current_frame:current_frame + blink_duration, 9] = blink_values
|
||||||
|
|
||||||
|
# Advance to next blink position
|
||||||
|
current_frame += blink_duration + np.random.randint(*blink_interval)
|
||||||
|
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def apply_random_eye_blinks_context(
|
||||||
|
animation_params: np.ndarray,
|
||||||
|
processed_frames: int = 0,
|
||||||
|
intensity_range: tuple = (0.8, 1.0)
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Applies random eye blink patterns to facial animation parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
animation_params: Input facial animation parameters array with shape [num_frames, num_features].
|
||||||
|
Columns 8 and 9 typically represent left/right eye blink parameters.
|
||||||
|
processed_frames: Number of already processed frames that shouldn't be modified
|
||||||
|
intensity_range: Tuple defining (min, max) scaling for blink intensity
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified animation parameters array with random eye blinks added to unprocessed frames
|
||||||
|
"""
|
||||||
|
remaining_frames = animation_params.shape[0] - processed_frames
|
||||||
|
|
||||||
|
# Only apply blinks if there's enough remaining frames (blink pattern requires 7 frames)
|
||||||
|
if remaining_frames <= 7:
|
||||||
|
return animation_params
|
||||||
|
|
||||||
|
# Configure blink timing parameters
|
||||||
|
min_blink_interval = 40 # Minimum frames between blinks
|
||||||
|
max_blink_interval = 100 # Maximum frames between blinks
|
||||||
|
|
||||||
|
# Find last blink in previously processed frames (column 8 > 0.5 indicates blink)
|
||||||
|
previous_blink_indices = np.where(animation_params[:processed_frames, 8] > 0.5)[0]
|
||||||
|
last_processed_blink = previous_blink_indices[-1] - 7 if previous_blink_indices.size > 0 else processed_frames
|
||||||
|
|
||||||
|
# Calculate first new blink position
|
||||||
|
blink_interval = np.random.randint(min_blink_interval, max_blink_interval)
|
||||||
|
first_blink_start = max(0, blink_interval - last_processed_blink)
|
||||||
|
|
||||||
|
# Apply first blink if there's enough space
|
||||||
|
if first_blink_start <= (remaining_frames - 7):
|
||||||
|
# Randomly select blink pattern and intensity
|
||||||
|
blink_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
|
||||||
|
intensity = np.random.uniform(*intensity_range)
|
||||||
|
|
||||||
|
# Calculate blink frame range
|
||||||
|
blink_start = processed_frames + first_blink_start
|
||||||
|
blink_end = blink_start + 7
|
||||||
|
|
||||||
|
# Apply pattern to both eyes
|
||||||
|
animation_params[blink_start:blink_end, 8] = blink_pattern * intensity
|
||||||
|
animation_params[blink_start:blink_end, 9] = blink_pattern * intensity
|
||||||
|
|
||||||
|
# Check space for additional blink
|
||||||
|
remaining_after_blink = animation_params.shape[0] - blink_end
|
||||||
|
if remaining_after_blink > min_blink_interval:
|
||||||
|
# Calculate second blink position
|
||||||
|
second_intensity = np.random.uniform(*intensity_range)
|
||||||
|
second_interval = np.random.randint(min_blink_interval, max_blink_interval)
|
||||||
|
|
||||||
|
if (remaining_after_blink - 7) > second_interval:
|
||||||
|
second_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
|
||||||
|
second_blink_start = blink_end + second_interval
|
||||||
|
second_blink_end = second_blink_start + 7
|
||||||
|
|
||||||
|
# Apply second blink
|
||||||
|
animation_params[second_blink_start:second_blink_end, 8] = second_pattern * second_intensity
|
||||||
|
animation_params[second_blink_start:second_blink_end, 9] = second_pattern * second_intensity
|
||||||
|
|
||||||
|
return animation_params
|
||||||
|
|
||||||
|
|
||||||
|
def export_blendshape_animation(
|
||||||
|
blendshape_weights: np.ndarray,
|
||||||
|
output_path: str,
|
||||||
|
blendshape_names: List[str],
|
||||||
|
fps: float,
|
||||||
|
rotation_data: Optional[np.ndarray] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Export blendshape animation data to JSON format compatible with ARKit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blendshape_weights: 2D numpy array of shape (N, 52) containing animation frames
|
||||||
|
output_path: Full path for output JSON file (including .json extension)
|
||||||
|
blendshape_names: Ordered list of 52 ARKit-standard blendshape names
|
||||||
|
fps: Frame rate for timing calculations (frames per second)
|
||||||
|
rotation_data: Optional 3D rotation data array of shape (N, 3)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input dimensions are incompatible
|
||||||
|
IOError: If file writing fails
|
||||||
|
"""
|
||||||
|
# Validate input dimensions
|
||||||
|
if blendshape_weights.shape[1] != 52:
|
||||||
|
raise ValueError(f"Expected 52 blendshapes, got {blendshape_weights.shape[1]}")
|
||||||
|
if len(blendshape_names) != 52:
|
||||||
|
raise ValueError(f"Requires 52 blendshape names, got {len(blendshape_names)}")
|
||||||
|
if rotation_data is not None and len(rotation_data) != len(blendshape_weights):
|
||||||
|
raise ValueError("Rotation data length must match animation frames")
|
||||||
|
|
||||||
|
# Build animation data structure
|
||||||
|
animation_data = {
|
||||||
|
"names":blendshape_names,
|
||||||
|
"metadata": {
|
||||||
|
"fps": fps,
|
||||||
|
"frame_count": len(blendshape_weights),
|
||||||
|
"blendshape_names": blendshape_names
|
||||||
|
},
|
||||||
|
"frames": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert numpy array to serializable format
|
||||||
|
for frame_idx in range(blendshape_weights.shape[0]):
|
||||||
|
frame_data = {
|
||||||
|
"weights": blendshape_weights[frame_idx].tolist(),
|
||||||
|
"time": frame_idx / fps,
|
||||||
|
"rotation": rotation_data[frame_idx].tolist() if rotation_data else []
|
||||||
|
}
|
||||||
|
animation_data["frames"].append(frame_data)
|
||||||
|
|
||||||
|
# Safeguard against data loss
|
||||||
|
if not output_path.endswith('.json'):
|
||||||
|
output_path += '.json'
|
||||||
|
|
||||||
|
# Write to file with error handling
|
||||||
|
try:
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as json_file:
|
||||||
|
json.dump(animation_data, json_file, indent=2, ensure_ascii=False)
|
||||||
|
except Exception as e:
|
||||||
|
raise IOError(f"Failed to write animation data: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def apply_savitzky_golay_smoothing(
|
||||||
|
input_data: np.ndarray,
|
||||||
|
window_length: int = 5,
|
||||||
|
polyorder: int = 2,
|
||||||
|
axis: int = 0,
|
||||||
|
validate: bool = True
|
||||||
|
) -> Tuple[np.ndarray, Optional[float]]:
|
||||||
|
"""
|
||||||
|
Apply Savitzky-Golay filter smoothing along specified axis of input data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: 2D numpy array of shape (n_samples, n_features)
|
||||||
|
window_length: Length of the filter window (must be odd and > polyorder)
|
||||||
|
polyorder: Order of the polynomial fit
|
||||||
|
axis: Axis along which to filter (0: column-wise, 1: row-wise)
|
||||||
|
validate: Enable input validation checks when True
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (smoothed_data, processing_time)
|
||||||
|
- smoothed_data: Smoothed output array
|
||||||
|
- processing_time: Execution time in seconds (None in validation mode)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: For invalid input dimensions or filter parameters
|
||||||
|
"""
|
||||||
|
# Validation mode timing bypass
|
||||||
|
processing_time = None
|
||||||
|
|
||||||
|
if validate:
|
||||||
|
# Input integrity checks
|
||||||
|
if input_data.ndim != 2:
|
||||||
|
raise ValueError(f"Expected 2D input, got {input_data.ndim}D array")
|
||||||
|
|
||||||
|
if window_length % 2 == 0 or window_length < 3:
|
||||||
|
raise ValueError("Window length must be odd integer ≥ 3")
|
||||||
|
|
||||||
|
if polyorder >= window_length:
|
||||||
|
raise ValueError("Polynomial order must be < window length")
|
||||||
|
|
||||||
|
# Store original dtype and convert to float64 for numerical stability
|
||||||
|
original_dtype = input_data.dtype
|
||||||
|
working_data = input_data.astype(np.float64)
|
||||||
|
|
||||||
|
# Start performance timer
|
||||||
|
timer_start = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Vectorized Savitzky-Golay application
|
||||||
|
smoothed_data = savgol_filter(working_data,
|
||||||
|
window_length=window_length,
|
||||||
|
polyorder=polyorder,
|
||||||
|
axis=axis,
|
||||||
|
mode='mirror')
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Filtering failed: {str(e)}") from e
|
||||||
|
|
||||||
|
# Stop timer and calculate duration
|
||||||
|
processing_time = time.perf_counter() - timer_start
|
||||||
|
|
||||||
|
# Restore original data type with overflow protection
|
||||||
|
return (
|
||||||
|
np.clip(smoothed_data,
|
||||||
|
0.0,
|
||||||
|
1.0
|
||||||
|
).astype(original_dtype),
|
||||||
|
processing_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _blend_region_start(
|
||||||
|
array: np.ndarray,
|
||||||
|
region: np.ndarray,
|
||||||
|
processed_boundary: int,
|
||||||
|
blend_frames: int
|
||||||
|
) -> None:
|
||||||
|
"""Applies linear blend between last active frame and silent region start."""
|
||||||
|
blend_length = min(blend_frames, region[0] - processed_boundary)
|
||||||
|
if blend_length <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
pre_frame = array[region[0] - 1]
|
||||||
|
for i in range(blend_length):
|
||||||
|
weight = (i + 1) / (blend_length + 1)
|
||||||
|
array[region[0] + i] = pre_frame * (1 - weight) + array[region[0] + i] * weight
|
||||||
|
|
||||||
|
def _blend_region_end(
|
||||||
|
array: np.ndarray,
|
||||||
|
region: np.ndarray,
|
||||||
|
blend_frames: int
|
||||||
|
) -> None:
|
||||||
|
"""Applies linear blend between silent region end and next active frame."""
|
||||||
|
blend_length = min(blend_frames, array.shape[0] - region[-1] - 1)
|
||||||
|
if blend_length <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
post_frame = array[region[-1] + 1]
|
||||||
|
for i in range(blend_length):
|
||||||
|
weight = (i + 1) / (blend_length + 1)
|
||||||
|
array[region[-1] - i] = post_frame * (1 - weight) + array[region[-1] - i] * weight
|
||||||
|
|
||||||
|
def find_low_value_regions(
|
||||||
|
signal: np.ndarray,
|
||||||
|
threshold: float,
|
||||||
|
min_region_length: int = 5
|
||||||
|
) -> list:
|
||||||
|
"""Identifies contiguous regions in a signal where values fall below a threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signal: Input 1D array of numerical values
|
||||||
|
threshold: Value threshold for identifying low regions
|
||||||
|
min_region_length: Minimum consecutive samples required to qualify as a region
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of numpy arrays, each containing indices for a qualifying low-value region
|
||||||
|
"""
|
||||||
|
low_value_indices = np.where(signal < threshold)[0]
|
||||||
|
contiguous_regions = []
|
||||||
|
current_region_length = 0
|
||||||
|
region_start_idx = 0
|
||||||
|
|
||||||
|
for i in range(1, len(low_value_indices)):
|
||||||
|
# Check if current index continues a consecutive sequence
|
||||||
|
if low_value_indices[i] != low_value_indices[i - 1] + 1:
|
||||||
|
# Finalize previous region if it meets length requirement
|
||||||
|
if current_region_length >= min_region_length:
|
||||||
|
contiguous_regions.append(low_value_indices[region_start_idx:i])
|
||||||
|
# Reset tracking for new potential region
|
||||||
|
region_start_idx = i
|
||||||
|
current_region_length = 0
|
||||||
|
current_region_length += 1
|
||||||
|
|
||||||
|
# Add the final region if it qualifies
|
||||||
|
if current_region_length >= min_region_length:
|
||||||
|
contiguous_regions.append(low_value_indices[region_start_idx:])
|
||||||
|
|
||||||
|
return contiguous_regions
|
||||||
|
|
||||||
|
|
||||||
|
def smooth_mouth_movements(
|
||||||
|
blend_shapes: np.ndarray,
|
||||||
|
processed_frames: int,
|
||||||
|
volume: np.ndarray = None,
|
||||||
|
silence_threshold: float = 0.001,
|
||||||
|
min_silence_duration: int = 7,
|
||||||
|
blend_window: int = 3
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Reduces jaw movement artifacts during silent periods in audio-driven animation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
|
||||||
|
processed_frames: Number of already processed frames that shouldn't be modified
|
||||||
|
volume: Audio volume array used to detect silent periods
|
||||||
|
silence_threshold: Volume threshold for considering a frame silent
|
||||||
|
min_silence_duration: Minimum consecutive silent frames to qualify for processing
|
||||||
|
blend_window: Number of frames to smooth at region boundaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified blend shape array with reduced mouth movements during silence
|
||||||
|
"""
|
||||||
|
if volume is None:
|
||||||
|
return blend_shapes
|
||||||
|
|
||||||
|
# Detect silence periods using volume data
|
||||||
|
silent_regions = find_low_value_regions(
|
||||||
|
volume,
|
||||||
|
threshold=silence_threshold,
|
||||||
|
min_region_length=min_silence_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
for region_indices in silent_regions:
|
||||||
|
# Reduce mouth blend shapes in silent region
|
||||||
|
mouth_blend_indices = [ARKitBlendShape.index(name) for name in MOUTH_BLENDSHAPES]
|
||||||
|
for region_indice in region_indices.tolist():
|
||||||
|
blend_shapes[region_indice, mouth_blend_indices] *= 0.1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Smooth transition into silent region
|
||||||
|
_blend_region_start(
|
||||||
|
blend_shapes,
|
||||||
|
region_indices,
|
||||||
|
processed_frames,
|
||||||
|
blend_window
|
||||||
|
)
|
||||||
|
|
||||||
|
# Smooth transition out of silent region
|
||||||
|
_blend_region_end(
|
||||||
|
blend_shapes,
|
||||||
|
region_indices,
|
||||||
|
blend_window
|
||||||
|
)
|
||||||
|
except IndexError as e:
|
||||||
|
warnings.warn(f"Edge blending skipped at region {region_indices}: {str(e)}")
|
||||||
|
|
||||||
|
return blend_shapes
|
||||||
|
|
||||||
|
|
||||||
|
def apply_frame_blending(
|
||||||
|
blend_shapes: np.ndarray,
|
||||||
|
processed_frames: int,
|
||||||
|
initial_blend_window: int = 3,
|
||||||
|
subsequent_blend_window: int = 5
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Smooths transitions between processed and unprocessed animation frames using linear blending.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
|
||||||
|
processed_frames: Number of already processed frames (0 means no previous processing)
|
||||||
|
initial_blend_window: Max frames to blend at sequence start
|
||||||
|
subsequent_blend_window: Max frames to blend between processed and new frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified blend shape array with smoothed transitions
|
||||||
|
"""
|
||||||
|
if processed_frames > 0:
|
||||||
|
# Blend transition between existing and new animation
|
||||||
|
_blend_animation_segment(
|
||||||
|
blend_shapes,
|
||||||
|
transition_start=processed_frames,
|
||||||
|
blend_window=subsequent_blend_window,
|
||||||
|
reference_frame=blend_shapes[processed_frames - 1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Smooth initial frames from neutral expression (zeros)
|
||||||
|
_blend_animation_segment(
|
||||||
|
blend_shapes,
|
||||||
|
transition_start=0,
|
||||||
|
blend_window=initial_blend_window,
|
||||||
|
reference_frame=np.zeros_like(blend_shapes[0])
|
||||||
|
)
|
||||||
|
return blend_shapes
|
||||||
|
|
||||||
|
|
||||||
|
def _blend_animation_segment(
|
||||||
|
array: np.ndarray,
|
||||||
|
transition_start: int,
|
||||||
|
blend_window: int,
|
||||||
|
reference_frame: np.ndarray
|
||||||
|
) -> None:
|
||||||
|
"""Applies linear interpolation between reference frame and target frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
array: Blend shape array to modify
|
||||||
|
transition_start: Starting index for blending
|
||||||
|
blend_window: Maximum number of frames to blend
|
||||||
|
reference_frame: The reference frame to blend from
|
||||||
|
"""
|
||||||
|
actual_blend_length = min(blend_window, array.shape[0] - transition_start)
|
||||||
|
|
||||||
|
for frame_offset in range(actual_blend_length):
|
||||||
|
current_idx = transition_start + frame_offset
|
||||||
|
blend_weight = (frame_offset + 1) / (actual_blend_length + 1)
|
||||||
|
|
||||||
|
# Linear interpolation: ref_frame * (1 - weight) + current_frame * weight
|
||||||
|
array[current_idx] = (reference_frame * (1 - blend_weight)
|
||||||
|
+ array[current_idx] * blend_weight)
|
||||||
|
|
||||||
|
|
||||||
|
BROW1 = np.array([[0.05597309, 0.05727929, 0.07995935, 0. , 0. ],
|
||||||
|
[0.00757574, 0.00936678, 0.12242376, 0. , 0. ],
|
||||||
|
[0. , 0. , 0.14943372, 0.04535687, 0.04264118],
|
||||||
|
[0. , 0. , 0.18015374, 0.09019445, 0.08736137],
|
||||||
|
[0. , 0. , 0.20549579, 0.12802747, 0.12450772],
|
||||||
|
[0. , 0. , 0.21098022, 0.1369939 , 0.13343132],
|
||||||
|
[0. , 0. , 0.20904602, 0.13903855, 0.13562402],
|
||||||
|
[0. , 0. , 0.20365039, 0.13977394, 0.13653506],
|
||||||
|
[0. , 0. , 0.19714841, 0.14096624, 0.13805152],
|
||||||
|
[0. , 0. , 0.20325482, 0.17303431, 0.17028868],
|
||||||
|
[0. , 0. , 0.21990852, 0.20164253, 0.19818163],
|
||||||
|
[0. , 0. , 0.23858181, 0.21908803, 0.21540019],
|
||||||
|
[0. , 0. , 0.2567876 , 0.23762083, 0.23396946],
|
||||||
|
[0. , 0. , 0.34093422, 0.27898848, 0.27651772],
|
||||||
|
[0. , 0. , 0.45288125, 0.35008961, 0.34887788],
|
||||||
|
[0. , 0. , 0.48076251, 0.36878952, 0.36778417],
|
||||||
|
[0. , 0. , 0.47798249, 0.36362219, 0.36145973],
|
||||||
|
[0. , 0. , 0.46186113, 0.33865979, 0.33597934],
|
||||||
|
[0. , 0. , 0.45264384, 0.33152157, 0.32891783],
|
||||||
|
[0. , 0. , 0.40986338, 0.29646468, 0.2945672 ],
|
||||||
|
[0. , 0. , 0.35628179, 0.23356403, 0.23155804],
|
||||||
|
[0. , 0. , 0.30870566, 0.1780673 , 0.17637439],
|
||||||
|
[0. , 0. , 0.25293985, 0.10710219, 0.10622486],
|
||||||
|
[0. , 0. , 0.18743332, 0.03252602, 0.03244236],
|
||||||
|
[0.02340254, 0.02364671, 0.15736724, 0. , 0. ]])
|
||||||
|
|
||||||
|
BROW2 = np.array([
|
||||||
|
[0. , 0. , 0.09799323, 0.05944436, 0.05002545],
|
||||||
|
[0. , 0. , 0.09780276, 0.07674237, 0.01636653],
|
||||||
|
[0. , 0. , 0.11136199, 0.1027964 , 0.04249811],
|
||||||
|
[0. , 0. , 0.26883412, 0.15861984, 0.15832305],
|
||||||
|
[0. , 0. , 0.42191629, 0.27038204, 0.27007768],
|
||||||
|
[0. , 0. , 0.3404977 , 0.21633868, 0.21597538],
|
||||||
|
[0. , 0. , 0.27301185, 0.17176409, 0.17134669],
|
||||||
|
[0. , 0. , 0.25960442, 0.15670464, 0.15622253],
|
||||||
|
[0. , 0. , 0.22877269, 0.11805892, 0.11754539],
|
||||||
|
[0. , 0. , 0.1451605 , 0.06389034, 0.0636282 ]])
|
||||||
|
|
||||||
|
BROW3 = np.array([
|
||||||
|
[0. , 0. , 0.124 , 0.0295, 0.0295],
|
||||||
|
[0. , 0. , 0.267 , 0.184 , 0.184 ],
|
||||||
|
[0. , 0. , 0.359 , 0.2765, 0.2765],
|
||||||
|
[0. , 0. , 0.3945, 0.3125, 0.3125],
|
||||||
|
[0. , 0. , 0.4125, 0.331 , 0.331 ],
|
||||||
|
[0. , 0. , 0.4235, 0.3445, 0.3445],
|
||||||
|
[0. , 0. , 0.4085, 0.3305, 0.3305],
|
||||||
|
[0. , 0. , 0.3695, 0.294 , 0.294 ],
|
||||||
|
[0. , 0. , 0.2835, 0.213 , 0.213 ],
|
||||||
|
[0. , 0. , 0.1795, 0.1005, 0.1005],
|
||||||
|
[0. , 0. , 0.108 , 0.014 , 0.014 ]])
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.ndimage import label
|
||||||
|
|
||||||
|
|
||||||
|
def apply_random_brow_movement(input_exp, volume):
|
||||||
|
FRAME_SEGMENT = 150
|
||||||
|
HOLD_THRESHOLD = 10
|
||||||
|
VOLUME_THRESHOLD = 0.08
|
||||||
|
MIN_REGION_LENGTH = 6
|
||||||
|
STRENGTH_RANGE = (0.7, 1.3)
|
||||||
|
|
||||||
|
BROW_PEAKS = {
|
||||||
|
0: np.argmax(BROW1[:, 2]),
|
||||||
|
1: np.argmax(BROW2[:, 2])
|
||||||
|
}
|
||||||
|
|
||||||
|
for seg_start in range(0, len(volume), FRAME_SEGMENT):
|
||||||
|
seg_end = min(seg_start + FRAME_SEGMENT, len(volume))
|
||||||
|
seg_volume = volume[seg_start:seg_end]
|
||||||
|
|
||||||
|
candidate_regions = []
|
||||||
|
|
||||||
|
high_vol_mask = seg_volume > VOLUME_THRESHOLD
|
||||||
|
labeled_array, num_features = label(high_vol_mask)
|
||||||
|
|
||||||
|
for i in range(1, num_features + 1):
|
||||||
|
region = (labeled_array == i)
|
||||||
|
region_indices = np.where(region)[0]
|
||||||
|
if len(region_indices) >= MIN_REGION_LENGTH:
|
||||||
|
candidate_regions.append(region_indices)
|
||||||
|
|
||||||
|
if candidate_regions:
|
||||||
|
selected_region = candidate_regions[np.random.choice(len(candidate_regions))]
|
||||||
|
region_start = selected_region[0]
|
||||||
|
region_end = selected_region[-1]
|
||||||
|
region_length = region_end - region_start + 1
|
||||||
|
|
||||||
|
brow_idx = np.random.randint(0, 2)
|
||||||
|
base_brow = BROW1 if brow_idx == 0 else BROW2
|
||||||
|
peak_idx = BROW_PEAKS[brow_idx]
|
||||||
|
|
||||||
|
if region_length > HOLD_THRESHOLD:
|
||||||
|
local_max_pos = seg_volume[selected_region].argmax()
|
||||||
|
global_peak_frame = seg_start + selected_region[local_max_pos]
|
||||||
|
|
||||||
|
rise_anim = base_brow[:peak_idx + 1]
|
||||||
|
hold_frame = base_brow[peak_idx:peak_idx + 1]
|
||||||
|
|
||||||
|
insert_start = max(global_peak_frame - peak_idx, seg_start)
|
||||||
|
insert_end = min(global_peak_frame + (region_length - local_max_pos), seg_end)
|
||||||
|
|
||||||
|
strength = np.random.uniform(*STRENGTH_RANGE)
|
||||||
|
|
||||||
|
if insert_start + len(rise_anim) <= seg_end:
|
||||||
|
input_exp[insert_start:insert_start + len(rise_anim), :5] += rise_anim * strength
|
||||||
|
hold_duration = insert_end - (insert_start + len(rise_anim))
|
||||||
|
if hold_duration > 0:
|
||||||
|
input_exp[insert_start + len(rise_anim):insert_end, :5] += np.tile(hold_frame * strength,
|
||||||
|
(hold_duration, 1))
|
||||||
|
else:
|
||||||
|
anim_length = base_brow.shape[0]
|
||||||
|
insert_pos = seg_start + region_start + (region_length - anim_length) // 2
|
||||||
|
insert_pos = max(seg_start, min(insert_pos, seg_end - anim_length))
|
||||||
|
|
||||||
|
if insert_pos + anim_length <= seg_end:
|
||||||
|
strength = np.random.uniform(*STRENGTH_RANGE)
|
||||||
|
input_exp[insert_pos:insert_pos + anim_length, :5] += base_brow * strength
|
||||||
|
|
||||||
|
return np.clip(input_exp, 0, 1)
|
||||||
10
requirements.txt
Normal file
10
requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
spleeter==2.4.2
|
||||||
|
opencv_python_headless==4.11.0.86
|
||||||
|
gradio==5.25.2
|
||||||
|
omegaconf==2.3.0
|
||||||
|
addict==2.4.0
|
||||||
|
yapf==0.40.1
|
||||||
|
librosa==0.11.0
|
||||||
|
transformers==4.36.2
|
||||||
|
termcolor==3.0.1
|
||||||
|
numpy==1.26.3
|
||||||
9
scripts/install/install_cu118.sh
Normal file
9
scripts/install/install_cu118.sh
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# install torch 2.1.2
|
||||||
|
# or conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||||
|
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
|
||||||
|
|
||||||
|
# install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# install H5-render
|
||||||
|
pip install wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl
|
||||||
9
scripts/install/install_cu121.sh
Normal file
9
scripts/install/install_cu121.sh
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# install torch 2.1.2
|
||||||
|
# or conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
|
||||||
|
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
|
||||||
|
# install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# install H5-render
|
||||||
|
pip install wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl
|
||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
53
utils/cache.py
Normal file
53
utils/cache.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import SharedArray
|
||||||
|
|
||||||
|
try:
|
||||||
|
from multiprocessing.shared_memory import ShareableList
|
||||||
|
except ImportError:
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn("Please update python version >= 3.8 to enable shared_memory")
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def shared_array(name, var=None):
|
||||||
|
if var is not None:
|
||||||
|
# check exist
|
||||||
|
if os.path.exists(f"/dev/shm/{name}"):
|
||||||
|
return SharedArray.attach(f"shm://{name}")
|
||||||
|
# create shared_array
|
||||||
|
data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype)
|
||||||
|
data[...] = var[...]
|
||||||
|
data.flags.writeable = False
|
||||||
|
else:
|
||||||
|
data = SharedArray.attach(f"shm://{name}").copy()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def shared_dict(name, var=None):
|
||||||
|
name = str(name)
|
||||||
|
assert "." not in name # '.' is used as sep flag
|
||||||
|
data = {}
|
||||||
|
if var is not None:
|
||||||
|
assert isinstance(var, dict)
|
||||||
|
keys = var.keys()
|
||||||
|
# current version only cache np.array
|
||||||
|
keys_valid = []
|
||||||
|
for key in keys:
|
||||||
|
if isinstance(var[key], np.ndarray):
|
||||||
|
keys_valid.append(key)
|
||||||
|
keys = keys_valid
|
||||||
|
|
||||||
|
ShareableList(sequence=keys, name=name + ".keys")
|
||||||
|
for key in keys:
|
||||||
|
if isinstance(var[key], np.ndarray):
|
||||||
|
data[key] = shared_array(name=f"{name}.{key}", var=var[key])
|
||||||
|
else:
|
||||||
|
keys = list(ShareableList(name=name + ".keys"))
|
||||||
|
for key in keys:
|
||||||
|
data[key] = shared_array(name=f"{name}.{key}")
|
||||||
|
return data
|
||||||
192
utils/comm.py
Normal file
192
utils/comm.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
_LOCAL_PROCESS_GROUP = None
|
||||||
|
"""
|
||||||
|
A torch process group which only includes processes that on the same machine as the current process.
|
||||||
|
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size() -> int:
|
||||||
|
if not dist.is_available():
|
||||||
|
return 1
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank() -> int:
|
||||||
|
if not dist.is_available():
|
||||||
|
return 0
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_rank() -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The rank of the current process within the local (per-machine) process group.
|
||||||
|
"""
|
||||||
|
if not dist.is_available():
|
||||||
|
return 0
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return 0
|
||||||
|
assert (
|
||||||
|
_LOCAL_PROCESS_GROUP is not None
|
||||||
|
), "Local process group is not created! Please use launch() to spawn processes!"
|
||||||
|
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_size() -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The size of the per-machine process group,
|
||||||
|
i.e. the number of processes per machine.
|
||||||
|
"""
|
||||||
|
if not dist.is_available():
|
||||||
|
return 1
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process() -> bool:
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize():
|
||||||
|
"""
|
||||||
|
Helper function to synchronize (barrier) among all processes when
|
||||||
|
using distributed training
|
||||||
|
"""
|
||||||
|
if not dist.is_available():
|
||||||
|
return
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
if world_size == 1:
|
||||||
|
return
|
||||||
|
if dist.get_backend() == dist.Backend.NCCL:
|
||||||
|
# This argument is needed to avoid warnings.
|
||||||
|
# It's valid only for NCCL backend.
|
||||||
|
dist.barrier(device_ids=[torch.cuda.current_device()])
|
||||||
|
else:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def _get_global_gloo_group():
|
||||||
|
"""
|
||||||
|
Return a process group based on gloo backend, containing all the ranks
|
||||||
|
The result is cached.
|
||||||
|
"""
|
||||||
|
if dist.get_backend() == "nccl":
|
||||||
|
return dist.new_group(backend="gloo")
|
||||||
|
else:
|
||||||
|
return dist.group.WORLD
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather(data, group=None):
|
||||||
|
"""
|
||||||
|
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
||||||
|
Args:
|
||||||
|
data: any picklable object
|
||||||
|
group: a torch process group. By default, will use a group which
|
||||||
|
contains all ranks on gloo backend.
|
||||||
|
Returns:
|
||||||
|
list[data]: list of data gathered from each rank
|
||||||
|
"""
|
||||||
|
if get_world_size() == 1:
|
||||||
|
return [data]
|
||||||
|
if group is None:
|
||||||
|
group = (
|
||||||
|
_get_global_gloo_group()
|
||||||
|
) # use CPU group by default, to reduce GPU RAM usage.
|
||||||
|
world_size = dist.get_world_size(group)
|
||||||
|
if world_size == 1:
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
output = [None for _ in range(world_size)]
|
||||||
|
dist.all_gather_object(output, data, group=group)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def gather(data, dst=0, group=None):
|
||||||
|
"""
|
||||||
|
Run gather on arbitrary picklable data (not necessarily tensors).
|
||||||
|
Args:
|
||||||
|
data: any picklable object
|
||||||
|
dst (int): destination rank
|
||||||
|
group: a torch process group. By default, will use a group which
|
||||||
|
contains all ranks on gloo backend.
|
||||||
|
Returns:
|
||||||
|
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
||||||
|
an empty list.
|
||||||
|
"""
|
||||||
|
if get_world_size() == 1:
|
||||||
|
return [data]
|
||||||
|
if group is None:
|
||||||
|
group = _get_global_gloo_group()
|
||||||
|
world_size = dist.get_world_size(group=group)
|
||||||
|
if world_size == 1:
|
||||||
|
return [data]
|
||||||
|
rank = dist.get_rank(group=group)
|
||||||
|
|
||||||
|
if rank == dst:
|
||||||
|
output = [None for _ in range(world_size)]
|
||||||
|
dist.gather_object(data, output, dst=dst, group=group)
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
dist.gather_object(data, None, dst=dst, group=group)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def shared_random_seed():
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
int: a random number that is the same across all workers.
|
||||||
|
If workers need a shared RNG, they can use this shared seed to
|
||||||
|
create one.
|
||||||
|
All workers must call this function, otherwise it will deadlock.
|
||||||
|
"""
|
||||||
|
ints = np.random.randint(2**31)
|
||||||
|
all_ints = all_gather(ints)
|
||||||
|
return all_ints[0]
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_dict(input_dict, average=True):
|
||||||
|
"""
|
||||||
|
Reduce the values in the dictionary from all processes so that process with rank
|
||||||
|
0 has the reduced results.
|
||||||
|
Args:
|
||||||
|
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
||||||
|
average (bool): whether to do average or sum
|
||||||
|
Returns:
|
||||||
|
a dict with the same keys as input_dict, after reduction.
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size < 2:
|
||||||
|
return input_dict
|
||||||
|
with torch.no_grad():
|
||||||
|
names = []
|
||||||
|
values = []
|
||||||
|
# sort the keys so that they are consistent across processes
|
||||||
|
for k in sorted(input_dict.keys()):
|
||||||
|
names.append(k)
|
||||||
|
values.append(input_dict[k])
|
||||||
|
values = torch.stack(values, dim=0)
|
||||||
|
dist.reduce(values, dst=0)
|
||||||
|
if dist.get_rank() == 0 and average:
|
||||||
|
# only main process gets accumulated, so only divide by
|
||||||
|
# world_size in this case
|
||||||
|
values /= world_size
|
||||||
|
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||||
|
return reduced_dict
|
||||||
696
utils/config.py
Normal file
696
utils/config.py
Normal file
@@ -0,0 +1,696 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
import ast
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
import warnings
|
||||||
|
from argparse import Action, ArgumentParser
|
||||||
|
from collections import abc
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
from addict import Dict
|
||||||
|
from yapf.yapflib.yapf_api import FormatCode
|
||||||
|
|
||||||
|
from .misc import import_modules_from_strings
|
||||||
|
from .path import check_file_exist
|
||||||
|
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
import regex as re
|
||||||
|
else:
|
||||||
|
import re
|
||||||
|
|
||||||
|
BASE_KEY = "_base_"
|
||||||
|
DELETE_KEY = "_delete_"
|
||||||
|
DEPRECATION_KEY = "_deprecation_"
|
||||||
|
RESERVED_KEYS = ["filename", "text", "pretty_text"]
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigDict(Dict):
|
||||||
|
def __missing__(self, name):
|
||||||
|
raise KeyError(name)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
try:
|
||||||
|
value = super(ConfigDict, self).__getattr__(name)
|
||||||
|
except KeyError:
|
||||||
|
ex = AttributeError(
|
||||||
|
f"'{self.__class__.__name__}' object has no " f"attribute '{name}'"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
ex = e
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
|
||||||
|
def add_args(parser, cfg, prefix=""):
|
||||||
|
for k, v in cfg.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
parser.add_argument("--" + prefix + k)
|
||||||
|
elif isinstance(v, int):
|
||||||
|
parser.add_argument("--" + prefix + k, type=int)
|
||||||
|
elif isinstance(v, float):
|
||||||
|
parser.add_argument("--" + prefix + k, type=float)
|
||||||
|
elif isinstance(v, bool):
|
||||||
|
parser.add_argument("--" + prefix + k, action="store_true")
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
add_args(parser, v, prefix + k + ".")
|
||||||
|
elif isinstance(v, abc.Iterable):
|
||||||
|
parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+")
|
||||||
|
else:
|
||||||
|
print(f"cannot parse key {prefix + k} of type {type(v)}")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""A facility for config and config files.
|
||||||
|
|
||||||
|
It supports common file formats as configs: python/json/yaml. The interface
|
||||||
|
is the same as a dict object and also allows access config values as
|
||||||
|
attributes.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||||
|
>>> cfg.a
|
||||||
|
1
|
||||||
|
>>> cfg.b
|
||||||
|
{'b1': [0, 1]}
|
||||||
|
>>> cfg.b.b1
|
||||||
|
[0, 1]
|
||||||
|
>>> cfg = Config.fromfile('tests/data/config/a.py')
|
||||||
|
>>> cfg.filename
|
||||||
|
"/home/kchen/projects/mmcv/tests/data/config/a.py"
|
||||||
|
>>> cfg.item4
|
||||||
|
'test'
|
||||||
|
>>> cfg
|
||||||
|
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
|
||||||
|
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_py_syntax(filename):
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
# Setting encoding explicitly to resolve coding issue on windows
|
||||||
|
content = f.read()
|
||||||
|
try:
|
||||||
|
ast.parse(content)
|
||||||
|
except SyntaxError as e:
|
||||||
|
raise SyntaxError(
|
||||||
|
"There are syntax errors in config " f"file {filename}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _substitute_predefined_vars(filename, temp_config_name):
|
||||||
|
file_dirname = osp.dirname(filename)
|
||||||
|
file_basename = osp.basename(filename)
|
||||||
|
file_basename_no_extension = osp.splitext(file_basename)[0]
|
||||||
|
file_extname = osp.splitext(filename)[1]
|
||||||
|
support_templates = dict(
|
||||||
|
fileDirname=file_dirname,
|
||||||
|
fileBasename=file_basename,
|
||||||
|
fileBasenameNoExtension=file_basename_no_extension,
|
||||||
|
fileExtname=file_extname,
|
||||||
|
)
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
# Setting encoding explicitly to resolve coding issue on windows
|
||||||
|
config_file = f.read()
|
||||||
|
for key, value in support_templates.items():
|
||||||
|
regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
|
||||||
|
value = value.replace("\\", "/")
|
||||||
|
config_file = re.sub(regexp, value, config_file)
|
||||||
|
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
|
||||||
|
tmp_config_file.write(config_file)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pre_substitute_base_vars(filename, temp_config_name):
|
||||||
|
"""Substitute base variable placehoders to string, so that parsing
|
||||||
|
would work."""
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
# Setting encoding explicitly to resolve coding issue on windows
|
||||||
|
config_file = f.read()
|
||||||
|
base_var_dict = {}
|
||||||
|
regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
|
||||||
|
base_vars = set(re.findall(regexp, config_file))
|
||||||
|
for base_var in base_vars:
|
||||||
|
randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
|
||||||
|
base_var_dict[randstr] = base_var
|
||||||
|
regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
|
||||||
|
config_file = re.sub(regexp, f'"{randstr}"', config_file)
|
||||||
|
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
|
||||||
|
tmp_config_file.write(config_file)
|
||||||
|
return base_var_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
|
||||||
|
"""Substitute variable strings to their actual values."""
|
||||||
|
cfg = copy.deepcopy(cfg)
|
||||||
|
|
||||||
|
if isinstance(cfg, dict):
|
||||||
|
for k, v in cfg.items():
|
||||||
|
if isinstance(v, str) and v in base_var_dict:
|
||||||
|
new_v = base_cfg
|
||||||
|
for new_k in base_var_dict[v].split("."):
|
||||||
|
new_v = new_v[new_k]
|
||||||
|
cfg[k] = new_v
|
||||||
|
elif isinstance(v, (list, tuple, dict)):
|
||||||
|
cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
|
||||||
|
elif isinstance(cfg, tuple):
|
||||||
|
cfg = tuple(
|
||||||
|
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
|
||||||
|
)
|
||||||
|
elif isinstance(cfg, list):
|
||||||
|
cfg = [
|
||||||
|
Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
|
||||||
|
]
|
||||||
|
elif isinstance(cfg, str) and cfg in base_var_dict:
|
||||||
|
new_v = base_cfg
|
||||||
|
for new_k in base_var_dict[cfg].split("."):
|
||||||
|
new_v = new_v[new_k]
|
||||||
|
cfg = new_v
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _file2dict(filename, use_predefined_variables=True):
|
||||||
|
filename = osp.abspath(osp.expanduser(filename))
|
||||||
|
check_file_exist(filename)
|
||||||
|
fileExtname = osp.splitext(filename)[1]
|
||||||
|
if fileExtname not in [".py", ".json", ".yaml", ".yml"]:
|
||||||
|
raise IOError("Only py/yml/yaml/json type are supported now!")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||||
|
temp_config_file = tempfile.NamedTemporaryFile(
|
||||||
|
dir=temp_config_dir, suffix=fileExtname
|
||||||
|
)
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
temp_config_file.close()
|
||||||
|
temp_config_name = osp.basename(temp_config_file.name)
|
||||||
|
# Substitute predefined variables
|
||||||
|
if use_predefined_variables:
|
||||||
|
Config._substitute_predefined_vars(filename, temp_config_file.name)
|
||||||
|
else:
|
||||||
|
shutil.copyfile(filename, temp_config_file.name)
|
||||||
|
# Substitute base variables from placeholders to strings
|
||||||
|
base_var_dict = Config._pre_substitute_base_vars(
|
||||||
|
temp_config_file.name, temp_config_file.name
|
||||||
|
)
|
||||||
|
|
||||||
|
if filename.endswith(".py"):
|
||||||
|
temp_module_name = osp.splitext(temp_config_name)[0]
|
||||||
|
sys.path.insert(0, temp_config_dir)
|
||||||
|
Config._validate_py_syntax(filename)
|
||||||
|
mod = import_module(temp_module_name)
|
||||||
|
sys.path.pop(0)
|
||||||
|
cfg_dict = {
|
||||||
|
name: value
|
||||||
|
for name, value in mod.__dict__.items()
|
||||||
|
if not name.startswith("__")
|
||||||
|
}
|
||||||
|
# delete imported module
|
||||||
|
del sys.modules[temp_module_name]
|
||||||
|
elif filename.endswith((".yml", ".yaml", ".json")):
|
||||||
|
raise NotImplementedError
|
||||||
|
# close temp file
|
||||||
|
temp_config_file.close()
|
||||||
|
|
||||||
|
# check deprecation information
|
||||||
|
if DEPRECATION_KEY in cfg_dict:
|
||||||
|
deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
|
||||||
|
warning_msg = (
|
||||||
|
f"The config file {filename} will be deprecated " "in the future."
|
||||||
|
)
|
||||||
|
if "expected" in deprecation_info:
|
||||||
|
warning_msg += f' Please use {deprecation_info["expected"]} ' "instead."
|
||||||
|
if "reference" in deprecation_info:
|
||||||
|
warning_msg += (
|
||||||
|
" More information can be found at "
|
||||||
|
f'{deprecation_info["reference"]}'
|
||||||
|
)
|
||||||
|
warnings.warn(warning_msg)
|
||||||
|
|
||||||
|
cfg_text = filename + "\n"
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
# Setting encoding explicitly to resolve coding issue on windows
|
||||||
|
cfg_text += f.read()
|
||||||
|
|
||||||
|
if BASE_KEY in cfg_dict:
|
||||||
|
cfg_dir = osp.dirname(filename)
|
||||||
|
base_filename = cfg_dict.pop(BASE_KEY)
|
||||||
|
base_filename = (
|
||||||
|
base_filename if isinstance(base_filename, list) else [base_filename]
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg_dict_list = list()
|
||||||
|
cfg_text_list = list()
|
||||||
|
for f in base_filename:
|
||||||
|
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
|
||||||
|
cfg_dict_list.append(_cfg_dict)
|
||||||
|
cfg_text_list.append(_cfg_text)
|
||||||
|
|
||||||
|
base_cfg_dict = dict()
|
||||||
|
for c in cfg_dict_list:
|
||||||
|
duplicate_keys = base_cfg_dict.keys() & c.keys()
|
||||||
|
if len(duplicate_keys) > 0:
|
||||||
|
raise KeyError(
|
||||||
|
"Duplicate key is not allowed among bases. "
|
||||||
|
f"Duplicate keys: {duplicate_keys}"
|
||||||
|
)
|
||||||
|
base_cfg_dict.update(c)
|
||||||
|
|
||||||
|
# Substitute base variables from strings to their actual values
|
||||||
|
cfg_dict = Config._substitute_base_vars(
|
||||||
|
cfg_dict, base_var_dict, base_cfg_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
|
||||||
|
cfg_dict = base_cfg_dict
|
||||||
|
|
||||||
|
# merge cfg_text
|
||||||
|
cfg_text_list.append(cfg_text)
|
||||||
|
cfg_text = "\n".join(cfg_text_list)
|
||||||
|
|
||||||
|
return cfg_dict, cfg_text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_a_into_b(a, b, allow_list_keys=False):
|
||||||
|
"""merge dict ``a`` into dict ``b`` (non-inplace).
|
||||||
|
|
||||||
|
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
|
||||||
|
in-place modifications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (dict): The source dict to be merged into ``b``.
|
||||||
|
b (dict): The origin dict to be fetch keys from ``a``.
|
||||||
|
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
||||||
|
are allowed in source ``a`` and will replace the element of the
|
||||||
|
corresponding index in b if b is a list. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The modified dict of ``b`` using ``a``.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Normally merge a into b.
|
||||||
|
>>> Config._merge_a_into_b(
|
||||||
|
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
|
||||||
|
{'obj': {'a': 2}}
|
||||||
|
|
||||||
|
# Delete b first and merge a into b.
|
||||||
|
>>> Config._merge_a_into_b(
|
||||||
|
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
|
||||||
|
{'obj': {'a': 2}}
|
||||||
|
|
||||||
|
# b is a list
|
||||||
|
>>> Config._merge_a_into_b(
|
||||||
|
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
|
||||||
|
[{'a': 2}, {'b': 2}]
|
||||||
|
"""
|
||||||
|
b = b.copy()
|
||||||
|
for k, v in a.items():
|
||||||
|
if allow_list_keys and k.isdigit() and isinstance(b, list):
|
||||||
|
k = int(k)
|
||||||
|
if len(b) <= k:
|
||||||
|
raise KeyError(f"Index {k} exceeds the length of list {b}")
|
||||||
|
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
||||||
|
elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
|
||||||
|
allowed_types = (dict, list) if allow_list_keys else dict
|
||||||
|
if not isinstance(b[k], allowed_types):
|
||||||
|
raise TypeError(
|
||||||
|
f"{k}={v} in child config cannot inherit from base "
|
||||||
|
f"because {k} is a dict in the child config but is of "
|
||||||
|
f"type {type(b[k])} in base config. You may set "
|
||||||
|
f"`{DELETE_KEY}=True` to ignore the base config"
|
||||||
|
)
|
||||||
|
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
||||||
|
else:
|
||||||
|
b[k] = v
|
||||||
|
return b
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fromfile(filename, use_predefined_variables=True, import_custom_modules=True):
|
||||||
|
cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
|
||||||
|
if import_custom_modules and cfg_dict.get("custom_imports", None):
|
||||||
|
import_modules_from_strings(**cfg_dict["custom_imports"])
|
||||||
|
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fromstring(cfg_str, file_format):
|
||||||
|
"""Generate config from config str.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg_str (str): Config str.
|
||||||
|
file_format (str): Config file format corresponding to the
|
||||||
|
config str. Only py/yml/yaml/json type are supported now!
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
obj:`Config`: Config obj.
|
||||||
|
"""
|
||||||
|
if file_format not in [".py", ".json", ".yaml", ".yml"]:
|
||||||
|
raise IOError("Only py/yml/yaml/json type are supported now!")
|
||||||
|
if file_format != ".py" and "dict(" in cfg_str:
|
||||||
|
# check if users specify a wrong suffix for python
|
||||||
|
warnings.warn('Please check "file_format", the file format may be .py')
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
"w", encoding="utf-8", suffix=file_format, delete=False
|
||||||
|
) as temp_file:
|
||||||
|
temp_file.write(cfg_str)
|
||||||
|
# on windows, previous implementation cause error
|
||||||
|
# see PR 1077 for details
|
||||||
|
cfg = Config.fromfile(temp_file.name)
|
||||||
|
os.remove(temp_file.name)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def auto_argparser(description=None):
|
||||||
|
"""Generate argparser from config file automatically (experimental)"""
|
||||||
|
partial_parser = ArgumentParser(description=description)
|
||||||
|
partial_parser.add_argument("config", help="config file path")
|
||||||
|
cfg_file = partial_parser.parse_known_args()[0].config
|
||||||
|
cfg = Config.fromfile(cfg_file)
|
||||||
|
parser = ArgumentParser(description=description)
|
||||||
|
parser.add_argument("config", help="config file path")
|
||||||
|
add_args(parser, cfg)
|
||||||
|
return parser, cfg
|
||||||
|
|
||||||
|
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
||||||
|
if cfg_dict is None:
|
||||||
|
cfg_dict = dict()
|
||||||
|
elif not isinstance(cfg_dict, dict):
|
||||||
|
raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
|
||||||
|
for key in cfg_dict:
|
||||||
|
if key in RESERVED_KEYS:
|
||||||
|
raise KeyError(f"{key} is reserved for config file")
|
||||||
|
|
||||||
|
super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
|
||||||
|
super(Config, self).__setattr__("_filename", filename)
|
||||||
|
if cfg_text:
|
||||||
|
text = cfg_text
|
||||||
|
elif filename:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
text = f.read()
|
||||||
|
else:
|
||||||
|
text = ""
|
||||||
|
super(Config, self).__setattr__("_text", text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self):
|
||||||
|
return self._filename
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return self._text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pretty_text(self):
|
||||||
|
indent = 4
|
||||||
|
|
||||||
|
def _indent(s_, num_spaces):
|
||||||
|
s = s_.split("\n")
|
||||||
|
if len(s) == 1:
|
||||||
|
return s_
|
||||||
|
first = s.pop(0)
|
||||||
|
s = [(num_spaces * " ") + line for line in s]
|
||||||
|
s = "\n".join(s)
|
||||||
|
s = first + "\n" + s
|
||||||
|
return s
|
||||||
|
|
||||||
|
def _format_basic_types(k, v, use_mapping=False):
|
||||||
|
if isinstance(v, str):
|
||||||
|
v_str = f"'{v}'"
|
||||||
|
else:
|
||||||
|
v_str = str(v)
|
||||||
|
|
||||||
|
if use_mapping:
|
||||||
|
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||||
|
attr_str = f"{k_str}: {v_str}"
|
||||||
|
else:
|
||||||
|
attr_str = f"{str(k)}={v_str}"
|
||||||
|
attr_str = _indent(attr_str, indent)
|
||||||
|
|
||||||
|
return attr_str
|
||||||
|
|
||||||
|
def _format_list(k, v, use_mapping=False):
|
||||||
|
# check if all items in the list are dict
|
||||||
|
if all(isinstance(_, dict) for _ in v):
|
||||||
|
v_str = "[\n"
|
||||||
|
v_str += "\n".join(
|
||||||
|
f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
|
||||||
|
).rstrip(",")
|
||||||
|
if use_mapping:
|
||||||
|
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||||
|
attr_str = f"{k_str}: {v_str}"
|
||||||
|
else:
|
||||||
|
attr_str = f"{str(k)}={v_str}"
|
||||||
|
attr_str = _indent(attr_str, indent) + "]"
|
||||||
|
else:
|
||||||
|
attr_str = _format_basic_types(k, v, use_mapping)
|
||||||
|
return attr_str
|
||||||
|
|
||||||
|
def _contain_invalid_identifier(dict_str):
|
||||||
|
contain_invalid_identifier = False
|
||||||
|
for key_name in dict_str:
|
||||||
|
contain_invalid_identifier |= not str(key_name).isidentifier()
|
||||||
|
return contain_invalid_identifier
|
||||||
|
|
||||||
|
def _format_dict(input_dict, outest_level=False):
|
||||||
|
r = ""
|
||||||
|
s = []
|
||||||
|
|
||||||
|
use_mapping = _contain_invalid_identifier(input_dict)
|
||||||
|
if use_mapping:
|
||||||
|
r += "{"
|
||||||
|
for idx, (k, v) in enumerate(input_dict.items()):
|
||||||
|
is_last = idx >= len(input_dict) - 1
|
||||||
|
end = "" if outest_level or is_last else ","
|
||||||
|
if isinstance(v, dict):
|
||||||
|
v_str = "\n" + _format_dict(v)
|
||||||
|
if use_mapping:
|
||||||
|
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||||
|
attr_str = f"{k_str}: dict({v_str}"
|
||||||
|
else:
|
||||||
|
attr_str = f"{str(k)}=dict({v_str}"
|
||||||
|
attr_str = _indent(attr_str, indent) + ")" + end
|
||||||
|
elif isinstance(v, list):
|
||||||
|
attr_str = _format_list(k, v, use_mapping) + end
|
||||||
|
else:
|
||||||
|
attr_str = _format_basic_types(k, v, use_mapping) + end
|
||||||
|
|
||||||
|
s.append(attr_str)
|
||||||
|
r += "\n".join(s)
|
||||||
|
if use_mapping:
|
||||||
|
r += "}"
|
||||||
|
return r
|
||||||
|
|
||||||
|
cfg_dict = self._cfg_dict.to_dict()
|
||||||
|
text = _format_dict(cfg_dict, outest_level=True)
|
||||||
|
# copied from setup.cfg
|
||||||
|
yapf_style = dict(
|
||||||
|
based_on_style="pep8",
|
||||||
|
blank_line_before_nested_class_or_def=True,
|
||||||
|
split_before_expression_after_opening_paren=True,
|
||||||
|
)
|
||||||
|
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._cfg_dict)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._cfg_dict, name)
|
||||||
|
|
||||||
|
def __getitem__(self, name):
|
||||||
|
return self._cfg_dict.__getitem__(name)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = ConfigDict(value)
|
||||||
|
self._cfg_dict.__setattr__(name, value)
|
||||||
|
|
||||||
|
def __setitem__(self, name, value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = ConfigDict(value)
|
||||||
|
self._cfg_dict.__setitem__(name, value)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._cfg_dict)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return (self._cfg_dict, self._filename, self._text)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
_cfg_dict, _filename, _text = state
|
||||||
|
super(Config, self).__setattr__("_cfg_dict", _cfg_dict)
|
||||||
|
super(Config, self).__setattr__("_filename", _filename)
|
||||||
|
super(Config, self).__setattr__("_text", _text)
|
||||||
|
|
||||||
|
def dump(self, file=None):
|
||||||
|
cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict()
|
||||||
|
if self.filename.endswith(".py"):
|
||||||
|
if file is None:
|
||||||
|
return self.pretty_text
|
||||||
|
else:
|
||||||
|
with open(file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(self.pretty_text)
|
||||||
|
else:
|
||||||
|
import mmcv
|
||||||
|
|
||||||
|
if file is None:
|
||||||
|
file_format = self.filename.split(".")[-1]
|
||||||
|
return mmcv.dump(cfg_dict, file_format=file_format)
|
||||||
|
else:
|
||||||
|
mmcv.dump(cfg_dict, file)
|
||||||
|
|
||||||
|
def merge_from_dict(self, options, allow_list_keys=True):
|
||||||
|
"""Merge list into cfg_dict.
|
||||||
|
|
||||||
|
Merge the dict parsed by MultipleKVAction into this cfg.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> options = {'models.backbone.depth': 50,
|
||||||
|
... 'models.backbone.with_cp':True}
|
||||||
|
>>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet'))))
|
||||||
|
>>> cfg.merge_from_dict(options)
|
||||||
|
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||||
|
>>> assert cfg_dict == dict(
|
||||||
|
... models=dict(backbone=dict(depth=50, with_cp=True)))
|
||||||
|
|
||||||
|
# Merge list element
|
||||||
|
>>> cfg = Config(dict(pipeline=[
|
||||||
|
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
|
||||||
|
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
|
||||||
|
>>> cfg.merge_from_dict(options, allow_list_keys=True)
|
||||||
|
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||||
|
>>> assert cfg_dict == dict(pipeline=[
|
||||||
|
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options (dict): dict of configs to merge from.
|
||||||
|
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
||||||
|
are allowed in ``options`` and will replace the element of the
|
||||||
|
corresponding index in the config if the config is a list.
|
||||||
|
Default: True.
|
||||||
|
"""
|
||||||
|
option_cfg_dict = {}
|
||||||
|
for full_key, v in options.items():
|
||||||
|
d = option_cfg_dict
|
||||||
|
key_list = full_key.split(".")
|
||||||
|
for subkey in key_list[:-1]:
|
||||||
|
d.setdefault(subkey, ConfigDict())
|
||||||
|
d = d[subkey]
|
||||||
|
subkey = key_list[-1]
|
||||||
|
d[subkey] = v
|
||||||
|
|
||||||
|
cfg_dict = super(Config, self).__getattribute__("_cfg_dict")
|
||||||
|
super(Config, self).__setattr__(
|
||||||
|
"_cfg_dict",
|
||||||
|
Config._merge_a_into_b(
|
||||||
|
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DictAction(Action):
|
||||||
|
"""
|
||||||
|
argparse action to split an argument into KEY=VALUE form
|
||||||
|
on the first = and append to a dictionary. List options can
|
||||||
|
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
|
||||||
|
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
|
||||||
|
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_int_float_bool(val):
|
||||||
|
try:
|
||||||
|
return int(val)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return float(val)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
if val.lower() in ["true", "false"]:
|
||||||
|
return True if val.lower() == "true" else False
|
||||||
|
return val
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_iterable(val):
|
||||||
|
"""Parse iterable values in the string.
|
||||||
|
|
||||||
|
All elements inside '()' or '[]' are treated as iterable values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val (str): Value string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list | tuple: The expanded list or tuple from the string.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> DictAction._parse_iterable('1,2,3')
|
||||||
|
[1, 2, 3]
|
||||||
|
>>> DictAction._parse_iterable('[a, b, c]')
|
||||||
|
['a', 'b', 'c']
|
||||||
|
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
|
||||||
|
[(1, 2, 3), ['a', 'b'], 'c']
|
||||||
|
"""
|
||||||
|
|
||||||
|
def find_next_comma(string):
|
||||||
|
"""Find the position of next comma in the string.
|
||||||
|
|
||||||
|
If no ',' is found in the string, return the string length. All
|
||||||
|
chars inside '()' and '[]' are treated as one element and thus ','
|
||||||
|
inside these brackets are ignored.
|
||||||
|
"""
|
||||||
|
assert (string.count("(") == string.count(")")) and (
|
||||||
|
string.count("[") == string.count("]")
|
||||||
|
), f"Imbalanced brackets exist in {string}"
|
||||||
|
end = len(string)
|
||||||
|
for idx, char in enumerate(string):
|
||||||
|
pre = string[:idx]
|
||||||
|
# The string before this ',' is balanced
|
||||||
|
if (
|
||||||
|
(char == ",")
|
||||||
|
and (pre.count("(") == pre.count(")"))
|
||||||
|
and (pre.count("[") == pre.count("]"))
|
||||||
|
):
|
||||||
|
end = idx
|
||||||
|
break
|
||||||
|
return end
|
||||||
|
|
||||||
|
# Strip ' and " characters and replace whitespace.
|
||||||
|
val = val.strip("'\"").replace(" ", "")
|
||||||
|
is_tuple = False
|
||||||
|
if val.startswith("(") and val.endswith(")"):
|
||||||
|
is_tuple = True
|
||||||
|
val = val[1:-1]
|
||||||
|
elif val.startswith("[") and val.endswith("]"):
|
||||||
|
val = val[1:-1]
|
||||||
|
elif "," not in val:
|
||||||
|
# val is a single value
|
||||||
|
return DictAction._parse_int_float_bool(val)
|
||||||
|
|
||||||
|
values = []
|
||||||
|
while len(val) > 0:
|
||||||
|
comma_idx = find_next_comma(val)
|
||||||
|
element = DictAction._parse_iterable(val[:comma_idx])
|
||||||
|
values.append(element)
|
||||||
|
val = val[comma_idx + 1 :]
|
||||||
|
if is_tuple:
|
||||||
|
values = tuple(values)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
options = {}
|
||||||
|
for kv in values:
|
||||||
|
key, val = kv.split("=", maxsplit=1)
|
||||||
|
options[key] = self._parse_iterable(val)
|
||||||
|
setattr(namespace, self.dest, options)
|
||||||
33
utils/env.py
Normal file
33
utils/env.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_seed():
|
||||||
|
seed = (
|
||||||
|
os.getpid()
|
||||||
|
+ int(datetime.now().strftime("%S%f"))
|
||||||
|
+ int.from_bytes(os.urandom(2), "big")
|
||||||
|
)
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed=None):
|
||||||
|
if seed is None:
|
||||||
|
seed = get_random_seed()
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
cudnn.benchmark = False
|
||||||
|
cudnn.deterministic = True
|
||||||
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||||
585
utils/events.py
Normal file
585
utils/events.py
Normal file
@@ -0,0 +1,585 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
from collections import defaultdict
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_event_storage",
|
||||||
|
"JSONWriter",
|
||||||
|
"TensorboardXWriter",
|
||||||
|
"CommonMetricPrinter",
|
||||||
|
"EventStorage",
|
||||||
|
]
|
||||||
|
|
||||||
|
_CURRENT_STORAGE_STACK = []
|
||||||
|
|
||||||
|
|
||||||
|
def get_event_storage():
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The :class:`EventStorage` object that's currently being used.
|
||||||
|
Throws an error if no :class:`EventStorage` is currently enabled.
|
||||||
|
"""
|
||||||
|
assert len(
|
||||||
|
_CURRENT_STORAGE_STACK
|
||||||
|
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
|
||||||
|
return _CURRENT_STORAGE_STACK[-1]
|
||||||
|
|
||||||
|
|
||||||
|
class EventWriter:
|
||||||
|
"""
|
||||||
|
Base class for writers that obtain events from :class:`EventStorage` and process them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JSONWriter(EventWriter):
|
||||||
|
"""
|
||||||
|
Write scalars to a json file.
|
||||||
|
It saves scalars as one json per line (instead of a big json) for easy parsing.
|
||||||
|
Examples parsing such a json file:
|
||||||
|
::
|
||||||
|
$ cat metrics.json | jq -s '.[0:2]'
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"data_time": 0.008433341979980469,
|
||||||
|
"iteration": 19,
|
||||||
|
"loss": 1.9228371381759644,
|
||||||
|
"loss_box_reg": 0.050025828182697296,
|
||||||
|
"loss_classifier": 0.5316952466964722,
|
||||||
|
"loss_mask": 0.7236229181289673,
|
||||||
|
"loss_rpn_box": 0.0856662318110466,
|
||||||
|
"loss_rpn_cls": 0.48198649287223816,
|
||||||
|
"lr": 0.007173333333333333,
|
||||||
|
"time": 0.25401854515075684
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data_time": 0.007216215133666992,
|
||||||
|
"iteration": 39,
|
||||||
|
"loss": 1.282649278640747,
|
||||||
|
"loss_box_reg": 0.06222952902317047,
|
||||||
|
"loss_classifier": 0.30682939291000366,
|
||||||
|
"loss_mask": 0.6970193982124329,
|
||||||
|
"loss_rpn_box": 0.038663312792778015,
|
||||||
|
"loss_rpn_cls": 0.1471673548221588,
|
||||||
|
"lr": 0.007706666666666667,
|
||||||
|
"time": 0.2490077018737793
|
||||||
|
}
|
||||||
|
]
|
||||||
|
$ cat metrics.json | jq '.loss_mask'
|
||||||
|
0.7126231789588928
|
||||||
|
0.689423680305481
|
||||||
|
0.6776131987571716
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, json_file, window_size=20):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
json_file (str): path to the json file. New data will be appended if the file exists.
|
||||||
|
window_size (int): the window size of median smoothing for the scalars whose
|
||||||
|
`smoothing_hint` are True.
|
||||||
|
"""
|
||||||
|
self._file_handle = open(json_file, "a")
|
||||||
|
self._window_size = window_size
|
||||||
|
self._last_write = -1
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
storage = get_event_storage()
|
||||||
|
to_save = defaultdict(dict)
|
||||||
|
|
||||||
|
for k, (v, iter) in storage.latest_with_smoothing_hint(
|
||||||
|
self._window_size
|
||||||
|
).items():
|
||||||
|
# keep scalars that have not been written
|
||||||
|
if iter <= self._last_write:
|
||||||
|
continue
|
||||||
|
to_save[iter][k] = v
|
||||||
|
if len(to_save):
|
||||||
|
all_iters = sorted(to_save.keys())
|
||||||
|
self._last_write = max(all_iters)
|
||||||
|
|
||||||
|
for itr, scalars_per_iter in to_save.items():
|
||||||
|
scalars_per_iter["iteration"] = itr
|
||||||
|
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
|
||||||
|
self._file_handle.flush()
|
||||||
|
try:
|
||||||
|
os.fsync(self._file_handle.fileno())
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._file_handle.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TensorboardXWriter(EventWriter):
|
||||||
|
"""
|
||||||
|
Write all scalars to a tensorboard file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
log_dir (str): the directory to save the output events
|
||||||
|
window_size (int): the scalars will be median-smoothed by this window size
|
||||||
|
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
|
||||||
|
"""
|
||||||
|
self._window_size = window_size
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
self._writer = SummaryWriter(log_dir, **kwargs)
|
||||||
|
self._last_write = -1
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
storage = get_event_storage()
|
||||||
|
new_last_write = self._last_write
|
||||||
|
for k, (v, iter) in storage.latest_with_smoothing_hint(
|
||||||
|
self._window_size
|
||||||
|
).items():
|
||||||
|
if iter > self._last_write:
|
||||||
|
self._writer.add_scalar(k, v, iter)
|
||||||
|
new_last_write = max(new_last_write, iter)
|
||||||
|
self._last_write = new_last_write
|
||||||
|
|
||||||
|
# storage.put_{image,histogram} is only meant to be used by
|
||||||
|
# tensorboard writer. So we access its internal fields directly from here.
|
||||||
|
if len(storage._vis_data) >= 1:
|
||||||
|
for img_name, img, step_num in storage._vis_data:
|
||||||
|
self._writer.add_image(img_name, img, step_num)
|
||||||
|
# Storage stores all image data and rely on this writer to clear them.
|
||||||
|
# As a result it assumes only one writer will use its image data.
|
||||||
|
# An alternative design is to let storage store limited recent
|
||||||
|
# data (e.g. only the most recent image) that all writers can access.
|
||||||
|
# In that case a writer may not see all image data if its period is long.
|
||||||
|
storage.clear_images()
|
||||||
|
|
||||||
|
if len(storage._histograms) >= 1:
|
||||||
|
for params in storage._histograms:
|
||||||
|
self._writer.add_histogram_raw(**params)
|
||||||
|
storage.clear_histograms()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
|
||||||
|
self._writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
class CommonMetricPrinter(EventWriter):
|
||||||
|
"""
|
||||||
|
Print **common** metrics to the terminal, including
|
||||||
|
iteration time, ETA, memory, all losses, and the learning rate.
|
||||||
|
It also applies smoothing using a window of 20 elements.
|
||||||
|
It's meant to print common metrics in common ways.
|
||||||
|
To print something in more customized ways, please implement a similar printer by yourself.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_iter: Optional[int] = None, window_size: int = 20):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
max_iter: the maximum number of iterations to train.
|
||||||
|
Used to compute ETA. If not given, ETA will not be printed.
|
||||||
|
window_size (int): the losses will be median-smoothed by this window size
|
||||||
|
"""
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self._max_iter = max_iter
|
||||||
|
self._window_size = window_size
|
||||||
|
self._last_write = (
|
||||||
|
None # (step, time) of last call to write(). Used to compute ETA
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_eta(self, storage) -> Optional[str]:
|
||||||
|
if self._max_iter is None:
|
||||||
|
return ""
|
||||||
|
iteration = storage.iter
|
||||||
|
try:
|
||||||
|
eta_seconds = storage.history("time").median(1000) * (
|
||||||
|
self._max_iter - iteration - 1
|
||||||
|
)
|
||||||
|
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
|
||||||
|
return str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
except KeyError:
|
||||||
|
# estimate eta on our own - more noisy
|
||||||
|
eta_string = None
|
||||||
|
if self._last_write is not None:
|
||||||
|
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
|
||||||
|
iteration - self._last_write[0]
|
||||||
|
)
|
||||||
|
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
self._last_write = (iteration, time.perf_counter())
|
||||||
|
return eta_string
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
storage = get_event_storage()
|
||||||
|
iteration = storage.iter
|
||||||
|
if iteration == self._max_iter:
|
||||||
|
# This hook only reports training progress (loss, ETA, etc) but not other data,
|
||||||
|
# therefore do not write anything after training succeeds, even if this method
|
||||||
|
# is called.
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_time = storage.history("data_time").avg(20)
|
||||||
|
except KeyError:
|
||||||
|
# they may not exist in the first few iterations (due to warmup)
|
||||||
|
# or when SimpleTrainer is not used
|
||||||
|
data_time = None
|
||||||
|
try:
|
||||||
|
iter_time = storage.history("time").global_avg()
|
||||||
|
except KeyError:
|
||||||
|
iter_time = None
|
||||||
|
try:
|
||||||
|
lr = "{:.5g}".format(storage.history("lr").latest())
|
||||||
|
except KeyError:
|
||||||
|
lr = "N/A"
|
||||||
|
|
||||||
|
eta_string = self._get_eta(storage)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
||||||
|
else:
|
||||||
|
max_mem_mb = None
|
||||||
|
|
||||||
|
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
||||||
|
self.logger.info(
|
||||||
|
" {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
|
||||||
|
eta=f"eta: {eta_string} " if eta_string else "",
|
||||||
|
iter=iteration,
|
||||||
|
losses=" ".join(
|
||||||
|
[
|
||||||
|
"{}: {:.4g}".format(k, v.median(self._window_size))
|
||||||
|
for k, v in storage.histories().items()
|
||||||
|
if "loss" in k
|
||||||
|
]
|
||||||
|
),
|
||||||
|
time="time: {:.4f} ".format(iter_time)
|
||||||
|
if iter_time is not None
|
||||||
|
else "",
|
||||||
|
data_time="data_time: {:.4f} ".format(data_time)
|
||||||
|
if data_time is not None
|
||||||
|
else "",
|
||||||
|
lr=lr,
|
||||||
|
memory="max_mem: {:.0f}M".format(max_mem_mb)
|
||||||
|
if max_mem_mb is not None
|
||||||
|
else "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EventStorage:
|
||||||
|
"""
|
||||||
|
The user-facing class that provides metric storage functionalities.
|
||||||
|
In the future we may add support for storing / logging other types of data if needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, start_iter=0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
start_iter (int): the iteration number to start with
|
||||||
|
"""
|
||||||
|
self._history = defaultdict(AverageMeter)
|
||||||
|
self._smoothing_hints = {}
|
||||||
|
self._latest_scalars = {}
|
||||||
|
self._iter = start_iter
|
||||||
|
self._current_prefix = ""
|
||||||
|
self._vis_data = []
|
||||||
|
self._histograms = []
|
||||||
|
|
||||||
|
# def put_image(self, img_name, img_tensor):
|
||||||
|
# """
|
||||||
|
# Add an `img_tensor` associated with `img_name`, to be shown on
|
||||||
|
# tensorboard.
|
||||||
|
# Args:
|
||||||
|
# img_name (str): The name of the image to put into tensorboard.
|
||||||
|
# img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
|
||||||
|
# Tensor of shape `[channel, height, width]` where `channel` is
|
||||||
|
# 3. The image format should be RGB. The elements in img_tensor
|
||||||
|
# can either have values in [0, 1] (float32) or [0, 255] (uint8).
|
||||||
|
# The `img_tensor` will be visualized in tensorboard.
|
||||||
|
# """
|
||||||
|
# self._vis_data.append((img_name, img_tensor, self._iter))
|
||||||
|
|
||||||
|
def put_scalar(self, name, value, n=1, smoothing_hint=False):
|
||||||
|
"""
|
||||||
|
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
|
||||||
|
Args:
|
||||||
|
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
|
||||||
|
smoothed when logged. The hint will be accessible through
|
||||||
|
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
|
||||||
|
and apply custom smoothing rule.
|
||||||
|
It defaults to True because most scalars we save need to be smoothed to
|
||||||
|
provide any useful signal.
|
||||||
|
"""
|
||||||
|
name = self._current_prefix + name
|
||||||
|
history = self._history[name]
|
||||||
|
history.update(value, n)
|
||||||
|
self._latest_scalars[name] = (value, self._iter)
|
||||||
|
|
||||||
|
existing_hint = self._smoothing_hints.get(name)
|
||||||
|
if existing_hint is not None:
|
||||||
|
assert (
|
||||||
|
existing_hint == smoothing_hint
|
||||||
|
), "Scalar {} was put with a different smoothing_hint!".format(name)
|
||||||
|
else:
|
||||||
|
self._smoothing_hints[name] = smoothing_hint
|
||||||
|
|
||||||
|
# def put_scalars(self, *, smoothing_hint=True, **kwargs):
|
||||||
|
# """
|
||||||
|
# Put multiple scalars from keyword arguments.
|
||||||
|
# Examples:
|
||||||
|
# storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
|
||||||
|
# """
|
||||||
|
# for k, v in kwargs.items():
|
||||||
|
# self.put_scalar(k, v, smoothing_hint=smoothing_hint)
|
||||||
|
#
|
||||||
|
# def put_histogram(self, hist_name, hist_tensor, bins=1000):
|
||||||
|
# """
|
||||||
|
# Create a histogram from a tensor.
|
||||||
|
# Args:
|
||||||
|
# hist_name (str): The name of the histogram to put into tensorboard.
|
||||||
|
# hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
|
||||||
|
# into a histogram.
|
||||||
|
# bins (int): Number of histogram bins.
|
||||||
|
# """
|
||||||
|
# ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
|
||||||
|
#
|
||||||
|
# # Create a histogram with PyTorch
|
||||||
|
# hist_counts = torch.histc(hist_tensor, bins=bins)
|
||||||
|
# hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
|
||||||
|
#
|
||||||
|
# # Parameter for the add_histogram_raw function of SummaryWriter
|
||||||
|
# hist_params = dict(
|
||||||
|
# tag=hist_name,
|
||||||
|
# min=ht_min,
|
||||||
|
# max=ht_max,
|
||||||
|
# num=len(hist_tensor),
|
||||||
|
# sum=float(hist_tensor.sum()),
|
||||||
|
# sum_squares=float(torch.sum(hist_tensor**2)),
|
||||||
|
# bucket_limits=hist_edges[1:].tolist(),
|
||||||
|
# bucket_counts=hist_counts.tolist(),
|
||||||
|
# global_step=self._iter,
|
||||||
|
# )
|
||||||
|
# self._histograms.append(hist_params)
|
||||||
|
|
||||||
|
def history(self, name):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
AverageMeter: the history for name
|
||||||
|
"""
|
||||||
|
ret = self._history.get(name, None)
|
||||||
|
if ret is None:
|
||||||
|
raise KeyError("No history metric available for {}!".format(name))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def histories(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
|
||||||
|
"""
|
||||||
|
return self._history
|
||||||
|
|
||||||
|
def latest(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
dict[str -> (float, int)]: mapping from the name of each scalar to the most
|
||||||
|
recent value and the iteration number its added.
|
||||||
|
"""
|
||||||
|
return self._latest_scalars
|
||||||
|
|
||||||
|
def latest_with_smoothing_hint(self, window_size=20):
|
||||||
|
"""
|
||||||
|
Similar to :meth:`latest`, but the returned values
|
||||||
|
are either the un-smoothed original latest value,
|
||||||
|
or a median of the given window_size,
|
||||||
|
depend on whether the smoothing_hint is True.
|
||||||
|
This provides a default behavior that other writers can use.
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for k, (v, itr) in self._latest_scalars.items():
|
||||||
|
result[k] = (
|
||||||
|
self._history[k].median(window_size) if self._smoothing_hints[k] else v,
|
||||||
|
itr,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def smoothing_hints(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
dict[name -> bool]: the user-provided hint on whether the scalar
|
||||||
|
is noisy and needs smoothing.
|
||||||
|
"""
|
||||||
|
return self._smoothing_hints
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""
|
||||||
|
User should either: (1) Call this function to increment storage.iter when needed. Or
|
||||||
|
(2) Set `storage.iter` to the correct iteration number before each iteration.
|
||||||
|
The storage will then be able to associate the new data with an iteration number.
|
||||||
|
"""
|
||||||
|
self._iter += 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def iter(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
int: The current iteration number. When used together with a trainer,
|
||||||
|
this is ensured to be the same as trainer.iter.
|
||||||
|
"""
|
||||||
|
return self._iter
|
||||||
|
|
||||||
|
@iter.setter
|
||||||
|
def iter(self, val):
|
||||||
|
self._iter = int(val)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def iteration(self):
|
||||||
|
# for backward compatibility
|
||||||
|
return self._iter
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
_CURRENT_STORAGE_STACK.append(self)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
assert _CURRENT_STORAGE_STACK[-1] == self
|
||||||
|
_CURRENT_STORAGE_STACK.pop()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def name_scope(self, name):
|
||||||
|
"""
|
||||||
|
Yields:
|
||||||
|
A context within which all the events added to this storage
|
||||||
|
will be prefixed by the name scope.
|
||||||
|
"""
|
||||||
|
old_prefix = self._current_prefix
|
||||||
|
self._current_prefix = name.rstrip("/") + "/"
|
||||||
|
yield
|
||||||
|
self._current_prefix = old_prefix
|
||||||
|
|
||||||
|
def clear_images(self):
|
||||||
|
"""
|
||||||
|
Delete all the stored images for visualization. This should be called
|
||||||
|
after images are written to tensorboard.
|
||||||
|
"""
|
||||||
|
self._vis_data = []
|
||||||
|
|
||||||
|
def clear_histograms(self):
|
||||||
|
"""
|
||||||
|
Delete all the stored histograms for visualization.
|
||||||
|
This should be called after histograms are written to tensorboard.
|
||||||
|
"""
|
||||||
|
self._histograms = []
|
||||||
|
|
||||||
|
def reset_history(self, name):
|
||||||
|
ret = self._history.get(name, None)
|
||||||
|
if ret is None:
|
||||||
|
raise KeyError("No history metric available for {}!".format(name))
|
||||||
|
ret.reset()
|
||||||
|
|
||||||
|
def reset_histories(self):
|
||||||
|
for name in self._history.keys():
|
||||||
|
self._history[name].reset()
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter:
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.total = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.total = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.total += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.total / self.count
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryBuffer:
|
||||||
|
"""
|
||||||
|
Track a series of scalar values and provide access to smoothed values over a
|
||||||
|
window or the global average of the series.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_length: int = 1000000) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
max_length: maximal number of values that can be stored in the
|
||||||
|
buffer. When the capacity of the buffer is exhausted, old
|
||||||
|
values will be removed.
|
||||||
|
"""
|
||||||
|
self._max_length: int = max_length
|
||||||
|
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
|
||||||
|
self._count: int = 0
|
||||||
|
self._global_avg: float = 0
|
||||||
|
|
||||||
|
def update(self, value: float, iteration: Optional[float] = None) -> None:
|
||||||
|
"""
|
||||||
|
Add a new scalar value produced at certain iteration. If the length
|
||||||
|
of the buffer exceeds self._max_length, the oldest element will be
|
||||||
|
removed from the buffer.
|
||||||
|
"""
|
||||||
|
if iteration is None:
|
||||||
|
iteration = self._count
|
||||||
|
if len(self._data) == self._max_length:
|
||||||
|
self._data.pop(0)
|
||||||
|
self._data.append((value, iteration))
|
||||||
|
|
||||||
|
self._count += 1
|
||||||
|
self._global_avg += (value - self._global_avg) / self._count
|
||||||
|
|
||||||
|
def latest(self) -> float:
|
||||||
|
"""
|
||||||
|
Return the latest scalar value added to the buffer.
|
||||||
|
"""
|
||||||
|
return self._data[-1][0]
|
||||||
|
|
||||||
|
def median(self, window_size: int) -> float:
|
||||||
|
"""
|
||||||
|
Return the median of the latest `window_size` values in the buffer.
|
||||||
|
"""
|
||||||
|
return np.median([x[0] for x in self._data[-window_size:]])
|
||||||
|
|
||||||
|
def avg(self, window_size: int) -> float:
|
||||||
|
"""
|
||||||
|
Return the mean of the latest `window_size` values in the buffer.
|
||||||
|
"""
|
||||||
|
return np.mean([x[0] for x in self._data[-window_size:]])
|
||||||
|
|
||||||
|
def global_avg(self) -> float:
|
||||||
|
"""
|
||||||
|
Return the mean of all the elements in the buffer. Note that this
|
||||||
|
includes those getting removed due to limited buffer storage.
|
||||||
|
"""
|
||||||
|
return self._global_avg
|
||||||
|
|
||||||
|
def values(self) -> List[Tuple[float, float]]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
list[(number, iteration)]: content of the current buffer.
|
||||||
|
"""
|
||||||
|
return self._data
|
||||||
167
utils/logger.py
Normal file
167
utils/logger.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
logger_initialized = {}
|
||||||
|
root_status = 0
|
||||||
|
|
||||||
|
|
||||||
|
class _ColorfulFormatter(logging.Formatter):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._root_name = kwargs.pop("root_name") + "."
|
||||||
|
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def formatMessage(self, record):
|
||||||
|
log = super(_ColorfulFormatter, self).formatMessage(record)
|
||||||
|
if record.levelno == logging.WARNING:
|
||||||
|
prefix = colored("WARNING", "red", attrs=["blink"])
|
||||||
|
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
||||||
|
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
||||||
|
else:
|
||||||
|
return log
|
||||||
|
return prefix + " " + log
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False):
|
||||||
|
"""Initialize and get a logger by name.
|
||||||
|
|
||||||
|
If the logger has not been initialized, this method will initialize the
|
||||||
|
logger by adding one or two handlers, otherwise the initialized logger will
|
||||||
|
be directly returned. During initialization, a StreamHandler will always be
|
||||||
|
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
||||||
|
will also be added.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Logger name.
|
||||||
|
log_file (str | None): The log filename. If specified, a FileHandler
|
||||||
|
will be added to the logger.
|
||||||
|
log_level (int): The logger level. Note that only the process of
|
||||||
|
rank 0 is affected, and other processes will set the level to
|
||||||
|
"Error" thus be silent most of the time.
|
||||||
|
file_mode (str): The file mode used in opening log file.
|
||||||
|
Defaults to 'a'.
|
||||||
|
color (bool): Colorful log output. Defaults to True
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: The expected logger.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
if name in logger_initialized:
|
||||||
|
return logger
|
||||||
|
# handle hierarchical names
|
||||||
|
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
||||||
|
# initialization since it is a child of "a".
|
||||||
|
for logger_name in logger_initialized:
|
||||||
|
if name.startswith(logger_name):
|
||||||
|
return logger
|
||||||
|
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
handlers = [stream_handler]
|
||||||
|
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
rank = dist.get_rank()
|
||||||
|
else:
|
||||||
|
rank = 0
|
||||||
|
|
||||||
|
# only rank 0 will add a FileHandler
|
||||||
|
if rank == 0 and log_file is not None:
|
||||||
|
# Here, the default behaviour of the official logger is 'a'. Thus, we
|
||||||
|
# provide an interface to change the file mode to the default
|
||||||
|
# behaviour.
|
||||||
|
file_handler = logging.FileHandler(log_file, file_mode)
|
||||||
|
handlers.append(file_handler)
|
||||||
|
|
||||||
|
plain_formatter = logging.Formatter(
|
||||||
|
"[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
|
||||||
|
)
|
||||||
|
if color:
|
||||||
|
formatter = _ColorfulFormatter(
|
||||||
|
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
||||||
|
datefmt="%m/%d %H:%M:%S",
|
||||||
|
root_name=name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
formatter = plain_formatter
|
||||||
|
for handler in handlers:
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
handler.setLevel(log_level)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
else:
|
||||||
|
logger.setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
logger_initialized[name] = True
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def print_log(msg, logger=None, level=logging.INFO):
|
||||||
|
"""Print a log message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg (str): The message to be logged.
|
||||||
|
logger (logging.Logger | str | None): The logger to be used.
|
||||||
|
Some special loggers are:
|
||||||
|
- "silent": no message will be printed.
|
||||||
|
- other str: the logger obtained with `get_root_logger(logger)`.
|
||||||
|
- None: The `print()` method will be used to print log messages.
|
||||||
|
level (int): Logging level. Only available when `logger` is a Logger
|
||||||
|
object or "root".
|
||||||
|
"""
|
||||||
|
if logger is None:
|
||||||
|
print(msg)
|
||||||
|
elif isinstance(logger, logging.Logger):
|
||||||
|
logger.log(level, msg)
|
||||||
|
elif logger == "silent":
|
||||||
|
pass
|
||||||
|
elif isinstance(logger, str):
|
||||||
|
_logger = get_logger(logger)
|
||||||
|
_logger.log(level, msg)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"logger should be either a logging.Logger object, str, "
|
||||||
|
f'"silent" or None, but got {type(logger)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"):
|
||||||
|
"""Get the root logger.
|
||||||
|
|
||||||
|
The logger will be initialized if it has not been initialized. By default a
|
||||||
|
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
||||||
|
also be added. The name of the root logger is the top-level package name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_file (str | None): The log filename. If specified, a FileHandler
|
||||||
|
will be added to the root logger.
|
||||||
|
log_level (int): The root logger level. Note that only the process of
|
||||||
|
rank 0 is affected, while other processes will set the level to
|
||||||
|
"Error" and be silent most of the time.
|
||||||
|
file_mode (str): File Mode of logger. (w or a)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger: The root logger.
|
||||||
|
"""
|
||||||
|
logger = get_logger(
|
||||||
|
name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode
|
||||||
|
)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def _log_api_usage(identifier: str):
|
||||||
|
"""
|
||||||
|
Internal function used to log the usage of different detectron2 components
|
||||||
|
inside facebook's infra.
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("pointcept." + identifier)
|
||||||
156
utils/misc.py
Normal file
156
utils/misc.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from collections import abc
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
def intersection_and_union(output, target, K, ignore_index=-1):
|
||||||
|
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
|
||||||
|
assert output.ndim in [1, 2, 3]
|
||||||
|
assert output.shape == target.shape
|
||||||
|
output = output.reshape(output.size).copy()
|
||||||
|
target = target.reshape(target.size)
|
||||||
|
output[np.where(target == ignore_index)[0]] = ignore_index
|
||||||
|
intersection = output[np.where(output == target)[0]]
|
||||||
|
area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
|
||||||
|
area_output, _ = np.histogram(output, bins=np.arange(K + 1))
|
||||||
|
area_target, _ = np.histogram(target, bins=np.arange(K + 1))
|
||||||
|
area_union = area_output + area_target - area_intersection
|
||||||
|
return area_intersection, area_union, area_target
|
||||||
|
|
||||||
|
|
||||||
|
def intersection_and_union_gpu(output, target, k, ignore_index=-1):
|
||||||
|
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
|
||||||
|
assert output.dim() in [1, 2, 3]
|
||||||
|
assert output.shape == target.shape
|
||||||
|
output = output.view(-1)
|
||||||
|
target = target.view(-1)
|
||||||
|
output[target == ignore_index] = ignore_index
|
||||||
|
intersection = output[output == target]
|
||||||
|
area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1)
|
||||||
|
area_output = torch.histc(output, bins=k, min=0, max=k - 1)
|
||||||
|
area_target = torch.histc(target, bins=k, min=0, max=k - 1)
|
||||||
|
area_union = area_output + area_target - area_intersection
|
||||||
|
return area_intersection, area_union, area_target
|
||||||
|
|
||||||
|
|
||||||
|
def make_dirs(dir_name):
|
||||||
|
if not os.path.exists(dir_name):
|
||||||
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def find_free_port():
|
||||||
|
import socket
|
||||||
|
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
# Binding to port 0 will cause the OS to find an available port for us
|
||||||
|
sock.bind(("", 0))
|
||||||
|
port = sock.getsockname()[1]
|
||||||
|
sock.close()
|
||||||
|
# NOTE: there is still a chance the port could be taken by other processes.
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
def is_seq_of(seq, expected_type, seq_type=None):
|
||||||
|
"""Check whether it is a sequence of some type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq (Sequence): The sequence to be checked.
|
||||||
|
expected_type (type): Expected type of sequence items.
|
||||||
|
seq_type (type, optional): Expected sequence type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the sequence is valid.
|
||||||
|
"""
|
||||||
|
if seq_type is None:
|
||||||
|
exp_seq_type = abc.Sequence
|
||||||
|
else:
|
||||||
|
assert isinstance(seq_type, type)
|
||||||
|
exp_seq_type = seq_type
|
||||||
|
if not isinstance(seq, exp_seq_type):
|
||||||
|
return False
|
||||||
|
for item in seq:
|
||||||
|
if not isinstance(item, expected_type):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_str(x):
|
||||||
|
"""Whether the input is an string instance.
|
||||||
|
|
||||||
|
Note: This method is deprecated since python 2 is no longer supported.
|
||||||
|
"""
|
||||||
|
return isinstance(x, str)
|
||||||
|
|
||||||
|
|
||||||
|
def import_modules_from_strings(imports, allow_failed_imports=False):
|
||||||
|
"""Import modules from the given list of strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imports (list | str | None): The given module names to be imported.
|
||||||
|
allow_failed_imports (bool): If True, the failed imports will return
|
||||||
|
None. Otherwise, an ImportError is raise. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[module] | module | None: The imported modules.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> osp, sys = import_modules_from_strings(
|
||||||
|
... ['os.path', 'sys'])
|
||||||
|
>>> import os.path as osp_
|
||||||
|
>>> import sys as sys_
|
||||||
|
>>> assert osp == osp_
|
||||||
|
>>> assert sys == sys_
|
||||||
|
"""
|
||||||
|
if not imports:
|
||||||
|
return
|
||||||
|
single_import = False
|
||||||
|
if isinstance(imports, str):
|
||||||
|
single_import = True
|
||||||
|
imports = [imports]
|
||||||
|
if not isinstance(imports, list):
|
||||||
|
raise TypeError(f"custom_imports must be a list but got type {type(imports)}")
|
||||||
|
imported = []
|
||||||
|
for imp in imports:
|
||||||
|
if not isinstance(imp, str):
|
||||||
|
raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.")
|
||||||
|
try:
|
||||||
|
imported_tmp = import_module(imp)
|
||||||
|
except ImportError:
|
||||||
|
if allow_failed_imports:
|
||||||
|
warnings.warn(f"{imp} failed to import and is ignored.", UserWarning)
|
||||||
|
imported_tmp = None
|
||||||
|
else:
|
||||||
|
raise ImportError
|
||||||
|
imported.append(imported_tmp)
|
||||||
|
if single_import:
|
||||||
|
imported = imported[0]
|
||||||
|
return imported
|
||||||
52
utils/optimizer.py
Normal file
52
utils/optimizer.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from utils.logger import get_root_logger
|
||||||
|
from utils.registry import Registry
|
||||||
|
|
||||||
|
OPTIMIZERS = Registry("optimizers")
|
||||||
|
|
||||||
|
|
||||||
|
OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD")
|
||||||
|
OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam")
|
||||||
|
OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW")
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer(cfg, model, param_dicts=None):
|
||||||
|
if param_dicts is None:
|
||||||
|
cfg.params = model.parameters()
|
||||||
|
else:
|
||||||
|
cfg.params = [dict(names=[], params=[], lr=cfg.lr)]
|
||||||
|
for i in range(len(param_dicts)):
|
||||||
|
param_group = dict(names=[], params=[])
|
||||||
|
if "lr" in param_dicts[i].keys():
|
||||||
|
param_group["lr"] = param_dicts[i].lr
|
||||||
|
if "momentum" in param_dicts[i].keys():
|
||||||
|
param_group["momentum"] = param_dicts[i].momentum
|
||||||
|
if "weight_decay" in param_dicts[i].keys():
|
||||||
|
param_group["weight_decay"] = param_dicts[i].weight_decay
|
||||||
|
cfg.params.append(param_group)
|
||||||
|
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
flag = False
|
||||||
|
for i in range(len(param_dicts)):
|
||||||
|
if param_dicts[i].keyword in n:
|
||||||
|
cfg.params[i + 1]["names"].append(n)
|
||||||
|
cfg.params[i + 1]["params"].append(p)
|
||||||
|
flag = True
|
||||||
|
break
|
||||||
|
if not flag:
|
||||||
|
cfg.params[0]["names"].append(n)
|
||||||
|
cfg.params[0]["params"].append(p)
|
||||||
|
|
||||||
|
logger = get_root_logger()
|
||||||
|
for i in range(len(cfg.params)):
|
||||||
|
param_names = cfg.params[i].pop("names")
|
||||||
|
message = ""
|
||||||
|
for key in cfg.params[i].keys():
|
||||||
|
if key != "params":
|
||||||
|
message += f" {key}: {cfg.params[i][key]};"
|
||||||
|
logger.info(f"Params Group {i+1} -{message} Params: {param_names}.")
|
||||||
|
return OPTIMIZERS.build(cfg=cfg)
|
||||||
105
utils/path.py
Normal file
105
utils/path.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .misc import is_str
|
||||||
|
|
||||||
|
|
||||||
|
def is_filepath(x):
|
||||||
|
return is_str(x) or isinstance(x, Path)
|
||||||
|
|
||||||
|
|
||||||
|
def fopen(filepath, *args, **kwargs):
|
||||||
|
if is_str(filepath):
|
||||||
|
return open(filepath, *args, **kwargs)
|
||||||
|
elif isinstance(filepath, Path):
|
||||||
|
return filepath.open(*args, **kwargs)
|
||||||
|
raise ValueError("`filepath` should be a string or a Path")
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
||||||
|
if not osp.isfile(filename):
|
||||||
|
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir_or_exist(dir_name, mode=0o777):
|
||||||
|
if dir_name == "":
|
||||||
|
return
|
||||||
|
dir_name = osp.expanduser(dir_name)
|
||||||
|
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def symlink(src, dst, overwrite=True, **kwargs):
|
||||||
|
if os.path.lexists(dst) and overwrite:
|
||||||
|
os.remove(dst)
|
||||||
|
os.symlink(src, dst, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
|
||||||
|
"""Scan a directory to find the interested files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dir_path (str | obj:`Path`): Path of the directory.
|
||||||
|
suffix (str | tuple(str), optional): File suffix that we are
|
||||||
|
interested in. Default: None.
|
||||||
|
recursive (bool, optional): If set to True, recursively scan the
|
||||||
|
directory. Default: False.
|
||||||
|
case_sensitive (bool, optional) : If set to False, ignore the case of
|
||||||
|
suffix. Default: True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A generator for all the interested files with relative paths.
|
||||||
|
"""
|
||||||
|
if isinstance(dir_path, (str, Path)):
|
||||||
|
dir_path = str(dir_path)
|
||||||
|
else:
|
||||||
|
raise TypeError('"dir_path" must be a string or Path object')
|
||||||
|
|
||||||
|
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||||
|
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||||
|
|
||||||
|
if suffix is not None and not case_sensitive:
|
||||||
|
suffix = (
|
||||||
|
suffix.lower()
|
||||||
|
if isinstance(suffix, str)
|
||||||
|
else tuple(item.lower() for item in suffix)
|
||||||
|
)
|
||||||
|
|
||||||
|
root = dir_path
|
||||||
|
|
||||||
|
def _scandir(dir_path, suffix, recursive, case_sensitive):
|
||||||
|
for entry in os.scandir(dir_path):
|
||||||
|
if not entry.name.startswith(".") and entry.is_file():
|
||||||
|
rel_path = osp.relpath(entry.path, root)
|
||||||
|
_rel_path = rel_path if case_sensitive else rel_path.lower()
|
||||||
|
if suffix is None or _rel_path.endswith(suffix):
|
||||||
|
yield rel_path
|
||||||
|
elif recursive and os.path.isdir(entry.path):
|
||||||
|
# scan recursively if entry.path is a directory
|
||||||
|
yield from _scandir(entry.path, suffix, recursive, case_sensitive)
|
||||||
|
|
||||||
|
return _scandir(dir_path, suffix, recursive, case_sensitive)
|
||||||
|
|
||||||
|
|
||||||
|
def find_vcs_root(path, markers=(".git",)):
|
||||||
|
"""Finds the root directory (including itself) of specified markers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path of directory or file.
|
||||||
|
markers (list[str], optional): List of file or directory names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The directory contained one of the markers or None if not found.
|
||||||
|
"""
|
||||||
|
if osp.isfile(path):
|
||||||
|
path = osp.dirname(path)
|
||||||
|
|
||||||
|
prev, cur = None, osp.abspath(osp.expanduser(path))
|
||||||
|
while cur != prev:
|
||||||
|
if any(osp.exists(osp.join(cur, marker)) for marker in markers):
|
||||||
|
return cur
|
||||||
|
prev, cur = cur, osp.split(cur)[0]
|
||||||
|
return None
|
||||||
318
utils/registry.py
Normal file
318
utils/registry.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from .misc import is_seq_of
|
||||||
|
|
||||||
|
|
||||||
|
def build_from_cfg(cfg, registry, default_args=None):
|
||||||
|
"""Build a module from configs dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): Config dict. It should at least contain the key "type".
|
||||||
|
registry (:obj:`Registry`): The registry to search the type from.
|
||||||
|
default_args (dict, optional): Default initialization arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The constructed object.
|
||||||
|
"""
|
||||||
|
if not isinstance(cfg, dict):
|
||||||
|
raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
|
||||||
|
if "type" not in cfg:
|
||||||
|
if default_args is None or "type" not in default_args:
|
||||||
|
raise KeyError(
|
||||||
|
'`cfg` or `default_args` must contain the key "type", '
|
||||||
|
f"but got {cfg}\n{default_args}"
|
||||||
|
)
|
||||||
|
if not isinstance(registry, Registry):
|
||||||
|
raise TypeError(
|
||||||
|
"registry must be an mmcv.Registry object, " f"but got {type(registry)}"
|
||||||
|
)
|
||||||
|
if not (isinstance(default_args, dict) or default_args is None):
|
||||||
|
raise TypeError(
|
||||||
|
"default_args must be a dict or None, " f"but got {type(default_args)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = cfg.copy()
|
||||||
|
|
||||||
|
if default_args is not None:
|
||||||
|
for name, value in default_args.items():
|
||||||
|
args.setdefault(name, value)
|
||||||
|
|
||||||
|
obj_type = args.pop("type")
|
||||||
|
if isinstance(obj_type, str):
|
||||||
|
obj_cls = registry.get(obj_type)
|
||||||
|
if obj_cls is None:
|
||||||
|
raise KeyError(f"{obj_type} is not in the {registry.name} registry")
|
||||||
|
elif inspect.isclass(obj_type):
|
||||||
|
obj_cls = obj_type
|
||||||
|
else:
|
||||||
|
raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
|
||||||
|
try:
|
||||||
|
return obj_cls(**args)
|
||||||
|
except Exception as e:
|
||||||
|
# Normal TypeError does not print class name.
|
||||||
|
raise type(e)(f"{obj_cls.__name__}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class Registry:
|
||||||
|
"""A registry to map strings to classes.
|
||||||
|
|
||||||
|
Registered object could be built from registry.
|
||||||
|
Example:
|
||||||
|
>>> MODELS = Registry('models')
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
>>> resnet = MODELS.build(dict(type='ResNet'))
|
||||||
|
|
||||||
|
Please refer to
|
||||||
|
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
||||||
|
advanced usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Registry name.
|
||||||
|
build_func(func, optional): Build function to construct instance from
|
||||||
|
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
||||||
|
``build_func`` is specified. If ``parent`` is specified and
|
||||||
|
``build_func`` is not given, ``build_func`` will be inherited
|
||||||
|
from ``parent``. Default: None.
|
||||||
|
parent (Registry, optional): Parent registry. The class registered in
|
||||||
|
children registry could be built from parent. Default: None.
|
||||||
|
scope (str, optional): The scope of registry. It is the key to search
|
||||||
|
for children registry. If not specified, scope will be the name of
|
||||||
|
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
||||||
|
Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name, build_func=None, parent=None, scope=None):
|
||||||
|
self._name = name
|
||||||
|
self._module_dict = dict()
|
||||||
|
self._children = dict()
|
||||||
|
self._scope = self.infer_scope() if scope is None else scope
|
||||||
|
|
||||||
|
# self.build_func will be set with the following priority:
|
||||||
|
# 1. build_func
|
||||||
|
# 2. parent.build_func
|
||||||
|
# 3. build_from_cfg
|
||||||
|
if build_func is None:
|
||||||
|
if parent is not None:
|
||||||
|
self.build_func = parent.build_func
|
||||||
|
else:
|
||||||
|
self.build_func = build_from_cfg
|
||||||
|
else:
|
||||||
|
self.build_func = build_func
|
||||||
|
if parent is not None:
|
||||||
|
assert isinstance(parent, Registry)
|
||||||
|
parent._add_children(self)
|
||||||
|
self.parent = parent
|
||||||
|
else:
|
||||||
|
self.parent = None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._module_dict)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return self.get(key) is not None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
format_str = (
|
||||||
|
self.__class__.__name__ + f"(name={self._name}, "
|
||||||
|
f"items={self._module_dict})"
|
||||||
|
)
|
||||||
|
return format_str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer_scope():
|
||||||
|
"""Infer the scope of registry.
|
||||||
|
|
||||||
|
The name of the package where registry is defined will be returned.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# in mmdet/models/backbone/resnet.py
|
||||||
|
>>> MODELS = Registry('models')
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
The scope of ``ResNet`` will be ``mmdet``.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
scope (str): The inferred scope name.
|
||||||
|
"""
|
||||||
|
# inspect.stack() trace where this function is called, the index-2
|
||||||
|
# indicates the frame where `infer_scope()` is called
|
||||||
|
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
|
||||||
|
split_filename = filename.split(".")
|
||||||
|
return split_filename[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_scope_key(key):
|
||||||
|
"""Split scope and key.
|
||||||
|
|
||||||
|
The first scope will be split from key.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> Registry.split_scope_key('mmdet.ResNet')
|
||||||
|
'mmdet', 'ResNet'
|
||||||
|
>>> Registry.split_scope_key('ResNet')
|
||||||
|
None, 'ResNet'
|
||||||
|
|
||||||
|
Return:
|
||||||
|
scope (str, None): The first scope.
|
||||||
|
key (str): The remaining key.
|
||||||
|
"""
|
||||||
|
split_index = key.find(".")
|
||||||
|
if split_index != -1:
|
||||||
|
return key[:split_index], key[split_index + 1 :]
|
||||||
|
else:
|
||||||
|
return None, key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scope(self):
|
||||||
|
return self._scope
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module_dict(self):
|
||||||
|
return self._module_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def children(self):
|
||||||
|
return self._children
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
"""Get the registry record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str): The class name in string format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
class: The corresponding class.
|
||||||
|
"""
|
||||||
|
scope, real_key = self.split_scope_key(key)
|
||||||
|
if scope is None or scope == self._scope:
|
||||||
|
# get from self
|
||||||
|
if real_key in self._module_dict:
|
||||||
|
return self._module_dict[real_key]
|
||||||
|
else:
|
||||||
|
# get from self._children
|
||||||
|
if scope in self._children:
|
||||||
|
return self._children[scope].get(real_key)
|
||||||
|
else:
|
||||||
|
# goto root
|
||||||
|
parent = self.parent
|
||||||
|
while parent.parent is not None:
|
||||||
|
parent = parent.parent
|
||||||
|
return parent.get(key)
|
||||||
|
|
||||||
|
def build(self, *args, **kwargs):
|
||||||
|
return self.build_func(*args, **kwargs, registry=self)
|
||||||
|
|
||||||
|
def _add_children(self, registry):
|
||||||
|
"""Add children for a registry.
|
||||||
|
|
||||||
|
The ``registry`` will be added as children based on its scope.
|
||||||
|
The parent registry could build objects from children registry.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> models = Registry('models')
|
||||||
|
>>> mmdet_models = Registry('models', parent=models)
|
||||||
|
>>> @mmdet_models.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(registry, Registry)
|
||||||
|
assert registry.scope is not None
|
||||||
|
assert (
|
||||||
|
registry.scope not in self.children
|
||||||
|
), f"scope {registry.scope} exists in {self.name} registry"
|
||||||
|
self.children[registry.scope] = registry
|
||||||
|
|
||||||
|
def _register_module(self, module_class, module_name=None, force=False):
|
||||||
|
if not inspect.isclass(module_class):
|
||||||
|
raise TypeError("module must be a class, " f"but got {type(module_class)}")
|
||||||
|
|
||||||
|
if module_name is None:
|
||||||
|
module_name = module_class.__name__
|
||||||
|
if isinstance(module_name, str):
|
||||||
|
module_name = [module_name]
|
||||||
|
for name in module_name:
|
||||||
|
if not force and name in self._module_dict:
|
||||||
|
raise KeyError(f"{name} is already registered " f"in {self.name}")
|
||||||
|
self._module_dict[name] = module_class
|
||||||
|
|
||||||
|
def deprecated_register_module(self, cls=None, force=False):
|
||||||
|
warnings.warn(
|
||||||
|
"The old API of register_module(module, force=False) "
|
||||||
|
"is deprecated and will be removed, please use the new API "
|
||||||
|
"register_module(name=None, force=False, module=None) instead."
|
||||||
|
)
|
||||||
|
if cls is None:
|
||||||
|
return partial(self.deprecated_register_module, force=force)
|
||||||
|
self._register_module(cls, force=force)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
def register_module(self, name=None, force=False, module=None):
|
||||||
|
"""Register a module.
|
||||||
|
|
||||||
|
A record will be added to `self._module_dict`, whose key is the class
|
||||||
|
name or the specified name, and value is the class itself.
|
||||||
|
It can be used as a decorator or a normal function.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> backbones = Registry('backbone')
|
||||||
|
>>> @backbones.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
|
||||||
|
>>> backbones = Registry('backbone')
|
||||||
|
>>> @backbones.register_module(name='mnet')
|
||||||
|
>>> class MobileNet:
|
||||||
|
>>> pass
|
||||||
|
|
||||||
|
>>> backbones = Registry('backbone')
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
>>> backbones.register_module(ResNet)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str | None): The module name to be registered. If not
|
||||||
|
specified, the class name will be used.
|
||||||
|
force (bool, optional): Whether to override an existing class with
|
||||||
|
the same name. Default: False.
|
||||||
|
module (type): Module class to be registered.
|
||||||
|
"""
|
||||||
|
if not isinstance(force, bool):
|
||||||
|
raise TypeError(f"force must be a boolean, but got {type(force)}")
|
||||||
|
# NOTE: This is a walkaround to be compatible with the old api,
|
||||||
|
# while it may introduce unexpected bugs.
|
||||||
|
if isinstance(name, type):
|
||||||
|
return self.deprecated_register_module(name, force=force)
|
||||||
|
|
||||||
|
# raise the error ahead of time
|
||||||
|
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
||||||
|
raise TypeError(
|
||||||
|
"name must be either of None, an instance of str or a sequence"
|
||||||
|
f" of str, but got {type(name)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# use it as a normal method: x.register_module(module=SomeClass)
|
||||||
|
if module is not None:
|
||||||
|
self._register_module(module_class=module, module_name=name, force=force)
|
||||||
|
return module
|
||||||
|
|
||||||
|
# use it as a decorator: @x.register_module()
|
||||||
|
def _register(cls):
|
||||||
|
self._register_module(module_class=cls, module_name=name, force=force)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return _register
|
||||||
144
utils/scheduler.py
Normal file
144
utils/scheduler.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
|
from .registry import Registry
|
||||||
|
|
||||||
|
SCHEDULERS = Registry("schedulers")
|
||||||
|
|
||||||
|
|
||||||
|
@SCHEDULERS.register_module()
|
||||||
|
class MultiStepLR(lr_scheduler.MultiStepLR):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
milestones,
|
||||||
|
total_steps,
|
||||||
|
gamma=0.1,
|
||||||
|
last_epoch=-1,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
optimizer=optimizer,
|
||||||
|
milestones=[rate * total_steps for rate in milestones],
|
||||||
|
gamma=gamma,
|
||||||
|
last_epoch=last_epoch,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@SCHEDULERS.register_module()
|
||||||
|
class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
milestones,
|
||||||
|
total_steps,
|
||||||
|
gamma=0.1,
|
||||||
|
warmup_rate=0.05,
|
||||||
|
warmup_scale=1e-6,
|
||||||
|
last_epoch=-1,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
milestones = [rate * total_steps for rate in milestones]
|
||||||
|
|
||||||
|
def multi_step_with_warmup(s):
|
||||||
|
factor = 1.0
|
||||||
|
for i in range(len(milestones)):
|
||||||
|
if s < milestones[i]:
|
||||||
|
break
|
||||||
|
factor *= gamma
|
||||||
|
|
||||||
|
if s <= warmup_rate * total_steps:
|
||||||
|
warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (
|
||||||
|
1 - warmup_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warmup_coefficient = 1.0
|
||||||
|
return warmup_coefficient * factor
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_lambda=multi_step_with_warmup,
|
||||||
|
last_epoch=last_epoch,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@SCHEDULERS.register_module()
|
||||||
|
class PolyLR(lr_scheduler.LambdaLR):
|
||||||
|
def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
|
||||||
|
last_epoch=last_epoch,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@SCHEDULERS.register_module()
|
||||||
|
class ExpLR(lr_scheduler.LambdaLR):
|
||||||
|
def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_lambda=lambda s: gamma ** (s / total_steps),
|
||||||
|
last_epoch=last_epoch,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@SCHEDULERS.register_module()
|
||||||
|
class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
|
||||||
|
def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer=optimizer,
|
||||||
|
T_max=total_steps,
|
||||||
|
eta_min=eta_min,
|
||||||
|
last_epoch=last_epoch,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@SCHEDULERS.register_module()
|
||||||
|
class OneCycleLR(lr_scheduler.OneCycleLR):
|
||||||
|
r"""
|
||||||
|
torch.optim.lr_scheduler.OneCycleLR, Block total_steps
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
max_lr,
|
||||||
|
total_steps=None,
|
||||||
|
pct_start=0.3,
|
||||||
|
anneal_strategy="cos",
|
||||||
|
cycle_momentum=True,
|
||||||
|
base_momentum=0.85,
|
||||||
|
max_momentum=0.95,
|
||||||
|
div_factor=25.0,
|
||||||
|
final_div_factor=1e4,
|
||||||
|
three_phase=False,
|
||||||
|
last_epoch=-1,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
optimizer=optimizer,
|
||||||
|
max_lr=max_lr,
|
||||||
|
total_steps=total_steps,
|
||||||
|
pct_start=pct_start,
|
||||||
|
anneal_strategy=anneal_strategy,
|
||||||
|
cycle_momentum=cycle_momentum,
|
||||||
|
base_momentum=base_momentum,
|
||||||
|
max_momentum=max_momentum,
|
||||||
|
div_factor=div_factor,
|
||||||
|
final_div_factor=final_div_factor,
|
||||||
|
three_phase=three_phase,
|
||||||
|
last_epoch=last_epoch,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_scheduler(cfg, optimizer):
|
||||||
|
cfg.optimizer = optimizer
|
||||||
|
return SCHEDULERS.build(cfg=cfg)
|
||||||
71
utils/timer.py
Normal file
71
utils/timer.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
from time import perf_counter
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
"""
|
||||||
|
A timer which computes the time elapsed since the start/reset of the timer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset the timer.
|
||||||
|
"""
|
||||||
|
self._start = perf_counter()
|
||||||
|
self._paused: Optional[float] = None
|
||||||
|
self._total_paused = 0
|
||||||
|
self._count_start = 1
|
||||||
|
|
||||||
|
def pause(self) -> None:
|
||||||
|
"""
|
||||||
|
Pause the timer.
|
||||||
|
"""
|
||||||
|
if self._paused is not None:
|
||||||
|
raise ValueError("Trying to pause a Timer that is already paused!")
|
||||||
|
self._paused = perf_counter()
|
||||||
|
|
||||||
|
def is_paused(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
bool: whether the timer is currently paused
|
||||||
|
"""
|
||||||
|
return self._paused is not None
|
||||||
|
|
||||||
|
def resume(self) -> None:
|
||||||
|
"""
|
||||||
|
Resume the timer.
|
||||||
|
"""
|
||||||
|
if self._paused is None:
|
||||||
|
raise ValueError("Trying to resume a Timer that is not paused!")
|
||||||
|
# pyre-fixme[58]: `-` is not supported for operand types `float` and
|
||||||
|
# `Optional[float]`.
|
||||||
|
self._total_paused += perf_counter() - self._paused
|
||||||
|
self._paused = None
|
||||||
|
self._count_start += 1
|
||||||
|
|
||||||
|
def seconds(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
(float): the total number of seconds since the start/reset of the
|
||||||
|
timer, excluding the time when the timer is paused.
|
||||||
|
"""
|
||||||
|
if self._paused is not None:
|
||||||
|
end_time: float = self._paused # type: ignore
|
||||||
|
else:
|
||||||
|
end_time = perf_counter()
|
||||||
|
return end_time - self._start - self._total_paused
|
||||||
|
|
||||||
|
def avg_seconds(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
(float): the average number of seconds between every start/reset and
|
||||||
|
pause.
|
||||||
|
"""
|
||||||
|
return self.seconds() / self._count_start
|
||||||
86
utils/visualization.py
Normal file
86
utils/visualization.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
The code is base on https://github.com/Pointcept/Pointcept
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import open3d as o3d
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def to_numpy(x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
x = x.clone().detach().cpu().numpy()
|
||||||
|
assert isinstance(x, np.ndarray)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None):
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
coord = to_numpy(coord)
|
||||||
|
if color is not None:
|
||||||
|
color = to_numpy(color)
|
||||||
|
pcd = o3d.geometry.PointCloud()
|
||||||
|
pcd.points = o3d.utility.Vector3dVector(coord)
|
||||||
|
pcd.colors = o3d.utility.Vector3dVector(
|
||||||
|
np.ones_like(coord) if color is None else color
|
||||||
|
)
|
||||||
|
o3d.io.write_point_cloud(file_path, pcd)
|
||||||
|
if logger is not None:
|
||||||
|
logger.info(f"Save Point Cloud to: {file_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_bounding_boxes(
|
||||||
|
bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None
|
||||||
|
):
|
||||||
|
bboxes_corners = to_numpy(bboxes_corners)
|
||||||
|
# point list
|
||||||
|
points = bboxes_corners.reshape(-1, 3)
|
||||||
|
# line list
|
||||||
|
box_lines = np.array(
|
||||||
|
[
|
||||||
|
[0, 1],
|
||||||
|
[1, 2],
|
||||||
|
[2, 3],
|
||||||
|
[3, 0],
|
||||||
|
[4, 5],
|
||||||
|
[5, 6],
|
||||||
|
[6, 7],
|
||||||
|
[7, 0],
|
||||||
|
[0, 4],
|
||||||
|
[1, 5],
|
||||||
|
[2, 6],
|
||||||
|
[3, 7],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
lines = []
|
||||||
|
for i, _ in enumerate(bboxes_corners):
|
||||||
|
lines.append(box_lines + i * 8)
|
||||||
|
lines = np.concatenate(lines)
|
||||||
|
# color list
|
||||||
|
color = np.array([color for _ in range(len(lines))])
|
||||||
|
# generate line set
|
||||||
|
line_set = o3d.geometry.LineSet()
|
||||||
|
line_set.points = o3d.utility.Vector3dVector(points)
|
||||||
|
line_set.lines = o3d.utility.Vector2iVector(lines)
|
||||||
|
line_set.colors = o3d.utility.Vector3dVector(color)
|
||||||
|
o3d.io.write_line_set(file_path, line_set)
|
||||||
|
|
||||||
|
if logger is not None:
|
||||||
|
logger.info(f"Save Boxes to: {file_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_lines(
|
||||||
|
points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None
|
||||||
|
):
|
||||||
|
points = to_numpy(points)
|
||||||
|
lines = to_numpy(lines)
|
||||||
|
colors = np.array([color for _ in range(len(lines))])
|
||||||
|
line_set = o3d.geometry.LineSet()
|
||||||
|
line_set.points = o3d.utility.Vector3dVector(points)
|
||||||
|
line_set.lines = o3d.utility.Vector2iVector(lines)
|
||||||
|
line_set.colors = o3d.utility.Vector3dVector(colors)
|
||||||
|
o3d.io.write_line_set(file_path, line_set)
|
||||||
|
|
||||||
|
if logger is not None:
|
||||||
|
logger.info(f"Save Lines to: {file_path}")
|
||||||
BIN
wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl
Normal file
BIN
wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl
Normal file
Binary file not shown.
Reference in New Issue
Block a user