commit ca93dd05724a5701b99a9710fe9a89f77cf60ab5
Author: fdyuandong
Date: Thu Apr 17 23:14:24 2025 +0800
feat: Initial commit
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..73c532f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,18 @@
+image/
+__pycache__
+**/build/
+**/*.egg-info/
+**/dist/
+*.so
+exp
+weights
+data
+log
+outputs/
+.vscode
+.idea
+*/.DS_Store
+TEMP/
+pretrained/
+**/*.out
+Dockerfile
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..f49a4e1
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..7636ef6
--- /dev/null
+++ b/README.md
@@ -0,0 +1,103 @@
+# LAM-A2E: Audio to Expression
+
+[](https://aigc3d.github.io/projects/LAM/)
+[](https://www.apache.org/licenses/LICENSE-2.0)
+
+#### This project leverages audio input to generate ARKit blendshapes-driven facial expressions in ⚡real-time⚡, powering ultra-realistic 3D avatars generated by [LAM](https://github.com/aigc3d/LAM).
+
+## Demo
+
+
+
+
+
+## 📢 News
+
+
+### To do list
+- [ ] Release Huggingface space.
+- [ ] Release Modelscope space.
+- [ ] Release the LAM-A2E model based on the Flame expression.
+- [ ] Release Interactive Chatting Avatar SDK with [OpenAvatarChat](https://github.com/HumanAIGC-Engineering/OpenAvatarChat), including LLM, ASR, TTS, LAM-Avatars.
+
+
+
+## 🚀 Get Started
+### Environment Setup
+```bash
+git clone git@github.com:aigc3d/LAM_Audio2Expression.git
+cd LAM_Audio2Expression
+# Install with Cuda 12.1
+sh ./scripts/install/install_cu121.sh
+# Or Install with Cuda 11.8
+sh ./scripts/install/install_cu118.sh
+```
+
+
+### Download
+
+```
+# HuggingFace download
+# Download Assets and Model Weights
+huggingface-cli download 3DAIGC/LAM_audio2exp --local-dir ./
+tar -xzvf LAM_audio2exp_assets.tar && rm -f LAM_audio2exp_assets.tar
+tar -xzvf LAM_audio2exp_streaming.tar && rm -f LAM_audio2exp_streaming.tar
+
+
+# Or OSS Download (In case of HuggingFace download failing)
+# Download Assets
+wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/LAM_audio2exp_assets.tar
+tar -xzvf LAM_audio2exp_assets.tar && rm -f LAM_audio2exp_assets.tar
+# Download Model Weights
+wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/LAM_audio2exp_streaming.tar
+tar -xzvf LAM_audio2exp_streaming.tar && rm -f LAM_audio2exp_streaming.tar
+```
+
+
+### Quick Start Guide
+#### Using Gradio Interface:
+We provide a simple Gradio demo with **WebGLGL Render**, and you can get rendering results by uploading audio in seconds.
+
+
+
+
+
+```
+python app_lam_audio2exp.py
+```
+
+### Inference
+```bash
+# example: python inference.py --config-file configs/lam_audio2exp_config_streaming.py --options save_path=exp/audio2exp weight=pretrained_models/lam_audio2exp_streaming.tar audio_input=./assets/sample_audio/BarackObama_english.wav
+python inference.py --config-file ${CONFIG_PATH} --options save_path=${SAVE_PATH} weight=${CHECKPOINT_PATH} audio_input=${AUDIO_INPUT}
+```
+
+### Acknowledgement
+This work is built on many amazing research works and open-source projects:
+- [FLAME](https://flame.is.tue.mpg.de)
+- [FaceFormer](https://github.com/EvelynFan/FaceFormer)
+- [Meshtalk](https://github.com/facebookresearch/meshtalk)
+- [Unitalker](https://github.com/X-niper/UniTalker)
+- [Pointcept](https://github.com/Pointcept/Pointcept)
+
+Thanks for their excellent works and great contribution.
+
+
+### Related Works
+Welcome to follow our other interesting works:
+- [LAM](https://github.com/aigc3d/LAM)
+- [LHM](https://github.com/aigc3d/LHM)
+
+
+### Citation
+```
+@inproceedings{he2025LAM,
+ title={LAM: Large Avatar Model for One-shot Animatable Gaussian Head},
+ author={
+ Yisheng He and Xiaodong Gu and Xiaodan Ye and Chao Xu and Zhengyi Zhao and Yuan Dong and Weihao Yuan and Zilong Dong and Liefeng Bo
+ },
+ booktitle={arXiv preprint arXiv:2502.17796},
+ year={2025}
+}
+```
diff --git a/app_lam_audio2exp.py b/app_lam_audio2exp.py
new file mode 100644
index 0000000..96ce483
--- /dev/null
+++ b/app_lam_audio2exp.py
@@ -0,0 +1,271 @@
+"""
+Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import base64
+
+import gradio as gr
+import argparse
+from omegaconf import OmegaConf
+from gradio_gaussian_render import gaussian_render
+
+from engines.defaults import (
+ default_argument_parser,
+ default_config_parser,
+ default_setup,
+)
+from engines.infer import INFER
+from pathlib import Path
+
+try:
+ import spaces
+except:
+ pass
+
+h5_rendering = True
+
+
+def assert_input_image(input_image):
+ if input_image is None:
+ raise gr.Error('No image selected or uploaded!')
+
+
+def prepare_working_dir():
+ import tempfile
+ working_dir = tempfile.TemporaryDirectory()
+ return working_dir
+
+def get_image_base64(path):
+ with open(path, 'rb') as image_file:
+ encoded_string = base64.b64encode(image_file.read()).decode()
+ return f'data:image/png;base64,{encoded_string}'
+
+
+def doRender():
+ print('H5 rendering ....')
+
+def parse_configs():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str)
+ parser.add_argument("--infer", type=str)
+ args, unknown = parser.parse_known_args()
+
+ cfg = OmegaConf.create()
+ cli_cfg = OmegaConf.from_cli(unknown)
+
+ # parse from ENV
+ if os.environ.get("APP_INFER") is not None:
+ args.infer = os.environ.get("APP_INFER")
+ if os.environ.get("APP_MODEL_NAME") is not None:
+ cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
+
+ args.config = args.infer if args.config is None else args.config
+
+ if args.config is not None:
+ cfg_train = OmegaConf.load(args.config)
+ cfg.source_size = cfg_train.dataset.source_image_res
+ try:
+ cfg.src_head_size = cfg_train.dataset.src_head_size
+ except:
+ cfg.src_head_size = 112
+ cfg.render_size = cfg_train.dataset.render_image.high
+ _relative_path = os.path.join(
+ cfg_train.experiment.parent,
+ cfg_train.experiment.child,
+ os.path.basename(cli_cfg.model_name).split("_")[-1],
+ )
+
+ cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
+ cfg.image_dump = os.path.join("exps", "images", _relative_path)
+ cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
+
+ if args.infer is not None:
+ cfg_infer = OmegaConf.load(args.infer)
+ cfg.merge_with(cfg_infer)
+ cfg.setdefault(
+ "save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
+ )
+ cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
+ cfg.setdefault(
+ "video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
+ )
+ cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
+
+ cfg.motion_video_read_fps = 30
+ cfg.merge_with(cli_cfg)
+
+ cfg.setdefault("logger", "INFO")
+
+ assert cfg.model_name is not None, "model_name is required"
+
+ return cfg, cfg_train
+
+
+def create_zip_archive(output_zip='assets/arkitWithBSData.zip', base_dir=""):
+ import os
+ if (os.path.exists(output_zip)):
+ os.remove(output_zip)
+ print(f"Reomve previous file: {output_zip}")
+ run_command = 'zip -r '+output_zip+' '+base_dir
+ os.system(run_command)
+ # check file
+ if(os.path.exists(output_zip)):
+ print(f"Archive created successfully: {output_zip}")
+ else:
+ raise ValueError(f"Archive created failed: {output_zip}")
+
+
+def demo_lam_audio2exp(infer, cfg):
+ def core_fn(image_path: str, audio_params, working_dir):
+
+ base_id = os.path.basename(image_path).split(".")[0]
+
+ # set input audio
+ cfg.audio_input = audio_params
+ cfg.save_json_path = os.path.join("./assets/sample_lam", base_id, 'arkitWithBSData', 'bsData.json')
+ infer.infer()
+
+ create_zip_archive(output_zip='./assets/arkitWithBSData.zip', base_dir=os.path.join("./assets/sample_lam", base_id))
+
+ return
+
+ with gr.Blocks(analytics_enabled=False) as demo:
+ logo_url = './assets/images/logo.jpeg'
+ logo_base64 = get_image_base64(logo_url)
+ gr.HTML(f"""
+
+
+
LAM-A2E: Audio to Expression
+
+
+ """)
+
+ gr.HTML(
+ """ Notes: This project leverages audio input to generate ARKit blendshapes-driven facial expressions in ⚡real-time⚡, powering ultra-realistic 3D avatars generated by LAM.
"""
+ )
+
+ # DISPLAY
+ with gr.Row():
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id='lam_input_image'):
+ with gr.TabItem('Input Image'):
+ with gr.Row():
+ input_image = gr.Image(label='Input Image',
+ image_mode='RGB',
+ height=480,
+ width=270,
+ sources='upload',
+ type='filepath', # 'numpy',
+ elem_id='content_image')
+ # EXAMPLES
+ with gr.Row():
+ examples = [
+ ['assets/sample_input/barbara.jpg'],
+ ['assets/sample_input/status.png'],
+ ['assets/sample_input/james.png'],
+ ['assets/sample_input/vfhq_case1.png'],
+ ]
+ gr.Examples(
+ examples=examples,
+ inputs=[input_image],
+ examples_per_page=20,
+ )
+
+ with gr.Column():
+ with gr.Tabs(elem_id='lam_input_audio'):
+ with gr.TabItem('Input Audio'):
+ with gr.Row():
+ audio_input = gr.Audio(label='Input Audio',
+ type='filepath',
+ waveform_options={
+ 'sample_rate': 16000,
+ 'waveform_progress_color': '#4682b4'
+ },
+ elem_id='content_audio')
+
+ examples = [
+ ['assets/sample_audio/Nangyanwen_chinese.wav'],
+ ['assets/sample_audio/LiBai_TTS_chinese.wav'],
+ ['assets/sample_audio/LinJing_TTS_chinese.wav'],
+ ['assets/sample_audio/BarackObama_english.wav'],
+ ['assets/sample_audio/HillaryClinton_english.wav'],
+ ['assets/sample_audio/XitongShi_japanese.wav'],
+ ['assets/sample_audio/FangXiao_japanese.wav'],
+ ]
+ gr.Examples(
+ examples=examples,
+ inputs=[audio_input],
+ examples_per_page=10,
+ )
+
+ # SETTING
+ with gr.Row():
+ with gr.Column(variant='panel', scale=1):
+ submit = gr.Button('Generate',
+ elem_id='lam_generate',
+ variant='primary')
+
+ if h5_rendering:
+ gr.set_static_paths(Path.cwd().absolute() / "assets/")
+ assetPrefix = 'gradio_api/file=assets/'
+ with gr.Row():
+ gs = gaussian_render(width=380, height=680, assets=assetPrefix + 'arkitWithBSData.zip')
+
+ working_dir = gr.State()
+ submit.click(
+ fn=assert_input_image,
+ inputs=[input_image],
+ queue=False,
+ ).success(
+ fn=prepare_working_dir,
+ outputs=[working_dir],
+ queue=False,
+ ).success(
+ fn=core_fn,
+ inputs=[input_image, audio_input,
+ working_dir], # video_params refer to smpl dir
+ outputs=[],
+ queue=False,
+ ).success(
+ doRender, js='''() => window.start()'''
+ )
+
+ demo.queue()
+ demo.launch()
+
+
+
+def launch_gradio_app():
+ os.environ.update({
+ 'APP_ENABLED': '1',
+ 'APP_MODEL_NAME':'',
+ 'APP_INFER': 'configs/lam_audio2exp_streaming_config.py',
+ 'APP_TYPE': 'infer.audio2exp',
+ 'NUMBA_THREADING_LAYER': 'omp',
+ })
+
+ args = default_argument_parser().parse_args()
+ args.config_file = 'configs/lam_audio2exp_config_streaming.py'
+ cfg = default_config_parser(args.config_file, args.options)
+ cfg = default_setup(cfg)
+
+ cfg.ex_vol = True
+ infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg))
+
+ demo_lam_audio2exp(infer, cfg)
+
+
+if __name__ == '__main__':
+ launch_gradio_app()
diff --git a/assets/images/logo.jpeg b/assets/images/logo.jpeg
new file mode 100644
index 0000000..6fa8d78
Binary files /dev/null and b/assets/images/logo.jpeg differ
diff --git a/assets/images/snapshot.png b/assets/images/snapshot.png
new file mode 100644
index 0000000..8fc9bc9
Binary files /dev/null and b/assets/images/snapshot.png differ
diff --git a/assets/images/teaser.jpg b/assets/images/teaser.jpg
new file mode 100644
index 0000000..8c7c406
Binary files /dev/null and b/assets/images/teaser.jpg differ
diff --git a/configs/lam_audio2exp_config.py b/configs/lam_audio2exp_config.py
new file mode 100644
index 0000000..a1e4abb
--- /dev/null
+++ b/configs/lam_audio2exp_config.py
@@ -0,0 +1,92 @@
+weight = 'pretrained_models/lam_audio2exp.tar' # path to model weight
+ex_vol = True # Isolates vocal track from audio file
+audio_input = './assets/sample_audio/BarackObama.wav'
+save_json_path = 'bsData.json'
+
+audio_sr = 16000
+fps = 30.0
+
+movement_smooth = True
+brow_movement = True
+id_idx = 153
+
+resume = False # whether to resume training process
+evaluate = True # evaluate after each epoch training process
+test_only = False # test process
+
+seed = None # train process will init a random seed and record
+save_path = "exp/audio2exp"
+num_worker = 16 # total worker in all gpu
+batch_size = 16 # total batch size in all gpu
+batch_size_val = None # auto adapt to bs 1 for each gpu
+batch_size_test = None # auto adapt to bs 1 for each gpu
+epoch = 100 # total epoch, data loop = epoch // eval_epoch
+eval_epoch = 100 # sche total eval & checkpoint epoch
+
+sync_bn = False
+enable_amp = False
+empty_cache = False
+find_unused_parameters = False
+
+mix_prob = 0
+param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)]
+
+# model settings
+model = dict(
+ type="DefaultEstimator",
+ backbone=dict(
+ type="Audio2Expression",
+ pretrained_encoder_type='wav2vec',
+ pretrained_encoder_path='facebook/wav2vec2-base-960h',
+ wav2vec2_config_path = 'configs/wav2vec2_config.json',
+ num_identity_classes=5016,
+ identity_feat_dim=64,
+ hidden_dim=512,
+ expression_dim=52,
+ norm_type='ln',
+ use_transformer=True,
+ num_attention_heads=8,
+ num_transformer_layers=6,
+ ),
+ criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)],
+)
+
+dataset_type = 'audio2exp'
+data_root = './'
+data = dict(
+ train=dict(
+ type=dataset_type,
+ split="train",
+ data_root=data_root,
+ test_mode=False,
+ ),
+ val=dict(
+ type=dataset_type,
+ split="val",
+ data_root=data_root,
+ test_mode=False,
+ ),
+ test=dict(
+ type=dataset_type,
+ split="val",
+ data_root=data_root,
+ test_mode=True
+ ),
+)
+
+# hook
+hooks = [
+ dict(type="CheckpointLoader"),
+ dict(type="IterationTimer", warmup_iter=2),
+ dict(type="InformationWriter"),
+ dict(type="SemSegEvaluator"),
+ dict(type="CheckpointSaver", save_freq=None),
+ dict(type="PreciseEvaluator", test_last=False),
+]
+
+# Trainer
+train = dict(type="DefaultTrainer")
+
+# Tester
+infer = dict(type="Audio2ExpressionInfer",
+ verbose=True)
diff --git a/configs/lam_audio2exp_config_streaming.py b/configs/lam_audio2exp_config_streaming.py
new file mode 100644
index 0000000..3f44b92
--- /dev/null
+++ b/configs/lam_audio2exp_config_streaming.py
@@ -0,0 +1,92 @@
+weight = 'pretrained_models/lam_audio2exp_streaming.tar' # path to model weight
+ex_vol = True # extract
+audio_input = './assets/sample_audio/BarackObama.wav'
+save_json_path = 'bsData.json'
+
+audio_sr = 16000
+fps = 30.0
+
+movement_smooth = False
+brow_movement = False
+id_idx = 0
+
+resume = False # whether to resume training process
+evaluate = True # evaluate after each epoch training process
+test_only = False # test process
+
+seed = None # train process will init a random seed and record
+save_path = "exp/audio2exp"
+num_worker = 16 # total worker in all gpu
+batch_size = 16 # total batch size in all gpu
+batch_size_val = None # auto adapt to bs 1 for each gpu
+batch_size_test = None # auto adapt to bs 1 for each gpu
+epoch = 100 # total epoch, data loop = epoch // eval_epoch
+eval_epoch = 100 # sche total eval & checkpoint epoch
+
+sync_bn = False
+enable_amp = False
+empty_cache = False
+find_unused_parameters = False
+
+mix_prob = 0
+param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)]
+
+# model settings
+model = dict(
+ type="DefaultEstimator",
+ backbone=dict(
+ type="Audio2Expression",
+ pretrained_encoder_type='wav2vec',
+ pretrained_encoder_path='facebook/wav2vec2-base-960h',
+ wav2vec2_config_path = 'configs/wav2vec2_config.json',
+ num_identity_classes=12,
+ identity_feat_dim=64,
+ hidden_dim=512,
+ expression_dim=52,
+ norm_type='ln',
+ use_transformer=False,
+ num_attention_heads=8,
+ num_transformer_layers=6,
+ ),
+ criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)],
+)
+
+dataset_type = 'audio2exp'
+data_root = './'
+data = dict(
+ train=dict(
+ type=dataset_type,
+ split="train",
+ data_root=data_root,
+ test_mode=False,
+ ),
+ val=dict(
+ type=dataset_type,
+ split="val",
+ data_root=data_root,
+ test_mode=False,
+ ),
+ test=dict(
+ type=dataset_type,
+ split="val",
+ data_root=data_root,
+ test_mode=True
+ ),
+)
+
+# hook
+hooks = [
+ dict(type="CheckpointLoader"),
+ dict(type="IterationTimer", warmup_iter=2),
+ dict(type="InformationWriter"),
+ dict(type="SemSegEvaluator"),
+ dict(type="CheckpointSaver", save_freq=None),
+ dict(type="PreciseEvaluator", test_last=False),
+]
+
+# Trainer
+train = dict(type="DefaultTrainer")
+
+# Tester
+infer = dict(type="Audio2ExpressionInfer",
+ verbose=True)
diff --git a/configs/wav2vec2_config.json b/configs/wav2vec2_config.json
new file mode 100644
index 0000000..8ca9cc7
--- /dev/null
+++ b/configs/wav2vec2_config.json
@@ -0,0 +1,77 @@
+{
+ "_name_or_path": "facebook/wav2vec2-base-960h",
+ "activation_dropout": 0.1,
+ "apply_spec_augment": true,
+ "architectures": [
+ "Wav2Vec2ForCTC"
+ ],
+ "attention_dropout": 0.1,
+ "bos_token_id": 1,
+ "codevector_dim": 256,
+ "contrastive_logits_temperature": 0.1,
+ "conv_bias": false,
+ "conv_dim": [
+ 512,
+ 512,
+ 512,
+ 512,
+ 512,
+ 512,
+ 512
+ ],
+ "conv_kernel": [
+ 10,
+ 3,
+ 3,
+ 3,
+ 3,
+ 2,
+ 2
+ ],
+ "conv_stride": [
+ 5,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2
+ ],
+ "ctc_loss_reduction": "sum",
+ "ctc_zero_infinity": false,
+ "diversity_loss_weight": 0.1,
+ "do_stable_layer_norm": false,
+ "eos_token_id": 2,
+ "feat_extract_activation": "gelu",
+ "feat_extract_dropout": 0.0,
+ "feat_extract_norm": "group",
+ "feat_proj_dropout": 0.1,
+ "feat_quantizer_dropout": 0.0,
+ "final_dropout": 0.1,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout": 0.1,
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-05,
+ "layerdrop": 0.1,
+ "mask_feature_length": 10,
+ "mask_feature_prob": 0.0,
+ "mask_time_length": 10,
+ "mask_time_prob": 0.05,
+ "model_type": "wav2vec2",
+ "num_attention_heads": 12,
+ "num_codevector_groups": 2,
+ "num_codevectors_per_group": 320,
+ "num_conv_pos_embedding_groups": 16,
+ "num_conv_pos_embeddings": 128,
+ "num_feat_extract_layers": 7,
+ "num_hidden_layers": 12,
+ "num_negatives": 100,
+ "pad_token_id": 0,
+ "proj_codevector_dim": 256,
+ "transformers_version": "4.7.0.dev0",
+ "vocab_size": 32
+}
diff --git a/engines/__init__.py b/engines/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/engines/defaults.py b/engines/defaults.py
new file mode 100644
index 0000000..488148b
--- /dev/null
+++ b/engines/defaults.py
@@ -0,0 +1,147 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import os
+import sys
+import argparse
+import multiprocessing as mp
+from torch.nn.parallel import DistributedDataParallel
+
+
+import utils.comm as comm
+from utils.env import get_random_seed, set_seed
+from utils.config import Config, DictAction
+
+
+def create_ddp_model(model, *, fp16_compression=False, **kwargs):
+ """
+ Create a DistributedDataParallel model if there are >1 processes.
+ Args:
+ model: a torch.nn.Module
+ fp16_compression: add fp16 compression hooks to the ddp object.
+ See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
+ kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
+ """
+ if comm.get_world_size() == 1:
+ return model
+ # kwargs['find_unused_parameters'] = True
+ if "device_ids" not in kwargs:
+ kwargs["device_ids"] = [comm.get_local_rank()]
+ if "output_device" not in kwargs:
+ kwargs["output_device"] = [comm.get_local_rank()]
+ ddp = DistributedDataParallel(model, **kwargs)
+ if fp16_compression:
+ from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
+
+ ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
+ return ddp
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ """Worker init func for dataloader.
+
+ The seed of each worker equals to num_worker * rank + worker_id + user_seed
+
+ Args:
+ worker_id (int): Worker id.
+ num_workers (int): Number of workers.
+ rank (int): The rank of current process.
+ seed (int): The random seed to use.
+ """
+
+ worker_seed = num_workers * rank + worker_id + seed
+ set_seed(worker_seed)
+
+
+def default_argument_parser(epilog=None):
+ parser = argparse.ArgumentParser(
+ epilog=epilog
+ or f"""
+ Examples:
+ Run on single machine:
+ $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
+ Change some config options:
+ $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
+ Run on multiple machines:
+ (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags]
+ (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags]
+ """,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ parser.add_argument(
+ "--config-file", default="", metavar="FILE", help="path to config file"
+ )
+ parser.add_argument(
+ "--num-gpus", type=int, default=1, help="number of gpus *per machine*"
+ )
+ parser.add_argument(
+ "--num-machines", type=int, default=1, help="total number of machines"
+ )
+ parser.add_argument(
+ "--machine-rank",
+ type=int,
+ default=0,
+ help="the rank of this machine (unique per machine)",
+ )
+ # PyTorch still may leave orphan processes in multi-gpu training.
+ # Therefore we use a deterministic way to obtain port,
+ # so that users are aware of orphan processes by seeing the port occupied.
+ # port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
+ parser.add_argument(
+ "--dist-url",
+ # default="tcp://127.0.0.1:{}".format(port),
+ default="auto",
+ help="initialization URL for pytorch distributed backend. See "
+ "https://pytorch.org/docs/stable/distributed.html for details.",
+ )
+ parser.add_argument(
+ "--options", nargs="+", action=DictAction, help="custom options"
+ )
+ return parser
+
+
+def default_config_parser(file_path, options):
+ # config name protocol: dataset_name/model_name-exp_name
+ if os.path.isfile(file_path):
+ cfg = Config.fromfile(file_path)
+ else:
+ sep = file_path.find("-")
+ cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :]))
+
+ if options is not None:
+ cfg.merge_from_dict(options)
+
+ if cfg.seed is None:
+ cfg.seed = get_random_seed()
+
+ cfg.data.train.loop = cfg.epoch // cfg.eval_epoch
+
+ os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True)
+ if not cfg.resume:
+ cfg.dump(os.path.join(cfg.save_path, "config.py"))
+ return cfg
+
+
+def default_setup(cfg):
+ # scalar by world size
+ world_size = comm.get_world_size()
+ cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count()
+ cfg.num_worker_per_gpu = cfg.num_worker // world_size
+ assert cfg.batch_size % world_size == 0
+ assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0
+ assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0
+ cfg.batch_size_per_gpu = cfg.batch_size // world_size
+ cfg.batch_size_val_per_gpu = (
+ cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1
+ )
+ cfg.batch_size_test_per_gpu = (
+ cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1
+ )
+ # update data loop
+ assert cfg.epoch % cfg.eval_epoch == 0
+ # settle random seed
+ rank = comm.get_rank()
+ seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank
+ set_seed(seed)
+ return cfg
diff --git a/engines/hooks/__init__.py b/engines/hooks/__init__.py
new file mode 100644
index 0000000..1ab2c4b
--- /dev/null
+++ b/engines/hooks/__init__.py
@@ -0,0 +1,5 @@
+from .default import HookBase
+from .misc import *
+from .evaluator import *
+
+from .builder import build_hooks
diff --git a/engines/hooks/builder.py b/engines/hooks/builder.py
new file mode 100644
index 0000000..e0a121c
--- /dev/null
+++ b/engines/hooks/builder.py
@@ -0,0 +1,15 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+from utils.registry import Registry
+
+
+HOOKS = Registry("hooks")
+
+
+def build_hooks(cfg):
+ hooks = []
+ for hook_cfg in cfg:
+ hooks.append(HOOKS.build(hook_cfg))
+ return hooks
diff --git a/engines/hooks/default.py b/engines/hooks/default.py
new file mode 100644
index 0000000..57150a7
--- /dev/null
+++ b/engines/hooks/default.py
@@ -0,0 +1,29 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+
+class HookBase:
+ """
+ Base class for hooks that can be registered with :class:`TrainerBase`.
+ """
+
+ trainer = None # A weak reference to the trainer object.
+
+ def before_train(self):
+ pass
+
+ def before_epoch(self):
+ pass
+
+ def before_step(self):
+ pass
+
+ def after_step(self):
+ pass
+
+ def after_epoch(self):
+ pass
+
+ def after_train(self):
+ pass
diff --git a/engines/hooks/evaluator.py b/engines/hooks/evaluator.py
new file mode 100644
index 0000000..c0d2717
--- /dev/null
+++ b/engines/hooks/evaluator.py
@@ -0,0 +1,577 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from uuid import uuid4
+
+import utils.comm as comm
+from utils.misc import intersection_and_union_gpu
+
+from .default import HookBase
+from .builder import HOOKS
+
+
+@HOOKS.register_module()
+class ClsEvaluator(HookBase):
+ def after_epoch(self):
+ if self.trainer.cfg.evaluate:
+ self.eval()
+
+ def eval(self):
+ self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
+ self.trainer.model.eval()
+ for i, input_dict in enumerate(self.trainer.val_loader):
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with torch.no_grad():
+ output_dict = self.trainer.model(input_dict)
+ output = output_dict["cls_logits"]
+ loss = output_dict["loss"]
+ pred = output.max(1)[1]
+ label = input_dict["category"]
+ intersection, union, target = intersection_and_union_gpu(
+ pred,
+ label,
+ self.trainer.cfg.data.num_classes,
+ self.trainer.cfg.data.ignore_index,
+ )
+ if comm.get_world_size() > 1:
+ dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
+ target
+ )
+ intersection, union, target = (
+ intersection.cpu().numpy(),
+ union.cpu().numpy(),
+ target.cpu().numpy(),
+ )
+ # Here there is no need to sync since sync happened in dist.all_reduce
+ self.trainer.storage.put_scalar("val_intersection", intersection)
+ self.trainer.storage.put_scalar("val_union", union)
+ self.trainer.storage.put_scalar("val_target", target)
+ self.trainer.storage.put_scalar("val_loss", loss.item())
+ self.trainer.logger.info(
+ "Test: [{iter}/{max_iter}] "
+ "Loss {loss:.4f} ".format(
+ iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
+ )
+ )
+ loss_avg = self.trainer.storage.history("val_loss").avg
+ intersection = self.trainer.storage.history("val_intersection").total
+ union = self.trainer.storage.history("val_union").total
+ target = self.trainer.storage.history("val_target").total
+ iou_class = intersection / (union + 1e-10)
+ acc_class = intersection / (target + 1e-10)
+ m_iou = np.mean(iou_class)
+ m_acc = np.mean(acc_class)
+ all_acc = sum(intersection) / (sum(target) + 1e-10)
+ self.trainer.logger.info(
+ "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
+ m_iou, m_acc, all_acc
+ )
+ )
+ for i in range(self.trainer.cfg.data.num_classes):
+ self.trainer.logger.info(
+ "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
+ idx=i,
+ name=self.trainer.cfg.data.names[i],
+ iou=iou_class[i],
+ accuracy=acc_class[i],
+ )
+ )
+ current_epoch = self.trainer.epoch + 1
+ if self.trainer.writer is not None:
+ self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
+ self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
+ self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
+ self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
+ self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
+ self.trainer.comm_info["current_metric_value"] = all_acc # save for saver
+ self.trainer.comm_info["current_metric_name"] = "allAcc" # save for saver
+
+ def after_train(self):
+ self.trainer.logger.info(
+ "Best {}: {:.4f}".format("allAcc", self.trainer.best_metric_value)
+ )
+
+
+@HOOKS.register_module()
+class SemSegEvaluator(HookBase):
+ def after_epoch(self):
+ if self.trainer.cfg.evaluate:
+ self.eval()
+
+ def eval(self):
+ self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
+ self.trainer.model.eval()
+ for i, input_dict in enumerate(self.trainer.val_loader):
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with torch.no_grad():
+ output_dict = self.trainer.model(input_dict)
+ output = output_dict["seg_logits"]
+ loss = output_dict["loss"]
+ pred = output.max(1)[1]
+ segment = input_dict["segment"]
+ if "origin_coord" in input_dict.keys():
+ idx, _ = pointops.knn_query(
+ 1,
+ input_dict["coord"].float(),
+ input_dict["offset"].int(),
+ input_dict["origin_coord"].float(),
+ input_dict["origin_offset"].int(),
+ )
+ pred = pred[idx.flatten().long()]
+ segment = input_dict["origin_segment"]
+ intersection, union, target = intersection_and_union_gpu(
+ pred,
+ segment,
+ self.trainer.cfg.data.num_classes,
+ self.trainer.cfg.data.ignore_index,
+ )
+ if comm.get_world_size() > 1:
+ dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
+ target
+ )
+ intersection, union, target = (
+ intersection.cpu().numpy(),
+ union.cpu().numpy(),
+ target.cpu().numpy(),
+ )
+ # Here there is no need to sync since sync happened in dist.all_reduce
+ self.trainer.storage.put_scalar("val_intersection", intersection)
+ self.trainer.storage.put_scalar("val_union", union)
+ self.trainer.storage.put_scalar("val_target", target)
+ self.trainer.storage.put_scalar("val_loss", loss.item())
+ info = "Test: [{iter}/{max_iter}] ".format(
+ iter=i + 1, max_iter=len(self.trainer.val_loader)
+ )
+ if "origin_coord" in input_dict.keys():
+ info = "Interp. " + info
+ self.trainer.logger.info(
+ info
+ + "Loss {loss:.4f} ".format(
+ iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
+ )
+ )
+ loss_avg = self.trainer.storage.history("val_loss").avg
+ intersection = self.trainer.storage.history("val_intersection").total
+ union = self.trainer.storage.history("val_union").total
+ target = self.trainer.storage.history("val_target").total
+ iou_class = intersection / (union + 1e-10)
+ acc_class = intersection / (target + 1e-10)
+ m_iou = np.mean(iou_class)
+ m_acc = np.mean(acc_class)
+ all_acc = sum(intersection) / (sum(target) + 1e-10)
+ self.trainer.logger.info(
+ "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
+ m_iou, m_acc, all_acc
+ )
+ )
+ for i in range(self.trainer.cfg.data.num_classes):
+ self.trainer.logger.info(
+ "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
+ idx=i,
+ name=self.trainer.cfg.data.names[i],
+ iou=iou_class[i],
+ accuracy=acc_class[i],
+ )
+ )
+ current_epoch = self.trainer.epoch + 1
+ if self.trainer.writer is not None:
+ self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
+ self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
+ self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
+ self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
+ self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
+ self.trainer.comm_info["current_metric_value"] = m_iou # save for saver
+ self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver
+
+ def after_train(self):
+ self.trainer.logger.info(
+ "Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value)
+ )
+
+
+@HOOKS.register_module()
+class InsSegEvaluator(HookBase):
+ def __init__(self, segment_ignore_index=(-1,), instance_ignore_index=-1):
+ self.segment_ignore_index = segment_ignore_index
+ self.instance_ignore_index = instance_ignore_index
+
+ self.valid_class_names = None # update in before train
+ self.overlaps = np.append(np.arange(0.5, 0.95, 0.05), 0.25)
+ self.min_region_sizes = 100
+ self.distance_threshes = float("inf")
+ self.distance_confs = -float("inf")
+
+ def before_train(self):
+ self.valid_class_names = [
+ self.trainer.cfg.data.names[i]
+ for i in range(self.trainer.cfg.data.num_classes)
+ if i not in self.segment_ignore_index
+ ]
+
+ def after_epoch(self):
+ if self.trainer.cfg.evaluate:
+ self.eval()
+
+ def associate_instances(self, pred, segment, instance):
+ segment = segment.cpu().numpy()
+ instance = instance.cpu().numpy()
+ void_mask = np.in1d(segment, self.segment_ignore_index)
+
+ assert (
+ pred["pred_classes"].shape[0]
+ == pred["pred_scores"].shape[0]
+ == pred["pred_masks"].shape[0]
+ )
+ assert pred["pred_masks"].shape[1] == segment.shape[0] == instance.shape[0]
+ # get gt instances
+ gt_instances = dict()
+ for i in range(self.trainer.cfg.data.num_classes):
+ if i not in self.segment_ignore_index:
+ gt_instances[self.trainer.cfg.data.names[i]] = []
+ instance_ids, idx, counts = np.unique(
+ instance, return_index=True, return_counts=True
+ )
+ segment_ids = segment[idx]
+ for i in range(len(instance_ids)):
+ if instance_ids[i] == self.instance_ignore_index:
+ continue
+ if segment_ids[i] in self.segment_ignore_index:
+ continue
+ gt_inst = dict()
+ gt_inst["instance_id"] = instance_ids[i]
+ gt_inst["segment_id"] = segment_ids[i]
+ gt_inst["dist_conf"] = 0.0
+ gt_inst["med_dist"] = -1.0
+ gt_inst["vert_count"] = counts[i]
+ gt_inst["matched_pred"] = []
+ gt_instances[self.trainer.cfg.data.names[segment_ids[i]]].append(gt_inst)
+
+ # get pred instances and associate with gt
+ pred_instances = dict()
+ for i in range(self.trainer.cfg.data.num_classes):
+ if i not in self.segment_ignore_index:
+ pred_instances[self.trainer.cfg.data.names[i]] = []
+ instance_id = 0
+ for i in range(len(pred["pred_classes"])):
+ if pred["pred_classes"][i] in self.segment_ignore_index:
+ continue
+ pred_inst = dict()
+ pred_inst["uuid"] = uuid4()
+ pred_inst["instance_id"] = instance_id
+ pred_inst["segment_id"] = pred["pred_classes"][i]
+ pred_inst["confidence"] = pred["pred_scores"][i]
+ pred_inst["mask"] = np.not_equal(pred["pred_masks"][i], 0)
+ pred_inst["vert_count"] = np.count_nonzero(pred_inst["mask"])
+ pred_inst["void_intersection"] = np.count_nonzero(
+ np.logical_and(void_mask, pred_inst["mask"])
+ )
+ if pred_inst["vert_count"] < self.min_region_sizes:
+ continue # skip if empty
+ segment_name = self.trainer.cfg.data.names[pred_inst["segment_id"]]
+ matched_gt = []
+ for gt_idx, gt_inst in enumerate(gt_instances[segment_name]):
+ intersection = np.count_nonzero(
+ np.logical_and(
+ instance == gt_inst["instance_id"], pred_inst["mask"]
+ )
+ )
+ if intersection > 0:
+ gt_inst_ = gt_inst.copy()
+ pred_inst_ = pred_inst.copy()
+ gt_inst_["intersection"] = intersection
+ pred_inst_["intersection"] = intersection
+ matched_gt.append(gt_inst_)
+ gt_inst["matched_pred"].append(pred_inst_)
+ pred_inst["matched_gt"] = matched_gt
+ pred_instances[segment_name].append(pred_inst)
+ instance_id += 1
+ return gt_instances, pred_instances
+
+ def evaluate_matches(self, scenes):
+ overlaps = self.overlaps
+ min_region_sizes = [self.min_region_sizes]
+ dist_threshes = [self.distance_threshes]
+ dist_confs = [self.distance_confs]
+
+ # results: class x overlap
+ ap_table = np.zeros(
+ (len(dist_threshes), len(self.valid_class_names), len(overlaps)), float
+ )
+ for di, (min_region_size, distance_thresh, distance_conf) in enumerate(
+ zip(min_region_sizes, dist_threshes, dist_confs)
+ ):
+ for oi, overlap_th in enumerate(overlaps):
+ pred_visited = {}
+ for scene in scenes:
+ for _ in scene["pred"]:
+ for label_name in self.valid_class_names:
+ for p in scene["pred"][label_name]:
+ if "uuid" in p:
+ pred_visited[p["uuid"]] = False
+ for li, label_name in enumerate(self.valid_class_names):
+ y_true = np.empty(0)
+ y_score = np.empty(0)
+ hard_false_negatives = 0
+ has_gt = False
+ has_pred = False
+ for scene in scenes:
+ pred_instances = scene["pred"][label_name]
+ gt_instances = scene["gt"][label_name]
+ # filter groups in ground truth
+ gt_instances = [
+ gt
+ for gt in gt_instances
+ if gt["vert_count"] >= min_region_size
+ and gt["med_dist"] <= distance_thresh
+ and gt["dist_conf"] >= distance_conf
+ ]
+ if gt_instances:
+ has_gt = True
+ if pred_instances:
+ has_pred = True
+
+ cur_true = np.ones(len(gt_instances))
+ cur_score = np.ones(len(gt_instances)) * (-float("inf"))
+ cur_match = np.zeros(len(gt_instances), dtype=bool)
+ # collect matches
+ for gti, gt in enumerate(gt_instances):
+ found_match = False
+ for pred in gt["matched_pred"]:
+ # greedy assignments
+ if pred_visited[pred["uuid"]]:
+ continue
+ overlap = float(pred["intersection"]) / (
+ gt["vert_count"]
+ + pred["vert_count"]
+ - pred["intersection"]
+ )
+ if overlap > overlap_th:
+ confidence = pred["confidence"]
+ # if already have a prediction for this gt,
+ # the prediction with the lower score is automatically a false positive
+ if cur_match[gti]:
+ max_score = max(cur_score[gti], confidence)
+ min_score = min(cur_score[gti], confidence)
+ cur_score[gti] = max_score
+ # append false positive
+ cur_true = np.append(cur_true, 0)
+ cur_score = np.append(cur_score, min_score)
+ cur_match = np.append(cur_match, True)
+ # otherwise set score
+ else:
+ found_match = True
+ cur_match[gti] = True
+ cur_score[gti] = confidence
+ pred_visited[pred["uuid"]] = True
+ if not found_match:
+ hard_false_negatives += 1
+ # remove non-matched ground truth instances
+ cur_true = cur_true[cur_match]
+ cur_score = cur_score[cur_match]
+
+ # collect non-matched predictions as false positive
+ for pred in pred_instances:
+ found_gt = False
+ for gt in pred["matched_gt"]:
+ overlap = float(gt["intersection"]) / (
+ gt["vert_count"]
+ + pred["vert_count"]
+ - gt["intersection"]
+ )
+ if overlap > overlap_th:
+ found_gt = True
+ break
+ if not found_gt:
+ num_ignore = pred["void_intersection"]
+ for gt in pred["matched_gt"]:
+ if gt["segment_id"] in self.segment_ignore_index:
+ num_ignore += gt["intersection"]
+ # small ground truth instances
+ if (
+ gt["vert_count"] < min_region_size
+ or gt["med_dist"] > distance_thresh
+ or gt["dist_conf"] < distance_conf
+ ):
+ num_ignore += gt["intersection"]
+ proportion_ignore = (
+ float(num_ignore) / pred["vert_count"]
+ )
+ # if not ignored append false positive
+ if proportion_ignore <= overlap_th:
+ cur_true = np.append(cur_true, 0)
+ confidence = pred["confidence"]
+ cur_score = np.append(cur_score, confidence)
+
+ # append to overall results
+ y_true = np.append(y_true, cur_true)
+ y_score = np.append(y_score, cur_score)
+
+ # compute average precision
+ if has_gt and has_pred:
+ # compute precision recall curve first
+
+ # sorting and cumsum
+ score_arg_sort = np.argsort(y_score)
+ y_score_sorted = y_score[score_arg_sort]
+ y_true_sorted = y_true[score_arg_sort]
+ y_true_sorted_cumsum = np.cumsum(y_true_sorted)
+
+ # unique thresholds
+ (thresholds, unique_indices) = np.unique(
+ y_score_sorted, return_index=True
+ )
+ num_prec_recall = len(unique_indices) + 1
+
+ # prepare precision recall
+ num_examples = len(y_score_sorted)
+ # https://github.com/ScanNet/ScanNet/pull/26
+ # all predictions are non-matched but also all of them are ignored and not counted as FP
+ # y_true_sorted_cumsum is empty
+ # num_true_examples = y_true_sorted_cumsum[-1]
+ num_true_examples = (
+ y_true_sorted_cumsum[-1]
+ if len(y_true_sorted_cumsum) > 0
+ else 0
+ )
+ precision = np.zeros(num_prec_recall)
+ recall = np.zeros(num_prec_recall)
+
+ # deal with the first point
+ y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
+ # deal with remaining
+ for idx_res, idx_scores in enumerate(unique_indices):
+ cumsum = y_true_sorted_cumsum[idx_scores - 1]
+ tp = num_true_examples - cumsum
+ fp = num_examples - idx_scores - tp
+ fn = cumsum + hard_false_negatives
+ p = float(tp) / (tp + fp)
+ r = float(tp) / (tp + fn)
+ precision[idx_res] = p
+ recall[idx_res] = r
+
+ # first point in curve is artificial
+ precision[-1] = 1.0
+ recall[-1] = 0.0
+
+ # compute average of precision-recall curve
+ recall_for_conv = np.copy(recall)
+ recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
+ recall_for_conv = np.append(recall_for_conv, 0.0)
+
+ stepWidths = np.convolve(
+ recall_for_conv, [-0.5, 0, 0.5], "valid"
+ )
+ # integrate is now simply a dot product
+ ap_current = np.dot(precision, stepWidths)
+
+ elif has_gt:
+ ap_current = 0.0
+ else:
+ ap_current = float("nan")
+ ap_table[di, li, oi] = ap_current
+ d_inf = 0
+ o50 = np.where(np.isclose(self.overlaps, 0.5))
+ o25 = np.where(np.isclose(self.overlaps, 0.25))
+ oAllBut25 = np.where(np.logical_not(np.isclose(self.overlaps, 0.25)))
+ ap_scores = dict()
+ ap_scores["all_ap"] = np.nanmean(ap_table[d_inf, :, oAllBut25])
+ ap_scores["all_ap_50%"] = np.nanmean(ap_table[d_inf, :, o50])
+ ap_scores["all_ap_25%"] = np.nanmean(ap_table[d_inf, :, o25])
+ ap_scores["classes"] = {}
+ for li, label_name in enumerate(self.valid_class_names):
+ ap_scores["classes"][label_name] = {}
+ ap_scores["classes"][label_name]["ap"] = np.average(
+ ap_table[d_inf, li, oAllBut25]
+ )
+ ap_scores["classes"][label_name]["ap50%"] = np.average(
+ ap_table[d_inf, li, o50]
+ )
+ ap_scores["classes"][label_name]["ap25%"] = np.average(
+ ap_table[d_inf, li, o25]
+ )
+ return ap_scores
+
+ def eval(self):
+ self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
+ self.trainer.model.eval()
+ scenes = []
+ for i, input_dict in enumerate(self.trainer.val_loader):
+ assert (
+ len(input_dict["offset"]) == 1
+ ) # currently only support bs 1 for each GPU
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with torch.no_grad():
+ output_dict = self.trainer.model(input_dict)
+
+ loss = output_dict["loss"]
+
+ segment = input_dict["segment"]
+ instance = input_dict["instance"]
+ # map to origin
+ if "origin_coord" in input_dict.keys():
+ idx, _ = pointops.knn_query(
+ 1,
+ input_dict["coord"].float(),
+ input_dict["offset"].int(),
+ input_dict["origin_coord"].float(),
+ input_dict["origin_offset"].int(),
+ )
+ idx = idx.cpu().flatten().long()
+ output_dict["pred_masks"] = output_dict["pred_masks"][:, idx]
+ segment = input_dict["origin_segment"]
+ instance = input_dict["origin_instance"]
+
+ gt_instances, pred_instance = self.associate_instances(
+ output_dict, segment, instance
+ )
+ scenes.append(dict(gt=gt_instances, pred=pred_instance))
+
+ self.trainer.storage.put_scalar("val_loss", loss.item())
+ self.trainer.logger.info(
+ "Test: [{iter}/{max_iter}] "
+ "Loss {loss:.4f} ".format(
+ iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
+ )
+ )
+
+ loss_avg = self.trainer.storage.history("val_loss").avg
+ comm.synchronize()
+ scenes_sync = comm.gather(scenes, dst=0)
+ scenes = [scene for scenes_ in scenes_sync for scene in scenes_]
+ ap_scores = self.evaluate_matches(scenes)
+ all_ap = ap_scores["all_ap"]
+ all_ap_50 = ap_scores["all_ap_50%"]
+ all_ap_25 = ap_scores["all_ap_25%"]
+ self.trainer.logger.info(
+ "Val result: mAP/AP50/AP25 {:.4f}/{:.4f}/{:.4f}.".format(
+ all_ap, all_ap_50, all_ap_25
+ )
+ )
+ for i, label_name in enumerate(self.valid_class_names):
+ ap = ap_scores["classes"][label_name]["ap"]
+ ap_50 = ap_scores["classes"][label_name]["ap50%"]
+ ap_25 = ap_scores["classes"][label_name]["ap25%"]
+ self.trainer.logger.info(
+ "Class_{idx}-{name} Result: AP/AP50/AP25 {AP:.4f}/{AP50:.4f}/{AP25:.4f}".format(
+ idx=i, name=label_name, AP=ap, AP50=ap_50, AP25=ap_25
+ )
+ )
+ current_epoch = self.trainer.epoch + 1
+ if self.trainer.writer is not None:
+ self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
+ self.trainer.writer.add_scalar("val/mAP", all_ap, current_epoch)
+ self.trainer.writer.add_scalar("val/AP50", all_ap_50, current_epoch)
+ self.trainer.writer.add_scalar("val/AP25", all_ap_25, current_epoch)
+ self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
+ self.trainer.comm_info["current_metric_value"] = all_ap_50 # save for saver
+ self.trainer.comm_info["current_metric_name"] = "AP50" # save for saver
diff --git a/engines/hooks/misc.py b/engines/hooks/misc.py
new file mode 100644
index 0000000..52b398e
--- /dev/null
+++ b/engines/hooks/misc.py
@@ -0,0 +1,460 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import sys
+import glob
+import os
+import shutil
+import time
+import torch
+import torch.utils.data
+from collections import OrderedDict
+
+if sys.version_info >= (3, 10):
+ from collections.abc import Sequence
+else:
+ from collections import Sequence
+from utils.timer import Timer
+from utils.comm import is_main_process, synchronize, get_world_size
+from utils.cache import shared_dict
+
+import utils.comm as comm
+from engines.test import TESTERS
+
+from .default import HookBase
+from .builder import HOOKS
+
+
+@HOOKS.register_module()
+class IterationTimer(HookBase):
+ def __init__(self, warmup_iter=1):
+ self._warmup_iter = warmup_iter
+ self._start_time = time.perf_counter()
+ self._iter_timer = Timer()
+ self._remain_iter = 0
+
+ def before_train(self):
+ self._start_time = time.perf_counter()
+ self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader)
+
+ def before_epoch(self):
+ self._iter_timer.reset()
+
+ def before_step(self):
+ data_time = self._iter_timer.seconds()
+ self.trainer.storage.put_scalar("data_time", data_time)
+
+ def after_step(self):
+ batch_time = self._iter_timer.seconds()
+ self._iter_timer.reset()
+ self.trainer.storage.put_scalar("batch_time", batch_time)
+ self._remain_iter -= 1
+ remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg
+ t_m, t_s = divmod(remain_time, 60)
+ t_h, t_m = divmod(t_m, 60)
+ remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s))
+ if "iter_info" in self.trainer.comm_info.keys():
+ info = (
+ "Data {data_time_val:.3f} ({data_time_avg:.3f}) "
+ "Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) "
+ "Remain {remain_time} ".format(
+ data_time_val=self.trainer.storage.history("data_time").val,
+ data_time_avg=self.trainer.storage.history("data_time").avg,
+ batch_time_val=self.trainer.storage.history("batch_time").val,
+ batch_time_avg=self.trainer.storage.history("batch_time").avg,
+ remain_time=remain_time,
+ )
+ )
+ self.trainer.comm_info["iter_info"] += info
+ if self.trainer.comm_info["iter"] <= self._warmup_iter:
+ self.trainer.storage.history("data_time").reset()
+ self.trainer.storage.history("batch_time").reset()
+
+
+@HOOKS.register_module()
+class InformationWriter(HookBase):
+ def __init__(self):
+ self.curr_iter = 0
+ self.model_output_keys = []
+
+ def before_train(self):
+ self.trainer.comm_info["iter_info"] = ""
+ self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader)
+
+ def before_step(self):
+ self.curr_iter += 1
+ # MSC pretrain do not have offset information. Comment the code for support MSC
+ # info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \
+ # "Scan {batch_size} ({points_num}) ".format(
+ # epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch,
+ # iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader),
+ # batch_size=len(self.trainer.comm_info["input_dict"]["offset"]),
+ # points_num=self.trainer.comm_info["input_dict"]["offset"][-1]
+ # )
+ info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format(
+ epoch=self.trainer.epoch + 1,
+ max_epoch=self.trainer.max_epoch,
+ iter=self.trainer.comm_info["iter"] + 1,
+ max_iter=len(self.trainer.train_loader),
+ )
+ self.trainer.comm_info["iter_info"] += info
+
+ def after_step(self):
+ if "model_output_dict" in self.trainer.comm_info.keys():
+ model_output_dict = self.trainer.comm_info["model_output_dict"]
+ self.model_output_keys = model_output_dict.keys()
+ for key in self.model_output_keys:
+ self.trainer.storage.put_scalar(key, model_output_dict[key].item())
+
+ for key in self.model_output_keys:
+ self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format(
+ key=key, value=self.trainer.storage.history(key).val
+ )
+ lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"]
+ self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr)
+ self.trainer.logger.info(self.trainer.comm_info["iter_info"])
+ self.trainer.comm_info["iter_info"] = "" # reset iter info
+ if self.trainer.writer is not None:
+ self.trainer.writer.add_scalar("lr", lr, self.curr_iter)
+ for key in self.model_output_keys:
+ self.trainer.writer.add_scalar(
+ "train_batch/" + key,
+ self.trainer.storage.history(key).val,
+ self.curr_iter,
+ )
+
+ def after_epoch(self):
+ epoch_info = "Train result: "
+ for key in self.model_output_keys:
+ epoch_info += "{key}: {value:.4f} ".format(
+ key=key, value=self.trainer.storage.history(key).avg
+ )
+ self.trainer.logger.info(epoch_info)
+ if self.trainer.writer is not None:
+ for key in self.model_output_keys:
+ self.trainer.writer.add_scalar(
+ "train/" + key,
+ self.trainer.storage.history(key).avg,
+ self.trainer.epoch + 1,
+ )
+
+
+@HOOKS.register_module()
+class CheckpointSaver(HookBase):
+ def __init__(self, save_freq=None):
+ self.save_freq = save_freq # None or int, None indicate only save model last
+
+ def after_epoch(self):
+ if is_main_process():
+ is_best = False
+ if self.trainer.cfg.evaluate:
+ current_metric_value = self.trainer.comm_info["current_metric_value"]
+ current_metric_name = self.trainer.comm_info["current_metric_name"]
+ if current_metric_value > self.trainer.best_metric_value:
+ self.trainer.best_metric_value = current_metric_value
+ is_best = True
+ self.trainer.logger.info(
+ "Best validation {} updated to: {:.4f}".format(
+ current_metric_name, current_metric_value
+ )
+ )
+ self.trainer.logger.info(
+ "Currently Best {}: {:.4f}".format(
+ current_metric_name, self.trainer.best_metric_value
+ )
+ )
+
+ filename = os.path.join(
+ self.trainer.cfg.save_path, "model", "model_last.pth"
+ )
+ self.trainer.logger.info("Saving checkpoint to: " + filename)
+ torch.save(
+ {
+ "epoch": self.trainer.epoch + 1,
+ "state_dict": self.trainer.model.state_dict(),
+ "optimizer": self.trainer.optimizer.state_dict(),
+ "scheduler": self.trainer.scheduler.state_dict(),
+ "scaler": self.trainer.scaler.state_dict()
+ if self.trainer.cfg.enable_amp
+ else None,
+ "best_metric_value": self.trainer.best_metric_value,
+ },
+ filename + ".tmp",
+ )
+ os.replace(filename + ".tmp", filename)
+ if is_best:
+ shutil.copyfile(
+ filename,
+ os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"),
+ )
+ if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0:
+ shutil.copyfile(
+ filename,
+ os.path.join(
+ self.trainer.cfg.save_path,
+ "model",
+ f"epoch_{self.trainer.epoch + 1}.pth",
+ ),
+ )
+
+
+@HOOKS.register_module()
+class CheckpointLoader(HookBase):
+ def __init__(self, keywords="", replacement=None, strict=False):
+ self.keywords = keywords
+ self.replacement = replacement if replacement is not None else keywords
+ self.strict = strict
+
+ def before_train(self):
+ self.trainer.logger.info("=> Loading checkpoint & weight ...")
+ if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight):
+ self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}")
+ checkpoint = torch.load(
+ self.trainer.cfg.weight,
+ map_location=lambda storage, loc: storage.cuda(),
+ )
+ self.trainer.logger.info(
+ f"Loading layer weights with keyword: {self.keywords}, "
+ f"replace keyword with: {self.replacement}"
+ )
+ weight = OrderedDict()
+ for key, value in checkpoint["state_dict"].items():
+ if not key.startswith("module."):
+ if comm.get_world_size() > 1:
+ key = "module." + key # xxx.xxx -> module.xxx.xxx
+ # Now all keys contain "module." no matter DDP or not.
+ if self.keywords in key:
+ key = key.replace(self.keywords, self.replacement)
+ if comm.get_world_size() == 1:
+ key = key[7:] # module.xxx.xxx -> xxx.xxx
+ weight[key] = value
+ load_state_info = self.trainer.model.load_state_dict(
+ weight, strict=self.strict
+ )
+ self.trainer.logger.info(f"Missing keys: {load_state_info[0]}")
+ if self.trainer.cfg.resume:
+ self.trainer.logger.info(
+ f"Resuming train at eval epoch: {checkpoint['epoch']}"
+ )
+ self.trainer.start_epoch = checkpoint["epoch"]
+ self.trainer.best_metric_value = checkpoint["best_metric_value"]
+ self.trainer.optimizer.load_state_dict(checkpoint["optimizer"])
+ self.trainer.scheduler.load_state_dict(checkpoint["scheduler"])
+ if self.trainer.cfg.enable_amp:
+ self.trainer.scaler.load_state_dict(checkpoint["scaler"])
+ else:
+ self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}")
+
+
+@HOOKS.register_module()
+class PreciseEvaluator(HookBase):
+ def __init__(self, test_last=False):
+ self.test_last = test_last
+
+ def after_train(self):
+ self.trainer.logger.info(
+ ">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>"
+ )
+ torch.cuda.empty_cache()
+ cfg = self.trainer.cfg
+ tester = TESTERS.build(
+ dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model)
+ )
+ if self.test_last:
+ self.trainer.logger.info("=> Testing on model_last ...")
+ else:
+ self.trainer.logger.info("=> Testing on model_best ...")
+ best_path = os.path.join(
+ self.trainer.cfg.save_path, "model", "model_best.pth"
+ )
+ checkpoint = torch.load(best_path)
+ state_dict = checkpoint["state_dict"]
+ tester.model.load_state_dict(state_dict, strict=True)
+ tester.test()
+
+
+@HOOKS.register_module()
+class DataCacheOperator(HookBase):
+ def __init__(self, data_root, split):
+ self.data_root = data_root
+ self.split = split
+ self.data_list = self.get_data_list()
+
+ def get_data_list(self):
+ if isinstance(self.split, str):
+ data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth"))
+ elif isinstance(self.split, Sequence):
+ data_list = []
+ for split in self.split:
+ data_list += glob.glob(os.path.join(self.data_root, split, "*.pth"))
+ else:
+ raise NotImplementedError
+ return data_list
+
+ def get_cache_name(self, data_path):
+ data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0]
+ return "pointcept" + data_name.replace(os.path.sep, "-")
+
+ def before_train(self):
+ self.trainer.logger.info(
+ f"=> Caching dataset: {self.data_root}, split: {self.split} ..."
+ )
+ if is_main_process():
+ for data_path in self.data_list:
+ cache_name = self.get_cache_name(data_path)
+ data = torch.load(data_path)
+ shared_dict(cache_name, data)
+ synchronize()
+
+
+@HOOKS.register_module()
+class RuntimeProfiler(HookBase):
+ def __init__(
+ self,
+ forward=True,
+ backward=True,
+ interrupt=False,
+ warm_up=2,
+ sort_by="cuda_time_total",
+ row_limit=30,
+ ):
+ self.forward = forward
+ self.backward = backward
+ self.interrupt = interrupt
+ self.warm_up = warm_up
+ self.sort_by = sort_by
+ self.row_limit = row_limit
+
+ def before_train(self):
+ self.trainer.logger.info("Profiling runtime ...")
+ from torch.profiler import profile, record_function, ProfilerActivity
+
+ for i, input_dict in enumerate(self.trainer.train_loader):
+ if i == self.warm_up + 1:
+ break
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ if self.forward:
+ with profile(
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ ) as forward_prof:
+ with record_function("model_inference"):
+ output_dict = self.trainer.model(input_dict)
+ else:
+ output_dict = self.trainer.model(input_dict)
+ loss = output_dict["loss"]
+ if self.backward:
+ with profile(
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ ) as backward_prof:
+ with record_function("model_inference"):
+ loss.backward()
+ self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]")
+ if self.forward:
+ self.trainer.logger.info(
+ "Forward profile: \n"
+ + str(
+ forward_prof.key_averages().table(
+ sort_by=self.sort_by, row_limit=self.row_limit
+ )
+ )
+ )
+ forward_prof.export_chrome_trace(
+ os.path.join(self.trainer.cfg.save_path, "forward_trace.json")
+ )
+
+ if self.backward:
+ self.trainer.logger.info(
+ "Backward profile: \n"
+ + str(
+ backward_prof.key_averages().table(
+ sort_by=self.sort_by, row_limit=self.row_limit
+ )
+ )
+ )
+ backward_prof.export_chrome_trace(
+ os.path.join(self.trainer.cfg.save_path, "backward_trace.json")
+ )
+ if self.interrupt:
+ sys.exit(0)
+
+
+@HOOKS.register_module()
+class RuntimeProfilerV2(HookBase):
+ def __init__(
+ self,
+ interrupt=False,
+ wait=1,
+ warmup=1,
+ active=10,
+ repeat=1,
+ sort_by="cuda_time_total",
+ row_limit=30,
+ ):
+ self.interrupt = interrupt
+ self.wait = wait
+ self.warmup = warmup
+ self.active = active
+ self.repeat = repeat
+ self.sort_by = sort_by
+ self.row_limit = row_limit
+
+ def before_train(self):
+ self.trainer.logger.info("Profiling runtime ...")
+ from torch.profiler import (
+ profile,
+ record_function,
+ ProfilerActivity,
+ schedule,
+ tensorboard_trace_handler,
+ )
+
+ prof = profile(
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ schedule=schedule(
+ wait=self.wait,
+ warmup=self.warmup,
+ active=self.active,
+ repeat=self.repeat,
+ ),
+ on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path),
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ )
+ prof.start()
+ for i, input_dict in enumerate(self.trainer.train_loader):
+ if i >= (self.wait + self.warmup + self.active) * self.repeat:
+ break
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with record_function("model_forward"):
+ output_dict = self.trainer.model(input_dict)
+ loss = output_dict["loss"]
+ with record_function("model_backward"):
+ loss.backward()
+ prof.step()
+ self.trainer.logger.info(
+ f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]"
+ )
+ self.trainer.logger.info(
+ "Profile: \n"
+ + str(
+ prof.key_averages().table(
+ sort_by=self.sort_by, row_limit=self.row_limit
+ )
+ )
+ )
+ prof.stop()
+
+ if self.interrupt:
+ sys.exit(0)
diff --git a/engines/infer.py b/engines/infer.py
new file mode 100644
index 0000000..9b7b8f1
--- /dev/null
+++ b/engines/infer.py
@@ -0,0 +1,285 @@
+"""
+Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+"""
+
+import os
+import math
+import time
+import librosa
+import numpy as np
+from collections import OrderedDict
+
+import torch
+import torch.utils.data
+import torch.nn.functional as F
+
+from .defaults import create_ddp_model
+import utils.comm as comm
+from models import build_model
+from utils.logger import get_root_logger
+from utils.registry import Registry
+from utils.misc import (
+ AverageMeter,
+)
+
+from models.utils import smooth_mouth_movements, apply_frame_blending, apply_savitzky_golay_smoothing, apply_random_brow_movement, \
+ symmetrize_blendshapes, apply_random_eye_blinks, apply_random_eye_blinks_context, export_blendshape_animation, \
+ RETURN_CODE, DEFAULT_CONTEXT, ARKitBlendShape
+
+INFER = Registry("infer")
+
+class InferBase:
+ def __init__(self, cfg, model=None, verbose=False) -> None:
+ torch.multiprocessing.set_sharing_strategy("file_system")
+ self.logger = get_root_logger(
+ log_file=os.path.join(cfg.save_path, "infer.log"),
+ file_mode="a" if cfg.resume else "w",
+ )
+ self.logger.info("=> Loading config ...")
+ self.cfg = cfg
+ self.verbose = verbose
+ if self.verbose:
+ self.logger.info(f"Save path: {cfg.save_path}")
+ self.logger.info(f"Config:\n{cfg.pretty_text}")
+ if model is None:
+ self.logger.info("=> Building model ...")
+ self.model = self.build_model()
+ else:
+ self.model = model
+
+ def build_model(self):
+ model = build_model(self.cfg.model)
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ self.logger.info(f"Num params: {n_parameters}")
+ model = create_ddp_model(
+ model.cuda(),
+ broadcast_buffers=False,
+ find_unused_parameters=self.cfg.find_unused_parameters,
+ )
+ if os.path.isfile(self.cfg.weight):
+ self.logger.info(f"Loading weight at: {self.cfg.weight}")
+ checkpoint = torch.load(self.cfg.weight)
+ weight = OrderedDict()
+ for key, value in checkpoint["state_dict"].items():
+ if key.startswith("module."):
+ if comm.get_world_size() == 1:
+ key = key[7:] # module.xxx.xxx -> xxx.xxx
+ else:
+ if comm.get_world_size() > 1:
+ key = "module." + key # xxx.xxx -> module.xxx.xxx
+ weight[key] = value
+ model.load_state_dict(weight, strict=True)
+ self.logger.info(
+ "=> Loaded weight '{}'".format(
+ self.cfg.weight
+ )
+ )
+ else:
+ raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight))
+ return model
+
+
+ def infer(self):
+ raise NotImplementedError
+
+
+
+@INFER.register_module()
+class Audio2ExpressionInfer(InferBase):
+ def infer(self):
+ logger = get_root_logger()
+ logger.info(">>>>>>>>>>>>>>>> Start Inference >>>>>>>>>>>>>>>>")
+ batch_time = AverageMeter()
+ self.model.eval()
+
+ # process audio-input
+ assert os.path.exists(self.cfg.audio_input)
+ if(self.cfg.ex_vol):
+ logger.info("Extract vocals ...")
+ vocal_path = self.extract_vocal_track(self.cfg.audio_input)
+ logger.info("=> Extract vocals at: {}".format(vocal_path if os.path.exists(vocal_path) else '... Failed'))
+ if(os.path.exists(vocal_path)):
+ self.cfg.audio_input = vocal_path
+
+ with torch.no_grad():
+ input_dict = {}
+ input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
+ self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[None,...]
+ speech_array, ssr = librosa.load(self.cfg.audio_input, sr=16000)
+ input_dict['input_audio_array'] = torch.FloatTensor(speech_array).cuda(non_blocking=True)[None,...]
+
+ end = time.time()
+ output_dict = self.model(input_dict)
+ batch_time.update(time.time() - end)
+
+ logger.info(
+ "Infer: [{}] "
+ "Running Time: {batch_time.avg:.3f} ".format(
+ self.cfg.audio_input,
+ batch_time=batch_time,
+ )
+ )
+
+ out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()
+
+ frame_length = math.ceil(speech_array.shape[0] / ssr * 30)
+ volume = librosa.feature.rms(y=speech_array, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
+ if (volume.shape[0] > frame_length):
+ volume = volume[:frame_length]
+
+ if(self.cfg.movement_smooth):
+ out_exp = smooth_mouth_movements(out_exp, 0, volume)
+
+ if (self.cfg.brow_movement):
+ out_exp = apply_random_brow_movement(out_exp, volume)
+
+ pred_exp = self.blendshape_postprocess(out_exp)
+
+ if(self.cfg.save_json_path is not None):
+ export_blendshape_animation(pred_exp,
+ self.cfg.save_json_path,
+ ARKitBlendShape,
+ fps=self.cfg.fps)
+
+ logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
+
+ def infer_streaming_audio(self,
+ audio: np.ndarray,
+ ssr: float,
+ context: dict):
+
+ if (context is None):
+ context = DEFAULT_CONTEXT.copy()
+ max_frame_length = 64
+
+ frame_length = math.ceil(audio.shape[0] / ssr * 30)
+ output_context = DEFAULT_CONTEXT.copy()
+
+ volume = librosa.feature.rms(y=audio, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
+ if (volume.shape[0] > frame_length):
+ volume = volume[:frame_length]
+
+ # resample audio
+ if (ssr != self.cfg.audio_sr):
+ in_audio = librosa.resample(audio.astype(np.float32), orig_sr=ssr, target_sr=self.cfg.audio_sr)
+ else:
+ in_audio = audio.copy()
+
+ start_frame = int(max_frame_length - in_audio.shape[0] / self.cfg.audio_sr * 30)
+
+ if (context['is_initial_input'] or (context['previous_audio'] is None)):
+ blank_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
+ blank_audio = np.zeros(blank_audio_length, dtype=np.float32)
+
+ # pre-append
+ input_audio = np.concatenate([blank_audio, in_audio])
+ output_context['previous_audio'] = input_audio
+
+ else:
+ clip_pre_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
+ clip_pre_audio = context['previous_audio'][-clip_pre_audio_length:]
+ input_audio = np.concatenate([clip_pre_audio, in_audio])
+ output_context['previous_audio'] = input_audio
+
+ with torch.no_grad():
+ try:
+ input_dict = {}
+ input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
+ self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[
+ None, ...]
+ input_dict['input_audio_array'] = torch.FloatTensor(input_audio).cuda(non_blocking=True)[None, ...]
+ output_dict = self.model(input_dict)
+ out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()[start_frame:, :]
+ except:
+ self.logger.error('Error: faided to predict expression.')
+ output_dict['pred_exp'] = torch.zeros((max_frame_length, 52)).float()
+ return
+
+
+ # post-process
+ if (context['previous_expression'] is None):
+ out_exp = self.apply_expression_postprocessing(out_exp, audio_volume=volume)
+ else:
+ previous_length = context['previous_expression'].shape[0]
+ out_exp = self.apply_expression_postprocessing(expression_params = np.concatenate([context['previous_expression'], out_exp], axis=0),
+ audio_volume=np.concatenate([context['previous_volume'], volume], axis=0),
+ processed_frames=previous_length)[previous_length:, :]
+
+ if (context['previous_expression'] is not None):
+ output_context['previous_expression'] = np.concatenate([context['previous_expression'], out_exp], axis=0)[
+ -max_frame_length:, :]
+ output_context['previous_volume'] = np.concatenate([context['previous_volume'], volume], axis=0)[-max_frame_length:]
+ else:
+ output_context['previous_expression'] = out_exp.copy()
+ output_context['previous_volume'] = volume.copy()
+
+ output_context['first_input_flag'] = False
+
+ return {"code": RETURN_CODE['SUCCESS'],
+ "expression": out_exp,
+ "headpose": None}, output_context
+ def apply_expression_postprocessing(
+ self,
+ expression_params: np.ndarray,
+ processed_frames: int = 0,
+ audio_volume: np.ndarray = None
+ ) -> np.ndarray:
+ """Applies full post-processing pipeline to facial expression parameters.
+
+ Args:
+ expression_params: Raw output from animation model [num_frames, num_parameters]
+ processed_frames: Number of frames already processed in previous batches
+ audio_volume: Optional volume array for audio-visual synchronization
+
+ Returns:
+ Processed expression parameters ready for animation synthesis
+ """
+ # Pipeline execution order matters - maintain sequence
+ expression_params = smooth_mouth_movements(expression_params, processed_frames, audio_volume)
+ expression_params = apply_frame_blending(expression_params, processed_frames)
+ expression_params, _ = apply_savitzky_golay_smoothing(expression_params, window_length=5)
+ expression_params = symmetrize_blendshapes(expression_params)
+ expression_params = apply_random_eye_blinks_context(expression_params, processed_frames=processed_frames)
+
+ return expression_params
+
+ def extract_vocal_track(
+ self,
+ input_audio_path: str
+ ) -> str:
+ """Isolates vocal track from audio file using source separation.
+
+ Args:
+ input_audio_path: Path to input audio file containing vocals+accompaniment
+
+ Returns:
+ Path to isolated vocal track in WAV format
+ """
+ separation_command = f'spleeter separate -p spleeter:2stems -o {self.cfg.save_path} {input_audio_path}'
+ os.system(separation_command)
+
+ base_name = os.path.splitext(os.path.basename(input_audio_path))[0]
+ return os.path.join(self.cfg.save_path, base_name, 'vocals.wav')
+
+ def blendshape_postprocess(self,
+ bs_array: np.ndarray
+ )->np.array:
+
+ bs_array, _ = apply_savitzky_golay_smoothing(bs_array, window_length=5)
+ bs_array = symmetrize_blendshapes(bs_array)
+ bs_array = apply_random_eye_blinks(bs_array)
+
+ return bs_array
diff --git a/engines/launch.py b/engines/launch.py
new file mode 100644
index 0000000..05f5671
--- /dev/null
+++ b/engines/launch.py
@@ -0,0 +1,135 @@
+"""
+Launcher
+
+modified from detectron2(https://github.com/facebookresearch/detectron2)
+
+"""
+
+import os
+import logging
+from datetime import timedelta
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from utils import comm
+
+__all__ = ["DEFAULT_TIMEOUT", "launch"]
+
+DEFAULT_TIMEOUT = timedelta(minutes=30)
+
+
+def _find_free_port():
+ import socket
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(("", 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+
+def launch(
+ main_func,
+ num_gpus_per_machine,
+ num_machines=1,
+ machine_rank=0,
+ dist_url=None,
+ cfg=(),
+ timeout=DEFAULT_TIMEOUT,
+):
+ """
+ Launch multi-gpu or distributed training.
+ This function must be called on all machines involved in the training.
+ It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
+ Args:
+ main_func: a function that will be called by `main_func(*args)`
+ num_gpus_per_machine (int): number of GPUs per machine
+ num_machines (int): the total number of machines
+ machine_rank (int): the rank of this machine
+ dist_url (str): url to connect to for distributed jobs, including protocol
+ e.g. "tcp://127.0.0.1:8686".
+ Can be set to "auto" to automatically select a free port on localhost
+ timeout (timedelta): timeout of the distributed workers
+ args (tuple): arguments passed to main_func
+ """
+ world_size = num_machines * num_gpus_per_machine
+ if world_size > 1:
+ if dist_url == "auto":
+ assert (
+ num_machines == 1
+ ), "dist_url=auto not supported in multi-machine jobs."
+ port = _find_free_port()
+ dist_url = f"tcp://127.0.0.1:{port}"
+ if num_machines > 1 and dist_url.startswith("file://"):
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
+ )
+
+ mp.spawn(
+ _distributed_worker,
+ nprocs=num_gpus_per_machine,
+ args=(
+ main_func,
+ world_size,
+ num_gpus_per_machine,
+ machine_rank,
+ dist_url,
+ cfg,
+ timeout,
+ ),
+ daemon=False,
+ )
+ else:
+ main_func(*cfg)
+
+
+def _distributed_worker(
+ local_rank,
+ main_func,
+ world_size,
+ num_gpus_per_machine,
+ machine_rank,
+ dist_url,
+ cfg,
+ timeout=DEFAULT_TIMEOUT,
+):
+ assert (
+ torch.cuda.is_available()
+ ), "cuda is not available. Please check your installation."
+ global_rank = machine_rank * num_gpus_per_machine + local_rank
+ try:
+ dist.init_process_group(
+ backend="NCCL",
+ init_method=dist_url,
+ world_size=world_size,
+ rank=global_rank,
+ timeout=timeout,
+ )
+ except Exception as e:
+ logger = logging.getLogger(__name__)
+ logger.error("Process group URL: {}".format(dist_url))
+ raise e
+
+ # Setup the local process group (which contains ranks within the same machine)
+ assert comm._LOCAL_PROCESS_GROUP is None
+ num_machines = world_size // num_gpus_per_machine
+ for i in range(num_machines):
+ ranks_on_i = list(
+ range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
+ )
+ pg = dist.new_group(ranks_on_i)
+ if i == machine_rank:
+ comm._LOCAL_PROCESS_GROUP = pg
+
+ assert num_gpus_per_machine <= torch.cuda.device_count()
+ torch.cuda.set_device(local_rank)
+
+ # synchronize is needed here to prevent a possible timeout after calling init_process_group
+ # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
+ comm.synchronize()
+
+ main_func(*cfg)
diff --git a/engines/train.py b/engines/train.py
new file mode 100644
index 0000000..7de2364
--- /dev/null
+++ b/engines/train.py
@@ -0,0 +1,299 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import os
+import sys
+import weakref
+import torch
+import torch.nn as nn
+import torch.utils.data
+from functools import partial
+
+if sys.version_info >= (3, 10):
+ from collections.abc import Iterator
+else:
+ from collections import Iterator
+from tensorboardX import SummaryWriter
+
+from .defaults import create_ddp_model, worker_init_fn
+from .hooks import HookBase, build_hooks
+import utils.comm as comm
+from datasets import build_dataset, point_collate_fn, collate_fn
+from models import build_model
+from utils.logger import get_root_logger
+from utils.optimizer import build_optimizer
+from utils.scheduler import build_scheduler
+from utils.events import EventStorage
+from utils.registry import Registry
+
+
+TRAINERS = Registry("trainers")
+
+
+class TrainerBase:
+ def __init__(self) -> None:
+ self.hooks = []
+ self.epoch = 0
+ self.start_epoch = 0
+ self.max_epoch = 0
+ self.max_iter = 0
+ self.comm_info = dict()
+ self.data_iterator: Iterator = enumerate([])
+ self.storage: EventStorage
+ self.writer: SummaryWriter
+
+ def register_hooks(self, hooks) -> None:
+ hooks = build_hooks(hooks)
+ for h in hooks:
+ assert isinstance(h, HookBase)
+ # To avoid circular reference, hooks and trainer cannot own each other.
+ # This normally does not matter, but will cause memory leak if the
+ # involved objects contain __del__:
+ # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
+ h.trainer = weakref.proxy(self)
+ self.hooks.extend(hooks)
+
+ def train(self):
+ with EventStorage() as self.storage:
+ # => before train
+ self.before_train()
+ for self.epoch in range(self.start_epoch, self.max_epoch):
+ # => before epoch
+ self.before_epoch()
+ # => run_epoch
+ for (
+ self.comm_info["iter"],
+ self.comm_info["input_dict"],
+ ) in self.data_iterator:
+ # => before_step
+ self.before_step()
+ # => run_step
+ self.run_step()
+ # => after_step
+ self.after_step()
+ # => after epoch
+ self.after_epoch()
+ # => after train
+ self.after_train()
+
+ def before_train(self):
+ for h in self.hooks:
+ h.before_train()
+
+ def before_epoch(self):
+ for h in self.hooks:
+ h.before_epoch()
+
+ def before_step(self):
+ for h in self.hooks:
+ h.before_step()
+
+ def run_step(self):
+ raise NotImplementedError
+
+ def after_step(self):
+ for h in self.hooks:
+ h.after_step()
+
+ def after_epoch(self):
+ for h in self.hooks:
+ h.after_epoch()
+ self.storage.reset_histories()
+
+ def after_train(self):
+ # Sync GPU before running train hooks
+ comm.synchronize()
+ for h in self.hooks:
+ h.after_train()
+ if comm.is_main_process():
+ self.writer.close()
+
+
+@TRAINERS.register_module("DefaultTrainer")
+class Trainer(TrainerBase):
+ def __init__(self, cfg):
+ super(Trainer, self).__init__()
+ self.epoch = 0
+ self.start_epoch = 0
+ self.max_epoch = cfg.eval_epoch
+ self.best_metric_value = -torch.inf
+ self.logger = get_root_logger(
+ log_file=os.path.join(cfg.save_path, "train.log"),
+ file_mode="a" if cfg.resume else "w",
+ )
+ self.logger.info("=> Loading config ...")
+ self.cfg = cfg
+ self.logger.info(f"Save path: {cfg.save_path}")
+ self.logger.info(f"Config:\n{cfg.pretty_text}")
+ self.logger.info("=> Building model ...")
+ self.model = self.build_model()
+ self.logger.info("=> Building writer ...")
+ self.writer = self.build_writer()
+ self.logger.info("=> Building train dataset & dataloader ...")
+ self.train_loader = self.build_train_loader()
+ self.logger.info("=> Building val dataset & dataloader ...")
+ self.val_loader = self.build_val_loader()
+ self.logger.info("=> Building optimize, scheduler, scaler(amp) ...")
+ self.optimizer = self.build_optimizer()
+ self.scheduler = self.build_scheduler()
+ self.scaler = self.build_scaler()
+ self.logger.info("=> Building hooks ...")
+ self.register_hooks(self.cfg.hooks)
+
+ def train(self):
+ with EventStorage() as self.storage:
+ # => before train
+ self.before_train()
+ self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
+ for self.epoch in range(self.start_epoch, self.max_epoch):
+ # => before epoch
+ # TODO: optimize to iteration based
+ if comm.get_world_size() > 1:
+ self.train_loader.sampler.set_epoch(self.epoch)
+ self.model.train()
+ self.data_iterator = enumerate(self.train_loader)
+ self.before_epoch()
+ # => run_epoch
+ for (
+ self.comm_info["iter"],
+ self.comm_info["input_dict"],
+ ) in self.data_iterator:
+ # => before_step
+ self.before_step()
+ # => run_step
+ self.run_step()
+ # => after_step
+ self.after_step()
+ # => after epoch
+ self.after_epoch()
+ # => after train
+ self.after_train()
+
+ def run_step(self):
+ input_dict = self.comm_info["input_dict"]
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp):
+ output_dict = self.model(input_dict)
+ loss = output_dict["loss"]
+ self.optimizer.zero_grad()
+ if self.cfg.enable_amp:
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.optimizer)
+
+ # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
+ # Fix torch warning scheduler step before optimizer step.
+ scaler = self.scaler.get_scale()
+ self.scaler.update()
+ if scaler <= self.scaler.get_scale():
+ self.scheduler.step()
+ else:
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step()
+ if self.cfg.empty_cache:
+ torch.cuda.empty_cache()
+ self.comm_info["model_output_dict"] = output_dict
+
+ def build_model(self):
+ model = build_model(self.cfg.model)
+ if self.cfg.sync_bn:
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ # logger.info(f"Model: \n{self.model}")
+ self.logger.info(f"Num params: {n_parameters}")
+ model = create_ddp_model(
+ model.cuda(),
+ broadcast_buffers=False,
+ find_unused_parameters=self.cfg.find_unused_parameters,
+ )
+ return model
+
+ def build_writer(self):
+ writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None
+ self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}")
+ return writer
+
+ def build_train_loader(self):
+ train_data = build_dataset(self.cfg.data.train)
+
+ if comm.get_world_size() > 1:
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
+ else:
+ train_sampler = None
+
+ init_fn = (
+ partial(
+ worker_init_fn,
+ num_workers=self.cfg.num_worker_per_gpu,
+ rank=comm.get_rank(),
+ seed=self.cfg.seed,
+ )
+ if self.cfg.seed is not None
+ else None
+ )
+
+ train_loader = torch.utils.data.DataLoader(
+ train_data,
+ batch_size=self.cfg.batch_size_per_gpu,
+ shuffle=(train_sampler is None),
+ num_workers=0,
+ sampler=train_sampler,
+ collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob),
+ pin_memory=True,
+ worker_init_fn=init_fn,
+ drop_last=True,
+ # persistent_workers=True,
+ )
+ return train_loader
+
+ def build_val_loader(self):
+ val_loader = None
+ if self.cfg.evaluate:
+ val_data = build_dataset(self.cfg.data.val)
+ if comm.get_world_size() > 1:
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
+ else:
+ val_sampler = None
+ val_loader = torch.utils.data.DataLoader(
+ val_data,
+ batch_size=self.cfg.batch_size_val_per_gpu,
+ shuffle=False,
+ num_workers=self.cfg.num_worker_per_gpu,
+ pin_memory=True,
+ sampler=val_sampler,
+ collate_fn=collate_fn,
+ )
+ return val_loader
+
+ def build_optimizer(self):
+ return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts)
+
+ def build_scheduler(self):
+ assert hasattr(self, "optimizer")
+ assert hasattr(self, "train_loader")
+ self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch
+ return build_scheduler(self.cfg.scheduler, self.optimizer)
+
+ def build_scaler(self):
+ scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None
+ return scaler
+
+
+@TRAINERS.register_module("MultiDatasetTrainer")
+class MultiDatasetTrainer(Trainer):
+ def build_train_loader(self):
+ from datasets import MultiDatasetDataloader
+
+ train_data = build_dataset(self.cfg.data.train)
+ train_loader = MultiDatasetDataloader(
+ train_data,
+ self.cfg.batch_size_per_gpu,
+ self.cfg.num_worker_per_gpu,
+ self.cfg.mix_prob,
+ self.cfg.seed,
+ )
+ self.comm_info["iter_per_epoch"] = len(train_loader)
+ return train_loader
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000..37ac22e
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,48 @@
+"""
+# Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+"""
+
+from engines.defaults import (
+ default_argument_parser,
+ default_config_parser,
+ default_setup,
+)
+from engines.infer import INFER
+from engines.launch import launch
+
+
+def main_worker(cfg):
+ cfg = default_setup(cfg)
+ infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg))
+ infer.infer()
+
+
+def main():
+ args = default_argument_parser().parse_args()
+ cfg = default_config_parser(args.config_file, args.options)
+
+ launch(
+ main_worker,
+ num_gpus_per_machine=args.num_gpus,
+ num_machines=args.num_machines,
+ machine_rank=args.machine_rank,
+ dist_url=args.dist_url,
+ cfg=(cfg,),
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/inference_streaming_audio.py b/inference_streaming_audio.py
new file mode 100644
index 0000000..c14b084
--- /dev/null
+++ b/inference_streaming_audio.py
@@ -0,0 +1,60 @@
+"""
+# Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+"""
+
+import numpy as np
+
+from engines.defaults import (
+ default_argument_parser,
+ default_config_parser,
+ default_setup,
+)
+from engines.infer import INFER
+import librosa
+from tqdm import tqdm
+import time
+
+
+def export_json(bs_array, json_path):
+ from models.utils import export_blendshape_animation, ARKitBlendShape
+ export_blendshape_animation(bs_array, json_path, ARKitBlendShape, fps=30.0)
+
+if __name__ == '__main__':
+ args = default_argument_parser().parse_args()
+ args.config_file = 'configs/lam_audio2exp_config_streaming.py'
+ cfg = default_config_parser(args.config_file, args.options)
+
+
+ cfg = default_setup(cfg)
+ infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg))
+ infer.model.eval()
+
+ audio, sample_rate = librosa.load(cfg.audio_input, sr=16000)
+ context = None
+ input_num = audio.shape[0]//16000+1
+ gap = 16000
+ all_exp = []
+ for i in tqdm(range(input_num)):
+
+ start = time.time()
+ output, context = infer.infer_streaming_audio(audio[i*gap:(i+1)*gap], sample_rate, context)
+ end = time.time()
+ print('Inference time {}'.format(end - start))
+ all_exp.append(output['expression'])
+
+ all_exp = np.concatenate(all_exp,axis=0)
+
+ export_json(all_exp, cfg.save_json_path)
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..f4beb83
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,7 @@
+from .builder import build_model
+
+from .default import DefaultEstimator
+
+# Backbones
+from .network import Audio2Expression
+
diff --git a/models/builder.py b/models/builder.py
new file mode 100644
index 0000000..eed2627
--- /dev/null
+++ b/models/builder.py
@@ -0,0 +1,13 @@
+"""
+Modified by https://github.com/Pointcept/Pointcept
+"""
+
+from utils.registry import Registry
+
+MODELS = Registry("models")
+MODULES = Registry("modules")
+
+
+def build_model(cfg):
+ """Build models."""
+ return MODELS.build(cfg)
diff --git a/models/default.py b/models/default.py
new file mode 100644
index 0000000..07655f6
--- /dev/null
+++ b/models/default.py
@@ -0,0 +1,25 @@
+import torch.nn as nn
+
+from models.losses import build_criteria
+from .builder import MODELS, build_model
+
+@MODELS.register_module()
+class DefaultEstimator(nn.Module):
+ def __init__(self, backbone=None, criteria=None):
+ super().__init__()
+ self.backbone = build_model(backbone)
+ self.criteria = build_criteria(criteria)
+
+ def forward(self, input_dict):
+ pred_exp = self.backbone(input_dict)
+ # train
+ if self.training:
+ loss = self.criteria(pred_exp, input_dict["gt_exp"])
+ return dict(loss=loss)
+ # eval
+ elif "gt_exp" in input_dict.keys():
+ loss = self.criteria(pred_exp, input_dict["gt_exp"])
+ return dict(loss=loss, pred_exp=pred_exp)
+ # infer
+ else:
+ return dict(pred_exp=pred_exp)
diff --git a/models/encoder/wav2vec.py b/models/encoder/wav2vec.py
new file mode 100644
index 0000000..f11fc57
--- /dev/null
+++ b/models/encoder/wav2vec.py
@@ -0,0 +1,248 @@
+import numpy as np
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from dataclasses import dataclass
+from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.file_utils import ModelOutput
+
+
+_CONFIG_FOR_DOC = "Wav2Vec2Config"
+_HIDDEN_STATES_START_POSITION = 2
+
+
+# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
+# initialize our encoder with the pre-trained wav2vec 2.0 weights.
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ all_num_mask = int(
+ mask_prob * all_sz / float(mask_length)
+ + np.random.rand()
+ )
+ all_num_mask = max(min_masks, all_num_mask)
+ mask_idcs = []
+ padding_mask = attention_mask.ne(1) if attention_mask is not None else None
+ for i in range(bsz):
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ num_mask = int(
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = all_sz
+ num_mask = all_num_mask
+
+ lengths = np.full(num_mask, mask_length)
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+ min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ if len(mask_idc) > min_len:
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+ mask[i, mask_idc] = True
+ return mask
+
+
+# linear interpolation layer
+def linear_interpolation(features, input_fps, output_fps, output_len=None):
+ features = features.transpose(1, 2)
+ seq_len = features.shape[2] / float(input_fps)
+ if output_len is None:
+ output_len = int(seq_len * output_fps)
+ output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
+ return output_features.transpose(1, 2)
+
+
+class Wav2Vec2Model(Wav2Vec2Model):
+ def __init__(self, config):
+ super().__init__(config)
+ self.lm_head = nn.Linear(1024, 32)
+
+ def forward(
+ self,
+ input_values,
+ attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ frame_num=None
+ ):
+ self.config.output_attentions = True
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.feature_extractor(input_values)
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
+
+ if attention_mask is not None:
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
+ attention_mask = torch.zeros(
+ hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
+ )
+ attention_mask[
+ (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
+ ] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+
+ hidden_states = self.feature_projection(hidden_states)[0]
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = encoder_outputs[0]
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@dataclass
+class SpeechClassifierOutput(ModelOutput):
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class Wav2Vec2ClassificationHead(nn.Module):
+ """Head for wav2vec classification task."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.final_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ x = features
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.tanh(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.pooling_mode = config.pooling_mode
+ self.config = config
+
+ self.wav2vec2 = Wav2Vec2Model(config)
+ self.classifier = Wav2Vec2ClassificationHead(config)
+
+ self.init_weights()
+
+ def freeze_feature_extractor(self):
+ self.wav2vec2.feature_extractor._freeze_parameters()
+
+ def merged_strategy(
+ self,
+ hidden_states,
+ mode="mean"
+ ):
+ if mode == "mean":
+ outputs = torch.mean(hidden_states, dim=1)
+ elif mode == "sum":
+ outputs = torch.sum(hidden_states, dim=1)
+ elif mode == "max":
+ outputs = torch.max(hidden_states, dim=1)[0]
+ else:
+ raise Exception(
+ "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
+
+ return outputs
+
+ def forward(
+ self,
+ input_values,
+ attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ labels=None,
+ frame_num=None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ outputs = self.wav2vec2(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ hidden_states1 = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
+ hidden_states = self.merged_strategy(hidden_states1, mode=self.pooling_mode)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SpeechClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=hidden_states1,
+ attentions=outputs.attentions,
+ )
diff --git a/models/encoder/wavlm.py b/models/encoder/wavlm.py
new file mode 100644
index 0000000..0e39b9b
--- /dev/null
+++ b/models/encoder/wavlm.py
@@ -0,0 +1,87 @@
+import numpy as np
+import torch
+from transformers import WavLMModel
+from transformers.modeling_outputs import Wav2Vec2BaseModelOutput
+from typing import Optional, Tuple, Union
+import torch.nn.functional as F
+
+def linear_interpolation(features, output_len: int):
+ features = features.transpose(1, 2)
+ output_features = F.interpolate(
+ features, size=output_len, align_corners=True, mode='linear')
+ return output_features.transpose(1, 2)
+
+# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
+# initialize our encoder with the pre-trained wav2vec 2.0 weights.
+
+
+class WavLMModel(WavLMModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
+ for param in self.parameters():
+ param.requires_grad = (not do_freeze)
+
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ frame_num=None,
+ interpolate_pos: int = 0,
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if interpolate_pos == 0:
+ extract_features = linear_interpolation(
+ extract_features, output_len=frame_num)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if interpolate_pos == 1:
+ hidden_states = linear_interpolation(
+ hidden_states, output_len=frame_num)
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return Wav2Vec2BaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
\ No newline at end of file
diff --git a/models/losses/__init__.py b/models/losses/__init__.py
new file mode 100644
index 0000000..782a0d3
--- /dev/null
+++ b/models/losses/__init__.py
@@ -0,0 +1,4 @@
+from .builder import build_criteria
+
+from .misc import CrossEntropyLoss, SmoothCELoss, DiceLoss, FocalLoss, BinaryFocalLoss, L1Loss
+from .lovasz import LovaszLoss
diff --git a/models/losses/builder.py b/models/losses/builder.py
new file mode 100644
index 0000000..ec936be
--- /dev/null
+++ b/models/losses/builder.py
@@ -0,0 +1,28 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+from utils.registry import Registry
+
+LOSSES = Registry("losses")
+
+
+class Criteria(object):
+ def __init__(self, cfg=None):
+ self.cfg = cfg if cfg is not None else []
+ self.criteria = []
+ for loss_cfg in self.cfg:
+ self.criteria.append(LOSSES.build(cfg=loss_cfg))
+
+ def __call__(self, pred, target):
+ if len(self.criteria) == 0:
+ # loss computation occur in model
+ return pred
+ loss = 0
+ for c in self.criteria:
+ loss += c(pred, target)
+ return loss
+
+
+def build_criteria(cfg):
+ return Criteria(cfg)
diff --git a/models/losses/lovasz.py b/models/losses/lovasz.py
new file mode 100644
index 0000000..dbdb844
--- /dev/null
+++ b/models/losses/lovasz.py
@@ -0,0 +1,253 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+from typing import Optional
+from itertools import filterfalse
+import torch
+import torch.nn.functional as F
+from torch.nn.modules.loss import _Loss
+
+from .builder import LOSSES
+
+BINARY_MODE: str = "binary"
+MULTICLASS_MODE: str = "multiclass"
+MULTILABEL_MODE: str = "multilabel"
+
+
+def _lovasz_grad(gt_sorted):
+ """Compute gradient of the Lovasz extension w.r.t sorted errors
+ See Alg. 1 in paper
+ """
+ p = len(gt_sorted)
+ gts = gt_sorted.sum()
+ intersection = gts - gt_sorted.float().cumsum(0)
+ union = gts + (1 - gt_sorted).float().cumsum(0)
+ jaccard = 1.0 - intersection / union
+ if p > 1: # cover 1-pixel case
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
+ return jaccard
+
+
+def _lovasz_hinge(logits, labels, per_image=True, ignore=None):
+ """
+ Binary Lovasz hinge loss
+ logits: [B, H, W] Logits at each pixel (between -infinity and +infinity)
+ labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
+ per_image: compute the loss per image instead of per batch
+ ignore: void class id
+ """
+ if per_image:
+ loss = mean(
+ _lovasz_hinge_flat(
+ *_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)
+ )
+ for log, lab in zip(logits, labels)
+ )
+ else:
+ loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore))
+ return loss
+
+
+def _lovasz_hinge_flat(logits, labels):
+ """Binary Lovasz hinge loss
+ Args:
+ logits: [P] Logits at each prediction (between -infinity and +infinity)
+ labels: [P] Tensor, binary ground truth labels (0 or 1)
+ """
+ if len(labels) == 0:
+ # only void pixels, the gradients should be 0
+ return logits.sum() * 0.0
+ signs = 2.0 * labels.float() - 1.0
+ errors = 1.0 - logits * signs
+ errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
+ perm = perm.data
+ gt_sorted = labels[perm]
+ grad = _lovasz_grad(gt_sorted)
+ loss = torch.dot(F.relu(errors_sorted), grad)
+ return loss
+
+
+def _flatten_binary_scores(scores, labels, ignore=None):
+ """Flattens predictions in the batch (binary case)
+ Remove labels equal to 'ignore'
+ """
+ scores = scores.view(-1)
+ labels = labels.view(-1)
+ if ignore is None:
+ return scores, labels
+ valid = labels != ignore
+ vscores = scores[valid]
+ vlabels = labels[valid]
+ return vscores, vlabels
+
+
+def _lovasz_softmax(
+ probas, labels, classes="present", class_seen=None, per_image=False, ignore=None
+):
+ """Multi-class Lovasz-Softmax loss
+ Args:
+ @param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1).
+ Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
+ @param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
+ @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
+ @param per_image: compute the loss per image instead of per batch
+ @param ignore: void class labels
+ """
+ if per_image:
+ loss = mean(
+ _lovasz_softmax_flat(
+ *_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore),
+ classes=classes
+ )
+ for prob, lab in zip(probas, labels)
+ )
+ else:
+ loss = _lovasz_softmax_flat(
+ *_flatten_probas(probas, labels, ignore),
+ classes=classes,
+ class_seen=class_seen
+ )
+ return loss
+
+
+def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None):
+ """Multi-class Lovasz-Softmax loss
+ Args:
+ @param probas: [P, C] Class probabilities at each prediction (between 0 and 1)
+ @param labels: [P] Tensor, ground truth labels (between 0 and C - 1)
+ @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
+ """
+ if probas.numel() == 0:
+ # only void pixels, the gradients should be 0
+ return probas * 0.0
+ C = probas.size(1)
+ losses = []
+ class_to_sum = list(range(C)) if classes in ["all", "present"] else classes
+ # for c in class_to_sum:
+ for c in labels.unique():
+ if class_seen is None:
+ fg = (labels == c).type_as(probas) # foreground for class c
+ if classes == "present" and fg.sum() == 0:
+ continue
+ if C == 1:
+ if len(classes) > 1:
+ raise ValueError("Sigmoid output possible only with 1 class")
+ class_pred = probas[:, 0]
+ else:
+ class_pred = probas[:, c]
+ errors = (fg - class_pred).abs()
+ errors_sorted, perm = torch.sort(errors, 0, descending=True)
+ perm = perm.data
+ fg_sorted = fg[perm]
+ losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
+ else:
+ if c in class_seen:
+ fg = (labels == c).type_as(probas) # foreground for class c
+ if classes == "present" and fg.sum() == 0:
+ continue
+ if C == 1:
+ if len(classes) > 1:
+ raise ValueError("Sigmoid output possible only with 1 class")
+ class_pred = probas[:, 0]
+ else:
+ class_pred = probas[:, c]
+ errors = (fg - class_pred).abs()
+ errors_sorted, perm = torch.sort(errors, 0, descending=True)
+ perm = perm.data
+ fg_sorted = fg[perm]
+ losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
+ return mean(losses)
+
+
+def _flatten_probas(probas, labels, ignore=None):
+ """Flattens predictions in the batch"""
+ if probas.dim() == 3:
+ # assumes output of a sigmoid layer
+ B, H, W = probas.size()
+ probas = probas.view(B, 1, H, W)
+
+ C = probas.size(1)
+ probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C]
+ probas = probas.contiguous().view(-1, C) # [P, C]
+
+ labels = labels.view(-1)
+ if ignore is None:
+ return probas, labels
+ valid = labels != ignore
+ vprobas = probas[valid]
+ vlabels = labels[valid]
+ return vprobas, vlabels
+
+
+def isnan(x):
+ return x != x
+
+
+def mean(values, ignore_nan=False, empty=0):
+ """Nan-mean compatible with generators."""
+ values = iter(values)
+ if ignore_nan:
+ values = filterfalse(isnan, values)
+ try:
+ n = 1
+ acc = next(values)
+ except StopIteration:
+ if empty == "raise":
+ raise ValueError("Empty mean")
+ return empty
+ for n, v in enumerate(values, 2):
+ acc += v
+ if n == 1:
+ return acc
+ return acc / n
+
+
+@LOSSES.register_module()
+class LovaszLoss(_Loss):
+ def __init__(
+ self,
+ mode: str,
+ class_seen: Optional[int] = None,
+ per_image: bool = False,
+ ignore_index: Optional[int] = None,
+ loss_weight: float = 1.0,
+ ):
+ """Lovasz loss for segmentation task.
+ It supports binary, multiclass and multilabel cases
+ Args:
+ mode: Loss mode 'binary', 'multiclass' or 'multilabel'
+ ignore_index: Label that indicates ignored pixels (does not contribute to loss)
+ per_image: If True loss computed per each image and then averaged, else computed per whole batch
+ Shape
+ - **y_pred** - torch.Tensor of shape (N, C, H, W)
+ - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
+ Reference
+ https://github.com/BloodAxe/pytorch-toolbelt
+ """
+ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
+ super().__init__()
+
+ self.mode = mode
+ self.ignore_index = ignore_index
+ self.per_image = per_image
+ self.class_seen = class_seen
+ self.loss_weight = loss_weight
+
+ def forward(self, y_pred, y_true):
+ if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
+ loss = _lovasz_hinge(
+ y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index
+ )
+ elif self.mode == MULTICLASS_MODE:
+ y_pred = y_pred.softmax(dim=1)
+ loss = _lovasz_softmax(
+ y_pred,
+ y_true,
+ class_seen=self.class_seen,
+ per_image=self.per_image,
+ ignore=self.ignore_index,
+ )
+ else:
+ raise ValueError("Wrong mode {}.".format(self.mode))
+ return loss * self.loss_weight
diff --git a/models/losses/misc.py b/models/losses/misc.py
new file mode 100644
index 0000000..48e26bb
--- /dev/null
+++ b/models/losses/misc.py
@@ -0,0 +1,241 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .builder import LOSSES
+
+
+@LOSSES.register_module()
+class CrossEntropyLoss(nn.Module):
+ def __init__(
+ self,
+ weight=None,
+ size_average=None,
+ reduce=None,
+ reduction="mean",
+ label_smoothing=0.0,
+ loss_weight=1.0,
+ ignore_index=-1,
+ ):
+ super(CrossEntropyLoss, self).__init__()
+ weight = torch.tensor(weight).cuda() if weight is not None else None
+ self.loss_weight = loss_weight
+ self.loss = nn.CrossEntropyLoss(
+ weight=weight,
+ size_average=size_average,
+ ignore_index=ignore_index,
+ reduce=reduce,
+ reduction=reduction,
+ label_smoothing=label_smoothing,
+ )
+
+ def forward(self, pred, target):
+ return self.loss(pred, target) * self.loss_weight
+
+
+@LOSSES.register_module()
+class L1Loss(nn.Module):
+ def __init__(
+ self,
+ weight=None,
+ size_average=None,
+ reduce=None,
+ reduction="mean",
+ label_smoothing=0.0,
+ loss_weight=1.0,
+ ignore_index=-1,
+ ):
+ super(L1Loss, self).__init__()
+ weight = torch.tensor(weight).cuda() if weight is not None else None
+ self.loss_weight = loss_weight
+ self.loss = nn.L1Loss(reduction='mean')
+
+ def forward(self, pred, target):
+ return self.loss(pred, target[:,None]) * self.loss_weight
+
+
+@LOSSES.register_module()
+class SmoothCELoss(nn.Module):
+ def __init__(self, smoothing_ratio=0.1):
+ super(SmoothCELoss, self).__init__()
+ self.smoothing_ratio = smoothing_ratio
+
+ def forward(self, pred, target):
+ eps = self.smoothing_ratio
+ n_class = pred.size(1)
+ one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
+ log_prb = F.log_softmax(pred, dim=1)
+ loss = -(one_hot * log_prb).total(dim=1)
+ loss = loss[torch.isfinite(loss)].mean()
+ return loss
+
+
+@LOSSES.register_module()
+class BinaryFocalLoss(nn.Module):
+ def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0):
+ """Binary Focal Loss
+ `
+ """
+ super(BinaryFocalLoss, self).__init__()
+ assert 0 < alpha < 1
+ self.gamma = gamma
+ self.alpha = alpha
+ self.logits = logits
+ self.reduce = reduce
+ self.loss_weight = loss_weight
+
+ def forward(self, pred, target, **kwargs):
+ """Forward function.
+ Args:
+ pred (torch.Tensor): The prediction with shape (N)
+ target (torch.Tensor): The ground truth. If containing class
+ indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities,
+ same shape as the input.
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ if self.logits:
+ bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
+ else:
+ bce = F.binary_cross_entropy(pred, target, reduction="none")
+ pt = torch.exp(-bce)
+ alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
+ focal_loss = alpha * (1 - pt) ** self.gamma * bce
+
+ if self.reduce:
+ focal_loss = torch.mean(focal_loss)
+ return focal_loss * self.loss_weight
+
+
+@LOSSES.register_module()
+class FocalLoss(nn.Module):
+ def __init__(
+ self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1
+ ):
+ """Focal Loss
+ `
+ """
+ super(FocalLoss, self).__init__()
+ assert reduction in (
+ "mean",
+ "sum",
+ ), "AssertionError: reduction should be 'mean' or 'sum'"
+ assert isinstance(
+ alpha, (float, list)
+ ), "AssertionError: alpha should be of type float"
+ assert isinstance(gamma, float), "AssertionError: gamma should be of type float"
+ assert isinstance(
+ loss_weight, float
+ ), "AssertionError: loss_weight should be of type float"
+ assert isinstance(ignore_index, int), "ignore_index must be of type int"
+ self.gamma = gamma
+ self.alpha = alpha
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.ignore_index = ignore_index
+
+ def forward(self, pred, target, **kwargs):
+ """Forward function.
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes.
+ target (torch.Tensor): The ground truth. If containing class
+ indices, shape (N) where each value is 0≤targets[i]≤C−1, If containing class probabilities,
+ same shape as the input.
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
+ pred = pred.transpose(0, 1)
+ # [C, B, d_1, d_2, ..., d_k] -> [C, N]
+ pred = pred.reshape(pred.size(0), -1)
+ # [C, N] -> [N, C]
+ pred = pred.transpose(0, 1).contiguous()
+ # (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
+ target = target.view(-1).contiguous()
+ assert pred.size(0) == target.size(
+ 0
+ ), "The shape of pred doesn't match the shape of target"
+ valid_mask = target != self.ignore_index
+ target = target[valid_mask]
+ pred = pred[valid_mask]
+
+ if len(target) == 0:
+ return 0.0
+
+ num_classes = pred.size(1)
+ target = F.one_hot(target, num_classes=num_classes)
+
+ alpha = self.alpha
+ if isinstance(alpha, list):
+ alpha = pred.new_tensor(alpha)
+ pred_sigmoid = pred.sigmoid()
+ target = target.type_as(pred)
+ one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
+ focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow(
+ self.gamma
+ )
+
+ loss = (
+ F.binary_cross_entropy_with_logits(pred, target, reduction="none")
+ * focal_weight
+ )
+ if self.reduction == "mean":
+ loss = loss.mean()
+ elif self.reduction == "sum":
+ loss = loss.total()
+ return self.loss_weight * loss
+
+
+@LOSSES.register_module()
+class DiceLoss(nn.Module):
+ def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1):
+ """DiceLoss.
+ This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
+ Volumetric Medical Image Segmentation `_.
+ """
+ super(DiceLoss, self).__init__()
+ self.smooth = smooth
+ self.exponent = exponent
+ self.loss_weight = loss_weight
+ self.ignore_index = ignore_index
+
+ def forward(self, pred, target, **kwargs):
+ # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
+ pred = pred.transpose(0, 1)
+ # [C, B, d_1, d_2, ..., d_k] -> [C, N]
+ pred = pred.reshape(pred.size(0), -1)
+ # [C, N] -> [N, C]
+ pred = pred.transpose(0, 1).contiguous()
+ # (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
+ target = target.view(-1).contiguous()
+ assert pred.size(0) == target.size(
+ 0
+ ), "The shape of pred doesn't match the shape of target"
+ valid_mask = target != self.ignore_index
+ target = target[valid_mask]
+ pred = pred[valid_mask]
+
+ pred = F.softmax(pred, dim=1)
+ num_classes = pred.shape[1]
+ target = F.one_hot(
+ torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes
+ )
+
+ total_loss = 0
+ for i in range(num_classes):
+ if i != self.ignore_index:
+ num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth
+ den = (
+ torch.sum(
+ pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent)
+ )
+ + self.smooth
+ )
+ dice_loss = 1 - num / den
+ total_loss += dice_loss
+ loss = total_loss / num_classes
+ return self.loss_weight * loss
diff --git a/models/network.py b/models/network.py
new file mode 100644
index 0000000..cdedbed
--- /dev/null
+++ b/models/network.py
@@ -0,0 +1,646 @@
+import math
+import os.path
+
+import torch
+
+import torch.nn as nn
+import torch.nn.functional as F
+import torchaudio as ta
+
+from models.encoder.wav2vec import Wav2Vec2Model
+from models.encoder.wavlm import WavLMModel
+
+from models.builder import MODELS
+
+from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
+
+@MODELS.register_module("Audio2Expression")
+class Audio2Expression(nn.Module):
+ def __init__(self,
+ device: torch.device = None,
+ pretrained_encoder_type: str = 'wav2vec',
+ pretrained_encoder_path: str = '',
+ wav2vec2_config_path: str = '',
+ num_identity_classes: int = 0,
+ identity_feat_dim: int = 64,
+ hidden_dim: int = 512,
+ expression_dim: int = 52,
+ norm_type: str = 'ln',
+ decoder_depth: int = 3,
+ use_transformer: bool = False,
+ num_attention_heads: int = 8,
+ num_transformer_layers: int = 6,
+ ):
+ super().__init__()
+
+ self.device = device
+
+ # Initialize audio feature encoder
+ if pretrained_encoder_type == 'wav2vec':
+ if os.path.exists(pretrained_encoder_path):
+ self.audio_encoder = Wav2Vec2Model.from_pretrained(pretrained_encoder_path)
+ else:
+ config = Wav2Vec2Config.from_pretrained(wav2vec2_config_path)
+ self.audio_encoder = Wav2Vec2Model(config)
+ encoder_output_dim = 768
+ elif pretrained_encoder_type == 'wavlm':
+ self.audio_encoder = WavLMModel.from_pretrained(pretrained_encoder_path)
+ encoder_output_dim = 768
+ else:
+ raise NotImplementedError(f"Encoder type {pretrained_encoder_type} not supported")
+
+ self.audio_encoder.feature_extractor._freeze_parameters()
+ self.feature_projection = nn.Linear(encoder_output_dim, hidden_dim)
+
+ self.identity_encoder = AudioIdentityEncoder(
+ hidden_dim,
+ num_identity_classes,
+ identity_feat_dim,
+ use_transformer,
+ num_attention_heads,
+ num_transformer_layers
+ )
+
+ self.decoder = nn.ModuleList([
+ nn.Sequential(*[
+ ConvNormRelu(hidden_dim, hidden_dim, norm=norm_type)
+ for _ in range(decoder_depth)
+ ])
+ ])
+
+ self.output_proj = nn.Linear(hidden_dim, expression_dim)
+
+ def freeze_encoder_parameters(self, do_freeze=False):
+
+ for name, param in self.audio_encoder.named_parameters():
+ if('feature_extractor' in name):
+ param.requires_grad = False
+ else:
+ param.requires_grad = (not do_freeze)
+
+ def forward(self, input_dict):
+
+ if 'time_steps' not in input_dict:
+ audio_length = input_dict['input_audio_array'].shape[1]
+ time_steps = math.ceil(audio_length / 16000 * 30)
+ else:
+ time_steps = input_dict['time_steps']
+
+ # Process audio through encoder
+ audio_input = input_dict['input_audio_array'].flatten(start_dim=1)
+ hidden_states = self.audio_encoder(audio_input, frame_num=time_steps).last_hidden_state
+
+ # Project features to hidden dimension
+ audio_features = self.feature_projection(hidden_states).transpose(1, 2)
+
+ # Process identity-conditioned features
+ audio_features = self.identity_encoder(audio_features, identity=input_dict['id_idx'])
+
+ # Refine features through decoder
+ audio_features = self.decoder[0](audio_features)
+
+ # Generate output parameters
+ audio_features = audio_features.permute(0, 2, 1)
+ expression_params = self.output_proj(audio_features)
+
+ return torch.sigmoid(expression_params)
+
+
+class AudioIdentityEncoder(nn.Module):
+ def __init__(self,
+ hidden_dim,
+ num_identity_classes=0,
+ identity_feat_dim=64,
+ use_transformer=False,
+ num_attention_heads = 8,
+ num_transformer_layers = 6,
+ dropout_ratio=0.1,
+ ):
+ super().__init__()
+
+ in_dim = hidden_dim + identity_feat_dim
+ self.id_mlp = nn.Conv1d(num_identity_classes, identity_feat_dim, 1, 1)
+ self.first_net = SeqTranslator1D(in_dim, hidden_dim,
+ min_layers_num=3,
+ residual=True,
+ norm='ln'
+ )
+ self.grus = nn.GRU(hidden_dim, hidden_dim, 1, batch_first=True)
+ self.dropout = nn.Dropout(dropout_ratio)
+
+ self.use_transformer = use_transformer
+ if(self.use_transformer):
+ encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_attention_heads, dim_feedforward= 2 * hidden_dim, batch_first=True)
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
+
+ def forward(self,
+ audio_features: torch.Tensor,
+ identity: torch.Tensor = None,
+ time_steps: int = None) -> tuple:
+
+ audio_features = self.dropout(audio_features)
+ identity = identity.reshape(identity.shape[0], -1, 1).repeat(1, 1, audio_features.shape[2]).to(torch.float32)
+ identity = self.id_mlp(identity)
+ audio_features = torch.cat([audio_features, identity], dim=1)
+
+ x = self.first_net(audio_features)
+
+ if time_steps is not None:
+ x = F.interpolate(x, size=time_steps, align_corners=False, mode='linear')
+
+ if(self.use_transformer):
+ x = x.permute(0, 2, 1)
+ x = self.transformer_encoder(x)
+ x = x.permute(0, 2, 1)
+
+ return x
+
+class ConvNormRelu(nn.Module):
+ '''
+ (B,C_in,H,W) -> (B, C_out, H, W)
+ there exist some kernel size that makes the result is not H/s
+ '''
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ type='1d',
+ leaky=False,
+ downsample=False,
+ kernel_size=None,
+ stride=None,
+ padding=None,
+ p=0,
+ groups=1,
+ residual=False,
+ norm='bn'):
+ '''
+ conv-bn-relu
+ '''
+ super(ConvNormRelu, self).__init__()
+ self.residual = residual
+ self.norm_type = norm
+ # kernel_size = k
+ # stride = s
+
+ if kernel_size is None and stride is None:
+ if not downsample:
+ kernel_size = 3
+ stride = 1
+ else:
+ kernel_size = 4
+ stride = 2
+
+ if padding is None:
+ if isinstance(kernel_size, int) and isinstance(stride, tuple):
+ padding = tuple(int((kernel_size - st) / 2) for st in stride)
+ elif isinstance(kernel_size, tuple) and isinstance(stride, int):
+ padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
+ elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
+ padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
+ else:
+ padding = int((kernel_size - stride) / 2)
+
+ if self.residual:
+ if downsample:
+ if type == '1d':
+ self.residual_layer = nn.Sequential(
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+ )
+ elif type == '2d':
+ self.residual_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+ )
+ else:
+ if in_channels == out_channels:
+ self.residual_layer = nn.Identity()
+ else:
+ if type == '1d':
+ self.residual_layer = nn.Sequential(
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+ )
+ elif type == '2d':
+ self.residual_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+ )
+
+ in_channels = in_channels * groups
+ out_channels = out_channels * groups
+ if type == '1d':
+ self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ groups=groups)
+ self.norm = nn.BatchNorm1d(out_channels)
+ self.dropout = nn.Dropout(p=p)
+ elif type == '2d':
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ groups=groups)
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.dropout = nn.Dropout2d(p=p)
+ if norm == 'gn':
+ self.norm = nn.GroupNorm(2, out_channels)
+ elif norm == 'ln':
+ self.norm = nn.LayerNorm(out_channels)
+ if leaky:
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
+ else:
+ self.relu = nn.ReLU()
+
+ def forward(self, x, **kwargs):
+ if self.norm_type == 'ln':
+ out = self.dropout(self.conv(x))
+ out = self.norm(out.transpose(1,2)).transpose(1,2)
+ else:
+ out = self.norm(self.dropout(self.conv(x)))
+ if self.residual:
+ residual = self.residual_layer(x)
+ out += residual
+ return self.relu(out)
+
+""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
+class SeqTranslator1D(nn.Module):
+ '''
+ (B, C, T)->(B, C_out, T)
+ '''
+ def __init__(self,
+ C_in,
+ C_out,
+ kernel_size=None,
+ stride=None,
+ min_layers_num=None,
+ residual=True,
+ norm='bn'
+ ):
+ super(SeqTranslator1D, self).__init__()
+
+ conv_layers = nn.ModuleList([])
+ conv_layers.append(ConvNormRelu(
+ in_channels=C_in,
+ out_channels=C_out,
+ type='1d',
+ kernel_size=kernel_size,
+ stride=stride,
+ residual=residual,
+ norm=norm
+ ))
+ self.num_layers = 1
+ if min_layers_num is not None and self.num_layers < min_layers_num:
+ while self.num_layers < min_layers_num:
+ conv_layers.append(ConvNormRelu(
+ in_channels=C_out,
+ out_channels=C_out,
+ type='1d',
+ kernel_size=kernel_size,
+ stride=stride,
+ residual=residual,
+ norm=norm
+ ))
+ self.num_layers += 1
+ self.conv_layers = nn.Sequential(*conv_layers)
+
+ def forward(self, x):
+ return self.conv_layers(x)
+
+
+def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
+ """
+ :param audio: 1 x T tensor containing a 16kHz audio signal
+ :param frame_rate: frame rate for video (we need one audio chunk per video frame)
+ :param chunk_size: number of audio samples per chunk
+ :return: num_chunks x chunk_size tensor containing sliced audio
+ """
+ samples_per_frame = 16000 // frame_rate
+ padding = (chunk_size - samples_per_frame) // 2
+ audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
+ anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
+ audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
+ return audio
+
+""" https://github.com/facebookresearch/meshtalk """
+class MeshtalkEncoder(nn.Module):
+ def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
+ """
+ :param latent_dim: size of the latent audio embedding
+ :param model_name: name of the model, used to load and save the model
+ """
+ super().__init__()
+
+ self.melspec = ta.transforms.MelSpectrogram(
+ sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
+ )
+
+ conv_len = 5
+ self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
+ self.weights_init(self.convert_dimensions)
+ self.receptive_field = conv_len
+
+ convs = []
+ for i in range(6):
+ dilation = 2 * (i % 3 + 1)
+ self.receptive_field += (conv_len - 1) * dilation
+ convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
+ self.weights_init(convs[-1])
+ self.convs = torch.nn.ModuleList(convs)
+ self.code = torch.nn.Linear(128, latent_dim)
+
+ self.apply(lambda x: self.weights_init(x))
+
+ def weights_init(self, m):
+ if isinstance(m, torch.nn.Conv1d):
+ torch.nn.init.xavier_uniform_(m.weight)
+ try:
+ torch.nn.init.constant_(m.bias, .01)
+ except:
+ pass
+
+ def forward(self, audio: torch.Tensor):
+ """
+ :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
+ :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
+ """
+ B, T = audio.shape[0], audio.shape[1]
+ x = self.melspec(audio).squeeze(1)
+ x = torch.log(x.clamp(min=1e-10, max=None))
+ if T == 1:
+ x = x.unsqueeze(1)
+
+ # Convert to the right dimensionality
+ x = x.view(-1, x.shape[2], x.shape[3])
+ x = F.leaky_relu(self.convert_dimensions(x), .2)
+
+ # Process stacks
+ for conv in self.convs:
+ x_ = F.leaky_relu(conv(x), .2)
+ if self.training:
+ x_ = F.dropout(x_, .2)
+ l = (x.shape[2] - x_.shape[2]) // 2
+ x = (x[:, :, l:-l] + x_) / 2
+
+ x = torch.mean(x, dim=-1)
+ x = x.view(B, T, x.shape[-1])
+ x = self.code(x)
+
+ return {"code": x}
+
+class PeriodicPositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=64):
+ super(PeriodicPositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ pe = torch.zeros(period, d_model)
+ position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0) # (1, period, d_model)
+ repeat_num = (max_seq_len//period) + 1
+ pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model)
+ self.register_buffer('pe', pe)
+ def forward(self, x):
+ # print(self.pe.shape, x.shape)
+ x = x + self.pe[:, :x.size(1), :]
+ return self.dropout(x)
+
+
+class GeneratorTransformer(nn.Module):
+ def __init__(self,
+ n_poses,
+ each_dim: list,
+ dim_list: list,
+ training=True,
+ device=None,
+ identity=False,
+ num_classes=0,
+ ):
+ super().__init__()
+
+ self.training = training
+ self.device = device
+ self.gen_length = n_poses
+
+ norm = 'ln'
+ in_dim = 256
+ out_dim = 256
+
+ self.encoder_choice = 'faceformer'
+
+ self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
+ self.audio_encoder.feature_extractor._freeze_parameters()
+ self.audio_feature_map = nn.Linear(768, in_dim)
+
+ self.audio_middle = AudioEncoder(in_dim, out_dim, False, num_classes)
+
+ self.dim_list = dim_list
+
+ self.decoder = nn.ModuleList()
+ self.final_out = nn.ModuleList()
+
+ self.hidden_size = 768
+ self.transformer_de_layer = nn.TransformerDecoderLayer(
+ d_model=self.hidden_size,
+ nhead=4,
+ dim_feedforward=self.hidden_size*2,
+ batch_first=True
+ )
+ self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4)
+ self.feature2face = nn.Linear(256, self.hidden_size)
+
+ self.position_embeddings = PeriodicPositionalEncoding(self.hidden_size, period=64, max_seq_len=64)
+ self.id_maping = nn.Linear(12,self.hidden_size)
+
+
+ self.decoder.append(self.face_decoder)
+ self.final_out.append(nn.Linear(self.hidden_size, 32))
+
+ def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
+ if gt_poses is None:
+ time_steps = 64
+ else:
+ time_steps = gt_poses.shape[1]
+
+ # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
+ if self.encoder_choice == 'meshtalk':
+ in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
+ feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
+ elif self.encoder_choice == 'faceformer':
+ hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
+ feature = self.audio_feature_map(hidden_states).transpose(1, 2)
+ else:
+ feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
+
+ feature, _ = self.audio_middle(feature, id=None)
+ feature = self.feature2face(feature.permute(0,2,1))
+
+ id = id.unsqueeze(1).repeat(1,64,1).to(torch.float32)
+ id_feature = self.id_maping(id)
+ id_feature = self.position_embeddings(id_feature)
+
+ for i in range(self.decoder.__len__()):
+ mid = self.decoder[i](tgt=id_feature, memory=feature)
+ out = self.final_out[i](mid)
+
+ return out, None
+
+def linear_interpolation(features, output_len: int):
+ features = features.transpose(1, 2)
+ output_features = F.interpolate(
+ features, size=output_len, align_corners=True, mode='linear')
+ return output_features.transpose(1, 2)
+
+def init_biased_mask(n_head, max_seq_len, period):
+
+ def get_slopes(n):
+
+ def get_slopes_power_of_2(n):
+ start = (2**(-2**-(math.log2(n) - 3)))
+ ratio = start
+ return [start * ratio**i for i in range(n)]
+
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2**math.floor(math.log2(n))
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes(
+ 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
+
+ slopes = torch.Tensor(get_slopes(n_head))
+ bias = torch.div(
+ torch.arange(start=0, end=max_seq_len,
+ step=period).unsqueeze(1).repeat(1, period).view(-1),
+ period,
+ rounding_mode='floor')
+ bias = -torch.flip(bias, dims=[0])
+ alibi = torch.zeros(max_seq_len, max_seq_len)
+ for i in range(max_seq_len):
+ alibi[i, :i + 1] = bias[-(i + 1):]
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
+ mask = (torch.triu(torch.ones(max_seq_len,
+ max_seq_len)) == 1).transpose(0, 1)
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
+ mask == 1, float(0.0))
+ mask = mask.unsqueeze(0) + alibi
+ return mask
+
+
+# Alignment Bias
+def enc_dec_mask(device, T, S):
+ mask = torch.ones(T, S)
+ for i in range(T):
+ mask[i, i] = 0
+ return (mask == 1).to(device=device)
+
+
+# Periodic Positional Encoding
+class PeriodicPositionalEncoding(nn.Module):
+
+ def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
+ super(PeriodicPositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ pe = torch.zeros(period, d_model)
+ position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2).float() *
+ (-math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0) # (1, period, d_model)
+ repeat_num = (max_seq_len // period) + 1
+ pe = pe.repeat(1, repeat_num, 1)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + self.pe[:, :x.size(1), :]
+ return self.dropout(x)
+
+
+class BaseModel(nn.Module):
+ """Base class for all models."""
+
+ def __init__(self):
+ super(BaseModel, self).__init__()
+ # self.logger = logging.getLogger(self.__class__.__name__)
+
+ def forward(self, *x):
+ """Forward pass logic.
+
+ :return: Model output
+ """
+ raise NotImplementedError
+
+ def freeze_model(self, do_freeze: bool = True):
+ for param in self.parameters():
+ param.requires_grad = (not do_freeze)
+
+ def summary(self, logger, writer=None):
+ """Model summary."""
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
+ params = sum([np.prod(p.size())
+ for p in model_parameters]) / 1e6 # Unit is Mega
+ logger.info('===>Trainable parameters: %.3f M' % params)
+ if writer is not None:
+ writer.add_text('Model Summary',
+ 'Trainable parameters: %.3f M' % params)
+
+
+"""https://github.com/X-niper/UniTalker"""
+class UniTalkerDecoderTransformer(BaseModel):
+
+ def __init__(self, out_dim, identity_num, period=30, interpolate_pos=1) -> None:
+ super().__init__()
+ self.learnable_style_emb = nn.Embedding(identity_num, out_dim)
+ self.PPE = PeriodicPositionalEncoding(
+ out_dim, period=period, max_seq_len=3000)
+ self.biased_mask = init_biased_mask(
+ n_head=4, max_seq_len=3000, period=period)
+ decoder_layer = nn.TransformerDecoderLayer(
+ d_model=out_dim,
+ nhead=4,
+ dim_feedforward=2 * out_dim,
+ batch_first=True)
+ self.transformer_decoder = nn.TransformerDecoder(
+ decoder_layer, num_layers=1)
+ self.interpolate_pos = interpolate_pos
+
+ def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,
+ frame_num: int):
+ style_idx = torch.argmax(style_idx, dim=1)
+ obj_embedding = self.learnable_style_emb(style_idx)
+ obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1)
+ style_input = self.PPE(obj_embedding)
+ tgt_mask = self.biased_mask.repeat(style_idx.shape[0], 1, 1)[:, :style_input.shape[1], :style_input.
+ shape[1]].clone().detach().to(
+ device=style_input.device)
+ memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1],
+ frame_num)
+ feat_out = self.transformer_decoder(
+ style_input,
+ hidden_states,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask)
+ if self.interpolate_pos == 2:
+ feat_out = linear_interpolation(feat_out, output_len=frame_num)
+ return feat_out
\ No newline at end of file
diff --git a/models/utils.py b/models/utils.py
new file mode 100644
index 0000000..4b15130
--- /dev/null
+++ b/models/utils.py
@@ -0,0 +1,752 @@
+import json
+import time
+import warnings
+import numpy as np
+from typing import List, Optional,Tuple
+from scipy.signal import savgol_filter
+
+
+ARKitLeftRightPair = [
+ ("jawLeft", "jawRight"),
+ ("mouthLeft", "mouthRight"),
+ ("mouthSmileLeft", "mouthSmileRight"),
+ ("mouthFrownLeft", "mouthFrownRight"),
+ ("mouthDimpleLeft", "mouthDimpleRight"),
+ ("mouthStretchLeft", "mouthStretchRight"),
+ ("mouthPressLeft", "mouthPressRight"),
+ ("mouthLowerDownLeft", "mouthLowerDownRight"),
+ ("mouthUpperUpLeft", "mouthUpperUpRight"),
+ ("cheekSquintLeft", "cheekSquintRight"),
+ ("noseSneerLeft", "noseSneerRight"),
+ ("browDownLeft", "browDownRight"),
+ ("browOuterUpLeft", "browOuterUpRight"),
+ ("eyeBlinkLeft","eyeBlinkRight"),
+ ("eyeLookDownLeft","eyeLookDownRight"),
+ ("eyeLookInLeft", "eyeLookInRight"),
+ ("eyeLookOutLeft","eyeLookOutRight"),
+ ("eyeLookUpLeft","eyeLookUpRight"),
+ ("eyeSquintLeft","eyeSquintRight"),
+ ("eyeWideLeft","eyeWideRight")
+ ]
+
+ARKitBlendShape =[
+ "browDownLeft",
+ "browDownRight",
+ "browInnerUp",
+ "browOuterUpLeft",
+ "browOuterUpRight",
+ "cheekPuff",
+ "cheekSquintLeft",
+ "cheekSquintRight",
+ "eyeBlinkLeft",
+ "eyeBlinkRight",
+ "eyeLookDownLeft",
+ "eyeLookDownRight",
+ "eyeLookInLeft",
+ "eyeLookInRight",
+ "eyeLookOutLeft",
+ "eyeLookOutRight",
+ "eyeLookUpLeft",
+ "eyeLookUpRight",
+ "eyeSquintLeft",
+ "eyeSquintRight",
+ "eyeWideLeft",
+ "eyeWideRight",
+ "jawForward",
+ "jawLeft",
+ "jawOpen",
+ "jawRight",
+ "mouthClose",
+ "mouthDimpleLeft",
+ "mouthDimpleRight",
+ "mouthFrownLeft",
+ "mouthFrownRight",
+ "mouthFunnel",
+ "mouthLeft",
+ "mouthLowerDownLeft",
+ "mouthLowerDownRight",
+ "mouthPressLeft",
+ "mouthPressRight",
+ "mouthPucker",
+ "mouthRight",
+ "mouthRollLower",
+ "mouthRollUpper",
+ "mouthShrugLower",
+ "mouthShrugUpper",
+ "mouthSmileLeft",
+ "mouthSmileRight",
+ "mouthStretchLeft",
+ "mouthStretchRight",
+ "mouthUpperUpLeft",
+ "mouthUpperUpRight",
+ "noseSneerLeft",
+ "noseSneerRight",
+ "tongueOut"
+]
+
+MOUTH_BLENDSHAPES = [ "mouthDimpleLeft",
+ "mouthDimpleRight",
+ "mouthFrownLeft",
+ "mouthFrownRight",
+ "mouthFunnel",
+ "mouthLeft",
+ "mouthLowerDownLeft",
+ "mouthLowerDownRight",
+ "mouthPressLeft",
+ "mouthPressRight",
+ "mouthPucker",
+ "mouthRight",
+ "mouthRollLower",
+ "mouthRollUpper",
+ "mouthShrugLower",
+ "mouthShrugUpper",
+ "mouthSmileLeft",
+ "mouthSmileRight",
+ "mouthStretchLeft",
+ "mouthStretchRight",
+ "mouthUpperUpLeft",
+ "mouthUpperUpRight",
+ "jawForward",
+ "jawLeft",
+ "jawOpen",
+ "jawRight",
+ "noseSneerLeft",
+ "noseSneerRight",
+ "cheekPuff",
+ ]
+
+DEFAULT_CONTEXT ={
+ 'is_initial_input': True,
+ 'previous_audio': None,
+ 'previous_expression': None,
+ 'previous_volume': None,
+ 'previous_headpose': None,
+}
+
+RETURN_CODE = {
+ "SUCCESS": 0,
+ "AUDIO_LENGTH_ERROR": 1,
+ "CHECKPOINT_PATH_ERROR":2,
+ "MODEL_INFERENCE_ERROR":3,
+}
+
+DEFAULT_CONTEXTRETURN = {
+ "code": RETURN_CODE['SUCCESS'],
+ "expression": None,
+ "headpose": None,
+}
+
+BLINK_PATTERNS = [
+ np.array([0.365, 0.950, 0.956, 0.917, 0.367, 0.119, 0.025]),
+ np.array([0.235, 0.910, 0.945, 0.778, 0.191, 0.235, 0.089]),
+ np.array([0.870, 0.950, 0.949, 0.696, 0.191, 0.073, 0.007]),
+ np.array([0.000, 0.557, 0.953, 0.942, 0.426, 0.148, 0.018])
+]
+
+# Postprocess
+def symmetrize_blendshapes(
+ bs_params: np.ndarray,
+ mode: str = "average",
+ symmetric_pairs: list = ARKitLeftRightPair
+) -> np.ndarray:
+ """
+ Apply symmetrization to ARKit blendshape parameters (batched version)
+
+ Args:
+ bs_params: numpy array of shape (N, 52), batch of ARKit parameters
+ mode: symmetrization mode ["average", "max", "min", "left_dominant", "right_dominant"]
+ symmetric_pairs: list of left-right parameter pairs
+
+ Returns:
+ Symmetrized parameters with same shape (N, 52)
+ """
+
+ name_to_idx = {name: i for i, name in enumerate(ARKitBlendShape)}
+
+ # Input validation
+ if bs_params.ndim != 2 or bs_params.shape[1] != 52:
+ raise ValueError("Input must be of shape (N, 52)")
+
+ symmetric_bs = bs_params.copy() # Shape (N, 52)
+
+ # Precompute valid index pairs
+ valid_pairs = []
+ for left, right in symmetric_pairs:
+ left_idx = name_to_idx.get(left)
+ right_idx = name_to_idx.get(right)
+ if None not in (left_idx, right_idx):
+ valid_pairs.append((left_idx, right_idx))
+
+ # Vectorized processing
+ for l_idx, r_idx in valid_pairs:
+ left_col = symmetric_bs[:, l_idx]
+ right_col = symmetric_bs[:, r_idx]
+
+ if mode == "average":
+ new_vals = (left_col + right_col) / 2
+ elif mode == "max":
+ new_vals = np.maximum(left_col, right_col)
+ elif mode == "min":
+ new_vals = np.minimum(left_col, right_col)
+ elif mode == "left_dominant":
+ new_vals = left_col
+ elif mode == "right_dominant":
+ new_vals = right_col
+ else:
+ raise ValueError(f"Invalid mode: {mode}")
+
+ # Update both columns simultaneously
+ symmetric_bs[:, l_idx] = new_vals
+ symmetric_bs[:, r_idx] = new_vals
+
+ return symmetric_bs
+
+
+def apply_random_eye_blinks(
+ input: np.ndarray,
+ blink_scale: tuple = (0.8, 1.0),
+ blink_interval: tuple = (60, 120),
+ blink_duration: int = 7
+) -> np.ndarray:
+ """
+ Apply randomized eye blinks to blendshape parameters
+
+ Args:
+ output: Input array of shape (N, 52) containing blendshape parameters
+ blink_scale: Tuple (min, max) for random blink intensity scaling
+ blink_interval: Tuple (min, max) for random blink spacing in frames
+ blink_duration: Number of frames for blink animation (fixed)
+
+ Returns:
+ None (modifies output array in-place)
+ """
+ # Define eye blink patterns (normalized 0-1)
+
+ # Initialize parameters
+ n_frames = input.shape[0]
+ input[:,8:10] = np.zeros((n_frames,2))
+ current_frame = 0
+
+ # Main blink application loop
+ while current_frame < n_frames - blink_duration:
+ # Randomize blink parameters
+ scale = np.random.uniform(*blink_scale)
+ pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
+
+ # Apply blink animation
+ blink_values = pattern * scale
+ input[current_frame:current_frame + blink_duration, 8] = blink_values
+ input[current_frame:current_frame + blink_duration, 9] = blink_values
+
+ # Advance to next blink position
+ current_frame += blink_duration + np.random.randint(*blink_interval)
+
+ return input
+
+
+def apply_random_eye_blinks_context(
+ animation_params: np.ndarray,
+ processed_frames: int = 0,
+ intensity_range: tuple = (0.8, 1.0)
+) -> np.ndarray:
+ """Applies random eye blink patterns to facial animation parameters.
+
+ Args:
+ animation_params: Input facial animation parameters array with shape [num_frames, num_features].
+ Columns 8 and 9 typically represent left/right eye blink parameters.
+ processed_frames: Number of already processed frames that shouldn't be modified
+ intensity_range: Tuple defining (min, max) scaling for blink intensity
+
+ Returns:
+ Modified animation parameters array with random eye blinks added to unprocessed frames
+ """
+ remaining_frames = animation_params.shape[0] - processed_frames
+
+ # Only apply blinks if there's enough remaining frames (blink pattern requires 7 frames)
+ if remaining_frames <= 7:
+ return animation_params
+
+ # Configure blink timing parameters
+ min_blink_interval = 40 # Minimum frames between blinks
+ max_blink_interval = 100 # Maximum frames between blinks
+
+ # Find last blink in previously processed frames (column 8 > 0.5 indicates blink)
+ previous_blink_indices = np.where(animation_params[:processed_frames, 8] > 0.5)[0]
+ last_processed_blink = previous_blink_indices[-1] - 7 if previous_blink_indices.size > 0 else processed_frames
+
+ # Calculate first new blink position
+ blink_interval = np.random.randint(min_blink_interval, max_blink_interval)
+ first_blink_start = max(0, blink_interval - last_processed_blink)
+
+ # Apply first blink if there's enough space
+ if first_blink_start <= (remaining_frames - 7):
+ # Randomly select blink pattern and intensity
+ blink_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
+ intensity = np.random.uniform(*intensity_range)
+
+ # Calculate blink frame range
+ blink_start = processed_frames + first_blink_start
+ blink_end = blink_start + 7
+
+ # Apply pattern to both eyes
+ animation_params[blink_start:blink_end, 8] = blink_pattern * intensity
+ animation_params[blink_start:blink_end, 9] = blink_pattern * intensity
+
+ # Check space for additional blink
+ remaining_after_blink = animation_params.shape[0] - blink_end
+ if remaining_after_blink > min_blink_interval:
+ # Calculate second blink position
+ second_intensity = np.random.uniform(*intensity_range)
+ second_interval = np.random.randint(min_blink_interval, max_blink_interval)
+
+ if (remaining_after_blink - 7) > second_interval:
+ second_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
+ second_blink_start = blink_end + second_interval
+ second_blink_end = second_blink_start + 7
+
+ # Apply second blink
+ animation_params[second_blink_start:second_blink_end, 8] = second_pattern * second_intensity
+ animation_params[second_blink_start:second_blink_end, 9] = second_pattern * second_intensity
+
+ return animation_params
+
+
+def export_blendshape_animation(
+ blendshape_weights: np.ndarray,
+ output_path: str,
+ blendshape_names: List[str],
+ fps: float,
+ rotation_data: Optional[np.ndarray] = None
+) -> None:
+ """
+ Export blendshape animation data to JSON format compatible with ARKit.
+
+ Args:
+ blendshape_weights: 2D numpy array of shape (N, 52) containing animation frames
+ output_path: Full path for output JSON file (including .json extension)
+ blendshape_names: Ordered list of 52 ARKit-standard blendshape names
+ fps: Frame rate for timing calculations (frames per second)
+ rotation_data: Optional 3D rotation data array of shape (N, 3)
+
+ Raises:
+ ValueError: If input dimensions are incompatible
+ IOError: If file writing fails
+ """
+ # Validate input dimensions
+ if blendshape_weights.shape[1] != 52:
+ raise ValueError(f"Expected 52 blendshapes, got {blendshape_weights.shape[1]}")
+ if len(blendshape_names) != 52:
+ raise ValueError(f"Requires 52 blendshape names, got {len(blendshape_names)}")
+ if rotation_data is not None and len(rotation_data) != len(blendshape_weights):
+ raise ValueError("Rotation data length must match animation frames")
+
+ # Build animation data structure
+ animation_data = {
+ "names":blendshape_names,
+ "metadata": {
+ "fps": fps,
+ "frame_count": len(blendshape_weights),
+ "blendshape_names": blendshape_names
+ },
+ "frames": []
+ }
+
+ # Convert numpy array to serializable format
+ for frame_idx in range(blendshape_weights.shape[0]):
+ frame_data = {
+ "weights": blendshape_weights[frame_idx].tolist(),
+ "time": frame_idx / fps,
+ "rotation": rotation_data[frame_idx].tolist() if rotation_data else []
+ }
+ animation_data["frames"].append(frame_data)
+
+ # Safeguard against data loss
+ if not output_path.endswith('.json'):
+ output_path += '.json'
+
+ # Write to file with error handling
+ try:
+ with open(output_path, 'w', encoding='utf-8') as json_file:
+ json.dump(animation_data, json_file, indent=2, ensure_ascii=False)
+ except Exception as e:
+ raise IOError(f"Failed to write animation data: {str(e)}") from e
+
+
+def apply_savitzky_golay_smoothing(
+ input_data: np.ndarray,
+ window_length: int = 5,
+ polyorder: int = 2,
+ axis: int = 0,
+ validate: bool = True
+) -> Tuple[np.ndarray, Optional[float]]:
+ """
+ Apply Savitzky-Golay filter smoothing along specified axis of input data.
+
+ Args:
+ input_data: 2D numpy array of shape (n_samples, n_features)
+ window_length: Length of the filter window (must be odd and > polyorder)
+ polyorder: Order of the polynomial fit
+ axis: Axis along which to filter (0: column-wise, 1: row-wise)
+ validate: Enable input validation checks when True
+
+ Returns:
+ tuple: (smoothed_data, processing_time)
+ - smoothed_data: Smoothed output array
+ - processing_time: Execution time in seconds (None in validation mode)
+
+ Raises:
+ ValueError: For invalid input dimensions or filter parameters
+ """
+ # Validation mode timing bypass
+ processing_time = None
+
+ if validate:
+ # Input integrity checks
+ if input_data.ndim != 2:
+ raise ValueError(f"Expected 2D input, got {input_data.ndim}D array")
+
+ if window_length % 2 == 0 or window_length < 3:
+ raise ValueError("Window length must be odd integer ≥ 3")
+
+ if polyorder >= window_length:
+ raise ValueError("Polynomial order must be < window length")
+
+ # Store original dtype and convert to float64 for numerical stability
+ original_dtype = input_data.dtype
+ working_data = input_data.astype(np.float64)
+
+ # Start performance timer
+ timer_start = time.perf_counter()
+
+ try:
+ # Vectorized Savitzky-Golay application
+ smoothed_data = savgol_filter(working_data,
+ window_length=window_length,
+ polyorder=polyorder,
+ axis=axis,
+ mode='mirror')
+ except Exception as e:
+ raise RuntimeError(f"Filtering failed: {str(e)}") from e
+
+ # Stop timer and calculate duration
+ processing_time = time.perf_counter() - timer_start
+
+ # Restore original data type with overflow protection
+ return (
+ np.clip(smoothed_data,
+ 0.0,
+ 1.0
+ ).astype(original_dtype),
+ processing_time
+ )
+
+
+def _blend_region_start(
+ array: np.ndarray,
+ region: np.ndarray,
+ processed_boundary: int,
+ blend_frames: int
+) -> None:
+ """Applies linear blend between last active frame and silent region start."""
+ blend_length = min(blend_frames, region[0] - processed_boundary)
+ if blend_length <= 0:
+ return
+
+ pre_frame = array[region[0] - 1]
+ for i in range(blend_length):
+ weight = (i + 1) / (blend_length + 1)
+ array[region[0] + i] = pre_frame * (1 - weight) + array[region[0] + i] * weight
+
+def _blend_region_end(
+ array: np.ndarray,
+ region: np.ndarray,
+ blend_frames: int
+) -> None:
+ """Applies linear blend between silent region end and next active frame."""
+ blend_length = min(blend_frames, array.shape[0] - region[-1] - 1)
+ if blend_length <= 0:
+ return
+
+ post_frame = array[region[-1] + 1]
+ for i in range(blend_length):
+ weight = (i + 1) / (blend_length + 1)
+ array[region[-1] - i] = post_frame * (1 - weight) + array[region[-1] - i] * weight
+
+def find_low_value_regions(
+ signal: np.ndarray,
+ threshold: float,
+ min_region_length: int = 5
+) -> list:
+ """Identifies contiguous regions in a signal where values fall below a threshold.
+
+ Args:
+ signal: Input 1D array of numerical values
+ threshold: Value threshold for identifying low regions
+ min_region_length: Minimum consecutive samples required to qualify as a region
+
+ Returns:
+ List of numpy arrays, each containing indices for a qualifying low-value region
+ """
+ low_value_indices = np.where(signal < threshold)[0]
+ contiguous_regions = []
+ current_region_length = 0
+ region_start_idx = 0
+
+ for i in range(1, len(low_value_indices)):
+ # Check if current index continues a consecutive sequence
+ if low_value_indices[i] != low_value_indices[i - 1] + 1:
+ # Finalize previous region if it meets length requirement
+ if current_region_length >= min_region_length:
+ contiguous_regions.append(low_value_indices[region_start_idx:i])
+ # Reset tracking for new potential region
+ region_start_idx = i
+ current_region_length = 0
+ current_region_length += 1
+
+ # Add the final region if it qualifies
+ if current_region_length >= min_region_length:
+ contiguous_regions.append(low_value_indices[region_start_idx:])
+
+ return contiguous_regions
+
+
+def smooth_mouth_movements(
+ blend_shapes: np.ndarray,
+ processed_frames: int,
+ volume: np.ndarray = None,
+ silence_threshold: float = 0.001,
+ min_silence_duration: int = 7,
+ blend_window: int = 3
+) -> np.ndarray:
+ """Reduces jaw movement artifacts during silent periods in audio-driven animation.
+
+ Args:
+ blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
+ processed_frames: Number of already processed frames that shouldn't be modified
+ volume: Audio volume array used to detect silent periods
+ silence_threshold: Volume threshold for considering a frame silent
+ min_silence_duration: Minimum consecutive silent frames to qualify for processing
+ blend_window: Number of frames to smooth at region boundaries
+
+ Returns:
+ Modified blend shape array with reduced mouth movements during silence
+ """
+ if volume is None:
+ return blend_shapes
+
+ # Detect silence periods using volume data
+ silent_regions = find_low_value_regions(
+ volume,
+ threshold=silence_threshold,
+ min_region_length=min_silence_duration
+ )
+
+ for region_indices in silent_regions:
+ # Reduce mouth blend shapes in silent region
+ mouth_blend_indices = [ARKitBlendShape.index(name) for name in MOUTH_BLENDSHAPES]
+ for region_indice in region_indices.tolist():
+ blend_shapes[region_indice, mouth_blend_indices] *= 0.1
+
+ try:
+ # Smooth transition into silent region
+ _blend_region_start(
+ blend_shapes,
+ region_indices,
+ processed_frames,
+ blend_window
+ )
+
+ # Smooth transition out of silent region
+ _blend_region_end(
+ blend_shapes,
+ region_indices,
+ blend_window
+ )
+ except IndexError as e:
+ warnings.warn(f"Edge blending skipped at region {region_indices}: {str(e)}")
+
+ return blend_shapes
+
+
+def apply_frame_blending(
+ blend_shapes: np.ndarray,
+ processed_frames: int,
+ initial_blend_window: int = 3,
+ subsequent_blend_window: int = 5
+) -> np.ndarray:
+ """Smooths transitions between processed and unprocessed animation frames using linear blending.
+
+ Args:
+ blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
+ processed_frames: Number of already processed frames (0 means no previous processing)
+ initial_blend_window: Max frames to blend at sequence start
+ subsequent_blend_window: Max frames to blend between processed and new frames
+
+ Returns:
+ Modified blend shape array with smoothed transitions
+ """
+ if processed_frames > 0:
+ # Blend transition between existing and new animation
+ _blend_animation_segment(
+ blend_shapes,
+ transition_start=processed_frames,
+ blend_window=subsequent_blend_window,
+ reference_frame=blend_shapes[processed_frames - 1]
+ )
+ else:
+ # Smooth initial frames from neutral expression (zeros)
+ _blend_animation_segment(
+ blend_shapes,
+ transition_start=0,
+ blend_window=initial_blend_window,
+ reference_frame=np.zeros_like(blend_shapes[0])
+ )
+ return blend_shapes
+
+
+def _blend_animation_segment(
+ array: np.ndarray,
+ transition_start: int,
+ blend_window: int,
+ reference_frame: np.ndarray
+) -> None:
+ """Applies linear interpolation between reference frame and target frames.
+
+ Args:
+ array: Blend shape array to modify
+ transition_start: Starting index for blending
+ blend_window: Maximum number of frames to blend
+ reference_frame: The reference frame to blend from
+ """
+ actual_blend_length = min(blend_window, array.shape[0] - transition_start)
+
+ for frame_offset in range(actual_blend_length):
+ current_idx = transition_start + frame_offset
+ blend_weight = (frame_offset + 1) / (actual_blend_length + 1)
+
+ # Linear interpolation: ref_frame * (1 - weight) + current_frame * weight
+ array[current_idx] = (reference_frame * (1 - blend_weight)
+ + array[current_idx] * blend_weight)
+
+
+BROW1 = np.array([[0.05597309, 0.05727929, 0.07995935, 0. , 0. ],
+ [0.00757574, 0.00936678, 0.12242376, 0. , 0. ],
+ [0. , 0. , 0.14943372, 0.04535687, 0.04264118],
+ [0. , 0. , 0.18015374, 0.09019445, 0.08736137],
+ [0. , 0. , 0.20549579, 0.12802747, 0.12450772],
+ [0. , 0. , 0.21098022, 0.1369939 , 0.13343132],
+ [0. , 0. , 0.20904602, 0.13903855, 0.13562402],
+ [0. , 0. , 0.20365039, 0.13977394, 0.13653506],
+ [0. , 0. , 0.19714841, 0.14096624, 0.13805152],
+ [0. , 0. , 0.20325482, 0.17303431, 0.17028868],
+ [0. , 0. , 0.21990852, 0.20164253, 0.19818163],
+ [0. , 0. , 0.23858181, 0.21908803, 0.21540019],
+ [0. , 0. , 0.2567876 , 0.23762083, 0.23396946],
+ [0. , 0. , 0.34093422, 0.27898848, 0.27651772],
+ [0. , 0. , 0.45288125, 0.35008961, 0.34887788],
+ [0. , 0. , 0.48076251, 0.36878952, 0.36778417],
+ [0. , 0. , 0.47798249, 0.36362219, 0.36145973],
+ [0. , 0. , 0.46186113, 0.33865979, 0.33597934],
+ [0. , 0. , 0.45264384, 0.33152157, 0.32891783],
+ [0. , 0. , 0.40986338, 0.29646468, 0.2945672 ],
+ [0. , 0. , 0.35628179, 0.23356403, 0.23155804],
+ [0. , 0. , 0.30870566, 0.1780673 , 0.17637439],
+ [0. , 0. , 0.25293985, 0.10710219, 0.10622486],
+ [0. , 0. , 0.18743332, 0.03252602, 0.03244236],
+ [0.02340254, 0.02364671, 0.15736724, 0. , 0. ]])
+
+BROW2 = np.array([
+ [0. , 0. , 0.09799323, 0.05944436, 0.05002545],
+ [0. , 0. , 0.09780276, 0.07674237, 0.01636653],
+ [0. , 0. , 0.11136199, 0.1027964 , 0.04249811],
+ [0. , 0. , 0.26883412, 0.15861984, 0.15832305],
+ [0. , 0. , 0.42191629, 0.27038204, 0.27007768],
+ [0. , 0. , 0.3404977 , 0.21633868, 0.21597538],
+ [0. , 0. , 0.27301185, 0.17176409, 0.17134669],
+ [0. , 0. , 0.25960442, 0.15670464, 0.15622253],
+ [0. , 0. , 0.22877269, 0.11805892, 0.11754539],
+ [0. , 0. , 0.1451605 , 0.06389034, 0.0636282 ]])
+
+BROW3 = np.array([
+ [0. , 0. , 0.124 , 0.0295, 0.0295],
+ [0. , 0. , 0.267 , 0.184 , 0.184 ],
+ [0. , 0. , 0.359 , 0.2765, 0.2765],
+ [0. , 0. , 0.3945, 0.3125, 0.3125],
+ [0. , 0. , 0.4125, 0.331 , 0.331 ],
+ [0. , 0. , 0.4235, 0.3445, 0.3445],
+ [0. , 0. , 0.4085, 0.3305, 0.3305],
+ [0. , 0. , 0.3695, 0.294 , 0.294 ],
+ [0. , 0. , 0.2835, 0.213 , 0.213 ],
+ [0. , 0. , 0.1795, 0.1005, 0.1005],
+ [0. , 0. , 0.108 , 0.014 , 0.014 ]])
+
+
+import numpy as np
+from scipy.ndimage import label
+
+
+def apply_random_brow_movement(input_exp, volume):
+ FRAME_SEGMENT = 150
+ HOLD_THRESHOLD = 10
+ VOLUME_THRESHOLD = 0.08
+ MIN_REGION_LENGTH = 6
+ STRENGTH_RANGE = (0.7, 1.3)
+
+ BROW_PEAKS = {
+ 0: np.argmax(BROW1[:, 2]),
+ 1: np.argmax(BROW2[:, 2])
+ }
+
+ for seg_start in range(0, len(volume), FRAME_SEGMENT):
+ seg_end = min(seg_start + FRAME_SEGMENT, len(volume))
+ seg_volume = volume[seg_start:seg_end]
+
+ candidate_regions = []
+
+ high_vol_mask = seg_volume > VOLUME_THRESHOLD
+ labeled_array, num_features = label(high_vol_mask)
+
+ for i in range(1, num_features + 1):
+ region = (labeled_array == i)
+ region_indices = np.where(region)[0]
+ if len(region_indices) >= MIN_REGION_LENGTH:
+ candidate_regions.append(region_indices)
+
+ if candidate_regions:
+ selected_region = candidate_regions[np.random.choice(len(candidate_regions))]
+ region_start = selected_region[0]
+ region_end = selected_region[-1]
+ region_length = region_end - region_start + 1
+
+ brow_idx = np.random.randint(0, 2)
+ base_brow = BROW1 if brow_idx == 0 else BROW2
+ peak_idx = BROW_PEAKS[brow_idx]
+
+ if region_length > HOLD_THRESHOLD:
+ local_max_pos = seg_volume[selected_region].argmax()
+ global_peak_frame = seg_start + selected_region[local_max_pos]
+
+ rise_anim = base_brow[:peak_idx + 1]
+ hold_frame = base_brow[peak_idx:peak_idx + 1]
+
+ insert_start = max(global_peak_frame - peak_idx, seg_start)
+ insert_end = min(global_peak_frame + (region_length - local_max_pos), seg_end)
+
+ strength = np.random.uniform(*STRENGTH_RANGE)
+
+ if insert_start + len(rise_anim) <= seg_end:
+ input_exp[insert_start:insert_start + len(rise_anim), :5] += rise_anim * strength
+ hold_duration = insert_end - (insert_start + len(rise_anim))
+ if hold_duration > 0:
+ input_exp[insert_start + len(rise_anim):insert_end, :5] += np.tile(hold_frame * strength,
+ (hold_duration, 1))
+ else:
+ anim_length = base_brow.shape[0]
+ insert_pos = seg_start + region_start + (region_length - anim_length) // 2
+ insert_pos = max(seg_start, min(insert_pos, seg_end - anim_length))
+
+ if insert_pos + anim_length <= seg_end:
+ strength = np.random.uniform(*STRENGTH_RANGE)
+ input_exp[insert_pos:insert_pos + anim_length, :5] += base_brow * strength
+
+ return np.clip(input_exp, 0, 1)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..2e09298
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+spleeter==2.4.2
+opencv_python_headless==4.11.0.86
+gradio==5.25.2
+omegaconf==2.3.0
+addict==2.4.0
+yapf==0.40.1
+librosa==0.11.0
+transformers==4.36.2
+termcolor==3.0.1
+numpy==1.26.3
\ No newline at end of file
diff --git a/scripts/install/install_cu118.sh b/scripts/install/install_cu118.sh
new file mode 100644
index 0000000..0a16bc9
--- /dev/null
+++ b/scripts/install/install_cu118.sh
@@ -0,0 +1,9 @@
+# install torch 2.1.2
+# or conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia
+pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
+
+# install dependencies
+pip install -r requirements.txt
+
+# install H5-render
+pip install wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl
\ No newline at end of file
diff --git a/scripts/install/install_cu121.sh b/scripts/install/install_cu121.sh
new file mode 100644
index 0000000..7c39e52
--- /dev/null
+++ b/scripts/install/install_cu121.sh
@@ -0,0 +1,9 @@
+# install torch 2.1.2
+# or conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
+pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
+
+# install dependencies
+pip install -r requirements.txt
+
+# install H5-render
+pip install wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/cache.py b/utils/cache.py
new file mode 100644
index 0000000..ac8bc33
--- /dev/null
+++ b/utils/cache.py
@@ -0,0 +1,53 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import os
+import SharedArray
+
+try:
+ from multiprocessing.shared_memory import ShareableList
+except ImportError:
+ import warnings
+
+ warnings.warn("Please update python version >= 3.8 to enable shared_memory")
+import numpy as np
+
+
+def shared_array(name, var=None):
+ if var is not None:
+ # check exist
+ if os.path.exists(f"/dev/shm/{name}"):
+ return SharedArray.attach(f"shm://{name}")
+ # create shared_array
+ data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype)
+ data[...] = var[...]
+ data.flags.writeable = False
+ else:
+ data = SharedArray.attach(f"shm://{name}").copy()
+ return data
+
+
+def shared_dict(name, var=None):
+ name = str(name)
+ assert "." not in name # '.' is used as sep flag
+ data = {}
+ if var is not None:
+ assert isinstance(var, dict)
+ keys = var.keys()
+ # current version only cache np.array
+ keys_valid = []
+ for key in keys:
+ if isinstance(var[key], np.ndarray):
+ keys_valid.append(key)
+ keys = keys_valid
+
+ ShareableList(sequence=keys, name=name + ".keys")
+ for key in keys:
+ if isinstance(var[key], np.ndarray):
+ data[key] = shared_array(name=f"{name}.{key}", var=var[key])
+ else:
+ keys = list(ShareableList(name=name + ".keys"))
+ for key in keys:
+ data[key] = shared_array(name=f"{name}.{key}")
+ return data
diff --git a/utils/comm.py b/utils/comm.py
new file mode 100644
index 0000000..23bec8e
--- /dev/null
+++ b/utils/comm.py
@@ -0,0 +1,192 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import functools
+import numpy as np
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert (
+ _LOCAL_PROCESS_GROUP is not None
+ ), "Local process group is not created! Please use launch() to spawn processes!"
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ if dist.get_backend() == dist.Backend.NCCL:
+ # This argument is needed to avoid warnings.
+ # It's valid only for NCCL backend.
+ dist.barrier(device_ids=[torch.cuda.current_device()])
+ else:
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = (
+ _get_global_gloo_group()
+ ) # use CPU group by default, to reduce GPU RAM usage.
+ world_size = dist.get_world_size(group)
+ if world_size == 1:
+ return [data]
+
+ output = [None for _ in range(world_size)]
+ dist.all_gather_object(output, data, group=group)
+ return output
+
+
+def gather(data, dst=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ world_size = dist.get_world_size(group=group)
+ if world_size == 1:
+ return [data]
+ rank = dist.get_rank(group=group)
+
+ if rank == dst:
+ output = [None for _ in range(world_size)]
+ dist.gather_object(data, output, dst=dst, group=group)
+ return output
+ else:
+ dist.gather_object(data, None, dst=dst, group=group)
+ return []
+
+
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2**31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+ Args:
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ average (bool): whether to do average or sum
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/utils/config.py b/utils/config.py
new file mode 100644
index 0000000..0513d64
--- /dev/null
+++ b/utils/config.py
@@ -0,0 +1,696 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+import ast
+import copy
+import os
+import os.path as osp
+import platform
+import shutil
+import sys
+import tempfile
+import uuid
+import warnings
+from argparse import Action, ArgumentParser
+from collections import abc
+from importlib import import_module
+
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+
+from .misc import import_modules_from_strings
+from .path import check_file_exist
+
+if platform.system() == "Windows":
+ import regex as re
+else:
+ import re
+
+BASE_KEY = "_base_"
+DELETE_KEY = "_delete_"
+DEPRECATION_KEY = "_deprecation_"
+RESERVED_KEYS = ["filename", "text", "pretty_text"]
+
+
+class ConfigDict(Dict):
+ def __missing__(self, name):
+ raise KeyError(name)
+
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(
+ f"'{self.__class__.__name__}' object has no " f"attribute '{name}'"
+ )
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+
+
+def add_args(parser, cfg, prefix=""):
+ for k, v in cfg.items():
+ if isinstance(v, str):
+ parser.add_argument("--" + prefix + k)
+ elif isinstance(v, int):
+ parser.add_argument("--" + prefix + k, type=int)
+ elif isinstance(v, float):
+ parser.add_argument("--" + prefix + k, type=float)
+ elif isinstance(v, bool):
+ parser.add_argument("--" + prefix + k, action="store_true")
+ elif isinstance(v, dict):
+ add_args(parser, v, prefix + k + ".")
+ elif isinstance(v, abc.Iterable):
+ parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+")
+ else:
+ print(f"cannot parse key {prefix + k} of type {type(v)}")
+ return parser
+
+
+class Config:
+ """A facility for config and config files.
+
+ It supports common file formats as configs: python/json/yaml. The interface
+ is the same as a dict object and also allows access config values as
+ attributes.
+
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+ """
+
+ @staticmethod
+ def _validate_py_syntax(filename):
+ with open(filename, "r", encoding="utf-8") as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError as e:
+ raise SyntaxError(
+ "There are syntax errors in config " f"file {filename}: {e}"
+ )
+
+ @staticmethod
+ def _substitute_predefined_vars(filename, temp_config_name):
+ file_dirname = osp.dirname(filename)
+ file_basename = osp.basename(filename)
+ file_basename_no_extension = osp.splitext(file_basename)[0]
+ file_extname = osp.splitext(filename)[1]
+ support_templates = dict(
+ fileDirname=file_dirname,
+ fileBasename=file_basename,
+ fileBasenameNoExtension=file_basename_no_extension,
+ fileExtname=file_extname,
+ )
+ with open(filename, "r", encoding="utf-8") as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ for key, value in support_templates.items():
+ regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
+ value = value.replace("\\", "/")
+ config_file = re.sub(regexp, value, config_file)
+ with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
+ tmp_config_file.write(config_file)
+
+ @staticmethod
+ def _pre_substitute_base_vars(filename, temp_config_name):
+ """Substitute base variable placehoders to string, so that parsing
+ would work."""
+ with open(filename, "r", encoding="utf-8") as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ base_var_dict = {}
+ regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
+ base_vars = set(re.findall(regexp, config_file))
+ for base_var in base_vars:
+ randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
+ base_var_dict[randstr] = base_var
+ regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
+ with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
+ tmp_config_file.write(config_file)
+ return base_var_dict
+
+ @staticmethod
+ def _substitute_base_vars(cfg, base_var_dict, base_cfg):
+ """Substitute variable strings to their actual values."""
+ cfg = copy.deepcopy(cfg)
+
+ if isinstance(cfg, dict):
+ for k, v in cfg.items():
+ if isinstance(v, str) and v in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[v].split("."):
+ new_v = new_v[new_k]
+ cfg[k] = new_v
+ elif isinstance(v, (list, tuple, dict)):
+ cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
+ elif isinstance(cfg, tuple):
+ cfg = tuple(
+ Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
+ )
+ elif isinstance(cfg, list):
+ cfg = [
+ Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg
+ ]
+ elif isinstance(cfg, str) and cfg in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[cfg].split("."):
+ new_v = new_v[new_k]
+ cfg = new_v
+
+ return cfg
+
+ @staticmethod
+ def _file2dict(filename, use_predefined_variables=True):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ fileExtname = osp.splitext(filename)[1]
+ if fileExtname not in [".py", ".json", ".yaml", ".yml"]:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(
+ dir=temp_config_dir, suffix=fileExtname
+ )
+ if platform.system() == "Windows":
+ temp_config_file.close()
+ temp_config_name = osp.basename(temp_config_file.name)
+ # Substitute predefined variables
+ if use_predefined_variables:
+ Config._substitute_predefined_vars(filename, temp_config_file.name)
+ else:
+ shutil.copyfile(filename, temp_config_file.name)
+ # Substitute base variables from placeholders to strings
+ base_var_dict = Config._pre_substitute_base_vars(
+ temp_config_file.name, temp_config_file.name
+ )
+
+ if filename.endswith(".py"):
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ Config._validate_py_syntax(filename)
+ mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith("__")
+ }
+ # delete imported module
+ del sys.modules[temp_module_name]
+ elif filename.endswith((".yml", ".yaml", ".json")):
+ raise NotImplementedError
+ # close temp file
+ temp_config_file.close()
+
+ # check deprecation information
+ if DEPRECATION_KEY in cfg_dict:
+ deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
+ warning_msg = (
+ f"The config file {filename} will be deprecated " "in the future."
+ )
+ if "expected" in deprecation_info:
+ warning_msg += f' Please use {deprecation_info["expected"]} ' "instead."
+ if "reference" in deprecation_info:
+ warning_msg += (
+ " More information can be found at "
+ f'{deprecation_info["reference"]}'
+ )
+ warnings.warn(warning_msg)
+
+ cfg_text = filename + "\n"
+ with open(filename, "r", encoding="utf-8") as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict:
+ cfg_dir = osp.dirname(filename)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = (
+ base_filename if isinstance(base_filename, list) else [base_filename]
+ )
+
+ cfg_dict_list = list()
+ cfg_text_list = list()
+ for f in base_filename:
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+ cfg_text_list.append(_cfg_text)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ duplicate_keys = base_cfg_dict.keys() & c.keys()
+ if len(duplicate_keys) > 0:
+ raise KeyError(
+ "Duplicate key is not allowed among bases. "
+ f"Duplicate keys: {duplicate_keys}"
+ )
+ base_cfg_dict.update(c)
+
+ # Substitute base variables from strings to their actual values
+ cfg_dict = Config._substitute_base_vars(
+ cfg_dict, base_var_dict, base_cfg_dict
+ )
+
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
+ cfg_dict = base_cfg_dict
+
+ # merge cfg_text
+ cfg_text_list.append(cfg_text)
+ cfg_text = "\n".join(cfg_text_list)
+
+ return cfg_dict, cfg_text
+
+ @staticmethod
+ def _merge_a_into_b(a, b, allow_list_keys=False):
+ """merge dict ``a`` into dict ``b`` (non-inplace).
+
+ Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
+ in-place modifications.
+
+ Args:
+ a (dict): The source dict to be merged into ``b``.
+ b (dict): The origin dict to be fetch keys from ``a``.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in source ``a`` and will replace the element of the
+ corresponding index in b if b is a list. Default: False.
+
+ Returns:
+ dict: The modified dict of ``b`` using ``a``.
+
+ Examples:
+ # Normally merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # Delete b first and merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # b is a list
+ >>> Config._merge_a_into_b(
+ ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
+ [{'a': 2}, {'b': 2}]
+ """
+ b = b.copy()
+ for k, v in a.items():
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
+ k = int(k)
+ if len(b) <= k:
+ raise KeyError(f"Index {k} exceeds the length of list {b}")
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
+ allowed_types = (dict, list) if allow_list_keys else dict
+ if not isinstance(b[k], allowed_types):
+ raise TypeError(
+ f"{k}={v} in child config cannot inherit from base "
+ f"because {k} is a dict in the child config but is of "
+ f"type {type(b[k])} in base config. You may set "
+ f"`{DELETE_KEY}=True` to ignore the base config"
+ )
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ else:
+ b[k] = v
+ return b
+
+ @staticmethod
+ def fromfile(filename, use_predefined_variables=True, import_custom_modules=True):
+ cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
+ if import_custom_modules and cfg_dict.get("custom_imports", None):
+ import_modules_from_strings(**cfg_dict["custom_imports"])
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
+
+ @staticmethod
+ def fromstring(cfg_str, file_format):
+ """Generate config from config str.
+
+ Args:
+ cfg_str (str): Config str.
+ file_format (str): Config file format corresponding to the
+ config str. Only py/yml/yaml/json type are supported now!
+
+ Returns:
+ obj:`Config`: Config obj.
+ """
+ if file_format not in [".py", ".json", ".yaml", ".yml"]:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+ if file_format != ".py" and "dict(" in cfg_str:
+ # check if users specify a wrong suffix for python
+ warnings.warn('Please check "file_format", the file format may be .py')
+ with tempfile.NamedTemporaryFile(
+ "w", encoding="utf-8", suffix=file_format, delete=False
+ ) as temp_file:
+ temp_file.write(cfg_str)
+ # on windows, previous implementation cause error
+ # see PR 1077 for details
+ cfg = Config.fromfile(temp_file.name)
+ os.remove(temp_file.name)
+ return cfg
+
+ @staticmethod
+ def auto_argparser(description=None):
+ """Generate argparser from config file automatically (experimental)"""
+ partial_parser = ArgumentParser(description=description)
+ partial_parser.add_argument("config", help="config file path")
+ cfg_file = partial_parser.parse_known_args()[0].config
+ cfg = Config.fromfile(cfg_file)
+ parser = ArgumentParser(description=description)
+ parser.add_argument("config", help="config file path")
+ add_args(parser, cfg)
+ return parser, cfg
+
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
+ for key in cfg_dict:
+ if key in RESERVED_KEYS:
+ raise KeyError(f"{key} is reserved for config file")
+
+ super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
+ super(Config, self).__setattr__("_filename", filename)
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, "r") as f:
+ text = f.read()
+ else:
+ text = ""
+ super(Config, self).__setattr__("_text", text)
+
+ @property
+ def filename(self):
+ return self._filename
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def pretty_text(self):
+ indent = 4
+
+ def _indent(s_, num_spaces):
+ s = s_.split("\n")
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(num_spaces * " ") + line for line in s]
+ s = "\n".join(s)
+ s = first + "\n" + s
+ return s
+
+ def _format_basic_types(k, v, use_mapping=False):
+ if isinstance(v, str):
+ v_str = f"'{v}'"
+ else:
+ v_str = str(v)
+
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f"{k_str}: {v_str}"
+ else:
+ attr_str = f"{str(k)}={v_str}"
+ attr_str = _indent(attr_str, indent)
+
+ return attr_str
+
+ def _format_list(k, v, use_mapping=False):
+ # check if all items in the list are dict
+ if all(isinstance(_, dict) for _ in v):
+ v_str = "[\n"
+ v_str += "\n".join(
+ f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
+ ).rstrip(",")
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f"{k_str}: {v_str}"
+ else:
+ attr_str = f"{str(k)}={v_str}"
+ attr_str = _indent(attr_str, indent) + "]"
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping)
+ return attr_str
+
+ def _contain_invalid_identifier(dict_str):
+ contain_invalid_identifier = False
+ for key_name in dict_str:
+ contain_invalid_identifier |= not str(key_name).isidentifier()
+ return contain_invalid_identifier
+
+ def _format_dict(input_dict, outest_level=False):
+ r = ""
+ s = []
+
+ use_mapping = _contain_invalid_identifier(input_dict)
+ if use_mapping:
+ r += "{"
+ for idx, (k, v) in enumerate(input_dict.items()):
+ is_last = idx >= len(input_dict) - 1
+ end = "" if outest_level or is_last else ","
+ if isinstance(v, dict):
+ v_str = "\n" + _format_dict(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f"{k_str}: dict({v_str}"
+ else:
+ attr_str = f"{str(k)}=dict({v_str}"
+ attr_str = _indent(attr_str, indent) + ")" + end
+ elif isinstance(v, list):
+ attr_str = _format_list(k, v, use_mapping) + end
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping) + end
+
+ s.append(attr_str)
+ r += "\n".join(s)
+ if use_mapping:
+ r += "}"
+ return r
+
+ cfg_dict = self._cfg_dict.to_dict()
+ text = _format_dict(cfg_dict, outest_level=True)
+ # copied from setup.cfg
+ yapf_style = dict(
+ based_on_style="pep8",
+ blank_line_before_nested_class_or_def=True,
+ split_before_expression_after_opening_paren=True,
+ )
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+
+ return text
+
+ def __repr__(self):
+ return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
+
+ def __len__(self):
+ return len(self._cfg_dict)
+
+ def __getattr__(self, name):
+ return getattr(self._cfg_dict, name)
+
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+
+ def __iter__(self):
+ return iter(self._cfg_dict)
+
+ def __getstate__(self):
+ return (self._cfg_dict, self._filename, self._text)
+
+ def __setstate__(self, state):
+ _cfg_dict, _filename, _text = state
+ super(Config, self).__setattr__("_cfg_dict", _cfg_dict)
+ super(Config, self).__setattr__("_filename", _filename)
+ super(Config, self).__setattr__("_text", _text)
+
+ def dump(self, file=None):
+ cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict()
+ if self.filename.endswith(".py"):
+ if file is None:
+ return self.pretty_text
+ else:
+ with open(file, "w", encoding="utf-8") as f:
+ f.write(self.pretty_text)
+ else:
+ import mmcv
+
+ if file is None:
+ file_format = self.filename.split(".")[-1]
+ return mmcv.dump(cfg_dict, file_format=file_format)
+ else:
+ mmcv.dump(cfg_dict, file)
+
+ def merge_from_dict(self, options, allow_list_keys=True):
+ """Merge list into cfg_dict.
+
+ Merge the dict parsed by MultipleKVAction into this cfg.
+
+ Examples:
+ >>> options = {'models.backbone.depth': 50,
+ ... 'models.backbone.with_cp':True}
+ >>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet'))))
+ >>> cfg.merge_from_dict(options)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(
+ ... models=dict(backbone=dict(depth=50, with_cp=True)))
+
+ # Merge list element
+ >>> cfg = Config(dict(pipeline=[
+ ... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
+ >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
+ >>> cfg.merge_from_dict(options, allow_list_keys=True)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(pipeline=[
+ ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
+
+ Args:
+ options (dict): dict of configs to merge from.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in ``options`` and will replace the element of the
+ corresponding index in the config if the config is a list.
+ Default: True.
+ """
+ option_cfg_dict = {}
+ for full_key, v in options.items():
+ d = option_cfg_dict
+ key_list = full_key.split(".")
+ for subkey in key_list[:-1]:
+ d.setdefault(subkey, ConfigDict())
+ d = d[subkey]
+ subkey = key_list[-1]
+ d[subkey] = v
+
+ cfg_dict = super(Config, self).__getattribute__("_cfg_dict")
+ super(Config, self).__setattr__(
+ "_cfg_dict",
+ Config._merge_a_into_b(
+ option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys
+ ),
+ )
+
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options can
+ be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
+ brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
+ list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ["true", "false"]:
+ return True if val.lower() == "true" else False
+ return val
+
+ @staticmethod
+ def _parse_iterable(val):
+ """Parse iterable values in the string.
+
+ All elements inside '()' or '[]' are treated as iterable values.
+
+ Args:
+ val (str): Value string.
+
+ Returns:
+ list | tuple: The expanded list or tuple from the string.
+
+ Examples:
+ >>> DictAction._parse_iterable('1,2,3')
+ [1, 2, 3]
+ >>> DictAction._parse_iterable('[a, b, c]')
+ ['a', 'b', 'c']
+ >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
+ [(1, 2, 3), ['a', 'b'], 'c']
+ """
+
+ def find_next_comma(string):
+ """Find the position of next comma in the string.
+
+ If no ',' is found in the string, return the string length. All
+ chars inside '()' and '[]' are treated as one element and thus ','
+ inside these brackets are ignored.
+ """
+ assert (string.count("(") == string.count(")")) and (
+ string.count("[") == string.count("]")
+ ), f"Imbalanced brackets exist in {string}"
+ end = len(string)
+ for idx, char in enumerate(string):
+ pre = string[:idx]
+ # The string before this ',' is balanced
+ if (
+ (char == ",")
+ and (pre.count("(") == pre.count(")"))
+ and (pre.count("[") == pre.count("]"))
+ ):
+ end = idx
+ break
+ return end
+
+ # Strip ' and " characters and replace whitespace.
+ val = val.strip("'\"").replace(" ", "")
+ is_tuple = False
+ if val.startswith("(") and val.endswith(")"):
+ is_tuple = True
+ val = val[1:-1]
+ elif val.startswith("[") and val.endswith("]"):
+ val = val[1:-1]
+ elif "," not in val:
+ # val is a single value
+ return DictAction._parse_int_float_bool(val)
+
+ values = []
+ while len(val) > 0:
+ comma_idx = find_next_comma(val)
+ element = DictAction._parse_iterable(val[:comma_idx])
+ values.append(element)
+ val = val[comma_idx + 1 :]
+ if is_tuple:
+ values = tuple(values)
+ return values
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split("=", maxsplit=1)
+ options[key] = self._parse_iterable(val)
+ setattr(namespace, self.dest, options)
diff --git a/utils/env.py b/utils/env.py
new file mode 100644
index 0000000..802ed90
--- /dev/null
+++ b/utils/env.py
@@ -0,0 +1,33 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import os
+import random
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+
+from datetime import datetime
+
+
+def get_random_seed():
+ seed = (
+ os.getpid()
+ + int(datetime.now().strftime("%S%f"))
+ + int.from_bytes(os.urandom(2), "big")
+ )
+ return seed
+
+
+def set_seed(seed=None):
+ if seed is None:
+ seed = get_random_seed()
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+ os.environ["PYTHONHASHSEED"] = str(seed)
diff --git a/utils/events.py b/utils/events.py
new file mode 100644
index 0000000..90412dd
--- /dev/null
+++ b/utils/events.py
@@ -0,0 +1,585 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+
+import datetime
+import json
+import logging
+import os
+import time
+import torch
+import numpy as np
+
+from typing import List, Optional, Tuple
+from collections import defaultdict
+from contextlib import contextmanager
+
+__all__ = [
+ "get_event_storage",
+ "JSONWriter",
+ "TensorboardXWriter",
+ "CommonMetricPrinter",
+ "EventStorage",
+]
+
+_CURRENT_STORAGE_STACK = []
+
+
+def get_event_storage():
+ """
+ Returns:
+ The :class:`EventStorage` object that's currently being used.
+ Throws an error if no :class:`EventStorage` is currently enabled.
+ """
+ assert len(
+ _CURRENT_STORAGE_STACK
+ ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
+ return _CURRENT_STORAGE_STACK[-1]
+
+
+class EventWriter:
+ """
+ Base class for writers that obtain events from :class:`EventStorage` and process them.
+ """
+
+ def write(self):
+ raise NotImplementedError
+
+ def close(self):
+ pass
+
+
+class JSONWriter(EventWriter):
+ """
+ Write scalars to a json file.
+ It saves scalars as one json per line (instead of a big json) for easy parsing.
+ Examples parsing such a json file:
+ ::
+ $ cat metrics.json | jq -s '.[0:2]'
+ [
+ {
+ "data_time": 0.008433341979980469,
+ "iteration": 19,
+ "loss": 1.9228371381759644,
+ "loss_box_reg": 0.050025828182697296,
+ "loss_classifier": 0.5316952466964722,
+ "loss_mask": 0.7236229181289673,
+ "loss_rpn_box": 0.0856662318110466,
+ "loss_rpn_cls": 0.48198649287223816,
+ "lr": 0.007173333333333333,
+ "time": 0.25401854515075684
+ },
+ {
+ "data_time": 0.007216215133666992,
+ "iteration": 39,
+ "loss": 1.282649278640747,
+ "loss_box_reg": 0.06222952902317047,
+ "loss_classifier": 0.30682939291000366,
+ "loss_mask": 0.6970193982124329,
+ "loss_rpn_box": 0.038663312792778015,
+ "loss_rpn_cls": 0.1471673548221588,
+ "lr": 0.007706666666666667,
+ "time": 0.2490077018737793
+ }
+ ]
+ $ cat metrics.json | jq '.loss_mask'
+ 0.7126231789588928
+ 0.689423680305481
+ 0.6776131987571716
+ ...
+ """
+
+ def __init__(self, json_file, window_size=20):
+ """
+ Args:
+ json_file (str): path to the json file. New data will be appended if the file exists.
+ window_size (int): the window size of median smoothing for the scalars whose
+ `smoothing_hint` are True.
+ """
+ self._file_handle = open(json_file, "a")
+ self._window_size = window_size
+ self._last_write = -1
+
+ def write(self):
+ storage = get_event_storage()
+ to_save = defaultdict(dict)
+
+ for k, (v, iter) in storage.latest_with_smoothing_hint(
+ self._window_size
+ ).items():
+ # keep scalars that have not been written
+ if iter <= self._last_write:
+ continue
+ to_save[iter][k] = v
+ if len(to_save):
+ all_iters = sorted(to_save.keys())
+ self._last_write = max(all_iters)
+
+ for itr, scalars_per_iter in to_save.items():
+ scalars_per_iter["iteration"] = itr
+ self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
+ self._file_handle.flush()
+ try:
+ os.fsync(self._file_handle.fileno())
+ except AttributeError:
+ pass
+
+ def close(self):
+ self._file_handle.close()
+
+
+class TensorboardXWriter(EventWriter):
+ """
+ Write all scalars to a tensorboard file.
+ """
+
+ def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
+ """
+ Args:
+ log_dir (str): the directory to save the output events
+ window_size (int): the scalars will be median-smoothed by this window size
+ kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
+ """
+ self._window_size = window_size
+ from torch.utils.tensorboard import SummaryWriter
+
+ self._writer = SummaryWriter(log_dir, **kwargs)
+ self._last_write = -1
+
+ def write(self):
+ storage = get_event_storage()
+ new_last_write = self._last_write
+ for k, (v, iter) in storage.latest_with_smoothing_hint(
+ self._window_size
+ ).items():
+ if iter > self._last_write:
+ self._writer.add_scalar(k, v, iter)
+ new_last_write = max(new_last_write, iter)
+ self._last_write = new_last_write
+
+ # storage.put_{image,histogram} is only meant to be used by
+ # tensorboard writer. So we access its internal fields directly from here.
+ if len(storage._vis_data) >= 1:
+ for img_name, img, step_num in storage._vis_data:
+ self._writer.add_image(img_name, img, step_num)
+ # Storage stores all image data and rely on this writer to clear them.
+ # As a result it assumes only one writer will use its image data.
+ # An alternative design is to let storage store limited recent
+ # data (e.g. only the most recent image) that all writers can access.
+ # In that case a writer may not see all image data if its period is long.
+ storage.clear_images()
+
+ if len(storage._histograms) >= 1:
+ for params in storage._histograms:
+ self._writer.add_histogram_raw(**params)
+ storage.clear_histograms()
+
+ def close(self):
+ if hasattr(self, "_writer"): # doesn't exist when the code fails at import
+ self._writer.close()
+
+
+class CommonMetricPrinter(EventWriter):
+ """
+ Print **common** metrics to the terminal, including
+ iteration time, ETA, memory, all losses, and the learning rate.
+ It also applies smoothing using a window of 20 elements.
+ It's meant to print common metrics in common ways.
+ To print something in more customized ways, please implement a similar printer by yourself.
+ """
+
+ def __init__(self, max_iter: Optional[int] = None, window_size: int = 20):
+ """
+ Args:
+ max_iter: the maximum number of iterations to train.
+ Used to compute ETA. If not given, ETA will not be printed.
+ window_size (int): the losses will be median-smoothed by this window size
+ """
+ self.logger = logging.getLogger(__name__)
+ self._max_iter = max_iter
+ self._window_size = window_size
+ self._last_write = (
+ None # (step, time) of last call to write(). Used to compute ETA
+ )
+
+ def _get_eta(self, storage) -> Optional[str]:
+ if self._max_iter is None:
+ return ""
+ iteration = storage.iter
+ try:
+ eta_seconds = storage.history("time").median(1000) * (
+ self._max_iter - iteration - 1
+ )
+ storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
+ return str(datetime.timedelta(seconds=int(eta_seconds)))
+ except KeyError:
+ # estimate eta on our own - more noisy
+ eta_string = None
+ if self._last_write is not None:
+ estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
+ iteration - self._last_write[0]
+ )
+ eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ self._last_write = (iteration, time.perf_counter())
+ return eta_string
+
+ def write(self):
+ storage = get_event_storage()
+ iteration = storage.iter
+ if iteration == self._max_iter:
+ # This hook only reports training progress (loss, ETA, etc) but not other data,
+ # therefore do not write anything after training succeeds, even if this method
+ # is called.
+ return
+
+ try:
+ data_time = storage.history("data_time").avg(20)
+ except KeyError:
+ # they may not exist in the first few iterations (due to warmup)
+ # or when SimpleTrainer is not used
+ data_time = None
+ try:
+ iter_time = storage.history("time").global_avg()
+ except KeyError:
+ iter_time = None
+ try:
+ lr = "{:.5g}".format(storage.history("lr").latest())
+ except KeyError:
+ lr = "N/A"
+
+ eta_string = self._get_eta(storage)
+
+ if torch.cuda.is_available():
+ max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
+ else:
+ max_mem_mb = None
+
+ # NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
+ self.logger.info(
+ " {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
+ eta=f"eta: {eta_string} " if eta_string else "",
+ iter=iteration,
+ losses=" ".join(
+ [
+ "{}: {:.4g}".format(k, v.median(self._window_size))
+ for k, v in storage.histories().items()
+ if "loss" in k
+ ]
+ ),
+ time="time: {:.4f} ".format(iter_time)
+ if iter_time is not None
+ else "",
+ data_time="data_time: {:.4f} ".format(data_time)
+ if data_time is not None
+ else "",
+ lr=lr,
+ memory="max_mem: {:.0f}M".format(max_mem_mb)
+ if max_mem_mb is not None
+ else "",
+ )
+ )
+
+
+class EventStorage:
+ """
+ The user-facing class that provides metric storage functionalities.
+ In the future we may add support for storing / logging other types of data if needed.
+ """
+
+ def __init__(self, start_iter=0):
+ """
+ Args:
+ start_iter (int): the iteration number to start with
+ """
+ self._history = defaultdict(AverageMeter)
+ self._smoothing_hints = {}
+ self._latest_scalars = {}
+ self._iter = start_iter
+ self._current_prefix = ""
+ self._vis_data = []
+ self._histograms = []
+
+ # def put_image(self, img_name, img_tensor):
+ # """
+ # Add an `img_tensor` associated with `img_name`, to be shown on
+ # tensorboard.
+ # Args:
+ # img_name (str): The name of the image to put into tensorboard.
+ # img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
+ # Tensor of shape `[channel, height, width]` where `channel` is
+ # 3. The image format should be RGB. The elements in img_tensor
+ # can either have values in [0, 1] (float32) or [0, 255] (uint8).
+ # The `img_tensor` will be visualized in tensorboard.
+ # """
+ # self._vis_data.append((img_name, img_tensor, self._iter))
+
+ def put_scalar(self, name, value, n=1, smoothing_hint=False):
+ """
+ Add a scalar `value` to the `HistoryBuffer` associated with `name`.
+ Args:
+ smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
+ smoothed when logged. The hint will be accessible through
+ :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
+ and apply custom smoothing rule.
+ It defaults to True because most scalars we save need to be smoothed to
+ provide any useful signal.
+ """
+ name = self._current_prefix + name
+ history = self._history[name]
+ history.update(value, n)
+ self._latest_scalars[name] = (value, self._iter)
+
+ existing_hint = self._smoothing_hints.get(name)
+ if existing_hint is not None:
+ assert (
+ existing_hint == smoothing_hint
+ ), "Scalar {} was put with a different smoothing_hint!".format(name)
+ else:
+ self._smoothing_hints[name] = smoothing_hint
+
+ # def put_scalars(self, *, smoothing_hint=True, **kwargs):
+ # """
+ # Put multiple scalars from keyword arguments.
+ # Examples:
+ # storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
+ # """
+ # for k, v in kwargs.items():
+ # self.put_scalar(k, v, smoothing_hint=smoothing_hint)
+ #
+ # def put_histogram(self, hist_name, hist_tensor, bins=1000):
+ # """
+ # Create a histogram from a tensor.
+ # Args:
+ # hist_name (str): The name of the histogram to put into tensorboard.
+ # hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
+ # into a histogram.
+ # bins (int): Number of histogram bins.
+ # """
+ # ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
+ #
+ # # Create a histogram with PyTorch
+ # hist_counts = torch.histc(hist_tensor, bins=bins)
+ # hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
+ #
+ # # Parameter for the add_histogram_raw function of SummaryWriter
+ # hist_params = dict(
+ # tag=hist_name,
+ # min=ht_min,
+ # max=ht_max,
+ # num=len(hist_tensor),
+ # sum=float(hist_tensor.sum()),
+ # sum_squares=float(torch.sum(hist_tensor**2)),
+ # bucket_limits=hist_edges[1:].tolist(),
+ # bucket_counts=hist_counts.tolist(),
+ # global_step=self._iter,
+ # )
+ # self._histograms.append(hist_params)
+
+ def history(self, name):
+ """
+ Returns:
+ AverageMeter: the history for name
+ """
+ ret = self._history.get(name, None)
+ if ret is None:
+ raise KeyError("No history metric available for {}!".format(name))
+ return ret
+
+ def histories(self):
+ """
+ Returns:
+ dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
+ """
+ return self._history
+
+ def latest(self):
+ """
+ Returns:
+ dict[str -> (float, int)]: mapping from the name of each scalar to the most
+ recent value and the iteration number its added.
+ """
+ return self._latest_scalars
+
+ def latest_with_smoothing_hint(self, window_size=20):
+ """
+ Similar to :meth:`latest`, but the returned values
+ are either the un-smoothed original latest value,
+ or a median of the given window_size,
+ depend on whether the smoothing_hint is True.
+ This provides a default behavior that other writers can use.
+ """
+ result = {}
+ for k, (v, itr) in self._latest_scalars.items():
+ result[k] = (
+ self._history[k].median(window_size) if self._smoothing_hints[k] else v,
+ itr,
+ )
+ return result
+
+ def smoothing_hints(self):
+ """
+ Returns:
+ dict[name -> bool]: the user-provided hint on whether the scalar
+ is noisy and needs smoothing.
+ """
+ return self._smoothing_hints
+
+ def step(self):
+ """
+ User should either: (1) Call this function to increment storage.iter when needed. Or
+ (2) Set `storage.iter` to the correct iteration number before each iteration.
+ The storage will then be able to associate the new data with an iteration number.
+ """
+ self._iter += 1
+
+ @property
+ def iter(self):
+ """
+ Returns:
+ int: The current iteration number. When used together with a trainer,
+ this is ensured to be the same as trainer.iter.
+ """
+ return self._iter
+
+ @iter.setter
+ def iter(self, val):
+ self._iter = int(val)
+
+ @property
+ def iteration(self):
+ # for backward compatibility
+ return self._iter
+
+ def __enter__(self):
+ _CURRENT_STORAGE_STACK.append(self)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ assert _CURRENT_STORAGE_STACK[-1] == self
+ _CURRENT_STORAGE_STACK.pop()
+
+ @contextmanager
+ def name_scope(self, name):
+ """
+ Yields:
+ A context within which all the events added to this storage
+ will be prefixed by the name scope.
+ """
+ old_prefix = self._current_prefix
+ self._current_prefix = name.rstrip("/") + "/"
+ yield
+ self._current_prefix = old_prefix
+
+ def clear_images(self):
+ """
+ Delete all the stored images for visualization. This should be called
+ after images are written to tensorboard.
+ """
+ self._vis_data = []
+
+ def clear_histograms(self):
+ """
+ Delete all the stored histograms for visualization.
+ This should be called after histograms are written to tensorboard.
+ """
+ self._histograms = []
+
+ def reset_history(self, name):
+ ret = self._history.get(name, None)
+ if ret is None:
+ raise KeyError("No history metric available for {}!".format(name))
+ ret.reset()
+
+ def reset_histories(self):
+ for name in self._history.keys():
+ self._history[name].reset()
+
+
+class AverageMeter:
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.val = 0
+ self.avg = 0
+ self.total = 0
+ self.count = 0
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.total = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.total += val * n
+ self.count += n
+ self.avg = self.total / self.count
+
+
+class HistoryBuffer:
+ """
+ Track a series of scalar values and provide access to smoothed values over a
+ window or the global average of the series.
+ """
+
+ def __init__(self, max_length: int = 1000000) -> None:
+ """
+ Args:
+ max_length: maximal number of values that can be stored in the
+ buffer. When the capacity of the buffer is exhausted, old
+ values will be removed.
+ """
+ self._max_length: int = max_length
+ self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
+ self._count: int = 0
+ self._global_avg: float = 0
+
+ def update(self, value: float, iteration: Optional[float] = None) -> None:
+ """
+ Add a new scalar value produced at certain iteration. If the length
+ of the buffer exceeds self._max_length, the oldest element will be
+ removed from the buffer.
+ """
+ if iteration is None:
+ iteration = self._count
+ if len(self._data) == self._max_length:
+ self._data.pop(0)
+ self._data.append((value, iteration))
+
+ self._count += 1
+ self._global_avg += (value - self._global_avg) / self._count
+
+ def latest(self) -> float:
+ """
+ Return the latest scalar value added to the buffer.
+ """
+ return self._data[-1][0]
+
+ def median(self, window_size: int) -> float:
+ """
+ Return the median of the latest `window_size` values in the buffer.
+ """
+ return np.median([x[0] for x in self._data[-window_size:]])
+
+ def avg(self, window_size: int) -> float:
+ """
+ Return the mean of the latest `window_size` values in the buffer.
+ """
+ return np.mean([x[0] for x in self._data[-window_size:]])
+
+ def global_avg(self) -> float:
+ """
+ Return the mean of all the elements in the buffer. Note that this
+ includes those getting removed due to limited buffer storage.
+ """
+ return self._global_avg
+
+ def values(self) -> List[Tuple[float, float]]:
+ """
+ Returns:
+ list[(number, iteration)]: content of the current buffer.
+ """
+ return self._data
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000..6e30c5d
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,167 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import logging
+import torch
+import torch.distributed as dist
+
+from termcolor import colored
+
+logger_initialized = {}
+root_status = 0
+
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False):
+ """Initialize and get a logger by name.
+
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
+ will also be added.
+
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ file_mode (str): The file mode used in opening log file.
+ Defaults to 'a'.
+ color (bool): Colorful log output. Defaults to True
+
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ logger.propagate = False
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+
+ # only rank 0 will add a FileHandler
+ if rank == 0 and log_file is not None:
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
+ # provide an interface to change the file mode to the default
+ # behaviour.
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
+ )
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ )
+ else:
+ formatter = plain_formatter
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ if rank == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+
+ logger_initialized[name] = True
+
+ return logger
+
+
+def print_log(msg, logger=None, level=logging.INFO):
+ """Print a log message.
+
+ Args:
+ msg (str): The message to be logged.
+ logger (logging.Logger | str | None): The logger to be used.
+ Some special loggers are:
+ - "silent": no message will be printed.
+ - other str: the logger obtained with `get_root_logger(logger)`.
+ - None: The `print()` method will be used to print log messages.
+ level (int): Logging level. Only available when `logger` is a Logger
+ object or "root".
+ """
+ if logger is None:
+ print(msg)
+ elif isinstance(logger, logging.Logger):
+ logger.log(level, msg)
+ elif logger == "silent":
+ pass
+ elif isinstance(logger, str):
+ _logger = get_logger(logger)
+ _logger.log(level, msg)
+ else:
+ raise TypeError(
+ "logger should be either a logging.Logger object, str, "
+ f'"silent" or None, but got {type(logger)}'
+ )
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"):
+ """Get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name.
+
+ Args:
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+ file_mode (str): File Mode of logger. (w or a)
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = get_logger(
+ name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode
+ )
+ return logger
+
+
+def _log_api_usage(identifier: str):
+ """
+ Internal function used to log the usage of different detectron2 components
+ inside facebook's infra.
+ """
+ torch._C._log_api_usage_once("pointcept." + identifier)
diff --git a/utils/misc.py b/utils/misc.py
new file mode 100644
index 0000000..dbd257e
--- /dev/null
+++ b/utils/misc.py
@@ -0,0 +1,156 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import os
+import warnings
+from collections import abc
+import numpy as np
+import torch
+from importlib import import_module
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def intersection_and_union(output, target, K, ignore_index=-1):
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
+ assert output.ndim in [1, 2, 3]
+ assert output.shape == target.shape
+ output = output.reshape(output.size).copy()
+ target = target.reshape(target.size)
+ output[np.where(target == ignore_index)[0]] = ignore_index
+ intersection = output[np.where(output == target)[0]]
+ area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
+ area_output, _ = np.histogram(output, bins=np.arange(K + 1))
+ area_target, _ = np.histogram(target, bins=np.arange(K + 1))
+ area_union = area_output + area_target - area_intersection
+ return area_intersection, area_union, area_target
+
+
+def intersection_and_union_gpu(output, target, k, ignore_index=-1):
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
+ assert output.dim() in [1, 2, 3]
+ assert output.shape == target.shape
+ output = output.view(-1)
+ target = target.view(-1)
+ output[target == ignore_index] = ignore_index
+ intersection = output[output == target]
+ area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1)
+ area_output = torch.histc(output, bins=k, min=0, max=k - 1)
+ area_target = torch.histc(target, bins=k, min=0, max=k - 1)
+ area_union = area_output + area_target - area_intersection
+ return area_intersection, area_union, area_target
+
+
+def make_dirs(dir_name):
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name, exist_ok=True)
+
+
+def find_free_port():
+ import socket
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(("", 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+
+def is_seq_of(seq, expected_type, seq_type=None):
+ """Check whether it is a sequence of some type.
+
+ Args:
+ seq (Sequence): The sequence to be checked.
+ expected_type (type): Expected type of sequence items.
+ seq_type (type, optional): Expected sequence type.
+
+ Returns:
+ bool: Whether the sequence is valid.
+ """
+ if seq_type is None:
+ exp_seq_type = abc.Sequence
+ else:
+ assert isinstance(seq_type, type)
+ exp_seq_type = seq_type
+ if not isinstance(seq, exp_seq_type):
+ return False
+ for item in seq:
+ if not isinstance(item, expected_type):
+ return False
+ return True
+
+
+def is_str(x):
+ """Whether the input is an string instance.
+
+ Note: This method is deprecated since python 2 is no longer supported.
+ """
+ return isinstance(x, str)
+
+
+def import_modules_from_strings(imports, allow_failed_imports=False):
+ """Import modules from the given list of strings.
+
+ Args:
+ imports (list | str | None): The given module names to be imported.
+ allow_failed_imports (bool): If True, the failed imports will return
+ None. Otherwise, an ImportError is raise. Default: False.
+
+ Returns:
+ list[module] | module | None: The imported modules.
+
+ Examples:
+ >>> osp, sys = import_modules_from_strings(
+ ... ['os.path', 'sys'])
+ >>> import os.path as osp_
+ >>> import sys as sys_
+ >>> assert osp == osp_
+ >>> assert sys == sys_
+ """
+ if not imports:
+ return
+ single_import = False
+ if isinstance(imports, str):
+ single_import = True
+ imports = [imports]
+ if not isinstance(imports, list):
+ raise TypeError(f"custom_imports must be a list but got type {type(imports)}")
+ imported = []
+ for imp in imports:
+ if not isinstance(imp, str):
+ raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.")
+ try:
+ imported_tmp = import_module(imp)
+ except ImportError:
+ if allow_failed_imports:
+ warnings.warn(f"{imp} failed to import and is ignored.", UserWarning)
+ imported_tmp = None
+ else:
+ raise ImportError
+ imported.append(imported_tmp)
+ if single_import:
+ imported = imported[0]
+ return imported
diff --git a/utils/optimizer.py b/utils/optimizer.py
new file mode 100644
index 0000000..2eb70a3
--- /dev/null
+++ b/utils/optimizer.py
@@ -0,0 +1,52 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import torch
+from utils.logger import get_root_logger
+from utils.registry import Registry
+
+OPTIMIZERS = Registry("optimizers")
+
+
+OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD")
+OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam")
+OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW")
+
+
+def build_optimizer(cfg, model, param_dicts=None):
+ if param_dicts is None:
+ cfg.params = model.parameters()
+ else:
+ cfg.params = [dict(names=[], params=[], lr=cfg.lr)]
+ for i in range(len(param_dicts)):
+ param_group = dict(names=[], params=[])
+ if "lr" in param_dicts[i].keys():
+ param_group["lr"] = param_dicts[i].lr
+ if "momentum" in param_dicts[i].keys():
+ param_group["momentum"] = param_dicts[i].momentum
+ if "weight_decay" in param_dicts[i].keys():
+ param_group["weight_decay"] = param_dicts[i].weight_decay
+ cfg.params.append(param_group)
+
+ for n, p in model.named_parameters():
+ flag = False
+ for i in range(len(param_dicts)):
+ if param_dicts[i].keyword in n:
+ cfg.params[i + 1]["names"].append(n)
+ cfg.params[i + 1]["params"].append(p)
+ flag = True
+ break
+ if not flag:
+ cfg.params[0]["names"].append(n)
+ cfg.params[0]["params"].append(p)
+
+ logger = get_root_logger()
+ for i in range(len(cfg.params)):
+ param_names = cfg.params[i].pop("names")
+ message = ""
+ for key in cfg.params[i].keys():
+ if key != "params":
+ message += f" {key}: {cfg.params[i][key]};"
+ logger.info(f"Params Group {i+1} -{message} Params: {param_names}.")
+ return OPTIMIZERS.build(cfg=cfg)
diff --git a/utils/path.py b/utils/path.py
new file mode 100644
index 0000000..5d1da76
--- /dev/null
+++ b/utils/path.py
@@ -0,0 +1,105 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+import os
+import os.path as osp
+from pathlib import Path
+
+from .misc import is_str
+
+
+def is_filepath(x):
+ return is_str(x) or isinstance(x, Path)
+
+
+def fopen(filepath, *args, **kwargs):
+ if is_str(filepath):
+ return open(filepath, *args, **kwargs)
+ elif isinstance(filepath, Path):
+ return filepath.open(*args, **kwargs)
+ raise ValueError("`filepath` should be a string or a Path")
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+ if dir_name == "":
+ return
+ dir_name = osp.expanduser(dir_name)
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
+
+
+def symlink(src, dst, overwrite=True, **kwargs):
+ if os.path.lexists(dst) and overwrite:
+ os.remove(dst)
+ os.symlink(src, dst, **kwargs)
+
+
+def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str | obj:`Path`): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ case_sensitive (bool, optional) : If set to False, ignore the case of
+ suffix. Default: True.
+
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+ if isinstance(dir_path, (str, Path)):
+ dir_path = str(dir_path)
+ else:
+ raise TypeError('"dir_path" must be a string or Path object')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ if suffix is not None and not case_sensitive:
+ suffix = (
+ suffix.lower()
+ if isinstance(suffix, str)
+ else tuple(item.lower() for item in suffix)
+ )
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith(".") and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
+ if suffix is None or _rel_path.endswith(suffix):
+ yield rel_path
+ elif recursive and os.path.isdir(entry.path):
+ # scan recursively if entry.path is a directory
+ yield from _scandir(entry.path, suffix, recursive, case_sensitive)
+
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
+
+
+def find_vcs_root(path, markers=(".git",)):
+ """Finds the root directory (including itself) of specified markers.
+
+ Args:
+ path (str): Path of directory or file.
+ markers (list[str], optional): List of file or directory names.
+
+ Returns:
+ The directory contained one of the markers or None if not found.
+ """
+ if osp.isfile(path):
+ path = osp.dirname(path)
+
+ prev, cur = None, osp.abspath(osp.expanduser(path))
+ while cur != prev:
+ if any(osp.exists(osp.join(cur, marker)) for marker in markers):
+ return cur
+ prev, cur = cur, osp.split(cur)[0]
+ return None
diff --git a/utils/registry.py b/utils/registry.py
new file mode 100644
index 0000000..bd0e55c
--- /dev/null
+++ b/utils/registry.py
@@ -0,0 +1,318 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+import inspect
+import warnings
+from functools import partial
+
+from .misc import is_seq_of
+
+
+def build_from_cfg(cfg, registry, default_args=None):
+ """Build a module from configs dict.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ registry (:obj:`Registry`): The registry to search the type from.
+ default_args (dict, optional): Default initialization arguments.
+
+ Returns:
+ object: The constructed object.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
+ if "type" not in cfg:
+ if default_args is None or "type" not in default_args:
+ raise KeyError(
+ '`cfg` or `default_args` must contain the key "type", '
+ f"but got {cfg}\n{default_args}"
+ )
+ if not isinstance(registry, Registry):
+ raise TypeError(
+ "registry must be an mmcv.Registry object, " f"but got {type(registry)}"
+ )
+ if not (isinstance(default_args, dict) or default_args is None):
+ raise TypeError(
+ "default_args must be a dict or None, " f"but got {type(default_args)}"
+ )
+
+ args = cfg.copy()
+
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+
+ obj_type = args.pop("type")
+ if isinstance(obj_type, str):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError(f"{obj_type} is not in the {registry.name} registry")
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
+ try:
+ return obj_cls(**args)
+ except Exception as e:
+ # Normal TypeError does not print class name.
+ raise type(e)(f"{obj_cls.__name__}: {e}")
+
+
+class Registry:
+ """A registry to map strings to classes.
+
+ Registered object could be built from registry.
+ Example:
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = MODELS.build(dict(type='ResNet'))
+
+ Please refer to
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
+ advanced usage.
+
+ Args:
+ name (str): Registry name.
+ build_func(func, optional): Build function to construct instance from
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
+ ``build_func`` is specified. If ``parent`` is specified and
+ ``build_func`` is not given, ``build_func`` will be inherited
+ from ``parent``. Default: None.
+ parent (Registry, optional): Parent registry. The class registered in
+ children registry could be built from parent. Default: None.
+ scope (str, optional): The scope of registry. It is the key to search
+ for children registry. If not specified, scope will be the name of
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
+ Default: None.
+ """
+
+ def __init__(self, name, build_func=None, parent=None, scope=None):
+ self._name = name
+ self._module_dict = dict()
+ self._children = dict()
+ self._scope = self.infer_scope() if scope is None else scope
+
+ # self.build_func will be set with the following priority:
+ # 1. build_func
+ # 2. parent.build_func
+ # 3. build_from_cfg
+ if build_func is None:
+ if parent is not None:
+ self.build_func = parent.build_func
+ else:
+ self.build_func = build_from_cfg
+ else:
+ self.build_func = build_func
+ if parent is not None:
+ assert isinstance(parent, Registry)
+ parent._add_children(self)
+ self.parent = parent
+ else:
+ self.parent = None
+
+ def __len__(self):
+ return len(self._module_dict)
+
+ def __contains__(self, key):
+ return self.get(key) is not None
+
+ def __repr__(self):
+ format_str = (
+ self.__class__.__name__ + f"(name={self._name}, "
+ f"items={self._module_dict})"
+ )
+ return format_str
+
+ @staticmethod
+ def infer_scope():
+ """Infer the scope of registry.
+
+ The name of the package where registry is defined will be returned.
+
+ Example:
+ # in mmdet/models/backbone/resnet.py
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ The scope of ``ResNet`` will be ``mmdet``.
+
+
+ Returns:
+ scope (str): The inferred scope name.
+ """
+ # inspect.stack() trace where this function is called, the index-2
+ # indicates the frame where `infer_scope()` is called
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
+ split_filename = filename.split(".")
+ return split_filename[0]
+
+ @staticmethod
+ def split_scope_key(key):
+ """Split scope and key.
+
+ The first scope will be split from key.
+
+ Examples:
+ >>> Registry.split_scope_key('mmdet.ResNet')
+ 'mmdet', 'ResNet'
+ >>> Registry.split_scope_key('ResNet')
+ None, 'ResNet'
+
+ Return:
+ scope (str, None): The first scope.
+ key (str): The remaining key.
+ """
+ split_index = key.find(".")
+ if split_index != -1:
+ return key[:split_index], key[split_index + 1 :]
+ else:
+ return None, key
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def scope(self):
+ return self._scope
+
+ @property
+ def module_dict(self):
+ return self._module_dict
+
+ @property
+ def children(self):
+ return self._children
+
+ def get(self, key):
+ """Get the registry record.
+
+ Args:
+ key (str): The class name in string format.
+
+ Returns:
+ class: The corresponding class.
+ """
+ scope, real_key = self.split_scope_key(key)
+ if scope is None or scope == self._scope:
+ # get from self
+ if real_key in self._module_dict:
+ return self._module_dict[real_key]
+ else:
+ # get from self._children
+ if scope in self._children:
+ return self._children[scope].get(real_key)
+ else:
+ # goto root
+ parent = self.parent
+ while parent.parent is not None:
+ parent = parent.parent
+ return parent.get(key)
+
+ def build(self, *args, **kwargs):
+ return self.build_func(*args, **kwargs, registry=self)
+
+ def _add_children(self, registry):
+ """Add children for a registry.
+
+ The ``registry`` will be added as children based on its scope.
+ The parent registry could build objects from children registry.
+
+ Example:
+ >>> models = Registry('models')
+ >>> mmdet_models = Registry('models', parent=models)
+ >>> @mmdet_models.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
+ """
+
+ assert isinstance(registry, Registry)
+ assert registry.scope is not None
+ assert (
+ registry.scope not in self.children
+ ), f"scope {registry.scope} exists in {self.name} registry"
+ self.children[registry.scope] = registry
+
+ def _register_module(self, module_class, module_name=None, force=False):
+ if not inspect.isclass(module_class):
+ raise TypeError("module must be a class, " f"but got {type(module_class)}")
+
+ if module_name is None:
+ module_name = module_class.__name__
+ if isinstance(module_name, str):
+ module_name = [module_name]
+ for name in module_name:
+ if not force and name in self._module_dict:
+ raise KeyError(f"{name} is already registered " f"in {self.name}")
+ self._module_dict[name] = module_class
+
+ def deprecated_register_module(self, cls=None, force=False):
+ warnings.warn(
+ "The old API of register_module(module, force=False) "
+ "is deprecated and will be removed, please use the new API "
+ "register_module(name=None, force=False, module=None) instead."
+ )
+ if cls is None:
+ return partial(self.deprecated_register_module, force=force)
+ self._register_module(cls, force=force)
+ return cls
+
+ def register_module(self, name=None, force=False, module=None):
+ """Register a module.
+
+ A record will be added to `self._module_dict`, whose key is the class
+ name or the specified name, and value is the class itself.
+ It can be used as a decorator or a normal function.
+
+ Example:
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module()
+ >>> class ResNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module(name='mnet')
+ >>> class MobileNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> class ResNet:
+ >>> pass
+ >>> backbones.register_module(ResNet)
+
+ Args:
+ name (str | None): The module name to be registered. If not
+ specified, the class name will be used.
+ force (bool, optional): Whether to override an existing class with
+ the same name. Default: False.
+ module (type): Module class to be registered.
+ """
+ if not isinstance(force, bool):
+ raise TypeError(f"force must be a boolean, but got {type(force)}")
+ # NOTE: This is a walkaround to be compatible with the old api,
+ # while it may introduce unexpected bugs.
+ if isinstance(name, type):
+ return self.deprecated_register_module(name, force=force)
+
+ # raise the error ahead of time
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
+ raise TypeError(
+ "name must be either of None, an instance of str or a sequence"
+ f" of str, but got {type(name)}"
+ )
+
+ # use it as a normal method: x.register_module(module=SomeClass)
+ if module is not None:
+ self._register_module(module_class=module, module_name=name, force=force)
+ return module
+
+ # use it as a decorator: @x.register_module()
+ def _register(cls):
+ self._register_module(module_class=cls, module_name=name, force=force)
+ return cls
+
+ return _register
diff --git a/utils/scheduler.py b/utils/scheduler.py
new file mode 100644
index 0000000..bb31459
--- /dev/null
+++ b/utils/scheduler.py
@@ -0,0 +1,144 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import torch.optim.lr_scheduler as lr_scheduler
+from .registry import Registry
+
+SCHEDULERS = Registry("schedulers")
+
+
+@SCHEDULERS.register_module()
+class MultiStepLR(lr_scheduler.MultiStepLR):
+ def __init__(
+ self,
+ optimizer,
+ milestones,
+ total_steps,
+ gamma=0.1,
+ last_epoch=-1,
+ verbose=False,
+ ):
+ super().__init__(
+ optimizer=optimizer,
+ milestones=[rate * total_steps for rate in milestones],
+ gamma=gamma,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
+ def __init__(
+ self,
+ optimizer,
+ milestones,
+ total_steps,
+ gamma=0.1,
+ warmup_rate=0.05,
+ warmup_scale=1e-6,
+ last_epoch=-1,
+ verbose=False,
+ ):
+ milestones = [rate * total_steps for rate in milestones]
+
+ def multi_step_with_warmup(s):
+ factor = 1.0
+ for i in range(len(milestones)):
+ if s < milestones[i]:
+ break
+ factor *= gamma
+
+ if s <= warmup_rate * total_steps:
+ warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (
+ 1 - warmup_scale
+ )
+ else:
+ warmup_coefficient = 1.0
+ return warmup_coefficient * factor
+
+ super().__init__(
+ optimizer=optimizer,
+ lr_lambda=multi_step_with_warmup,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class PolyLR(lr_scheduler.LambdaLR):
+ def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
+ super().__init__(
+ optimizer=optimizer,
+ lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class ExpLR(lr_scheduler.LambdaLR):
+ def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
+ super().__init__(
+ optimizer=optimizer,
+ lr_lambda=lambda s: gamma ** (s / total_steps),
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
+ def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
+ super().__init__(
+ optimizer=optimizer,
+ T_max=total_steps,
+ eta_min=eta_min,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+@SCHEDULERS.register_module()
+class OneCycleLR(lr_scheduler.OneCycleLR):
+ r"""
+ torch.optim.lr_scheduler.OneCycleLR, Block total_steps
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ max_lr,
+ total_steps=None,
+ pct_start=0.3,
+ anneal_strategy="cos",
+ cycle_momentum=True,
+ base_momentum=0.85,
+ max_momentum=0.95,
+ div_factor=25.0,
+ final_div_factor=1e4,
+ three_phase=False,
+ last_epoch=-1,
+ verbose=False,
+ ):
+ super().__init__(
+ optimizer=optimizer,
+ max_lr=max_lr,
+ total_steps=total_steps,
+ pct_start=pct_start,
+ anneal_strategy=anneal_strategy,
+ cycle_momentum=cycle_momentum,
+ base_momentum=base_momentum,
+ max_momentum=max_momentum,
+ div_factor=div_factor,
+ final_div_factor=final_div_factor,
+ three_phase=three_phase,
+ last_epoch=last_epoch,
+ verbose=verbose,
+ )
+
+
+def build_scheduler(cfg, optimizer):
+ cfg.optimizer = optimizer
+ return SCHEDULERS.build(cfg=cfg)
diff --git a/utils/timer.py b/utils/timer.py
new file mode 100644
index 0000000..7b7e9cb
--- /dev/null
+++ b/utils/timer.py
@@ -0,0 +1,71 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+from time import perf_counter
+from typing import Optional
+
+
+class Timer:
+ """
+ A timer which computes the time elapsed since the start/reset of the timer.
+ """
+
+ def __init__(self) -> None:
+ self.reset()
+
+ def reset(self) -> None:
+ """
+ Reset the timer.
+ """
+ self._start = perf_counter()
+ self._paused: Optional[float] = None
+ self._total_paused = 0
+ self._count_start = 1
+
+ def pause(self) -> None:
+ """
+ Pause the timer.
+ """
+ if self._paused is not None:
+ raise ValueError("Trying to pause a Timer that is already paused!")
+ self._paused = perf_counter()
+
+ def is_paused(self) -> bool:
+ """
+ Returns:
+ bool: whether the timer is currently paused
+ """
+ return self._paused is not None
+
+ def resume(self) -> None:
+ """
+ Resume the timer.
+ """
+ if self._paused is None:
+ raise ValueError("Trying to resume a Timer that is not paused!")
+ # pyre-fixme[58]: `-` is not supported for operand types `float` and
+ # `Optional[float]`.
+ self._total_paused += perf_counter() - self._paused
+ self._paused = None
+ self._count_start += 1
+
+ def seconds(self) -> float:
+ """
+ Returns:
+ (float): the total number of seconds since the start/reset of the
+ timer, excluding the time when the timer is paused.
+ """
+ if self._paused is not None:
+ end_time: float = self._paused # type: ignore
+ else:
+ end_time = perf_counter()
+ return end_time - self._start - self._total_paused
+
+ def avg_seconds(self) -> float:
+ """
+ Returns:
+ (float): the average number of seconds between every start/reset and
+ pause.
+ """
+ return self.seconds() / self._count_start
diff --git a/utils/visualization.py b/utils/visualization.py
new file mode 100644
index 0000000..053cb64
--- /dev/null
+++ b/utils/visualization.py
@@ -0,0 +1,86 @@
+"""
+The code is base on https://github.com/Pointcept/Pointcept
+"""
+
+import os
+import open3d as o3d
+import numpy as np
+import torch
+
+
+def to_numpy(x):
+ if isinstance(x, torch.Tensor):
+ x = x.clone().detach().cpu().numpy()
+ assert isinstance(x, np.ndarray)
+ return x
+
+
+def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None):
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ coord = to_numpy(coord)
+ if color is not None:
+ color = to_numpy(color)
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(coord)
+ pcd.colors = o3d.utility.Vector3dVector(
+ np.ones_like(coord) if color is None else color
+ )
+ o3d.io.write_point_cloud(file_path, pcd)
+ if logger is not None:
+ logger.info(f"Save Point Cloud to: {file_path}")
+
+
+def save_bounding_boxes(
+ bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None
+):
+ bboxes_corners = to_numpy(bboxes_corners)
+ # point list
+ points = bboxes_corners.reshape(-1, 3)
+ # line list
+ box_lines = np.array(
+ [
+ [0, 1],
+ [1, 2],
+ [2, 3],
+ [3, 0],
+ [4, 5],
+ [5, 6],
+ [6, 7],
+ [7, 0],
+ [0, 4],
+ [1, 5],
+ [2, 6],
+ [3, 7],
+ ]
+ )
+ lines = []
+ for i, _ in enumerate(bboxes_corners):
+ lines.append(box_lines + i * 8)
+ lines = np.concatenate(lines)
+ # color list
+ color = np.array([color for _ in range(len(lines))])
+ # generate line set
+ line_set = o3d.geometry.LineSet()
+ line_set.points = o3d.utility.Vector3dVector(points)
+ line_set.lines = o3d.utility.Vector2iVector(lines)
+ line_set.colors = o3d.utility.Vector3dVector(color)
+ o3d.io.write_line_set(file_path, line_set)
+
+ if logger is not None:
+ logger.info(f"Save Boxes to: {file_path}")
+
+
+def save_lines(
+ points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None
+):
+ points = to_numpy(points)
+ lines = to_numpy(lines)
+ colors = np.array([color for _ in range(len(lines))])
+ line_set = o3d.geometry.LineSet()
+ line_set.points = o3d.utility.Vector3dVector(points)
+ line_set.lines = o3d.utility.Vector2iVector(lines)
+ line_set.colors = o3d.utility.Vector3dVector(colors)
+ o3d.io.write_line_set(file_path, line_set)
+
+ if logger is not None:
+ logger.info(f"Save Lines to: {file_path}")
diff --git a/wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl b/wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl
new file mode 100644
index 0000000..afd5c61
Binary files /dev/null and b/wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl differ