commit ca93dd05724a5701b99a9710fe9a89f77cf60ab5 Author: fdyuandong Date: Thu Apr 17 23:14:24 2025 +0800 feat: Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..73c532f --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +image/ +__pycache__ +**/build/ +**/*.egg-info/ +**/dist/ +*.so +exp +weights +data +log +outputs/ +.vscode +.idea +*/.DS_Store +TEMP/ +pretrained/ +**/*.out +Dockerfile \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f49a4e1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..7636ef6 --- /dev/null +++ b/README.md @@ -0,0 +1,103 @@ +# LAM-A2E: Audio to Expression + +[![Website](https://raw.githubusercontent.com/prs-eth/Marigold/main/doc/badges/badge-website.svg)](https://aigc3d.github.io/projects/LAM/) +[![Apache License](https://img.shields.io/badge/📃-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0) + +#### This project leverages audio input to generate ARKit blendshapes-driven facial expressions in ⚡real-time⚡, powering ultra-realistic 3D avatars generated by [LAM](https://github.com/aigc3d/LAM). + +## Demo + +
+ +
+ +## 📢 News + + +### To do list +- [ ] Release Huggingface space. +- [ ] Release Modelscope space. +- [ ] Release the LAM-A2E model based on the Flame expression. +- [ ] Release Interactive Chatting Avatar SDK with [OpenAvatarChat](https://github.com/HumanAIGC-Engineering/OpenAvatarChat), including LLM, ASR, TTS, LAM-Avatars. + + + +## 🚀 Get Started +### Environment Setup +```bash +git clone git@github.com:aigc3d/LAM_Audio2Expression.git +cd LAM_Audio2Expression +# Install with Cuda 12.1 +sh ./scripts/install/install_cu121.sh +# Or Install with Cuda 11.8 +sh ./scripts/install/install_cu118.sh +``` + + +### Download + +``` +# HuggingFace download +# Download Assets and Model Weights +huggingface-cli download 3DAIGC/LAM_audio2exp --local-dir ./ +tar -xzvf LAM_audio2exp_assets.tar && rm -f LAM_audio2exp_assets.tar +tar -xzvf LAM_audio2exp_streaming.tar && rm -f LAM_audio2exp_streaming.tar + + +# Or OSS Download (In case of HuggingFace download failing) +# Download Assets +wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/LAM_audio2exp_assets.tar +tar -xzvf LAM_audio2exp_assets.tar && rm -f LAM_audio2exp_assets.tar +# Download Model Weights +wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/data/LAM/LAM_audio2exp_streaming.tar +tar -xzvf LAM_audio2exp_streaming.tar && rm -f LAM_audio2exp_streaming.tar +``` + + +### Quick Start Guide +#### Using Gradio Interface: +We provide a simple Gradio demo with **WebGLGL Render**, and you can get rendering results by uploading audio in seconds. + +teaser + + + +``` +python app_lam_audio2exp.py +``` + +### Inference +```bash +# example: python inference.py --config-file configs/lam_audio2exp_config_streaming.py --options save_path=exp/audio2exp weight=pretrained_models/lam_audio2exp_streaming.tar audio_input=./assets/sample_audio/BarackObama_english.wav +python inference.py --config-file ${CONFIG_PATH} --options save_path=${SAVE_PATH} weight=${CHECKPOINT_PATH} audio_input=${AUDIO_INPUT} +``` + +### Acknowledgement +This work is built on many amazing research works and open-source projects: +- [FLAME](https://flame.is.tue.mpg.de) +- [FaceFormer](https://github.com/EvelynFan/FaceFormer) +- [Meshtalk](https://github.com/facebookresearch/meshtalk) +- [Unitalker](https://github.com/X-niper/UniTalker) +- [Pointcept](https://github.com/Pointcept/Pointcept) + +Thanks for their excellent works and great contribution. + + +### Related Works +Welcome to follow our other interesting works: +- [LAM](https://github.com/aigc3d/LAM) +- [LHM](https://github.com/aigc3d/LHM) + + +### Citation +``` +@inproceedings{he2025LAM, + title={LAM: Large Avatar Model for One-shot Animatable Gaussian Head}, + author={ + Yisheng He and Xiaodong Gu and Xiaodan Ye and Chao Xu and Zhengyi Zhao and Yuan Dong and Weihao Yuan and Zilong Dong and Liefeng Bo + }, + booktitle={arXiv preprint arXiv:2502.17796}, + year={2025} +} +``` diff --git a/app_lam_audio2exp.py b/app_lam_audio2exp.py new file mode 100644 index 0000000..96ce483 --- /dev/null +++ b/app_lam_audio2exp.py @@ -0,0 +1,271 @@ +""" +Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import base64 + +import gradio as gr +import argparse +from omegaconf import OmegaConf +from gradio_gaussian_render import gaussian_render + +from engines.defaults import ( + default_argument_parser, + default_config_parser, + default_setup, +) +from engines.infer import INFER +from pathlib import Path + +try: + import spaces +except: + pass + +h5_rendering = True + + +def assert_input_image(input_image): + if input_image is None: + raise gr.Error('No image selected or uploaded!') + + +def prepare_working_dir(): + import tempfile + working_dir = tempfile.TemporaryDirectory() + return working_dir + +def get_image_base64(path): + with open(path, 'rb') as image_file: + encoded_string = base64.b64encode(image_file.read()).decode() + return f'data:image/png;base64,{encoded_string}' + + +def doRender(): + print('H5 rendering ....') + +def parse_configs(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str) + parser.add_argument("--infer", type=str) + args, unknown = parser.parse_known_args() + + cfg = OmegaConf.create() + cli_cfg = OmegaConf.from_cli(unknown) + + # parse from ENV + if os.environ.get("APP_INFER") is not None: + args.infer = os.environ.get("APP_INFER") + if os.environ.get("APP_MODEL_NAME") is not None: + cli_cfg.model_name = os.environ.get("APP_MODEL_NAME") + + args.config = args.infer if args.config is None else args.config + + if args.config is not None: + cfg_train = OmegaConf.load(args.config) + cfg.source_size = cfg_train.dataset.source_image_res + try: + cfg.src_head_size = cfg_train.dataset.src_head_size + except: + cfg.src_head_size = 112 + cfg.render_size = cfg_train.dataset.render_image.high + _relative_path = os.path.join( + cfg_train.experiment.parent, + cfg_train.experiment.child, + os.path.basename(cli_cfg.model_name).split("_")[-1], + ) + + cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path) + cfg.image_dump = os.path.join("exps", "images", _relative_path) + cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path + + if args.infer is not None: + cfg_infer = OmegaConf.load(args.infer) + cfg.merge_with(cfg_infer) + cfg.setdefault( + "save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp") + ) + cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images")) + cfg.setdefault( + "video_dump", os.path.join("dumps", cli_cfg.model_name, "videos") + ) + cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes")) + + cfg.motion_video_read_fps = 30 + cfg.merge_with(cli_cfg) + + cfg.setdefault("logger", "INFO") + + assert cfg.model_name is not None, "model_name is required" + + return cfg, cfg_train + + +def create_zip_archive(output_zip='assets/arkitWithBSData.zip', base_dir=""): + import os + if (os.path.exists(output_zip)): + os.remove(output_zip) + print(f"Reomve previous file: {output_zip}") + run_command = 'zip -r '+output_zip+' '+base_dir + os.system(run_command) + # check file + if(os.path.exists(output_zip)): + print(f"Archive created successfully: {output_zip}") + else: + raise ValueError(f"Archive created failed: {output_zip}") + + +def demo_lam_audio2exp(infer, cfg): + def core_fn(image_path: str, audio_params, working_dir): + + base_id = os.path.basename(image_path).split(".")[0] + + # set input audio + cfg.audio_input = audio_params + cfg.save_json_path = os.path.join("./assets/sample_lam", base_id, 'arkitWithBSData', 'bsData.json') + infer.infer() + + create_zip_archive(output_zip='./assets/arkitWithBSData.zip', base_dir=os.path.join("./assets/sample_lam", base_id)) + + return + + with gr.Blocks(analytics_enabled=False) as demo: + logo_url = './assets/images/logo.jpeg' + logo_base64 = get_image_base64(logo_url) + gr.HTML(f""" +
+
+

LAM-A2E: Audio to Expression

+
+
+ """) + + gr.HTML( + """

Notes: This project leverages audio input to generate ARKit blendshapes-driven facial expressions in ⚡real-time⚡, powering ultra-realistic 3D avatars generated by LAM.

""" + ) + + # DISPLAY + with gr.Row(): + with gr.Column(variant='panel', scale=1): + with gr.Tabs(elem_id='lam_input_image'): + with gr.TabItem('Input Image'): + with gr.Row(): + input_image = gr.Image(label='Input Image', + image_mode='RGB', + height=480, + width=270, + sources='upload', + type='filepath', # 'numpy', + elem_id='content_image') + # EXAMPLES + with gr.Row(): + examples = [ + ['assets/sample_input/barbara.jpg'], + ['assets/sample_input/status.png'], + ['assets/sample_input/james.png'], + ['assets/sample_input/vfhq_case1.png'], + ] + gr.Examples( + examples=examples, + inputs=[input_image], + examples_per_page=20, + ) + + with gr.Column(): + with gr.Tabs(elem_id='lam_input_audio'): + with gr.TabItem('Input Audio'): + with gr.Row(): + audio_input = gr.Audio(label='Input Audio', + type='filepath', + waveform_options={ + 'sample_rate': 16000, + 'waveform_progress_color': '#4682b4' + }, + elem_id='content_audio') + + examples = [ + ['assets/sample_audio/Nangyanwen_chinese.wav'], + ['assets/sample_audio/LiBai_TTS_chinese.wav'], + ['assets/sample_audio/LinJing_TTS_chinese.wav'], + ['assets/sample_audio/BarackObama_english.wav'], + ['assets/sample_audio/HillaryClinton_english.wav'], + ['assets/sample_audio/XitongShi_japanese.wav'], + ['assets/sample_audio/FangXiao_japanese.wav'], + ] + gr.Examples( + examples=examples, + inputs=[audio_input], + examples_per_page=10, + ) + + # SETTING + with gr.Row(): + with gr.Column(variant='panel', scale=1): + submit = gr.Button('Generate', + elem_id='lam_generate', + variant='primary') + + if h5_rendering: + gr.set_static_paths(Path.cwd().absolute() / "assets/") + assetPrefix = 'gradio_api/file=assets/' + with gr.Row(): + gs = gaussian_render(width=380, height=680, assets=assetPrefix + 'arkitWithBSData.zip') + + working_dir = gr.State() + submit.click( + fn=assert_input_image, + inputs=[input_image], + queue=False, + ).success( + fn=prepare_working_dir, + outputs=[working_dir], + queue=False, + ).success( + fn=core_fn, + inputs=[input_image, audio_input, + working_dir], # video_params refer to smpl dir + outputs=[], + queue=False, + ).success( + doRender, js='''() => window.start()''' + ) + + demo.queue() + demo.launch() + + + +def launch_gradio_app(): + os.environ.update({ + 'APP_ENABLED': '1', + 'APP_MODEL_NAME':'', + 'APP_INFER': 'configs/lam_audio2exp_streaming_config.py', + 'APP_TYPE': 'infer.audio2exp', + 'NUMBA_THREADING_LAYER': 'omp', + }) + + args = default_argument_parser().parse_args() + args.config_file = 'configs/lam_audio2exp_config_streaming.py' + cfg = default_config_parser(args.config_file, args.options) + cfg = default_setup(cfg) + + cfg.ex_vol = True + infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg)) + + demo_lam_audio2exp(infer, cfg) + + +if __name__ == '__main__': + launch_gradio_app() diff --git a/assets/images/logo.jpeg b/assets/images/logo.jpeg new file mode 100644 index 0000000..6fa8d78 Binary files /dev/null and b/assets/images/logo.jpeg differ diff --git a/assets/images/snapshot.png b/assets/images/snapshot.png new file mode 100644 index 0000000..8fc9bc9 Binary files /dev/null and b/assets/images/snapshot.png differ diff --git a/assets/images/teaser.jpg b/assets/images/teaser.jpg new file mode 100644 index 0000000..8c7c406 Binary files /dev/null and b/assets/images/teaser.jpg differ diff --git a/configs/lam_audio2exp_config.py b/configs/lam_audio2exp_config.py new file mode 100644 index 0000000..a1e4abb --- /dev/null +++ b/configs/lam_audio2exp_config.py @@ -0,0 +1,92 @@ +weight = 'pretrained_models/lam_audio2exp.tar' # path to model weight +ex_vol = True # Isolates vocal track from audio file +audio_input = './assets/sample_audio/BarackObama.wav' +save_json_path = 'bsData.json' + +audio_sr = 16000 +fps = 30.0 + +movement_smooth = True +brow_movement = True +id_idx = 153 + +resume = False # whether to resume training process +evaluate = True # evaluate after each epoch training process +test_only = False # test process + +seed = None # train process will init a random seed and record +save_path = "exp/audio2exp" +num_worker = 16 # total worker in all gpu +batch_size = 16 # total batch size in all gpu +batch_size_val = None # auto adapt to bs 1 for each gpu +batch_size_test = None # auto adapt to bs 1 for each gpu +epoch = 100 # total epoch, data loop = epoch // eval_epoch +eval_epoch = 100 # sche total eval & checkpoint epoch + +sync_bn = False +enable_amp = False +empty_cache = False +find_unused_parameters = False + +mix_prob = 0 +param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)] + +# model settings +model = dict( + type="DefaultEstimator", + backbone=dict( + type="Audio2Expression", + pretrained_encoder_type='wav2vec', + pretrained_encoder_path='facebook/wav2vec2-base-960h', + wav2vec2_config_path = 'configs/wav2vec2_config.json', + num_identity_classes=5016, + identity_feat_dim=64, + hidden_dim=512, + expression_dim=52, + norm_type='ln', + use_transformer=True, + num_attention_heads=8, + num_transformer_layers=6, + ), + criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)], +) + +dataset_type = 'audio2exp' +data_root = './' +data = dict( + train=dict( + type=dataset_type, + split="train", + data_root=data_root, + test_mode=False, + ), + val=dict( + type=dataset_type, + split="val", + data_root=data_root, + test_mode=False, + ), + test=dict( + type=dataset_type, + split="val", + data_root=data_root, + test_mode=True + ), +) + +# hook +hooks = [ + dict(type="CheckpointLoader"), + dict(type="IterationTimer", warmup_iter=2), + dict(type="InformationWriter"), + dict(type="SemSegEvaluator"), + dict(type="CheckpointSaver", save_freq=None), + dict(type="PreciseEvaluator", test_last=False), +] + +# Trainer +train = dict(type="DefaultTrainer") + +# Tester +infer = dict(type="Audio2ExpressionInfer", + verbose=True) diff --git a/configs/lam_audio2exp_config_streaming.py b/configs/lam_audio2exp_config_streaming.py new file mode 100644 index 0000000..3f44b92 --- /dev/null +++ b/configs/lam_audio2exp_config_streaming.py @@ -0,0 +1,92 @@ +weight = 'pretrained_models/lam_audio2exp_streaming.tar' # path to model weight +ex_vol = True # extract +audio_input = './assets/sample_audio/BarackObama.wav' +save_json_path = 'bsData.json' + +audio_sr = 16000 +fps = 30.0 + +movement_smooth = False +brow_movement = False +id_idx = 0 + +resume = False # whether to resume training process +evaluate = True # evaluate after each epoch training process +test_only = False # test process + +seed = None # train process will init a random seed and record +save_path = "exp/audio2exp" +num_worker = 16 # total worker in all gpu +batch_size = 16 # total batch size in all gpu +batch_size_val = None # auto adapt to bs 1 for each gpu +batch_size_test = None # auto adapt to bs 1 for each gpu +epoch = 100 # total epoch, data loop = epoch // eval_epoch +eval_epoch = 100 # sche total eval & checkpoint epoch + +sync_bn = False +enable_amp = False +empty_cache = False +find_unused_parameters = False + +mix_prob = 0 +param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)] + +# model settings +model = dict( + type="DefaultEstimator", + backbone=dict( + type="Audio2Expression", + pretrained_encoder_type='wav2vec', + pretrained_encoder_path='facebook/wav2vec2-base-960h', + wav2vec2_config_path = 'configs/wav2vec2_config.json', + num_identity_classes=12, + identity_feat_dim=64, + hidden_dim=512, + expression_dim=52, + norm_type='ln', + use_transformer=False, + num_attention_heads=8, + num_transformer_layers=6, + ), + criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)], +) + +dataset_type = 'audio2exp' +data_root = './' +data = dict( + train=dict( + type=dataset_type, + split="train", + data_root=data_root, + test_mode=False, + ), + val=dict( + type=dataset_type, + split="val", + data_root=data_root, + test_mode=False, + ), + test=dict( + type=dataset_type, + split="val", + data_root=data_root, + test_mode=True + ), +) + +# hook +hooks = [ + dict(type="CheckpointLoader"), + dict(type="IterationTimer", warmup_iter=2), + dict(type="InformationWriter"), + dict(type="SemSegEvaluator"), + dict(type="CheckpointSaver", save_freq=None), + dict(type="PreciseEvaluator", test_last=False), +] + +# Trainer +train = dict(type="DefaultTrainer") + +# Tester +infer = dict(type="Audio2ExpressionInfer", + verbose=True) diff --git a/configs/wav2vec2_config.json b/configs/wav2vec2_config.json new file mode 100644 index 0000000..8ca9cc7 --- /dev/null +++ b/configs/wav2vec2_config.json @@ -0,0 +1,77 @@ +{ + "_name_or_path": "facebook/wav2vec2-base-960h", + "activation_dropout": 0.1, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2ForCTC" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "codevector_dim": 256, + "contrastive_logits_temperature": 0.1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "diversity_loss_weight": 0.1, + "do_stable_layer_norm": false, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "group", + "feat_proj_dropout": 0.1, + "feat_quantizer_dropout": 0.0, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_prob": 0.05, + "model_type": "wav2vec2", + "num_attention_heads": 12, + "num_codevector_groups": 2, + "num_codevectors_per_group": 320, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 12, + "num_negatives": 100, + "pad_token_id": 0, + "proj_codevector_dim": 256, + "transformers_version": "4.7.0.dev0", + "vocab_size": 32 +} diff --git a/engines/__init__.py b/engines/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/engines/defaults.py b/engines/defaults.py new file mode 100644 index 0000000..488148b --- /dev/null +++ b/engines/defaults.py @@ -0,0 +1,147 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import os +import sys +import argparse +import multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel + + +import utils.comm as comm +from utils.env import get_random_seed, set_seed +from utils.config import Config, DictAction + + +def create_ddp_model(model, *, fp16_compression=False, **kwargs): + """ + Create a DistributedDataParallel model if there are >1 processes. + Args: + model: a torch.nn.Module + fp16_compression: add fp16 compression hooks to the ddp object. + See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook + kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. + """ + if comm.get_world_size() == 1: + return model + # kwargs['find_unused_parameters'] = True + if "device_ids" not in kwargs: + kwargs["device_ids"] = [comm.get_local_rank()] + if "output_device" not in kwargs: + kwargs["output_device"] = [comm.get_local_rank()] + ddp = DistributedDataParallel(model, **kwargs) + if fp16_compression: + from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks + + ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) + return ddp + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Worker init func for dataloader. + + The seed of each worker equals to num_worker * rank + worker_id + user_seed + + Args: + worker_id (int): Worker id. + num_workers (int): Number of workers. + rank (int): The rank of current process. + seed (int): The random seed to use. + """ + + worker_seed = num_workers * rank + worker_id + seed + set_seed(worker_seed) + + +def default_argument_parser(epilog=None): + parser = argparse.ArgumentParser( + epilog=epilog + or f""" + Examples: + Run on single machine: + $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml + Change some config options: + $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 + Run on multiple machines: + (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags] + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--config-file", default="", metavar="FILE", help="path to config file" + ) + parser.add_argument( + "--num-gpus", type=int, default=1, help="number of gpus *per machine*" + ) + parser.add_argument( + "--num-machines", type=int, default=1, help="total number of machines" + ) + parser.add_argument( + "--machine-rank", + type=int, + default=0, + help="the rank of this machine (unique per machine)", + ) + # PyTorch still may leave orphan processes in multi-gpu training. + # Therefore we use a deterministic way to obtain port, + # so that users are aware of orphan processes by seeing the port occupied. + # port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 + parser.add_argument( + "--dist-url", + # default="tcp://127.0.0.1:{}".format(port), + default="auto", + help="initialization URL for pytorch distributed backend. See " + "https://pytorch.org/docs/stable/distributed.html for details.", + ) + parser.add_argument( + "--options", nargs="+", action=DictAction, help="custom options" + ) + return parser + + +def default_config_parser(file_path, options): + # config name protocol: dataset_name/model_name-exp_name + if os.path.isfile(file_path): + cfg = Config.fromfile(file_path) + else: + sep = file_path.find("-") + cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :])) + + if options is not None: + cfg.merge_from_dict(options) + + if cfg.seed is None: + cfg.seed = get_random_seed() + + cfg.data.train.loop = cfg.epoch // cfg.eval_epoch + + os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True) + if not cfg.resume: + cfg.dump(os.path.join(cfg.save_path, "config.py")) + return cfg + + +def default_setup(cfg): + # scalar by world size + world_size = comm.get_world_size() + cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count() + cfg.num_worker_per_gpu = cfg.num_worker // world_size + assert cfg.batch_size % world_size == 0 + assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0 + assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0 + cfg.batch_size_per_gpu = cfg.batch_size // world_size + cfg.batch_size_val_per_gpu = ( + cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1 + ) + cfg.batch_size_test_per_gpu = ( + cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1 + ) + # update data loop + assert cfg.epoch % cfg.eval_epoch == 0 + # settle random seed + rank = comm.get_rank() + seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank + set_seed(seed) + return cfg diff --git a/engines/hooks/__init__.py b/engines/hooks/__init__.py new file mode 100644 index 0000000..1ab2c4b --- /dev/null +++ b/engines/hooks/__init__.py @@ -0,0 +1,5 @@ +from .default import HookBase +from .misc import * +from .evaluator import * + +from .builder import build_hooks diff --git a/engines/hooks/builder.py b/engines/hooks/builder.py new file mode 100644 index 0000000..e0a121c --- /dev/null +++ b/engines/hooks/builder.py @@ -0,0 +1,15 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +from utils.registry import Registry + + +HOOKS = Registry("hooks") + + +def build_hooks(cfg): + hooks = [] + for hook_cfg in cfg: + hooks.append(HOOKS.build(hook_cfg)) + return hooks diff --git a/engines/hooks/default.py b/engines/hooks/default.py new file mode 100644 index 0000000..57150a7 --- /dev/null +++ b/engines/hooks/default.py @@ -0,0 +1,29 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + + +class HookBase: + """ + Base class for hooks that can be registered with :class:`TrainerBase`. + """ + + trainer = None # A weak reference to the trainer object. + + def before_train(self): + pass + + def before_epoch(self): + pass + + def before_step(self): + pass + + def after_step(self): + pass + + def after_epoch(self): + pass + + def after_train(self): + pass diff --git a/engines/hooks/evaluator.py b/engines/hooks/evaluator.py new file mode 100644 index 0000000..c0d2717 --- /dev/null +++ b/engines/hooks/evaluator.py @@ -0,0 +1,577 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import numpy as np +import torch +import torch.distributed as dist +from uuid import uuid4 + +import utils.comm as comm +from utils.misc import intersection_and_union_gpu + +from .default import HookBase +from .builder import HOOKS + + +@HOOKS.register_module() +class ClsEvaluator(HookBase): + def after_epoch(self): + if self.trainer.cfg.evaluate: + self.eval() + + def eval(self): + self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") + self.trainer.model.eval() + for i, input_dict in enumerate(self.trainer.val_loader): + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.no_grad(): + output_dict = self.trainer.model(input_dict) + output = output_dict["cls_logits"] + loss = output_dict["loss"] + pred = output.max(1)[1] + label = input_dict["category"] + intersection, union, target = intersection_and_union_gpu( + pred, + label, + self.trainer.cfg.data.num_classes, + self.trainer.cfg.data.ignore_index, + ) + if comm.get_world_size() > 1: + dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce( + target + ) + intersection, union, target = ( + intersection.cpu().numpy(), + union.cpu().numpy(), + target.cpu().numpy(), + ) + # Here there is no need to sync since sync happened in dist.all_reduce + self.trainer.storage.put_scalar("val_intersection", intersection) + self.trainer.storage.put_scalar("val_union", union) + self.trainer.storage.put_scalar("val_target", target) + self.trainer.storage.put_scalar("val_loss", loss.item()) + self.trainer.logger.info( + "Test: [{iter}/{max_iter}] " + "Loss {loss:.4f} ".format( + iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item() + ) + ) + loss_avg = self.trainer.storage.history("val_loss").avg + intersection = self.trainer.storage.history("val_intersection").total + union = self.trainer.storage.history("val_union").total + target = self.trainer.storage.history("val_target").total + iou_class = intersection / (union + 1e-10) + acc_class = intersection / (target + 1e-10) + m_iou = np.mean(iou_class) + m_acc = np.mean(acc_class) + all_acc = sum(intersection) / (sum(target) + 1e-10) + self.trainer.logger.info( + "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format( + m_iou, m_acc, all_acc + ) + ) + for i in range(self.trainer.cfg.data.num_classes): + self.trainer.logger.info( + "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( + idx=i, + name=self.trainer.cfg.data.names[i], + iou=iou_class[i], + accuracy=acc_class[i], + ) + ) + current_epoch = self.trainer.epoch + 1 + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch) + self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch) + self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch) + self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch) + self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + self.trainer.comm_info["current_metric_value"] = all_acc # save for saver + self.trainer.comm_info["current_metric_name"] = "allAcc" # save for saver + + def after_train(self): + self.trainer.logger.info( + "Best {}: {:.4f}".format("allAcc", self.trainer.best_metric_value) + ) + + +@HOOKS.register_module() +class SemSegEvaluator(HookBase): + def after_epoch(self): + if self.trainer.cfg.evaluate: + self.eval() + + def eval(self): + self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") + self.trainer.model.eval() + for i, input_dict in enumerate(self.trainer.val_loader): + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.no_grad(): + output_dict = self.trainer.model(input_dict) + output = output_dict["seg_logits"] + loss = output_dict["loss"] + pred = output.max(1)[1] + segment = input_dict["segment"] + if "origin_coord" in input_dict.keys(): + idx, _ = pointops.knn_query( + 1, + input_dict["coord"].float(), + input_dict["offset"].int(), + input_dict["origin_coord"].float(), + input_dict["origin_offset"].int(), + ) + pred = pred[idx.flatten().long()] + segment = input_dict["origin_segment"] + intersection, union, target = intersection_and_union_gpu( + pred, + segment, + self.trainer.cfg.data.num_classes, + self.trainer.cfg.data.ignore_index, + ) + if comm.get_world_size() > 1: + dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce( + target + ) + intersection, union, target = ( + intersection.cpu().numpy(), + union.cpu().numpy(), + target.cpu().numpy(), + ) + # Here there is no need to sync since sync happened in dist.all_reduce + self.trainer.storage.put_scalar("val_intersection", intersection) + self.trainer.storage.put_scalar("val_union", union) + self.trainer.storage.put_scalar("val_target", target) + self.trainer.storage.put_scalar("val_loss", loss.item()) + info = "Test: [{iter}/{max_iter}] ".format( + iter=i + 1, max_iter=len(self.trainer.val_loader) + ) + if "origin_coord" in input_dict.keys(): + info = "Interp. " + info + self.trainer.logger.info( + info + + "Loss {loss:.4f} ".format( + iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item() + ) + ) + loss_avg = self.trainer.storage.history("val_loss").avg + intersection = self.trainer.storage.history("val_intersection").total + union = self.trainer.storage.history("val_union").total + target = self.trainer.storage.history("val_target").total + iou_class = intersection / (union + 1e-10) + acc_class = intersection / (target + 1e-10) + m_iou = np.mean(iou_class) + m_acc = np.mean(acc_class) + all_acc = sum(intersection) / (sum(target) + 1e-10) + self.trainer.logger.info( + "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format( + m_iou, m_acc, all_acc + ) + ) + for i in range(self.trainer.cfg.data.num_classes): + self.trainer.logger.info( + "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( + idx=i, + name=self.trainer.cfg.data.names[i], + iou=iou_class[i], + accuracy=acc_class[i], + ) + ) + current_epoch = self.trainer.epoch + 1 + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch) + self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch) + self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch) + self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch) + self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + self.trainer.comm_info["current_metric_value"] = m_iou # save for saver + self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver + + def after_train(self): + self.trainer.logger.info( + "Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value) + ) + + +@HOOKS.register_module() +class InsSegEvaluator(HookBase): + def __init__(self, segment_ignore_index=(-1,), instance_ignore_index=-1): + self.segment_ignore_index = segment_ignore_index + self.instance_ignore_index = instance_ignore_index + + self.valid_class_names = None # update in before train + self.overlaps = np.append(np.arange(0.5, 0.95, 0.05), 0.25) + self.min_region_sizes = 100 + self.distance_threshes = float("inf") + self.distance_confs = -float("inf") + + def before_train(self): + self.valid_class_names = [ + self.trainer.cfg.data.names[i] + for i in range(self.trainer.cfg.data.num_classes) + if i not in self.segment_ignore_index + ] + + def after_epoch(self): + if self.trainer.cfg.evaluate: + self.eval() + + def associate_instances(self, pred, segment, instance): + segment = segment.cpu().numpy() + instance = instance.cpu().numpy() + void_mask = np.in1d(segment, self.segment_ignore_index) + + assert ( + pred["pred_classes"].shape[0] + == pred["pred_scores"].shape[0] + == pred["pred_masks"].shape[0] + ) + assert pred["pred_masks"].shape[1] == segment.shape[0] == instance.shape[0] + # get gt instances + gt_instances = dict() + for i in range(self.trainer.cfg.data.num_classes): + if i not in self.segment_ignore_index: + gt_instances[self.trainer.cfg.data.names[i]] = [] + instance_ids, idx, counts = np.unique( + instance, return_index=True, return_counts=True + ) + segment_ids = segment[idx] + for i in range(len(instance_ids)): + if instance_ids[i] == self.instance_ignore_index: + continue + if segment_ids[i] in self.segment_ignore_index: + continue + gt_inst = dict() + gt_inst["instance_id"] = instance_ids[i] + gt_inst["segment_id"] = segment_ids[i] + gt_inst["dist_conf"] = 0.0 + gt_inst["med_dist"] = -1.0 + gt_inst["vert_count"] = counts[i] + gt_inst["matched_pred"] = [] + gt_instances[self.trainer.cfg.data.names[segment_ids[i]]].append(gt_inst) + + # get pred instances and associate with gt + pred_instances = dict() + for i in range(self.trainer.cfg.data.num_classes): + if i not in self.segment_ignore_index: + pred_instances[self.trainer.cfg.data.names[i]] = [] + instance_id = 0 + for i in range(len(pred["pred_classes"])): + if pred["pred_classes"][i] in self.segment_ignore_index: + continue + pred_inst = dict() + pred_inst["uuid"] = uuid4() + pred_inst["instance_id"] = instance_id + pred_inst["segment_id"] = pred["pred_classes"][i] + pred_inst["confidence"] = pred["pred_scores"][i] + pred_inst["mask"] = np.not_equal(pred["pred_masks"][i], 0) + pred_inst["vert_count"] = np.count_nonzero(pred_inst["mask"]) + pred_inst["void_intersection"] = np.count_nonzero( + np.logical_and(void_mask, pred_inst["mask"]) + ) + if pred_inst["vert_count"] < self.min_region_sizes: + continue # skip if empty + segment_name = self.trainer.cfg.data.names[pred_inst["segment_id"]] + matched_gt = [] + for gt_idx, gt_inst in enumerate(gt_instances[segment_name]): + intersection = np.count_nonzero( + np.logical_and( + instance == gt_inst["instance_id"], pred_inst["mask"] + ) + ) + if intersection > 0: + gt_inst_ = gt_inst.copy() + pred_inst_ = pred_inst.copy() + gt_inst_["intersection"] = intersection + pred_inst_["intersection"] = intersection + matched_gt.append(gt_inst_) + gt_inst["matched_pred"].append(pred_inst_) + pred_inst["matched_gt"] = matched_gt + pred_instances[segment_name].append(pred_inst) + instance_id += 1 + return gt_instances, pred_instances + + def evaluate_matches(self, scenes): + overlaps = self.overlaps + min_region_sizes = [self.min_region_sizes] + dist_threshes = [self.distance_threshes] + dist_confs = [self.distance_confs] + + # results: class x overlap + ap_table = np.zeros( + (len(dist_threshes), len(self.valid_class_names), len(overlaps)), float + ) + for di, (min_region_size, distance_thresh, distance_conf) in enumerate( + zip(min_region_sizes, dist_threshes, dist_confs) + ): + for oi, overlap_th in enumerate(overlaps): + pred_visited = {} + for scene in scenes: + for _ in scene["pred"]: + for label_name in self.valid_class_names: + for p in scene["pred"][label_name]: + if "uuid" in p: + pred_visited[p["uuid"]] = False + for li, label_name in enumerate(self.valid_class_names): + y_true = np.empty(0) + y_score = np.empty(0) + hard_false_negatives = 0 + has_gt = False + has_pred = False + for scene in scenes: + pred_instances = scene["pred"][label_name] + gt_instances = scene["gt"][label_name] + # filter groups in ground truth + gt_instances = [ + gt + for gt in gt_instances + if gt["vert_count"] >= min_region_size + and gt["med_dist"] <= distance_thresh + and gt["dist_conf"] >= distance_conf + ] + if gt_instances: + has_gt = True + if pred_instances: + has_pred = True + + cur_true = np.ones(len(gt_instances)) + cur_score = np.ones(len(gt_instances)) * (-float("inf")) + cur_match = np.zeros(len(gt_instances), dtype=bool) + # collect matches + for gti, gt in enumerate(gt_instances): + found_match = False + for pred in gt["matched_pred"]: + # greedy assignments + if pred_visited[pred["uuid"]]: + continue + overlap = float(pred["intersection"]) / ( + gt["vert_count"] + + pred["vert_count"] + - pred["intersection"] + ) + if overlap > overlap_th: + confidence = pred["confidence"] + # if already have a prediction for this gt, + # the prediction with the lower score is automatically a false positive + if cur_match[gti]: + max_score = max(cur_score[gti], confidence) + min_score = min(cur_score[gti], confidence) + cur_score[gti] = max_score + # append false positive + cur_true = np.append(cur_true, 0) + cur_score = np.append(cur_score, min_score) + cur_match = np.append(cur_match, True) + # otherwise set score + else: + found_match = True + cur_match[gti] = True + cur_score[gti] = confidence + pred_visited[pred["uuid"]] = True + if not found_match: + hard_false_negatives += 1 + # remove non-matched ground truth instances + cur_true = cur_true[cur_match] + cur_score = cur_score[cur_match] + + # collect non-matched predictions as false positive + for pred in pred_instances: + found_gt = False + for gt in pred["matched_gt"]: + overlap = float(gt["intersection"]) / ( + gt["vert_count"] + + pred["vert_count"] + - gt["intersection"] + ) + if overlap > overlap_th: + found_gt = True + break + if not found_gt: + num_ignore = pred["void_intersection"] + for gt in pred["matched_gt"]: + if gt["segment_id"] in self.segment_ignore_index: + num_ignore += gt["intersection"] + # small ground truth instances + if ( + gt["vert_count"] < min_region_size + or gt["med_dist"] > distance_thresh + or gt["dist_conf"] < distance_conf + ): + num_ignore += gt["intersection"] + proportion_ignore = ( + float(num_ignore) / pred["vert_count"] + ) + # if not ignored append false positive + if proportion_ignore <= overlap_th: + cur_true = np.append(cur_true, 0) + confidence = pred["confidence"] + cur_score = np.append(cur_score, confidence) + + # append to overall results + y_true = np.append(y_true, cur_true) + y_score = np.append(y_score, cur_score) + + # compute average precision + if has_gt and has_pred: + # compute precision recall curve first + + # sorting and cumsum + score_arg_sort = np.argsort(y_score) + y_score_sorted = y_score[score_arg_sort] + y_true_sorted = y_true[score_arg_sort] + y_true_sorted_cumsum = np.cumsum(y_true_sorted) + + # unique thresholds + (thresholds, unique_indices) = np.unique( + y_score_sorted, return_index=True + ) + num_prec_recall = len(unique_indices) + 1 + + # prepare precision recall + num_examples = len(y_score_sorted) + # https://github.com/ScanNet/ScanNet/pull/26 + # all predictions are non-matched but also all of them are ignored and not counted as FP + # y_true_sorted_cumsum is empty + # num_true_examples = y_true_sorted_cumsum[-1] + num_true_examples = ( + y_true_sorted_cumsum[-1] + if len(y_true_sorted_cumsum) > 0 + else 0 + ) + precision = np.zeros(num_prec_recall) + recall = np.zeros(num_prec_recall) + + # deal with the first point + y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0) + # deal with remaining + for idx_res, idx_scores in enumerate(unique_indices): + cumsum = y_true_sorted_cumsum[idx_scores - 1] + tp = num_true_examples - cumsum + fp = num_examples - idx_scores - tp + fn = cumsum + hard_false_negatives + p = float(tp) / (tp + fp) + r = float(tp) / (tp + fn) + precision[idx_res] = p + recall[idx_res] = r + + # first point in curve is artificial + precision[-1] = 1.0 + recall[-1] = 0.0 + + # compute average of precision-recall curve + recall_for_conv = np.copy(recall) + recall_for_conv = np.append(recall_for_conv[0], recall_for_conv) + recall_for_conv = np.append(recall_for_conv, 0.0) + + stepWidths = np.convolve( + recall_for_conv, [-0.5, 0, 0.5], "valid" + ) + # integrate is now simply a dot product + ap_current = np.dot(precision, stepWidths) + + elif has_gt: + ap_current = 0.0 + else: + ap_current = float("nan") + ap_table[di, li, oi] = ap_current + d_inf = 0 + o50 = np.where(np.isclose(self.overlaps, 0.5)) + o25 = np.where(np.isclose(self.overlaps, 0.25)) + oAllBut25 = np.where(np.logical_not(np.isclose(self.overlaps, 0.25))) + ap_scores = dict() + ap_scores["all_ap"] = np.nanmean(ap_table[d_inf, :, oAllBut25]) + ap_scores["all_ap_50%"] = np.nanmean(ap_table[d_inf, :, o50]) + ap_scores["all_ap_25%"] = np.nanmean(ap_table[d_inf, :, o25]) + ap_scores["classes"] = {} + for li, label_name in enumerate(self.valid_class_names): + ap_scores["classes"][label_name] = {} + ap_scores["classes"][label_name]["ap"] = np.average( + ap_table[d_inf, li, oAllBut25] + ) + ap_scores["classes"][label_name]["ap50%"] = np.average( + ap_table[d_inf, li, o50] + ) + ap_scores["classes"][label_name]["ap25%"] = np.average( + ap_table[d_inf, li, o25] + ) + return ap_scores + + def eval(self): + self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") + self.trainer.model.eval() + scenes = [] + for i, input_dict in enumerate(self.trainer.val_loader): + assert ( + len(input_dict["offset"]) == 1 + ) # currently only support bs 1 for each GPU + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.no_grad(): + output_dict = self.trainer.model(input_dict) + + loss = output_dict["loss"] + + segment = input_dict["segment"] + instance = input_dict["instance"] + # map to origin + if "origin_coord" in input_dict.keys(): + idx, _ = pointops.knn_query( + 1, + input_dict["coord"].float(), + input_dict["offset"].int(), + input_dict["origin_coord"].float(), + input_dict["origin_offset"].int(), + ) + idx = idx.cpu().flatten().long() + output_dict["pred_masks"] = output_dict["pred_masks"][:, idx] + segment = input_dict["origin_segment"] + instance = input_dict["origin_instance"] + + gt_instances, pred_instance = self.associate_instances( + output_dict, segment, instance + ) + scenes.append(dict(gt=gt_instances, pred=pred_instance)) + + self.trainer.storage.put_scalar("val_loss", loss.item()) + self.trainer.logger.info( + "Test: [{iter}/{max_iter}] " + "Loss {loss:.4f} ".format( + iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item() + ) + ) + + loss_avg = self.trainer.storage.history("val_loss").avg + comm.synchronize() + scenes_sync = comm.gather(scenes, dst=0) + scenes = [scene for scenes_ in scenes_sync for scene in scenes_] + ap_scores = self.evaluate_matches(scenes) + all_ap = ap_scores["all_ap"] + all_ap_50 = ap_scores["all_ap_50%"] + all_ap_25 = ap_scores["all_ap_25%"] + self.trainer.logger.info( + "Val result: mAP/AP50/AP25 {:.4f}/{:.4f}/{:.4f}.".format( + all_ap, all_ap_50, all_ap_25 + ) + ) + for i, label_name in enumerate(self.valid_class_names): + ap = ap_scores["classes"][label_name]["ap"] + ap_50 = ap_scores["classes"][label_name]["ap50%"] + ap_25 = ap_scores["classes"][label_name]["ap25%"] + self.trainer.logger.info( + "Class_{idx}-{name} Result: AP/AP50/AP25 {AP:.4f}/{AP50:.4f}/{AP25:.4f}".format( + idx=i, name=label_name, AP=ap, AP50=ap_50, AP25=ap_25 + ) + ) + current_epoch = self.trainer.epoch + 1 + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch) + self.trainer.writer.add_scalar("val/mAP", all_ap, current_epoch) + self.trainer.writer.add_scalar("val/AP50", all_ap_50, current_epoch) + self.trainer.writer.add_scalar("val/AP25", all_ap_25, current_epoch) + self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + self.trainer.comm_info["current_metric_value"] = all_ap_50 # save for saver + self.trainer.comm_info["current_metric_name"] = "AP50" # save for saver diff --git a/engines/hooks/misc.py b/engines/hooks/misc.py new file mode 100644 index 0000000..52b398e --- /dev/null +++ b/engines/hooks/misc.py @@ -0,0 +1,460 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import sys +import glob +import os +import shutil +import time +import torch +import torch.utils.data +from collections import OrderedDict + +if sys.version_info >= (3, 10): + from collections.abc import Sequence +else: + from collections import Sequence +from utils.timer import Timer +from utils.comm import is_main_process, synchronize, get_world_size +from utils.cache import shared_dict + +import utils.comm as comm +from engines.test import TESTERS + +from .default import HookBase +from .builder import HOOKS + + +@HOOKS.register_module() +class IterationTimer(HookBase): + def __init__(self, warmup_iter=1): + self._warmup_iter = warmup_iter + self._start_time = time.perf_counter() + self._iter_timer = Timer() + self._remain_iter = 0 + + def before_train(self): + self._start_time = time.perf_counter() + self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader) + + def before_epoch(self): + self._iter_timer.reset() + + def before_step(self): + data_time = self._iter_timer.seconds() + self.trainer.storage.put_scalar("data_time", data_time) + + def after_step(self): + batch_time = self._iter_timer.seconds() + self._iter_timer.reset() + self.trainer.storage.put_scalar("batch_time", batch_time) + self._remain_iter -= 1 + remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg + t_m, t_s = divmod(remain_time, 60) + t_h, t_m = divmod(t_m, 60) + remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s)) + if "iter_info" in self.trainer.comm_info.keys(): + info = ( + "Data {data_time_val:.3f} ({data_time_avg:.3f}) " + "Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) " + "Remain {remain_time} ".format( + data_time_val=self.trainer.storage.history("data_time").val, + data_time_avg=self.trainer.storage.history("data_time").avg, + batch_time_val=self.trainer.storage.history("batch_time").val, + batch_time_avg=self.trainer.storage.history("batch_time").avg, + remain_time=remain_time, + ) + ) + self.trainer.comm_info["iter_info"] += info + if self.trainer.comm_info["iter"] <= self._warmup_iter: + self.trainer.storage.history("data_time").reset() + self.trainer.storage.history("batch_time").reset() + + +@HOOKS.register_module() +class InformationWriter(HookBase): + def __init__(self): + self.curr_iter = 0 + self.model_output_keys = [] + + def before_train(self): + self.trainer.comm_info["iter_info"] = "" + self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader) + + def before_step(self): + self.curr_iter += 1 + # MSC pretrain do not have offset information. Comment the code for support MSC + # info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \ + # "Scan {batch_size} ({points_num}) ".format( + # epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch, + # iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader), + # batch_size=len(self.trainer.comm_info["input_dict"]["offset"]), + # points_num=self.trainer.comm_info["input_dict"]["offset"][-1] + # ) + info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format( + epoch=self.trainer.epoch + 1, + max_epoch=self.trainer.max_epoch, + iter=self.trainer.comm_info["iter"] + 1, + max_iter=len(self.trainer.train_loader), + ) + self.trainer.comm_info["iter_info"] += info + + def after_step(self): + if "model_output_dict" in self.trainer.comm_info.keys(): + model_output_dict = self.trainer.comm_info["model_output_dict"] + self.model_output_keys = model_output_dict.keys() + for key in self.model_output_keys: + self.trainer.storage.put_scalar(key, model_output_dict[key].item()) + + for key in self.model_output_keys: + self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format( + key=key, value=self.trainer.storage.history(key).val + ) + lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"] + self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr) + self.trainer.logger.info(self.trainer.comm_info["iter_info"]) + self.trainer.comm_info["iter_info"] = "" # reset iter info + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("lr", lr, self.curr_iter) + for key in self.model_output_keys: + self.trainer.writer.add_scalar( + "train_batch/" + key, + self.trainer.storage.history(key).val, + self.curr_iter, + ) + + def after_epoch(self): + epoch_info = "Train result: " + for key in self.model_output_keys: + epoch_info += "{key}: {value:.4f} ".format( + key=key, value=self.trainer.storage.history(key).avg + ) + self.trainer.logger.info(epoch_info) + if self.trainer.writer is not None: + for key in self.model_output_keys: + self.trainer.writer.add_scalar( + "train/" + key, + self.trainer.storage.history(key).avg, + self.trainer.epoch + 1, + ) + + +@HOOKS.register_module() +class CheckpointSaver(HookBase): + def __init__(self, save_freq=None): + self.save_freq = save_freq # None or int, None indicate only save model last + + def after_epoch(self): + if is_main_process(): + is_best = False + if self.trainer.cfg.evaluate: + current_metric_value = self.trainer.comm_info["current_metric_value"] + current_metric_name = self.trainer.comm_info["current_metric_name"] + if current_metric_value > self.trainer.best_metric_value: + self.trainer.best_metric_value = current_metric_value + is_best = True + self.trainer.logger.info( + "Best validation {} updated to: {:.4f}".format( + current_metric_name, current_metric_value + ) + ) + self.trainer.logger.info( + "Currently Best {}: {:.4f}".format( + current_metric_name, self.trainer.best_metric_value + ) + ) + + filename = os.path.join( + self.trainer.cfg.save_path, "model", "model_last.pth" + ) + self.trainer.logger.info("Saving checkpoint to: " + filename) + torch.save( + { + "epoch": self.trainer.epoch + 1, + "state_dict": self.trainer.model.state_dict(), + "optimizer": self.trainer.optimizer.state_dict(), + "scheduler": self.trainer.scheduler.state_dict(), + "scaler": self.trainer.scaler.state_dict() + if self.trainer.cfg.enable_amp + else None, + "best_metric_value": self.trainer.best_metric_value, + }, + filename + ".tmp", + ) + os.replace(filename + ".tmp", filename) + if is_best: + shutil.copyfile( + filename, + os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"), + ) + if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0: + shutil.copyfile( + filename, + os.path.join( + self.trainer.cfg.save_path, + "model", + f"epoch_{self.trainer.epoch + 1}.pth", + ), + ) + + +@HOOKS.register_module() +class CheckpointLoader(HookBase): + def __init__(self, keywords="", replacement=None, strict=False): + self.keywords = keywords + self.replacement = replacement if replacement is not None else keywords + self.strict = strict + + def before_train(self): + self.trainer.logger.info("=> Loading checkpoint & weight ...") + if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight): + self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}") + checkpoint = torch.load( + self.trainer.cfg.weight, + map_location=lambda storage, loc: storage.cuda(), + ) + self.trainer.logger.info( + f"Loading layer weights with keyword: {self.keywords}, " + f"replace keyword with: {self.replacement}" + ) + weight = OrderedDict() + for key, value in checkpoint["state_dict"].items(): + if not key.startswith("module."): + if comm.get_world_size() > 1: + key = "module." + key # xxx.xxx -> module.xxx.xxx + # Now all keys contain "module." no matter DDP or not. + if self.keywords in key: + key = key.replace(self.keywords, self.replacement) + if comm.get_world_size() == 1: + key = key[7:] # module.xxx.xxx -> xxx.xxx + weight[key] = value + load_state_info = self.trainer.model.load_state_dict( + weight, strict=self.strict + ) + self.trainer.logger.info(f"Missing keys: {load_state_info[0]}") + if self.trainer.cfg.resume: + self.trainer.logger.info( + f"Resuming train at eval epoch: {checkpoint['epoch']}" + ) + self.trainer.start_epoch = checkpoint["epoch"] + self.trainer.best_metric_value = checkpoint["best_metric_value"] + self.trainer.optimizer.load_state_dict(checkpoint["optimizer"]) + self.trainer.scheduler.load_state_dict(checkpoint["scheduler"]) + if self.trainer.cfg.enable_amp: + self.trainer.scaler.load_state_dict(checkpoint["scaler"]) + else: + self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}") + + +@HOOKS.register_module() +class PreciseEvaluator(HookBase): + def __init__(self, test_last=False): + self.test_last = test_last + + def after_train(self): + self.trainer.logger.info( + ">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>" + ) + torch.cuda.empty_cache() + cfg = self.trainer.cfg + tester = TESTERS.build( + dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model) + ) + if self.test_last: + self.trainer.logger.info("=> Testing on model_last ...") + else: + self.trainer.logger.info("=> Testing on model_best ...") + best_path = os.path.join( + self.trainer.cfg.save_path, "model", "model_best.pth" + ) + checkpoint = torch.load(best_path) + state_dict = checkpoint["state_dict"] + tester.model.load_state_dict(state_dict, strict=True) + tester.test() + + +@HOOKS.register_module() +class DataCacheOperator(HookBase): + def __init__(self, data_root, split): + self.data_root = data_root + self.split = split + self.data_list = self.get_data_list() + + def get_data_list(self): + if isinstance(self.split, str): + data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth")) + elif isinstance(self.split, Sequence): + data_list = [] + for split in self.split: + data_list += glob.glob(os.path.join(self.data_root, split, "*.pth")) + else: + raise NotImplementedError + return data_list + + def get_cache_name(self, data_path): + data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0] + return "pointcept" + data_name.replace(os.path.sep, "-") + + def before_train(self): + self.trainer.logger.info( + f"=> Caching dataset: {self.data_root}, split: {self.split} ..." + ) + if is_main_process(): + for data_path in self.data_list: + cache_name = self.get_cache_name(data_path) + data = torch.load(data_path) + shared_dict(cache_name, data) + synchronize() + + +@HOOKS.register_module() +class RuntimeProfiler(HookBase): + def __init__( + self, + forward=True, + backward=True, + interrupt=False, + warm_up=2, + sort_by="cuda_time_total", + row_limit=30, + ): + self.forward = forward + self.backward = backward + self.interrupt = interrupt + self.warm_up = warm_up + self.sort_by = sort_by + self.row_limit = row_limit + + def before_train(self): + self.trainer.logger.info("Profiling runtime ...") + from torch.profiler import profile, record_function, ProfilerActivity + + for i, input_dict in enumerate(self.trainer.train_loader): + if i == self.warm_up + 1: + break + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + if self.forward: + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as forward_prof: + with record_function("model_inference"): + output_dict = self.trainer.model(input_dict) + else: + output_dict = self.trainer.model(input_dict) + loss = output_dict["loss"] + if self.backward: + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as backward_prof: + with record_function("model_inference"): + loss.backward() + self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]") + if self.forward: + self.trainer.logger.info( + "Forward profile: \n" + + str( + forward_prof.key_averages().table( + sort_by=self.sort_by, row_limit=self.row_limit + ) + ) + ) + forward_prof.export_chrome_trace( + os.path.join(self.trainer.cfg.save_path, "forward_trace.json") + ) + + if self.backward: + self.trainer.logger.info( + "Backward profile: \n" + + str( + backward_prof.key_averages().table( + sort_by=self.sort_by, row_limit=self.row_limit + ) + ) + ) + backward_prof.export_chrome_trace( + os.path.join(self.trainer.cfg.save_path, "backward_trace.json") + ) + if self.interrupt: + sys.exit(0) + + +@HOOKS.register_module() +class RuntimeProfilerV2(HookBase): + def __init__( + self, + interrupt=False, + wait=1, + warmup=1, + active=10, + repeat=1, + sort_by="cuda_time_total", + row_limit=30, + ): + self.interrupt = interrupt + self.wait = wait + self.warmup = warmup + self.active = active + self.repeat = repeat + self.sort_by = sort_by + self.row_limit = row_limit + + def before_train(self): + self.trainer.logger.info("Profiling runtime ...") + from torch.profiler import ( + profile, + record_function, + ProfilerActivity, + schedule, + tensorboard_trace_handler, + ) + + prof = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule( + wait=self.wait, + warmup=self.warmup, + active=self.active, + repeat=self.repeat, + ), + on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + prof.start() + for i, input_dict in enumerate(self.trainer.train_loader): + if i >= (self.wait + self.warmup + self.active) * self.repeat: + break + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with record_function("model_forward"): + output_dict = self.trainer.model(input_dict) + loss = output_dict["loss"] + with record_function("model_backward"): + loss.backward() + prof.step() + self.trainer.logger.info( + f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]" + ) + self.trainer.logger.info( + "Profile: \n" + + str( + prof.key_averages().table( + sort_by=self.sort_by, row_limit=self.row_limit + ) + ) + ) + prof.stop() + + if self.interrupt: + sys.exit(0) diff --git a/engines/infer.py b/engines/infer.py new file mode 100644 index 0000000..9b7b8f1 --- /dev/null +++ b/engines/infer.py @@ -0,0 +1,285 @@ +""" +Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +""" + +import os +import math +import time +import librosa +import numpy as np +from collections import OrderedDict + +import torch +import torch.utils.data +import torch.nn.functional as F + +from .defaults import create_ddp_model +import utils.comm as comm +from models import build_model +from utils.logger import get_root_logger +from utils.registry import Registry +from utils.misc import ( + AverageMeter, +) + +from models.utils import smooth_mouth_movements, apply_frame_blending, apply_savitzky_golay_smoothing, apply_random_brow_movement, \ + symmetrize_blendshapes, apply_random_eye_blinks, apply_random_eye_blinks_context, export_blendshape_animation, \ + RETURN_CODE, DEFAULT_CONTEXT, ARKitBlendShape + +INFER = Registry("infer") + +class InferBase: + def __init__(self, cfg, model=None, verbose=False) -> None: + torch.multiprocessing.set_sharing_strategy("file_system") + self.logger = get_root_logger( + log_file=os.path.join(cfg.save_path, "infer.log"), + file_mode="a" if cfg.resume else "w", + ) + self.logger.info("=> Loading config ...") + self.cfg = cfg + self.verbose = verbose + if self.verbose: + self.logger.info(f"Save path: {cfg.save_path}") + self.logger.info(f"Config:\n{cfg.pretty_text}") + if model is None: + self.logger.info("=> Building model ...") + self.model = self.build_model() + else: + self.model = model + + def build_model(self): + model = build_model(self.cfg.model) + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.logger.info(f"Num params: {n_parameters}") + model = create_ddp_model( + model.cuda(), + broadcast_buffers=False, + find_unused_parameters=self.cfg.find_unused_parameters, + ) + if os.path.isfile(self.cfg.weight): + self.logger.info(f"Loading weight at: {self.cfg.weight}") + checkpoint = torch.load(self.cfg.weight) + weight = OrderedDict() + for key, value in checkpoint["state_dict"].items(): + if key.startswith("module."): + if comm.get_world_size() == 1: + key = key[7:] # module.xxx.xxx -> xxx.xxx + else: + if comm.get_world_size() > 1: + key = "module." + key # xxx.xxx -> module.xxx.xxx + weight[key] = value + model.load_state_dict(weight, strict=True) + self.logger.info( + "=> Loaded weight '{}'".format( + self.cfg.weight + ) + ) + else: + raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight)) + return model + + + def infer(self): + raise NotImplementedError + + + +@INFER.register_module() +class Audio2ExpressionInfer(InferBase): + def infer(self): + logger = get_root_logger() + logger.info(">>>>>>>>>>>>>>>> Start Inference >>>>>>>>>>>>>>>>") + batch_time = AverageMeter() + self.model.eval() + + # process audio-input + assert os.path.exists(self.cfg.audio_input) + if(self.cfg.ex_vol): + logger.info("Extract vocals ...") + vocal_path = self.extract_vocal_track(self.cfg.audio_input) + logger.info("=> Extract vocals at: {}".format(vocal_path if os.path.exists(vocal_path) else '... Failed')) + if(os.path.exists(vocal_path)): + self.cfg.audio_input = vocal_path + + with torch.no_grad(): + input_dict = {} + input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx), + self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[None,...] + speech_array, ssr = librosa.load(self.cfg.audio_input, sr=16000) + input_dict['input_audio_array'] = torch.FloatTensor(speech_array).cuda(non_blocking=True)[None,...] + + end = time.time() + output_dict = self.model(input_dict) + batch_time.update(time.time() - end) + + logger.info( + "Infer: [{}] " + "Running Time: {batch_time.avg:.3f} ".format( + self.cfg.audio_input, + batch_time=batch_time, + ) + ) + + out_exp = output_dict['pred_exp'].squeeze().cpu().numpy() + + frame_length = math.ceil(speech_array.shape[0] / ssr * 30) + volume = librosa.feature.rms(y=speech_array, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0] + if (volume.shape[0] > frame_length): + volume = volume[:frame_length] + + if(self.cfg.movement_smooth): + out_exp = smooth_mouth_movements(out_exp, 0, volume) + + if (self.cfg.brow_movement): + out_exp = apply_random_brow_movement(out_exp, volume) + + pred_exp = self.blendshape_postprocess(out_exp) + + if(self.cfg.save_json_path is not None): + export_blendshape_animation(pred_exp, + self.cfg.save_json_path, + ARKitBlendShape, + fps=self.cfg.fps) + + logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + + def infer_streaming_audio(self, + audio: np.ndarray, + ssr: float, + context: dict): + + if (context is None): + context = DEFAULT_CONTEXT.copy() + max_frame_length = 64 + + frame_length = math.ceil(audio.shape[0] / ssr * 30) + output_context = DEFAULT_CONTEXT.copy() + + volume = librosa.feature.rms(y=audio, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0] + if (volume.shape[0] > frame_length): + volume = volume[:frame_length] + + # resample audio + if (ssr != self.cfg.audio_sr): + in_audio = librosa.resample(audio.astype(np.float32), orig_sr=ssr, target_sr=self.cfg.audio_sr) + else: + in_audio = audio.copy() + + start_frame = int(max_frame_length - in_audio.shape[0] / self.cfg.audio_sr * 30) + + if (context['is_initial_input'] or (context['previous_audio'] is None)): + blank_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0] + blank_audio = np.zeros(blank_audio_length, dtype=np.float32) + + # pre-append + input_audio = np.concatenate([blank_audio, in_audio]) + output_context['previous_audio'] = input_audio + + else: + clip_pre_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0] + clip_pre_audio = context['previous_audio'][-clip_pre_audio_length:] + input_audio = np.concatenate([clip_pre_audio, in_audio]) + output_context['previous_audio'] = input_audio + + with torch.no_grad(): + try: + input_dict = {} + input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx), + self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[ + None, ...] + input_dict['input_audio_array'] = torch.FloatTensor(input_audio).cuda(non_blocking=True)[None, ...] + output_dict = self.model(input_dict) + out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()[start_frame:, :] + except: + self.logger.error('Error: faided to predict expression.') + output_dict['pred_exp'] = torch.zeros((max_frame_length, 52)).float() + return + + + # post-process + if (context['previous_expression'] is None): + out_exp = self.apply_expression_postprocessing(out_exp, audio_volume=volume) + else: + previous_length = context['previous_expression'].shape[0] + out_exp = self.apply_expression_postprocessing(expression_params = np.concatenate([context['previous_expression'], out_exp], axis=0), + audio_volume=np.concatenate([context['previous_volume'], volume], axis=0), + processed_frames=previous_length)[previous_length:, :] + + if (context['previous_expression'] is not None): + output_context['previous_expression'] = np.concatenate([context['previous_expression'], out_exp], axis=0)[ + -max_frame_length:, :] + output_context['previous_volume'] = np.concatenate([context['previous_volume'], volume], axis=0)[-max_frame_length:] + else: + output_context['previous_expression'] = out_exp.copy() + output_context['previous_volume'] = volume.copy() + + output_context['first_input_flag'] = False + + return {"code": RETURN_CODE['SUCCESS'], + "expression": out_exp, + "headpose": None}, output_context + def apply_expression_postprocessing( + self, + expression_params: np.ndarray, + processed_frames: int = 0, + audio_volume: np.ndarray = None + ) -> np.ndarray: + """Applies full post-processing pipeline to facial expression parameters. + + Args: + expression_params: Raw output from animation model [num_frames, num_parameters] + processed_frames: Number of frames already processed in previous batches + audio_volume: Optional volume array for audio-visual synchronization + + Returns: + Processed expression parameters ready for animation synthesis + """ + # Pipeline execution order matters - maintain sequence + expression_params = smooth_mouth_movements(expression_params, processed_frames, audio_volume) + expression_params = apply_frame_blending(expression_params, processed_frames) + expression_params, _ = apply_savitzky_golay_smoothing(expression_params, window_length=5) + expression_params = symmetrize_blendshapes(expression_params) + expression_params = apply_random_eye_blinks_context(expression_params, processed_frames=processed_frames) + + return expression_params + + def extract_vocal_track( + self, + input_audio_path: str + ) -> str: + """Isolates vocal track from audio file using source separation. + + Args: + input_audio_path: Path to input audio file containing vocals+accompaniment + + Returns: + Path to isolated vocal track in WAV format + """ + separation_command = f'spleeter separate -p spleeter:2stems -o {self.cfg.save_path} {input_audio_path}' + os.system(separation_command) + + base_name = os.path.splitext(os.path.basename(input_audio_path))[0] + return os.path.join(self.cfg.save_path, base_name, 'vocals.wav') + + def blendshape_postprocess(self, + bs_array: np.ndarray + )->np.array: + + bs_array, _ = apply_savitzky_golay_smoothing(bs_array, window_length=5) + bs_array = symmetrize_blendshapes(bs_array) + bs_array = apply_random_eye_blinks(bs_array) + + return bs_array diff --git a/engines/launch.py b/engines/launch.py new file mode 100644 index 0000000..05f5671 --- /dev/null +++ b/engines/launch.py @@ -0,0 +1,135 @@ +""" +Launcher + +modified from detectron2(https://github.com/facebookresearch/detectron2) + +""" + +import os +import logging +from datetime import timedelta +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from utils import comm + +__all__ = ["DEFAULT_TIMEOUT", "launch"] + +DEFAULT_TIMEOUT = timedelta(minutes=30) + + +def _find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def launch( + main_func, + num_gpus_per_machine, + num_machines=1, + machine_rank=0, + dist_url=None, + cfg=(), + timeout=DEFAULT_TIMEOUT, +): + """ + Launch multi-gpu or distributed training. + This function must be called on all machines involved in the training. + It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine. + Args: + main_func: a function that will be called by `main_func(*args)` + num_gpus_per_machine (int): number of GPUs per machine + num_machines (int): the total number of machines + machine_rank (int): the rank of this machine + dist_url (str): url to connect to for distributed jobs, including protocol + e.g. "tcp://127.0.0.1:8686". + Can be set to "auto" to automatically select a free port on localhost + timeout (timedelta): timeout of the distributed workers + args (tuple): arguments passed to main_func + """ + world_size = num_machines * num_gpus_per_machine + if world_size > 1: + if dist_url == "auto": + assert ( + num_machines == 1 + ), "dist_url=auto not supported in multi-machine jobs." + port = _find_free_port() + dist_url = f"tcp://127.0.0.1:{port}" + if num_machines > 1 and dist_url.startswith("file://"): + logger = logging.getLogger(__name__) + logger.warning( + "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" + ) + + mp.spawn( + _distributed_worker, + nprocs=num_gpus_per_machine, + args=( + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + cfg, + timeout, + ), + daemon=False, + ) + else: + main_func(*cfg) + + +def _distributed_worker( + local_rank, + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + cfg, + timeout=DEFAULT_TIMEOUT, +): + assert ( + torch.cuda.is_available() + ), "cuda is not available. Please check your installation." + global_rank = machine_rank * num_gpus_per_machine + local_rank + try: + dist.init_process_group( + backend="NCCL", + init_method=dist_url, + world_size=world_size, + rank=global_rank, + timeout=timeout, + ) + except Exception as e: + logger = logging.getLogger(__name__) + logger.error("Process group URL: {}".format(dist_url)) + raise e + + # Setup the local process group (which contains ranks within the same machine) + assert comm._LOCAL_PROCESS_GROUP is None + num_machines = world_size // num_gpus_per_machine + for i in range(num_machines): + ranks_on_i = list( + range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) + ) + pg = dist.new_group(ranks_on_i) + if i == machine_rank: + comm._LOCAL_PROCESS_GROUP = pg + + assert num_gpus_per_machine <= torch.cuda.device_count() + torch.cuda.set_device(local_rank) + + # synchronize is needed here to prevent a possible timeout after calling init_process_group + # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 + comm.synchronize() + + main_func(*cfg) diff --git a/engines/train.py b/engines/train.py new file mode 100644 index 0000000..7de2364 --- /dev/null +++ b/engines/train.py @@ -0,0 +1,299 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import os +import sys +import weakref +import torch +import torch.nn as nn +import torch.utils.data +from functools import partial + +if sys.version_info >= (3, 10): + from collections.abc import Iterator +else: + from collections import Iterator +from tensorboardX import SummaryWriter + +from .defaults import create_ddp_model, worker_init_fn +from .hooks import HookBase, build_hooks +import utils.comm as comm +from datasets import build_dataset, point_collate_fn, collate_fn +from models import build_model +from utils.logger import get_root_logger +from utils.optimizer import build_optimizer +from utils.scheduler import build_scheduler +from utils.events import EventStorage +from utils.registry import Registry + + +TRAINERS = Registry("trainers") + + +class TrainerBase: + def __init__(self) -> None: + self.hooks = [] + self.epoch = 0 + self.start_epoch = 0 + self.max_epoch = 0 + self.max_iter = 0 + self.comm_info = dict() + self.data_iterator: Iterator = enumerate([]) + self.storage: EventStorage + self.writer: SummaryWriter + + def register_hooks(self, hooks) -> None: + hooks = build_hooks(hooks) + for h in hooks: + assert isinstance(h, HookBase) + # To avoid circular reference, hooks and trainer cannot own each other. + # This normally does not matter, but will cause memory leak if the + # involved objects contain __del__: + # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/ + h.trainer = weakref.proxy(self) + self.hooks.extend(hooks) + + def train(self): + with EventStorage() as self.storage: + # => before train + self.before_train() + for self.epoch in range(self.start_epoch, self.max_epoch): + # => before epoch + self.before_epoch() + # => run_epoch + for ( + self.comm_info["iter"], + self.comm_info["input_dict"], + ) in self.data_iterator: + # => before_step + self.before_step() + # => run_step + self.run_step() + # => after_step + self.after_step() + # => after epoch + self.after_epoch() + # => after train + self.after_train() + + def before_train(self): + for h in self.hooks: + h.before_train() + + def before_epoch(self): + for h in self.hooks: + h.before_epoch() + + def before_step(self): + for h in self.hooks: + h.before_step() + + def run_step(self): + raise NotImplementedError + + def after_step(self): + for h in self.hooks: + h.after_step() + + def after_epoch(self): + for h in self.hooks: + h.after_epoch() + self.storage.reset_histories() + + def after_train(self): + # Sync GPU before running train hooks + comm.synchronize() + for h in self.hooks: + h.after_train() + if comm.is_main_process(): + self.writer.close() + + +@TRAINERS.register_module("DefaultTrainer") +class Trainer(TrainerBase): + def __init__(self, cfg): + super(Trainer, self).__init__() + self.epoch = 0 + self.start_epoch = 0 + self.max_epoch = cfg.eval_epoch + self.best_metric_value = -torch.inf + self.logger = get_root_logger( + log_file=os.path.join(cfg.save_path, "train.log"), + file_mode="a" if cfg.resume else "w", + ) + self.logger.info("=> Loading config ...") + self.cfg = cfg + self.logger.info(f"Save path: {cfg.save_path}") + self.logger.info(f"Config:\n{cfg.pretty_text}") + self.logger.info("=> Building model ...") + self.model = self.build_model() + self.logger.info("=> Building writer ...") + self.writer = self.build_writer() + self.logger.info("=> Building train dataset & dataloader ...") + self.train_loader = self.build_train_loader() + self.logger.info("=> Building val dataset & dataloader ...") + self.val_loader = self.build_val_loader() + self.logger.info("=> Building optimize, scheduler, scaler(amp) ...") + self.optimizer = self.build_optimizer() + self.scheduler = self.build_scheduler() + self.scaler = self.build_scaler() + self.logger.info("=> Building hooks ...") + self.register_hooks(self.cfg.hooks) + + def train(self): + with EventStorage() as self.storage: + # => before train + self.before_train() + self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>") + for self.epoch in range(self.start_epoch, self.max_epoch): + # => before epoch + # TODO: optimize to iteration based + if comm.get_world_size() > 1: + self.train_loader.sampler.set_epoch(self.epoch) + self.model.train() + self.data_iterator = enumerate(self.train_loader) + self.before_epoch() + # => run_epoch + for ( + self.comm_info["iter"], + self.comm_info["input_dict"], + ) in self.data_iterator: + # => before_step + self.before_step() + # => run_step + self.run_step() + # => after_step + self.after_step() + # => after epoch + self.after_epoch() + # => after train + self.after_train() + + def run_step(self): + input_dict = self.comm_info["input_dict"] + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp): + output_dict = self.model(input_dict) + loss = output_dict["loss"] + self.optimizer.zero_grad() + if self.cfg.enable_amp: + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + + # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large. + # Fix torch warning scheduler step before optimizer step. + scaler = self.scaler.get_scale() + self.scaler.update() + if scaler <= self.scaler.get_scale(): + self.scheduler.step() + else: + loss.backward() + self.optimizer.step() + self.scheduler.step() + if self.cfg.empty_cache: + torch.cuda.empty_cache() + self.comm_info["model_output_dict"] = output_dict + + def build_model(self): + model = build_model(self.cfg.model) + if self.cfg.sync_bn: + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + # logger.info(f"Model: \n{self.model}") + self.logger.info(f"Num params: {n_parameters}") + model = create_ddp_model( + model.cuda(), + broadcast_buffers=False, + find_unused_parameters=self.cfg.find_unused_parameters, + ) + return model + + def build_writer(self): + writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None + self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}") + return writer + + def build_train_loader(self): + train_data = build_dataset(self.cfg.data.train) + + if comm.get_world_size() > 1: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) + else: + train_sampler = None + + init_fn = ( + partial( + worker_init_fn, + num_workers=self.cfg.num_worker_per_gpu, + rank=comm.get_rank(), + seed=self.cfg.seed, + ) + if self.cfg.seed is not None + else None + ) + + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=self.cfg.batch_size_per_gpu, + shuffle=(train_sampler is None), + num_workers=0, + sampler=train_sampler, + collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob), + pin_memory=True, + worker_init_fn=init_fn, + drop_last=True, + # persistent_workers=True, + ) + return train_loader + + def build_val_loader(self): + val_loader = None + if self.cfg.evaluate: + val_data = build_dataset(self.cfg.data.val) + if comm.get_world_size() > 1: + val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) + else: + val_sampler = None + val_loader = torch.utils.data.DataLoader( + val_data, + batch_size=self.cfg.batch_size_val_per_gpu, + shuffle=False, + num_workers=self.cfg.num_worker_per_gpu, + pin_memory=True, + sampler=val_sampler, + collate_fn=collate_fn, + ) + return val_loader + + def build_optimizer(self): + return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts) + + def build_scheduler(self): + assert hasattr(self, "optimizer") + assert hasattr(self, "train_loader") + self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch + return build_scheduler(self.cfg.scheduler, self.optimizer) + + def build_scaler(self): + scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None + return scaler + + +@TRAINERS.register_module("MultiDatasetTrainer") +class MultiDatasetTrainer(Trainer): + def build_train_loader(self): + from datasets import MultiDatasetDataloader + + train_data = build_dataset(self.cfg.data.train) + train_loader = MultiDatasetDataloader( + train_data, + self.cfg.batch_size_per_gpu, + self.cfg.num_worker_per_gpu, + self.cfg.mix_prob, + self.cfg.seed, + ) + self.comm_info["iter_per_epoch"] = len(train_loader) + return train_loader diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..37ac22e --- /dev/null +++ b/inference.py @@ -0,0 +1,48 @@ +""" +# Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +""" + +from engines.defaults import ( + default_argument_parser, + default_config_parser, + default_setup, +) +from engines.infer import INFER +from engines.launch import launch + + +def main_worker(cfg): + cfg = default_setup(cfg) + infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg)) + infer.infer() + + +def main(): + args = default_argument_parser().parse_args() + cfg = default_config_parser(args.config_file, args.options) + + launch( + main_worker, + num_gpus_per_machine=args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + cfg=(cfg,), + ) + + +if __name__ == "__main__": + main() diff --git a/inference_streaming_audio.py b/inference_streaming_audio.py new file mode 100644 index 0000000..c14b084 --- /dev/null +++ b/inference_streaming_audio.py @@ -0,0 +1,60 @@ +""" +# Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +""" + +import numpy as np + +from engines.defaults import ( + default_argument_parser, + default_config_parser, + default_setup, +) +from engines.infer import INFER +import librosa +from tqdm import tqdm +import time + + +def export_json(bs_array, json_path): + from models.utils import export_blendshape_animation, ARKitBlendShape + export_blendshape_animation(bs_array, json_path, ARKitBlendShape, fps=30.0) + +if __name__ == '__main__': + args = default_argument_parser().parse_args() + args.config_file = 'configs/lam_audio2exp_config_streaming.py' + cfg = default_config_parser(args.config_file, args.options) + + + cfg = default_setup(cfg) + infer = INFER.build(dict(type=cfg.infer.type, cfg=cfg)) + infer.model.eval() + + audio, sample_rate = librosa.load(cfg.audio_input, sr=16000) + context = None + input_num = audio.shape[0]//16000+1 + gap = 16000 + all_exp = [] + for i in tqdm(range(input_num)): + + start = time.time() + output, context = infer.infer_streaming_audio(audio[i*gap:(i+1)*gap], sample_rate, context) + end = time.time() + print('Inference time {}'.format(end - start)) + all_exp.append(output['expression']) + + all_exp = np.concatenate(all_exp,axis=0) + + export_json(all_exp, cfg.save_json_path) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..f4beb83 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,7 @@ +from .builder import build_model + +from .default import DefaultEstimator + +# Backbones +from .network import Audio2Expression + diff --git a/models/builder.py b/models/builder.py new file mode 100644 index 0000000..eed2627 --- /dev/null +++ b/models/builder.py @@ -0,0 +1,13 @@ +""" +Modified by https://github.com/Pointcept/Pointcept +""" + +from utils.registry import Registry + +MODELS = Registry("models") +MODULES = Registry("modules") + + +def build_model(cfg): + """Build models.""" + return MODELS.build(cfg) diff --git a/models/default.py b/models/default.py new file mode 100644 index 0000000..07655f6 --- /dev/null +++ b/models/default.py @@ -0,0 +1,25 @@ +import torch.nn as nn + +from models.losses import build_criteria +from .builder import MODELS, build_model + +@MODELS.register_module() +class DefaultEstimator(nn.Module): + def __init__(self, backbone=None, criteria=None): + super().__init__() + self.backbone = build_model(backbone) + self.criteria = build_criteria(criteria) + + def forward(self, input_dict): + pred_exp = self.backbone(input_dict) + # train + if self.training: + loss = self.criteria(pred_exp, input_dict["gt_exp"]) + return dict(loss=loss) + # eval + elif "gt_exp" in input_dict.keys(): + loss = self.criteria(pred_exp, input_dict["gt_exp"]) + return dict(loss=loss, pred_exp=pred_exp) + # infer + else: + return dict(pred_exp=pred_exp) diff --git a/models/encoder/wav2vec.py b/models/encoder/wav2vec.py new file mode 100644 index 0000000..f11fc57 --- /dev/null +++ b/models/encoder/wav2vec.py @@ -0,0 +1,248 @@ +import numpy as np +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from dataclasses import dataclass +from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel +from transformers.modeling_outputs import BaseModelOutput +from transformers.file_utils import ModelOutput + + +_CONFIG_FOR_DOC = "Wav2Vec2Config" +_HIDDEN_STATES_START_POSITION = 2 + + +# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model +# initialize our encoder with the pre-trained wav2vec 2.0 weights. +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.Tensor] = None, + min_masks: int = 0, +) -> np.ndarray: + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + mask_idcs = [] + padding_mask = attention_mask.ne(1) if attention_mask is not None else None + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + lengths = np.full(num_mask, mask_length) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + return mask + + +# linear interpolation layer +def linear_interpolation(features, input_fps, output_fps, output_len=None): + features = features.transpose(1, 2) + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') + return output_features.transpose(1, 2) + + +class Wav2Vec2Model(Wav2Vec2Model): + def __init__(self, config): + super().__init__(config) + self.lm_head = nn.Linear(1024, 32) + + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + frame_num=None + ): + self.config.output_attentions = True + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.feature_extractor(input_values) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = linear_interpolation(hidden_states, 50, 30, output_len=frame_num) + + if attention_mask is not None: + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + attention_mask = torch.zeros( + hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device + ) + attention_mask[ + (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) + ] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + + hidden_states = self.feature_projection(hidden_states)[0] + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = encoder_outputs[0] + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@dataclass +class SpeechClassifierOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class Wav2Vec2ClassificationHead(nn.Module): + """Head for wav2vec classification task.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.final_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.pooling_mode = config.pooling_mode + self.config = config + + self.wav2vec2 = Wav2Vec2Model(config) + self.classifier = Wav2Vec2ClassificationHead(config) + + self.init_weights() + + def freeze_feature_extractor(self): + self.wav2vec2.feature_extractor._freeze_parameters() + + def merged_strategy( + self, + hidden_states, + mode="mean" + ): + if mode == "mean": + outputs = torch.mean(hidden_states, dim=1) + elif mode == "sum": + outputs = torch.sum(hidden_states, dim=1) + elif mode == "max": + outputs = torch.max(hidden_states, dim=1)[0] + else: + raise Exception( + "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") + + return outputs + + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + frame_num=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + hidden_states1 = linear_interpolation(hidden_states, 50, 30, output_len=frame_num) + hidden_states = self.merged_strategy(hidden_states1, mode=self.pooling_mode) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SpeechClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states1, + attentions=outputs.attentions, + ) diff --git a/models/encoder/wavlm.py b/models/encoder/wavlm.py new file mode 100644 index 0000000..0e39b9b --- /dev/null +++ b/models/encoder/wavlm.py @@ -0,0 +1,87 @@ +import numpy as np +import torch +from transformers import WavLMModel +from transformers.modeling_outputs import Wav2Vec2BaseModelOutput +from typing import Optional, Tuple, Union +import torch.nn.functional as F + +def linear_interpolation(features, output_len: int): + features = features.transpose(1, 2) + output_features = F.interpolate( + features, size=output_len, align_corners=True, mode='linear') + return output_features.transpose(1, 2) + +# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501 +# initialize our encoder with the pre-trained wav2vec 2.0 weights. + + +class WavLMModel(WavLMModel): + def __init__(self, config): + super().__init__(config) + + def _freeze_wav2vec2_parameters(self, do_freeze: bool = True): + for param in self.parameters(): + param.requires_grad = (not do_freeze) + + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + frame_num=None, + interpolate_pos: int = 0, + ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if interpolate_pos == 0: + extract_features = linear_interpolation( + extract_features, output_len=frame_num) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if interpolate_pos == 1: + hidden_states = linear_interpolation( + hidden_states, output_len=frame_num) + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) \ No newline at end of file diff --git a/models/losses/__init__.py b/models/losses/__init__.py new file mode 100644 index 0000000..782a0d3 --- /dev/null +++ b/models/losses/__init__.py @@ -0,0 +1,4 @@ +from .builder import build_criteria + +from .misc import CrossEntropyLoss, SmoothCELoss, DiceLoss, FocalLoss, BinaryFocalLoss, L1Loss +from .lovasz import LovaszLoss diff --git a/models/losses/builder.py b/models/losses/builder.py new file mode 100644 index 0000000..ec936be --- /dev/null +++ b/models/losses/builder.py @@ -0,0 +1,28 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +from utils.registry import Registry + +LOSSES = Registry("losses") + + +class Criteria(object): + def __init__(self, cfg=None): + self.cfg = cfg if cfg is not None else [] + self.criteria = [] + for loss_cfg in self.cfg: + self.criteria.append(LOSSES.build(cfg=loss_cfg)) + + def __call__(self, pred, target): + if len(self.criteria) == 0: + # loss computation occur in model + return pred + loss = 0 + for c in self.criteria: + loss += c(pred, target) + return loss + + +def build_criteria(cfg): + return Criteria(cfg) diff --git a/models/losses/lovasz.py b/models/losses/lovasz.py new file mode 100644 index 0000000..dbdb844 --- /dev/null +++ b/models/losses/lovasz.py @@ -0,0 +1,253 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +from typing import Optional +from itertools import filterfalse +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + +from .builder import LOSSES + +BINARY_MODE: str = "binary" +MULTICLASS_MODE: str = "multiclass" +MULTILABEL_MODE: str = "multilabel" + + +def _lovasz_grad(gt_sorted): + """Compute gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1.0 - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def _lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Logits at each pixel (between -infinity and +infinity) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean( + _lovasz_hinge_flat( + *_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore) + ) + for log, lab in zip(logits, labels) + ) + else: + loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore)) + return loss + + +def _lovasz_hinge_flat(logits, labels): + """Binary Lovasz hinge loss + Args: + logits: [P] Logits at each prediction (between -infinity and +infinity) + labels: [P] Tensor, binary ground truth labels (0 or 1) + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0.0 + signs = 2.0 * labels.float() - 1.0 + errors = 1.0 - logits * signs + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = _lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def _flatten_binary_scores(scores, labels, ignore=None): + """Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = labels != ignore + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +def _lovasz_softmax( + probas, labels, classes="present", class_seen=None, per_image=False, ignore=None +): + """Multi-class Lovasz-Softmax loss + Args: + @param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + @param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + @param per_image: compute the loss per image instead of per batch + @param ignore: void class labels + """ + if per_image: + loss = mean( + _lovasz_softmax_flat( + *_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), + classes=classes + ) + for prob, lab in zip(probas, labels) + ) + else: + loss = _lovasz_softmax_flat( + *_flatten_probas(probas, labels, ignore), + classes=classes, + class_seen=class_seen + ) + return loss + + +def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None): + """Multi-class Lovasz-Softmax loss + Args: + @param probas: [P, C] Class probabilities at each prediction (between 0 and 1) + @param labels: [P] Tensor, ground truth labels (between 0 and C - 1) + @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0.0 + C = probas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ["all", "present"] else classes + # for c in class_to_sum: + for c in labels.unique(): + if class_seen is None: + fg = (labels == c).type_as(probas) # foreground for class c + if classes == "present" and fg.sum() == 0: + continue + if C == 1: + if len(classes) > 1: + raise ValueError("Sigmoid output possible only with 1 class") + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) + else: + if c in class_seen: + fg = (labels == c).type_as(probas) # foreground for class c + if classes == "present" and fg.sum() == 0: + continue + if C == 1: + if len(classes) > 1: + raise ValueError("Sigmoid output possible only with 1 class") + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) + return mean(losses) + + +def _flatten_probas(probas, labels, ignore=None): + """Flattens predictions in the batch""" + if probas.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probas.size() + probas = probas.view(B, 1, H, W) + + C = probas.size(1) + probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C] + probas = probas.contiguous().view(-1, C) # [P, C] + + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = labels != ignore + vprobas = probas[valid] + vlabels = labels[valid] + return vprobas, vlabels + + +def isnan(x): + return x != x + + +def mean(values, ignore_nan=False, empty=0): + """Nan-mean compatible with generators.""" + values = iter(values) + if ignore_nan: + values = filterfalse(isnan, values) + try: + n = 1 + acc = next(values) + except StopIteration: + if empty == "raise": + raise ValueError("Empty mean") + return empty + for n, v in enumerate(values, 2): + acc += v + if n == 1: + return acc + return acc / n + + +@LOSSES.register_module() +class LovaszLoss(_Loss): + def __init__( + self, + mode: str, + class_seen: Optional[int] = None, + per_image: bool = False, + ignore_index: Optional[int] = None, + loss_weight: float = 1.0, + ): + """Lovasz loss for segmentation task. + It supports binary, multiclass and multilabel cases + Args: + mode: Loss mode 'binary', 'multiclass' or 'multilabel' + ignore_index: Label that indicates ignored pixels (does not contribute to loss) + per_image: If True loss computed per each image and then averaged, else computed per whole batch + Shape + - **y_pred** - torch.Tensor of shape (N, C, H, W) + - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) + Reference + https://github.com/BloodAxe/pytorch-toolbelt + """ + assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} + super().__init__() + + self.mode = mode + self.ignore_index = ignore_index + self.per_image = per_image + self.class_seen = class_seen + self.loss_weight = loss_weight + + def forward(self, y_pred, y_true): + if self.mode in {BINARY_MODE, MULTILABEL_MODE}: + loss = _lovasz_hinge( + y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index + ) + elif self.mode == MULTICLASS_MODE: + y_pred = y_pred.softmax(dim=1) + loss = _lovasz_softmax( + y_pred, + y_true, + class_seen=self.class_seen, + per_image=self.per_image, + ignore=self.ignore_index, + ) + else: + raise ValueError("Wrong mode {}.".format(self.mode)) + return loss * self.loss_weight diff --git a/models/losses/misc.py b/models/losses/misc.py new file mode 100644 index 0000000..48e26bb --- /dev/null +++ b/models/losses/misc.py @@ -0,0 +1,241 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .builder import LOSSES + + +@LOSSES.register_module() +class CrossEntropyLoss(nn.Module): + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction="mean", + label_smoothing=0.0, + loss_weight=1.0, + ignore_index=-1, + ): + super(CrossEntropyLoss, self).__init__() + weight = torch.tensor(weight).cuda() if weight is not None else None + self.loss_weight = loss_weight + self.loss = nn.CrossEntropyLoss( + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + label_smoothing=label_smoothing, + ) + + def forward(self, pred, target): + return self.loss(pred, target) * self.loss_weight + + +@LOSSES.register_module() +class L1Loss(nn.Module): + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction="mean", + label_smoothing=0.0, + loss_weight=1.0, + ignore_index=-1, + ): + super(L1Loss, self).__init__() + weight = torch.tensor(weight).cuda() if weight is not None else None + self.loss_weight = loss_weight + self.loss = nn.L1Loss(reduction='mean') + + def forward(self, pred, target): + return self.loss(pred, target[:,None]) * self.loss_weight + + +@LOSSES.register_module() +class SmoothCELoss(nn.Module): + def __init__(self, smoothing_ratio=0.1): + super(SmoothCELoss, self).__init__() + self.smoothing_ratio = smoothing_ratio + + def forward(self, pred, target): + eps = self.smoothing_ratio + n_class = pred.size(1) + one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, dim=1) + loss = -(one_hot * log_prb).total(dim=1) + loss = loss[torch.isfinite(loss)].mean() + return loss + + +@LOSSES.register_module() +class BinaryFocalLoss(nn.Module): + def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0): + """Binary Focal Loss + ` + """ + super(BinaryFocalLoss, self).__init__() + assert 0 < alpha < 1 + self.gamma = gamma + self.alpha = alpha + self.logits = logits + self.reduce = reduce + self.loss_weight = loss_weight + + def forward(self, pred, target, **kwargs): + """Forward function. + Args: + pred (torch.Tensor): The prediction with shape (N) + target (torch.Tensor): The ground truth. If containing class + indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities, + same shape as the input. + Returns: + torch.Tensor: The calculated loss + """ + if self.logits: + bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none") + else: + bce = F.binary_cross_entropy(pred, target, reduction="none") + pt = torch.exp(-bce) + alpha = self.alpha * target + (1 - self.alpha) * (1 - target) + focal_loss = alpha * (1 - pt) ** self.gamma * bce + + if self.reduce: + focal_loss = torch.mean(focal_loss) + return focal_loss * self.loss_weight + + +@LOSSES.register_module() +class FocalLoss(nn.Module): + def __init__( + self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1 + ): + """Focal Loss + ` + """ + super(FocalLoss, self).__init__() + assert reduction in ( + "mean", + "sum", + ), "AssertionError: reduction should be 'mean' or 'sum'" + assert isinstance( + alpha, (float, list) + ), "AssertionError: alpha should be of type float" + assert isinstance(gamma, float), "AssertionError: gamma should be of type float" + assert isinstance( + loss_weight, float + ), "AssertionError: loss_weight should be of type float" + assert isinstance(ignore_index, int), "ignore_index must be of type int" + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.loss_weight = loss_weight + self.ignore_index = ignore_index + + def forward(self, pred, target, **kwargs): + """Forward function. + Args: + pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes. + target (torch.Tensor): The ground truth. If containing class + indices, shape (N) where each value is 0≤targets[i]≤C−1, If containing class probabilities, + same shape as the input. + Returns: + torch.Tensor: The calculated loss + """ + # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] + pred = pred.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + pred = pred.reshape(pred.size(0), -1) + # [C, N] -> [N, C] + pred = pred.transpose(0, 1).contiguous() + # (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,) + target = target.view(-1).contiguous() + assert pred.size(0) == target.size( + 0 + ), "The shape of pred doesn't match the shape of target" + valid_mask = target != self.ignore_index + target = target[valid_mask] + pred = pred[valid_mask] + + if len(target) == 0: + return 0.0 + + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes) + + alpha = self.alpha + if isinstance(alpha, list): + alpha = pred.new_tensor(alpha) + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow( + self.gamma + ) + + loss = ( + F.binary_cross_entropy_with_logits(pred, target, reduction="none") + * focal_weight + ) + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.total() + return self.loss_weight * loss + + +@LOSSES.register_module() +class DiceLoss(nn.Module): + def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1): + """DiceLoss. + This loss is proposed in `V-Net: Fully Convolutional Neural Networks for + Volumetric Medical Image Segmentation `_. + """ + super(DiceLoss, self).__init__() + self.smooth = smooth + self.exponent = exponent + self.loss_weight = loss_weight + self.ignore_index = ignore_index + + def forward(self, pred, target, **kwargs): + # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] + pred = pred.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + pred = pred.reshape(pred.size(0), -1) + # [C, N] -> [N, C] + pred = pred.transpose(0, 1).contiguous() + # (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,) + target = target.view(-1).contiguous() + assert pred.size(0) == target.size( + 0 + ), "The shape of pred doesn't match the shape of target" + valid_mask = target != self.ignore_index + target = target[valid_mask] + pred = pred[valid_mask] + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes + ) + + total_loss = 0 + for i in range(num_classes): + if i != self.ignore_index: + num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth + den = ( + torch.sum( + pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent) + ) + + self.smooth + ) + dice_loss = 1 - num / den + total_loss += dice_loss + loss = total_loss / num_classes + return self.loss_weight * loss diff --git a/models/network.py b/models/network.py new file mode 100644 index 0000000..cdedbed --- /dev/null +++ b/models/network.py @@ -0,0 +1,646 @@ +import math +import os.path + +import torch + +import torch.nn as nn +import torch.nn.functional as F +import torchaudio as ta + +from models.encoder.wav2vec import Wav2Vec2Model +from models.encoder.wavlm import WavLMModel + +from models.builder import MODELS + +from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config + +@MODELS.register_module("Audio2Expression") +class Audio2Expression(nn.Module): + def __init__(self, + device: torch.device = None, + pretrained_encoder_type: str = 'wav2vec', + pretrained_encoder_path: str = '', + wav2vec2_config_path: str = '', + num_identity_classes: int = 0, + identity_feat_dim: int = 64, + hidden_dim: int = 512, + expression_dim: int = 52, + norm_type: str = 'ln', + decoder_depth: int = 3, + use_transformer: bool = False, + num_attention_heads: int = 8, + num_transformer_layers: int = 6, + ): + super().__init__() + + self.device = device + + # Initialize audio feature encoder + if pretrained_encoder_type == 'wav2vec': + if os.path.exists(pretrained_encoder_path): + self.audio_encoder = Wav2Vec2Model.from_pretrained(pretrained_encoder_path) + else: + config = Wav2Vec2Config.from_pretrained(wav2vec2_config_path) + self.audio_encoder = Wav2Vec2Model(config) + encoder_output_dim = 768 + elif pretrained_encoder_type == 'wavlm': + self.audio_encoder = WavLMModel.from_pretrained(pretrained_encoder_path) + encoder_output_dim = 768 + else: + raise NotImplementedError(f"Encoder type {pretrained_encoder_type} not supported") + + self.audio_encoder.feature_extractor._freeze_parameters() + self.feature_projection = nn.Linear(encoder_output_dim, hidden_dim) + + self.identity_encoder = AudioIdentityEncoder( + hidden_dim, + num_identity_classes, + identity_feat_dim, + use_transformer, + num_attention_heads, + num_transformer_layers + ) + + self.decoder = nn.ModuleList([ + nn.Sequential(*[ + ConvNormRelu(hidden_dim, hidden_dim, norm=norm_type) + for _ in range(decoder_depth) + ]) + ]) + + self.output_proj = nn.Linear(hidden_dim, expression_dim) + + def freeze_encoder_parameters(self, do_freeze=False): + + for name, param in self.audio_encoder.named_parameters(): + if('feature_extractor' in name): + param.requires_grad = False + else: + param.requires_grad = (not do_freeze) + + def forward(self, input_dict): + + if 'time_steps' not in input_dict: + audio_length = input_dict['input_audio_array'].shape[1] + time_steps = math.ceil(audio_length / 16000 * 30) + else: + time_steps = input_dict['time_steps'] + + # Process audio through encoder + audio_input = input_dict['input_audio_array'].flatten(start_dim=1) + hidden_states = self.audio_encoder(audio_input, frame_num=time_steps).last_hidden_state + + # Project features to hidden dimension + audio_features = self.feature_projection(hidden_states).transpose(1, 2) + + # Process identity-conditioned features + audio_features = self.identity_encoder(audio_features, identity=input_dict['id_idx']) + + # Refine features through decoder + audio_features = self.decoder[0](audio_features) + + # Generate output parameters + audio_features = audio_features.permute(0, 2, 1) + expression_params = self.output_proj(audio_features) + + return torch.sigmoid(expression_params) + + +class AudioIdentityEncoder(nn.Module): + def __init__(self, + hidden_dim, + num_identity_classes=0, + identity_feat_dim=64, + use_transformer=False, + num_attention_heads = 8, + num_transformer_layers = 6, + dropout_ratio=0.1, + ): + super().__init__() + + in_dim = hidden_dim + identity_feat_dim + self.id_mlp = nn.Conv1d(num_identity_classes, identity_feat_dim, 1, 1) + self.first_net = SeqTranslator1D(in_dim, hidden_dim, + min_layers_num=3, + residual=True, + norm='ln' + ) + self.grus = nn.GRU(hidden_dim, hidden_dim, 1, batch_first=True) + self.dropout = nn.Dropout(dropout_ratio) + + self.use_transformer = use_transformer + if(self.use_transformer): + encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_attention_heads, dim_feedforward= 2 * hidden_dim, batch_first=True) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers) + + def forward(self, + audio_features: torch.Tensor, + identity: torch.Tensor = None, + time_steps: int = None) -> tuple: + + audio_features = self.dropout(audio_features) + identity = identity.reshape(identity.shape[0], -1, 1).repeat(1, 1, audio_features.shape[2]).to(torch.float32) + identity = self.id_mlp(identity) + audio_features = torch.cat([audio_features, identity], dim=1) + + x = self.first_net(audio_features) + + if time_steps is not None: + x = F.interpolate(x, size=time_steps, align_corners=False, mode='linear') + + if(self.use_transformer): + x = x.permute(0, 2, 1) + x = self.transformer_encoder(x) + x = x.permute(0, 2, 1) + + return x + +class ConvNormRelu(nn.Module): + ''' + (B,C_in,H,W) -> (B, C_out, H, W) + there exist some kernel size that makes the result is not H/s + ''' + + def __init__(self, + in_channels, + out_channels, + type='1d', + leaky=False, + downsample=False, + kernel_size=None, + stride=None, + padding=None, + p=0, + groups=1, + residual=False, + norm='bn'): + ''' + conv-bn-relu + ''' + super(ConvNormRelu, self).__init__() + self.residual = residual + self.norm_type = norm + # kernel_size = k + # stride = s + + if kernel_size is None and stride is None: + if not downsample: + kernel_size = 3 + stride = 1 + else: + kernel_size = 4 + stride = 2 + + if padding is None: + if isinstance(kernel_size, int) and isinstance(stride, tuple): + padding = tuple(int((kernel_size - st) / 2) for st in stride) + elif isinstance(kernel_size, tuple) and isinstance(stride, int): + padding = tuple(int((ks - stride) / 2) for ks in kernel_size) + elif isinstance(kernel_size, tuple) and isinstance(stride, tuple): + padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride)) + else: + padding = int((kernel_size - stride) / 2) + + if self.residual: + if downsample: + if type == '1d': + self.residual_layer = nn.Sequential( + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + ) + elif type == '2d': + self.residual_layer = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + ) + else: + if in_channels == out_channels: + self.residual_layer = nn.Identity() + else: + if type == '1d': + self.residual_layer = nn.Sequential( + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + ) + elif type == '2d': + self.residual_layer = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + ) + + in_channels = in_channels * groups + out_channels = out_channels * groups + if type == '1d': + self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups) + self.norm = nn.BatchNorm1d(out_channels) + self.dropout = nn.Dropout(p=p) + elif type == '2d': + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + groups=groups) + self.norm = nn.BatchNorm2d(out_channels) + self.dropout = nn.Dropout2d(p=p) + if norm == 'gn': + self.norm = nn.GroupNorm(2, out_channels) + elif norm == 'ln': + self.norm = nn.LayerNorm(out_channels) + if leaky: + self.relu = nn.LeakyReLU(negative_slope=0.2) + else: + self.relu = nn.ReLU() + + def forward(self, x, **kwargs): + if self.norm_type == 'ln': + out = self.dropout(self.conv(x)) + out = self.norm(out.transpose(1,2)).transpose(1,2) + else: + out = self.norm(self.dropout(self.conv(x))) + if self.residual: + residual = self.residual_layer(x) + out += residual + return self.relu(out) + +""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ +class SeqTranslator1D(nn.Module): + ''' + (B, C, T)->(B, C_out, T) + ''' + def __init__(self, + C_in, + C_out, + kernel_size=None, + stride=None, + min_layers_num=None, + residual=True, + norm='bn' + ): + super(SeqTranslator1D, self).__init__() + + conv_layers = nn.ModuleList([]) + conv_layers.append(ConvNormRelu( + in_channels=C_in, + out_channels=C_out, + type='1d', + kernel_size=kernel_size, + stride=stride, + residual=residual, + norm=norm + )) + self.num_layers = 1 + if min_layers_num is not None and self.num_layers < min_layers_num: + while self.num_layers < min_layers_num: + conv_layers.append(ConvNormRelu( + in_channels=C_out, + out_channels=C_out, + type='1d', + kernel_size=kernel_size, + stride=stride, + residual=residual, + norm=norm + )) + self.num_layers += 1 + self.conv_layers = nn.Sequential(*conv_layers) + + def forward(self, x): + return self.conv_layers(x) + + +def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000): + """ + :param audio: 1 x T tensor containing a 16kHz audio signal + :param frame_rate: frame rate for video (we need one audio chunk per video frame) + :param chunk_size: number of audio samples per chunk + :return: num_chunks x chunk_size tensor containing sliced audio + """ + samples_per_frame = 16000 // frame_rate + padding = (chunk_size - samples_per_frame) // 2 + audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0) + anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame)) + audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0) + return audio + +""" https://github.com/facebookresearch/meshtalk """ +class MeshtalkEncoder(nn.Module): + def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'): + """ + :param latent_dim: size of the latent audio embedding + :param model_name: name of the model, used to load and save the model + """ + super().__init__() + + self.melspec = ta.transforms.MelSpectrogram( + sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80 + ) + + conv_len = 5 + self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len) + self.weights_init(self.convert_dimensions) + self.receptive_field = conv_len + + convs = [] + for i in range(6): + dilation = 2 * (i % 3 + 1) + self.receptive_field += (conv_len - 1) * dilation + convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)] + self.weights_init(convs[-1]) + self.convs = torch.nn.ModuleList(convs) + self.code = torch.nn.Linear(128, latent_dim) + + self.apply(lambda x: self.weights_init(x)) + + def weights_init(self, m): + if isinstance(m, torch.nn.Conv1d): + torch.nn.init.xavier_uniform_(m.weight) + try: + torch.nn.init.constant_(m.bias, .01) + except: + pass + + def forward(self, audio: torch.Tensor): + """ + :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame + :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding + """ + B, T = audio.shape[0], audio.shape[1] + x = self.melspec(audio).squeeze(1) + x = torch.log(x.clamp(min=1e-10, max=None)) + if T == 1: + x = x.unsqueeze(1) + + # Convert to the right dimensionality + x = x.view(-1, x.shape[2], x.shape[3]) + x = F.leaky_relu(self.convert_dimensions(x), .2) + + # Process stacks + for conv in self.convs: + x_ = F.leaky_relu(conv(x), .2) + if self.training: + x_ = F.dropout(x_, .2) + l = (x.shape[2] - x_.shape[2]) // 2 + x = (x[:, :, l:-l] + x_) / 2 + + x = torch.mean(x, dim=-1) + x = x.view(B, T, x.shape[-1]) + x = self.code(x) + + return {"code": x} + +class PeriodicPositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=64): + super(PeriodicPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(period, d_model) + position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # (1, period, d_model) + repeat_num = (max_seq_len//period) + 1 + pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model) + self.register_buffer('pe', pe) + def forward(self, x): + # print(self.pe.shape, x.shape) + x = x + self.pe[:, :x.size(1), :] + return self.dropout(x) + + +class GeneratorTransformer(nn.Module): + def __init__(self, + n_poses, + each_dim: list, + dim_list: list, + training=True, + device=None, + identity=False, + num_classes=0, + ): + super().__init__() + + self.training = training + self.device = device + self.gen_length = n_poses + + norm = 'ln' + in_dim = 256 + out_dim = 256 + + self.encoder_choice = 'faceformer' + + self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h" + self.audio_encoder.feature_extractor._freeze_parameters() + self.audio_feature_map = nn.Linear(768, in_dim) + + self.audio_middle = AudioEncoder(in_dim, out_dim, False, num_classes) + + self.dim_list = dim_list + + self.decoder = nn.ModuleList() + self.final_out = nn.ModuleList() + + self.hidden_size = 768 + self.transformer_de_layer = nn.TransformerDecoderLayer( + d_model=self.hidden_size, + nhead=4, + dim_feedforward=self.hidden_size*2, + batch_first=True + ) + self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4) + self.feature2face = nn.Linear(256, self.hidden_size) + + self.position_embeddings = PeriodicPositionalEncoding(self.hidden_size, period=64, max_seq_len=64) + self.id_maping = nn.Linear(12,self.hidden_size) + + + self.decoder.append(self.face_decoder) + self.final_out.append(nn.Linear(self.hidden_size, 32)) + + def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None): + if gt_poses is None: + time_steps = 64 + else: + time_steps = gt_poses.shape[1] + + # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) + if self.encoder_choice == 'meshtalk': + in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000) + feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2) + elif self.encoder_choice == 'faceformer': + hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state + feature = self.audio_feature_map(hidden_states).transpose(1, 2) + else: + feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) + + feature, _ = self.audio_middle(feature, id=None) + feature = self.feature2face(feature.permute(0,2,1)) + + id = id.unsqueeze(1).repeat(1,64,1).to(torch.float32) + id_feature = self.id_maping(id) + id_feature = self.position_embeddings(id_feature) + + for i in range(self.decoder.__len__()): + mid = self.decoder[i](tgt=id_feature, memory=feature) + out = self.final_out[i](mid) + + return out, None + +def linear_interpolation(features, output_len: int): + features = features.transpose(1, 2) + output_features = F.interpolate( + features, size=output_len, align_corners=True, mode='linear') + return output_features.transpose(1, 2) + +def init_biased_mask(n_head, max_seq_len, period): + + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = (2**(-2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return get_slopes_power_of_2(closest_power_of_2) + get_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2] + + slopes = torch.Tensor(get_slopes(n_head)) + bias = torch.div( + torch.arange(start=0, end=max_seq_len, + step=period).unsqueeze(1).repeat(1, period).view(-1), + period, + rounding_mode='floor') + bias = -torch.flip(bias, dims=[0]) + alibi = torch.zeros(max_seq_len, max_seq_len) + for i in range(max_seq_len): + alibi[i, :i + 1] = bias[-(i + 1):] + alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0) + mask = (torch.triu(torch.ones(max_seq_len, + max_seq_len)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill( + mask == 1, float(0.0)) + mask = mask.unsqueeze(0) + alibi + return mask + + +# Alignment Bias +def enc_dec_mask(device, T, S): + mask = torch.ones(T, S) + for i in range(T): + mask[i, i] = 0 + return (mask == 1).to(device=device) + + +# Periodic Positional Encoding +class PeriodicPositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000): + super(PeriodicPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(period, d_model) + position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * + (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # (1, period, d_model) + repeat_num = (max_seq_len // period) + 1 + pe = pe.repeat(1, repeat_num, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1), :] + return self.dropout(x) + + +class BaseModel(nn.Module): + """Base class for all models.""" + + def __init__(self): + super(BaseModel, self).__init__() + # self.logger = logging.getLogger(self.__class__.__name__) + + def forward(self, *x): + """Forward pass logic. + + :return: Model output + """ + raise NotImplementedError + + def freeze_model(self, do_freeze: bool = True): + for param in self.parameters(): + param.requires_grad = (not do_freeze) + + def summary(self, logger, writer=None): + """Model summary.""" + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) + for p in model_parameters]) / 1e6 # Unit is Mega + logger.info('===>Trainable parameters: %.3f M' % params) + if writer is not None: + writer.add_text('Model Summary', + 'Trainable parameters: %.3f M' % params) + + +"""https://github.com/X-niper/UniTalker""" +class UniTalkerDecoderTransformer(BaseModel): + + def __init__(self, out_dim, identity_num, period=30, interpolate_pos=1) -> None: + super().__init__() + self.learnable_style_emb = nn.Embedding(identity_num, out_dim) + self.PPE = PeriodicPositionalEncoding( + out_dim, period=period, max_seq_len=3000) + self.biased_mask = init_biased_mask( + n_head=4, max_seq_len=3000, period=period) + decoder_layer = nn.TransformerDecoderLayer( + d_model=out_dim, + nhead=4, + dim_feedforward=2 * out_dim, + batch_first=True) + self.transformer_decoder = nn.TransformerDecoder( + decoder_layer, num_layers=1) + self.interpolate_pos = interpolate_pos + + def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor, + frame_num: int): + style_idx = torch.argmax(style_idx, dim=1) + obj_embedding = self.learnable_style_emb(style_idx) + obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1) + style_input = self.PPE(obj_embedding) + tgt_mask = self.biased_mask.repeat(style_idx.shape[0], 1, 1)[:, :style_input.shape[1], :style_input. + shape[1]].clone().detach().to( + device=style_input.device) + memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1], + frame_num) + feat_out = self.transformer_decoder( + style_input, + hidden_states, + tgt_mask=tgt_mask, + memory_mask=memory_mask) + if self.interpolate_pos == 2: + feat_out = linear_interpolation(feat_out, output_len=frame_num) + return feat_out \ No newline at end of file diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..4b15130 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,752 @@ +import json +import time +import warnings +import numpy as np +from typing import List, Optional,Tuple +from scipy.signal import savgol_filter + + +ARKitLeftRightPair = [ + ("jawLeft", "jawRight"), + ("mouthLeft", "mouthRight"), + ("mouthSmileLeft", "mouthSmileRight"), + ("mouthFrownLeft", "mouthFrownRight"), + ("mouthDimpleLeft", "mouthDimpleRight"), + ("mouthStretchLeft", "mouthStretchRight"), + ("mouthPressLeft", "mouthPressRight"), + ("mouthLowerDownLeft", "mouthLowerDownRight"), + ("mouthUpperUpLeft", "mouthUpperUpRight"), + ("cheekSquintLeft", "cheekSquintRight"), + ("noseSneerLeft", "noseSneerRight"), + ("browDownLeft", "browDownRight"), + ("browOuterUpLeft", "browOuterUpRight"), + ("eyeBlinkLeft","eyeBlinkRight"), + ("eyeLookDownLeft","eyeLookDownRight"), + ("eyeLookInLeft", "eyeLookInRight"), + ("eyeLookOutLeft","eyeLookOutRight"), + ("eyeLookUpLeft","eyeLookUpRight"), + ("eyeSquintLeft","eyeSquintRight"), + ("eyeWideLeft","eyeWideRight") + ] + +ARKitBlendShape =[ + "browDownLeft", + "browDownRight", + "browInnerUp", + "browOuterUpLeft", + "browOuterUpRight", + "cheekPuff", + "cheekSquintLeft", + "cheekSquintRight", + "eyeBlinkLeft", + "eyeBlinkRight", + "eyeLookDownLeft", + "eyeLookDownRight", + "eyeLookInLeft", + "eyeLookInRight", + "eyeLookOutLeft", + "eyeLookOutRight", + "eyeLookUpLeft", + "eyeLookUpRight", + "eyeSquintLeft", + "eyeSquintRight", + "eyeWideLeft", + "eyeWideRight", + "jawForward", + "jawLeft", + "jawOpen", + "jawRight", + "mouthClose", + "mouthDimpleLeft", + "mouthDimpleRight", + "mouthFrownLeft", + "mouthFrownRight", + "mouthFunnel", + "mouthLeft", + "mouthLowerDownLeft", + "mouthLowerDownRight", + "mouthPressLeft", + "mouthPressRight", + "mouthPucker", + "mouthRight", + "mouthRollLower", + "mouthRollUpper", + "mouthShrugLower", + "mouthShrugUpper", + "mouthSmileLeft", + "mouthSmileRight", + "mouthStretchLeft", + "mouthStretchRight", + "mouthUpperUpLeft", + "mouthUpperUpRight", + "noseSneerLeft", + "noseSneerRight", + "tongueOut" +] + +MOUTH_BLENDSHAPES = [ "mouthDimpleLeft", + "mouthDimpleRight", + "mouthFrownLeft", + "mouthFrownRight", + "mouthFunnel", + "mouthLeft", + "mouthLowerDownLeft", + "mouthLowerDownRight", + "mouthPressLeft", + "mouthPressRight", + "mouthPucker", + "mouthRight", + "mouthRollLower", + "mouthRollUpper", + "mouthShrugLower", + "mouthShrugUpper", + "mouthSmileLeft", + "mouthSmileRight", + "mouthStretchLeft", + "mouthStretchRight", + "mouthUpperUpLeft", + "mouthUpperUpRight", + "jawForward", + "jawLeft", + "jawOpen", + "jawRight", + "noseSneerLeft", + "noseSneerRight", + "cheekPuff", + ] + +DEFAULT_CONTEXT ={ + 'is_initial_input': True, + 'previous_audio': None, + 'previous_expression': None, + 'previous_volume': None, + 'previous_headpose': None, +} + +RETURN_CODE = { + "SUCCESS": 0, + "AUDIO_LENGTH_ERROR": 1, + "CHECKPOINT_PATH_ERROR":2, + "MODEL_INFERENCE_ERROR":3, +} + +DEFAULT_CONTEXTRETURN = { + "code": RETURN_CODE['SUCCESS'], + "expression": None, + "headpose": None, +} + +BLINK_PATTERNS = [ + np.array([0.365, 0.950, 0.956, 0.917, 0.367, 0.119, 0.025]), + np.array([0.235, 0.910, 0.945, 0.778, 0.191, 0.235, 0.089]), + np.array([0.870, 0.950, 0.949, 0.696, 0.191, 0.073, 0.007]), + np.array([0.000, 0.557, 0.953, 0.942, 0.426, 0.148, 0.018]) +] + +# Postprocess +def symmetrize_blendshapes( + bs_params: np.ndarray, + mode: str = "average", + symmetric_pairs: list = ARKitLeftRightPair +) -> np.ndarray: + """ + Apply symmetrization to ARKit blendshape parameters (batched version) + + Args: + bs_params: numpy array of shape (N, 52), batch of ARKit parameters + mode: symmetrization mode ["average", "max", "min", "left_dominant", "right_dominant"] + symmetric_pairs: list of left-right parameter pairs + + Returns: + Symmetrized parameters with same shape (N, 52) + """ + + name_to_idx = {name: i for i, name in enumerate(ARKitBlendShape)} + + # Input validation + if bs_params.ndim != 2 or bs_params.shape[1] != 52: + raise ValueError("Input must be of shape (N, 52)") + + symmetric_bs = bs_params.copy() # Shape (N, 52) + + # Precompute valid index pairs + valid_pairs = [] + for left, right in symmetric_pairs: + left_idx = name_to_idx.get(left) + right_idx = name_to_idx.get(right) + if None not in (left_idx, right_idx): + valid_pairs.append((left_idx, right_idx)) + + # Vectorized processing + for l_idx, r_idx in valid_pairs: + left_col = symmetric_bs[:, l_idx] + right_col = symmetric_bs[:, r_idx] + + if mode == "average": + new_vals = (left_col + right_col) / 2 + elif mode == "max": + new_vals = np.maximum(left_col, right_col) + elif mode == "min": + new_vals = np.minimum(left_col, right_col) + elif mode == "left_dominant": + new_vals = left_col + elif mode == "right_dominant": + new_vals = right_col + else: + raise ValueError(f"Invalid mode: {mode}") + + # Update both columns simultaneously + symmetric_bs[:, l_idx] = new_vals + symmetric_bs[:, r_idx] = new_vals + + return symmetric_bs + + +def apply_random_eye_blinks( + input: np.ndarray, + blink_scale: tuple = (0.8, 1.0), + blink_interval: tuple = (60, 120), + blink_duration: int = 7 +) -> np.ndarray: + """ + Apply randomized eye blinks to blendshape parameters + + Args: + output: Input array of shape (N, 52) containing blendshape parameters + blink_scale: Tuple (min, max) for random blink intensity scaling + blink_interval: Tuple (min, max) for random blink spacing in frames + blink_duration: Number of frames for blink animation (fixed) + + Returns: + None (modifies output array in-place) + """ + # Define eye blink patterns (normalized 0-1) + + # Initialize parameters + n_frames = input.shape[0] + input[:,8:10] = np.zeros((n_frames,2)) + current_frame = 0 + + # Main blink application loop + while current_frame < n_frames - blink_duration: + # Randomize blink parameters + scale = np.random.uniform(*blink_scale) + pattern = BLINK_PATTERNS[np.random.randint(0, 4)] + + # Apply blink animation + blink_values = pattern * scale + input[current_frame:current_frame + blink_duration, 8] = blink_values + input[current_frame:current_frame + blink_duration, 9] = blink_values + + # Advance to next blink position + current_frame += blink_duration + np.random.randint(*blink_interval) + + return input + + +def apply_random_eye_blinks_context( + animation_params: np.ndarray, + processed_frames: int = 0, + intensity_range: tuple = (0.8, 1.0) +) -> np.ndarray: + """Applies random eye blink patterns to facial animation parameters. + + Args: + animation_params: Input facial animation parameters array with shape [num_frames, num_features]. + Columns 8 and 9 typically represent left/right eye blink parameters. + processed_frames: Number of already processed frames that shouldn't be modified + intensity_range: Tuple defining (min, max) scaling for blink intensity + + Returns: + Modified animation parameters array with random eye blinks added to unprocessed frames + """ + remaining_frames = animation_params.shape[0] - processed_frames + + # Only apply blinks if there's enough remaining frames (blink pattern requires 7 frames) + if remaining_frames <= 7: + return animation_params + + # Configure blink timing parameters + min_blink_interval = 40 # Minimum frames between blinks + max_blink_interval = 100 # Maximum frames between blinks + + # Find last blink in previously processed frames (column 8 > 0.5 indicates blink) + previous_blink_indices = np.where(animation_params[:processed_frames, 8] > 0.5)[0] + last_processed_blink = previous_blink_indices[-1] - 7 if previous_blink_indices.size > 0 else processed_frames + + # Calculate first new blink position + blink_interval = np.random.randint(min_blink_interval, max_blink_interval) + first_blink_start = max(0, blink_interval - last_processed_blink) + + # Apply first blink if there's enough space + if first_blink_start <= (remaining_frames - 7): + # Randomly select blink pattern and intensity + blink_pattern = BLINK_PATTERNS[np.random.randint(0, 4)] + intensity = np.random.uniform(*intensity_range) + + # Calculate blink frame range + blink_start = processed_frames + first_blink_start + blink_end = blink_start + 7 + + # Apply pattern to both eyes + animation_params[blink_start:blink_end, 8] = blink_pattern * intensity + animation_params[blink_start:blink_end, 9] = blink_pattern * intensity + + # Check space for additional blink + remaining_after_blink = animation_params.shape[0] - blink_end + if remaining_after_blink > min_blink_interval: + # Calculate second blink position + second_intensity = np.random.uniform(*intensity_range) + second_interval = np.random.randint(min_blink_interval, max_blink_interval) + + if (remaining_after_blink - 7) > second_interval: + second_pattern = BLINK_PATTERNS[np.random.randint(0, 4)] + second_blink_start = blink_end + second_interval + second_blink_end = second_blink_start + 7 + + # Apply second blink + animation_params[second_blink_start:second_blink_end, 8] = second_pattern * second_intensity + animation_params[second_blink_start:second_blink_end, 9] = second_pattern * second_intensity + + return animation_params + + +def export_blendshape_animation( + blendshape_weights: np.ndarray, + output_path: str, + blendshape_names: List[str], + fps: float, + rotation_data: Optional[np.ndarray] = None +) -> None: + """ + Export blendshape animation data to JSON format compatible with ARKit. + + Args: + blendshape_weights: 2D numpy array of shape (N, 52) containing animation frames + output_path: Full path for output JSON file (including .json extension) + blendshape_names: Ordered list of 52 ARKit-standard blendshape names + fps: Frame rate for timing calculations (frames per second) + rotation_data: Optional 3D rotation data array of shape (N, 3) + + Raises: + ValueError: If input dimensions are incompatible + IOError: If file writing fails + """ + # Validate input dimensions + if blendshape_weights.shape[1] != 52: + raise ValueError(f"Expected 52 blendshapes, got {blendshape_weights.shape[1]}") + if len(blendshape_names) != 52: + raise ValueError(f"Requires 52 blendshape names, got {len(blendshape_names)}") + if rotation_data is not None and len(rotation_data) != len(blendshape_weights): + raise ValueError("Rotation data length must match animation frames") + + # Build animation data structure + animation_data = { + "names":blendshape_names, + "metadata": { + "fps": fps, + "frame_count": len(blendshape_weights), + "blendshape_names": blendshape_names + }, + "frames": [] + } + + # Convert numpy array to serializable format + for frame_idx in range(blendshape_weights.shape[0]): + frame_data = { + "weights": blendshape_weights[frame_idx].tolist(), + "time": frame_idx / fps, + "rotation": rotation_data[frame_idx].tolist() if rotation_data else [] + } + animation_data["frames"].append(frame_data) + + # Safeguard against data loss + if not output_path.endswith('.json'): + output_path += '.json' + + # Write to file with error handling + try: + with open(output_path, 'w', encoding='utf-8') as json_file: + json.dump(animation_data, json_file, indent=2, ensure_ascii=False) + except Exception as e: + raise IOError(f"Failed to write animation data: {str(e)}") from e + + +def apply_savitzky_golay_smoothing( + input_data: np.ndarray, + window_length: int = 5, + polyorder: int = 2, + axis: int = 0, + validate: bool = True +) -> Tuple[np.ndarray, Optional[float]]: + """ + Apply Savitzky-Golay filter smoothing along specified axis of input data. + + Args: + input_data: 2D numpy array of shape (n_samples, n_features) + window_length: Length of the filter window (must be odd and > polyorder) + polyorder: Order of the polynomial fit + axis: Axis along which to filter (0: column-wise, 1: row-wise) + validate: Enable input validation checks when True + + Returns: + tuple: (smoothed_data, processing_time) + - smoothed_data: Smoothed output array + - processing_time: Execution time in seconds (None in validation mode) + + Raises: + ValueError: For invalid input dimensions or filter parameters + """ + # Validation mode timing bypass + processing_time = None + + if validate: + # Input integrity checks + if input_data.ndim != 2: + raise ValueError(f"Expected 2D input, got {input_data.ndim}D array") + + if window_length % 2 == 0 or window_length < 3: + raise ValueError("Window length must be odd integer ≥ 3") + + if polyorder >= window_length: + raise ValueError("Polynomial order must be < window length") + + # Store original dtype and convert to float64 for numerical stability + original_dtype = input_data.dtype + working_data = input_data.astype(np.float64) + + # Start performance timer + timer_start = time.perf_counter() + + try: + # Vectorized Savitzky-Golay application + smoothed_data = savgol_filter(working_data, + window_length=window_length, + polyorder=polyorder, + axis=axis, + mode='mirror') + except Exception as e: + raise RuntimeError(f"Filtering failed: {str(e)}") from e + + # Stop timer and calculate duration + processing_time = time.perf_counter() - timer_start + + # Restore original data type with overflow protection + return ( + np.clip(smoothed_data, + 0.0, + 1.0 + ).astype(original_dtype), + processing_time + ) + + +def _blend_region_start( + array: np.ndarray, + region: np.ndarray, + processed_boundary: int, + blend_frames: int +) -> None: + """Applies linear blend between last active frame and silent region start.""" + blend_length = min(blend_frames, region[0] - processed_boundary) + if blend_length <= 0: + return + + pre_frame = array[region[0] - 1] + for i in range(blend_length): + weight = (i + 1) / (blend_length + 1) + array[region[0] + i] = pre_frame * (1 - weight) + array[region[0] + i] * weight + +def _blend_region_end( + array: np.ndarray, + region: np.ndarray, + blend_frames: int +) -> None: + """Applies linear blend between silent region end and next active frame.""" + blend_length = min(blend_frames, array.shape[0] - region[-1] - 1) + if blend_length <= 0: + return + + post_frame = array[region[-1] + 1] + for i in range(blend_length): + weight = (i + 1) / (blend_length + 1) + array[region[-1] - i] = post_frame * (1 - weight) + array[region[-1] - i] * weight + +def find_low_value_regions( + signal: np.ndarray, + threshold: float, + min_region_length: int = 5 +) -> list: + """Identifies contiguous regions in a signal where values fall below a threshold. + + Args: + signal: Input 1D array of numerical values + threshold: Value threshold for identifying low regions + min_region_length: Minimum consecutive samples required to qualify as a region + + Returns: + List of numpy arrays, each containing indices for a qualifying low-value region + """ + low_value_indices = np.where(signal < threshold)[0] + contiguous_regions = [] + current_region_length = 0 + region_start_idx = 0 + + for i in range(1, len(low_value_indices)): + # Check if current index continues a consecutive sequence + if low_value_indices[i] != low_value_indices[i - 1] + 1: + # Finalize previous region if it meets length requirement + if current_region_length >= min_region_length: + contiguous_regions.append(low_value_indices[region_start_idx:i]) + # Reset tracking for new potential region + region_start_idx = i + current_region_length = 0 + current_region_length += 1 + + # Add the final region if it qualifies + if current_region_length >= min_region_length: + contiguous_regions.append(low_value_indices[region_start_idx:]) + + return contiguous_regions + + +def smooth_mouth_movements( + blend_shapes: np.ndarray, + processed_frames: int, + volume: np.ndarray = None, + silence_threshold: float = 0.001, + min_silence_duration: int = 7, + blend_window: int = 3 +) -> np.ndarray: + """Reduces jaw movement artifacts during silent periods in audio-driven animation. + + Args: + blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes] + processed_frames: Number of already processed frames that shouldn't be modified + volume: Audio volume array used to detect silent periods + silence_threshold: Volume threshold for considering a frame silent + min_silence_duration: Minimum consecutive silent frames to qualify for processing + blend_window: Number of frames to smooth at region boundaries + + Returns: + Modified blend shape array with reduced mouth movements during silence + """ + if volume is None: + return blend_shapes + + # Detect silence periods using volume data + silent_regions = find_low_value_regions( + volume, + threshold=silence_threshold, + min_region_length=min_silence_duration + ) + + for region_indices in silent_regions: + # Reduce mouth blend shapes in silent region + mouth_blend_indices = [ARKitBlendShape.index(name) for name in MOUTH_BLENDSHAPES] + for region_indice in region_indices.tolist(): + blend_shapes[region_indice, mouth_blend_indices] *= 0.1 + + try: + # Smooth transition into silent region + _blend_region_start( + blend_shapes, + region_indices, + processed_frames, + blend_window + ) + + # Smooth transition out of silent region + _blend_region_end( + blend_shapes, + region_indices, + blend_window + ) + except IndexError as e: + warnings.warn(f"Edge blending skipped at region {region_indices}: {str(e)}") + + return blend_shapes + + +def apply_frame_blending( + blend_shapes: np.ndarray, + processed_frames: int, + initial_blend_window: int = 3, + subsequent_blend_window: int = 5 +) -> np.ndarray: + """Smooths transitions between processed and unprocessed animation frames using linear blending. + + Args: + blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes] + processed_frames: Number of already processed frames (0 means no previous processing) + initial_blend_window: Max frames to blend at sequence start + subsequent_blend_window: Max frames to blend between processed and new frames + + Returns: + Modified blend shape array with smoothed transitions + """ + if processed_frames > 0: + # Blend transition between existing and new animation + _blend_animation_segment( + blend_shapes, + transition_start=processed_frames, + blend_window=subsequent_blend_window, + reference_frame=blend_shapes[processed_frames - 1] + ) + else: + # Smooth initial frames from neutral expression (zeros) + _blend_animation_segment( + blend_shapes, + transition_start=0, + blend_window=initial_blend_window, + reference_frame=np.zeros_like(blend_shapes[0]) + ) + return blend_shapes + + +def _blend_animation_segment( + array: np.ndarray, + transition_start: int, + blend_window: int, + reference_frame: np.ndarray +) -> None: + """Applies linear interpolation between reference frame and target frames. + + Args: + array: Blend shape array to modify + transition_start: Starting index for blending + blend_window: Maximum number of frames to blend + reference_frame: The reference frame to blend from + """ + actual_blend_length = min(blend_window, array.shape[0] - transition_start) + + for frame_offset in range(actual_blend_length): + current_idx = transition_start + frame_offset + blend_weight = (frame_offset + 1) / (actual_blend_length + 1) + + # Linear interpolation: ref_frame * (1 - weight) + current_frame * weight + array[current_idx] = (reference_frame * (1 - blend_weight) + + array[current_idx] * blend_weight) + + +BROW1 = np.array([[0.05597309, 0.05727929, 0.07995935, 0. , 0. ], + [0.00757574, 0.00936678, 0.12242376, 0. , 0. ], + [0. , 0. , 0.14943372, 0.04535687, 0.04264118], + [0. , 0. , 0.18015374, 0.09019445, 0.08736137], + [0. , 0. , 0.20549579, 0.12802747, 0.12450772], + [0. , 0. , 0.21098022, 0.1369939 , 0.13343132], + [0. , 0. , 0.20904602, 0.13903855, 0.13562402], + [0. , 0. , 0.20365039, 0.13977394, 0.13653506], + [0. , 0. , 0.19714841, 0.14096624, 0.13805152], + [0. , 0. , 0.20325482, 0.17303431, 0.17028868], + [0. , 0. , 0.21990852, 0.20164253, 0.19818163], + [0. , 0. , 0.23858181, 0.21908803, 0.21540019], + [0. , 0. , 0.2567876 , 0.23762083, 0.23396946], + [0. , 0. , 0.34093422, 0.27898848, 0.27651772], + [0. , 0. , 0.45288125, 0.35008961, 0.34887788], + [0. , 0. , 0.48076251, 0.36878952, 0.36778417], + [0. , 0. , 0.47798249, 0.36362219, 0.36145973], + [0. , 0. , 0.46186113, 0.33865979, 0.33597934], + [0. , 0. , 0.45264384, 0.33152157, 0.32891783], + [0. , 0. , 0.40986338, 0.29646468, 0.2945672 ], + [0. , 0. , 0.35628179, 0.23356403, 0.23155804], + [0. , 0. , 0.30870566, 0.1780673 , 0.17637439], + [0. , 0. , 0.25293985, 0.10710219, 0.10622486], + [0. , 0. , 0.18743332, 0.03252602, 0.03244236], + [0.02340254, 0.02364671, 0.15736724, 0. , 0. ]]) + +BROW2 = np.array([ + [0. , 0. , 0.09799323, 0.05944436, 0.05002545], + [0. , 0. , 0.09780276, 0.07674237, 0.01636653], + [0. , 0. , 0.11136199, 0.1027964 , 0.04249811], + [0. , 0. , 0.26883412, 0.15861984, 0.15832305], + [0. , 0. , 0.42191629, 0.27038204, 0.27007768], + [0. , 0. , 0.3404977 , 0.21633868, 0.21597538], + [0. , 0. , 0.27301185, 0.17176409, 0.17134669], + [0. , 0. , 0.25960442, 0.15670464, 0.15622253], + [0. , 0. , 0.22877269, 0.11805892, 0.11754539], + [0. , 0. , 0.1451605 , 0.06389034, 0.0636282 ]]) + +BROW3 = np.array([ + [0. , 0. , 0.124 , 0.0295, 0.0295], + [0. , 0. , 0.267 , 0.184 , 0.184 ], + [0. , 0. , 0.359 , 0.2765, 0.2765], + [0. , 0. , 0.3945, 0.3125, 0.3125], + [0. , 0. , 0.4125, 0.331 , 0.331 ], + [0. , 0. , 0.4235, 0.3445, 0.3445], + [0. , 0. , 0.4085, 0.3305, 0.3305], + [0. , 0. , 0.3695, 0.294 , 0.294 ], + [0. , 0. , 0.2835, 0.213 , 0.213 ], + [0. , 0. , 0.1795, 0.1005, 0.1005], + [0. , 0. , 0.108 , 0.014 , 0.014 ]]) + + +import numpy as np +from scipy.ndimage import label + + +def apply_random_brow_movement(input_exp, volume): + FRAME_SEGMENT = 150 + HOLD_THRESHOLD = 10 + VOLUME_THRESHOLD = 0.08 + MIN_REGION_LENGTH = 6 + STRENGTH_RANGE = (0.7, 1.3) + + BROW_PEAKS = { + 0: np.argmax(BROW1[:, 2]), + 1: np.argmax(BROW2[:, 2]) + } + + for seg_start in range(0, len(volume), FRAME_SEGMENT): + seg_end = min(seg_start + FRAME_SEGMENT, len(volume)) + seg_volume = volume[seg_start:seg_end] + + candidate_regions = [] + + high_vol_mask = seg_volume > VOLUME_THRESHOLD + labeled_array, num_features = label(high_vol_mask) + + for i in range(1, num_features + 1): + region = (labeled_array == i) + region_indices = np.where(region)[0] + if len(region_indices) >= MIN_REGION_LENGTH: + candidate_regions.append(region_indices) + + if candidate_regions: + selected_region = candidate_regions[np.random.choice(len(candidate_regions))] + region_start = selected_region[0] + region_end = selected_region[-1] + region_length = region_end - region_start + 1 + + brow_idx = np.random.randint(0, 2) + base_brow = BROW1 if brow_idx == 0 else BROW2 + peak_idx = BROW_PEAKS[brow_idx] + + if region_length > HOLD_THRESHOLD: + local_max_pos = seg_volume[selected_region].argmax() + global_peak_frame = seg_start + selected_region[local_max_pos] + + rise_anim = base_brow[:peak_idx + 1] + hold_frame = base_brow[peak_idx:peak_idx + 1] + + insert_start = max(global_peak_frame - peak_idx, seg_start) + insert_end = min(global_peak_frame + (region_length - local_max_pos), seg_end) + + strength = np.random.uniform(*STRENGTH_RANGE) + + if insert_start + len(rise_anim) <= seg_end: + input_exp[insert_start:insert_start + len(rise_anim), :5] += rise_anim * strength + hold_duration = insert_end - (insert_start + len(rise_anim)) + if hold_duration > 0: + input_exp[insert_start + len(rise_anim):insert_end, :5] += np.tile(hold_frame * strength, + (hold_duration, 1)) + else: + anim_length = base_brow.shape[0] + insert_pos = seg_start + region_start + (region_length - anim_length) // 2 + insert_pos = max(seg_start, min(insert_pos, seg_end - anim_length)) + + if insert_pos + anim_length <= seg_end: + strength = np.random.uniform(*STRENGTH_RANGE) + input_exp[insert_pos:insert_pos + anim_length, :5] += base_brow * strength + + return np.clip(input_exp, 0, 1) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2e09298 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +spleeter==2.4.2 +opencv_python_headless==4.11.0.86 +gradio==5.25.2 +omegaconf==2.3.0 +addict==2.4.0 +yapf==0.40.1 +librosa==0.11.0 +transformers==4.36.2 +termcolor==3.0.1 +numpy==1.26.3 \ No newline at end of file diff --git a/scripts/install/install_cu118.sh b/scripts/install/install_cu118.sh new file mode 100644 index 0000000..0a16bc9 --- /dev/null +++ b/scripts/install/install_cu118.sh @@ -0,0 +1,9 @@ +# install torch 2.1.2 +# or conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia +pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 + +# install dependencies +pip install -r requirements.txt + +# install H5-render +pip install wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl \ No newline at end of file diff --git a/scripts/install/install_cu121.sh b/scripts/install/install_cu121.sh new file mode 100644 index 0000000..7c39e52 --- /dev/null +++ b/scripts/install/install_cu121.sh @@ -0,0 +1,9 @@ +# install torch 2.1.2 +# or conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia +pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121 + +# install dependencies +pip install -r requirements.txt + +# install H5-render +pip install wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/cache.py b/utils/cache.py new file mode 100644 index 0000000..ac8bc33 --- /dev/null +++ b/utils/cache.py @@ -0,0 +1,53 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import os +import SharedArray + +try: + from multiprocessing.shared_memory import ShareableList +except ImportError: + import warnings + + warnings.warn("Please update python version >= 3.8 to enable shared_memory") +import numpy as np + + +def shared_array(name, var=None): + if var is not None: + # check exist + if os.path.exists(f"/dev/shm/{name}"): + return SharedArray.attach(f"shm://{name}") + # create shared_array + data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype) + data[...] = var[...] + data.flags.writeable = False + else: + data = SharedArray.attach(f"shm://{name}").copy() + return data + + +def shared_dict(name, var=None): + name = str(name) + assert "." not in name # '.' is used as sep flag + data = {} + if var is not None: + assert isinstance(var, dict) + keys = var.keys() + # current version only cache np.array + keys_valid = [] + for key in keys: + if isinstance(var[key], np.ndarray): + keys_valid.append(key) + keys = keys_valid + + ShareableList(sequence=keys, name=name + ".keys") + for key in keys: + if isinstance(var[key], np.ndarray): + data[key] = shared_array(name=f"{name}.{key}", var=var[key]) + else: + keys = list(ShareableList(name=name + ".keys")) + for key in keys: + data[key] = shared_array(name=f"{name}.{key}") + return data diff --git a/utils/comm.py b/utils/comm.py new file mode 100644 index 0000000..23bec8e --- /dev/null +++ b/utils/comm.py @@ -0,0 +1,192 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import functools +import numpy as np +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert ( + _LOCAL_PROCESS_GROUP is not None + ), "Local process group is not created! Please use launch() to spawn processes!" + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + if dist.get_backend() == dist.Backend.NCCL: + # This argument is needed to avoid warnings. + # It's valid only for NCCL backend. + dist.barrier(device_ids=[torch.cuda.current_device()]) + else: + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = ( + _get_global_gloo_group() + ) # use CPU group by default, to reduce GPU RAM usage. + world_size = dist.get_world_size(group) + if world_size == 1: + return [data] + + output = [None for _ in range(world_size)] + dist.all_gather_object(output, data, group=group) + return output + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + world_size = dist.get_world_size(group=group) + if world_size == 1: + return [data] + rank = dist.get_rank(group=group) + + if rank == dst: + output = [None for _ in range(world_size)] + dist.gather_object(data, output, dst=dst, group=group) + return output + else: + dist.gather_object(data, None, dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2**31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..0513d64 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,696 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" +import ast +import copy +import os +import os.path as osp +import platform +import shutil +import sys +import tempfile +import uuid +import warnings +from argparse import Action, ArgumentParser +from collections import abc +from importlib import import_module + +from addict import Dict +from yapf.yapflib.yapf_api import FormatCode + +from .misc import import_modules_from_strings +from .path import check_file_exist + +if platform.system() == "Windows": + import regex as re +else: + import re + +BASE_KEY = "_base_" +DELETE_KEY = "_delete_" +DEPRECATION_KEY = "_deprecation_" +RESERVED_KEYS = ["filename", "text", "pretty_text"] + + +class ConfigDict(Dict): + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super(ConfigDict, self).__getattr__(name) + except KeyError: + ex = AttributeError( + f"'{self.__class__.__name__}' object has no " f"attribute '{name}'" + ) + except Exception as e: + ex = e + else: + return value + raise ex + + +def add_args(parser, cfg, prefix=""): + for k, v in cfg.items(): + if isinstance(v, str): + parser.add_argument("--" + prefix + k) + elif isinstance(v, int): + parser.add_argument("--" + prefix + k, type=int) + elif isinstance(v, float): + parser.add_argument("--" + prefix + k, type=float) + elif isinstance(v, bool): + parser.add_argument("--" + prefix + k, action="store_true") + elif isinstance(v, dict): + add_args(parser, v, prefix + k + ".") + elif isinstance(v, abc.Iterable): + parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+") + else: + print(f"cannot parse key {prefix + k} of type {type(v)}") + return parser + + +class Config: + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> cfg.a + 1 + >>> cfg.b + {'b1': [0, 1]} + >>> cfg.b.b1 + [0, 1] + >>> cfg = Config.fromfile('tests/data/config/a.py') + >>> cfg.filename + "/home/kchen/projects/mmcv/tests/data/config/a.py" + >>> cfg.item4 + 'test' + >>> cfg + "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " + "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" + """ + + @staticmethod + def _validate_py_syntax(filename): + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError( + "There are syntax errors in config " f"file {filename}: {e}" + ) + + @staticmethod + def _substitute_predefined_vars(filename, temp_config_name): + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname, + ) + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + for key, value in support_templates.items(): + regexp = r"\{\{\s*" + str(key) + r"\s*\}\}" + value = value.replace("\\", "/") + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: + tmp_config_file.write(config_file) + + @staticmethod + def _pre_substitute_base_vars(filename, temp_config_name): + """Substitute base variable placehoders to string, so that parsing + would work.""" + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + base_var_dict = {} + regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}" + base_vars = set(re.findall(regexp, config_file)) + for base_var in base_vars: + randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}" + base_var_dict[randstr] = base_var + regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}" + config_file = re.sub(regexp, f'"{randstr}"', config_file) + with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: + tmp_config_file.write(config_file) + return base_var_dict + + @staticmethod + def _substitute_base_vars(cfg, base_var_dict, base_cfg): + """Substitute variable strings to their actual values.""" + cfg = copy.deepcopy(cfg) + + if isinstance(cfg, dict): + for k, v in cfg.items(): + if isinstance(v, str) and v in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[v].split("."): + new_v = new_v[new_k] + cfg[k] = new_v + elif isinstance(v, (list, tuple, dict)): + cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg) + elif isinstance(cfg, tuple): + cfg = tuple( + Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg + ) + elif isinstance(cfg, list): + cfg = [ + Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg + ] + elif isinstance(cfg, str) and cfg in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[cfg].split("."): + new_v = new_v[new_k] + cfg = new_v + + return cfg + + @staticmethod + def _file2dict(filename, use_predefined_variables=True): + filename = osp.abspath(osp.expanduser(filename)) + check_file_exist(filename) + fileExtname = osp.splitext(filename)[1] + if fileExtname not in [".py", ".json", ".yaml", ".yml"]: + raise IOError("Only py/yml/yaml/json type are supported now!") + + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile( + dir=temp_config_dir, suffix=fileExtname + ) + if platform.system() == "Windows": + temp_config_file.close() + temp_config_name = osp.basename(temp_config_file.name) + # Substitute predefined variables + if use_predefined_variables: + Config._substitute_predefined_vars(filename, temp_config_file.name) + else: + shutil.copyfile(filename, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars( + temp_config_file.name, temp_config_file.name + ) + + if filename.endswith(".py"): + temp_module_name = osp.splitext(temp_config_name)[0] + sys.path.insert(0, temp_config_dir) + Config._validate_py_syntax(filename) + mod = import_module(temp_module_name) + sys.path.pop(0) + cfg_dict = { + name: value + for name, value in mod.__dict__.items() + if not name.startswith("__") + } + # delete imported module + del sys.modules[temp_module_name] + elif filename.endswith((".yml", ".yaml", ".json")): + raise NotImplementedError + # close temp file + temp_config_file.close() + + # check deprecation information + if DEPRECATION_KEY in cfg_dict: + deprecation_info = cfg_dict.pop(DEPRECATION_KEY) + warning_msg = ( + f"The config file {filename} will be deprecated " "in the future." + ) + if "expected" in deprecation_info: + warning_msg += f' Please use {deprecation_info["expected"]} ' "instead." + if "reference" in deprecation_info: + warning_msg += ( + " More information can be found at " + f'{deprecation_info["reference"]}' + ) + warnings.warn(warning_msg) + + cfg_text = filename + "\n" + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + if BASE_KEY in cfg_dict: + cfg_dir = osp.dirname(filename) + base_filename = cfg_dict.pop(BASE_KEY) + base_filename = ( + base_filename if isinstance(base_filename, list) else [base_filename] + ) + + cfg_dict_list = list() + cfg_text_list = list() + for f in base_filename: + _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) + cfg_dict_list.append(_cfg_dict) + cfg_text_list.append(_cfg_text) + + base_cfg_dict = dict() + for c in cfg_dict_list: + duplicate_keys = base_cfg_dict.keys() & c.keys() + if len(duplicate_keys) > 0: + raise KeyError( + "Duplicate key is not allowed among bases. " + f"Duplicate keys: {duplicate_keys}" + ) + base_cfg_dict.update(c) + + # Substitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars( + cfg_dict, base_var_dict, base_cfg_dict + ) + + base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = base_cfg_dict + + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = "\n".join(cfg_text_list) + + return cfg_dict, cfg_text + + @staticmethod + def _merge_a_into_b(a, b, allow_list_keys=False): + """merge dict ``a`` into dict ``b`` (non-inplace). + + Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid + in-place modifications. + + Args: + a (dict): The source dict to be merged into ``b``. + b (dict): The origin dict to be fetch keys from ``a``. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in source ``a`` and will replace the element of the + corresponding index in b if b is a list. Default: False. + + Returns: + dict: The modified dict of ``b`` using ``a``. + + Examples: + # Normally merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # Delete b first and merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # b is a list + >>> Config._merge_a_into_b( + ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) + [{'a': 2}, {'b': 2}] + """ + b = b.copy() + for k, v in a.items(): + if allow_list_keys and k.isdigit() and isinstance(b, list): + k = int(k) + if len(b) <= k: + raise KeyError(f"Index {k} exceeds the length of list {b}") + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): + allowed_types = (dict, list) if allow_list_keys else dict + if not isinstance(b[k], allowed_types): + raise TypeError( + f"{k}={v} in child config cannot inherit from base " + f"because {k} is a dict in the child config but is of " + f"type {type(b[k])} in base config. You may set " + f"`{DELETE_KEY}=True` to ignore the base config" + ) + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + else: + b[k] = v + return b + + @staticmethod + def fromfile(filename, use_predefined_variables=True, import_custom_modules=True): + cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables) + if import_custom_modules and cfg_dict.get("custom_imports", None): + import_modules_from_strings(**cfg_dict["custom_imports"]) + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + @staticmethod + def fromstring(cfg_str, file_format): + """Generate config from config str. + + Args: + cfg_str (str): Config str. + file_format (str): Config file format corresponding to the + config str. Only py/yml/yaml/json type are supported now! + + Returns: + obj:`Config`: Config obj. + """ + if file_format not in [".py", ".json", ".yaml", ".yml"]: + raise IOError("Only py/yml/yaml/json type are supported now!") + if file_format != ".py" and "dict(" in cfg_str: + # check if users specify a wrong suffix for python + warnings.warn('Please check "file_format", the file format may be .py') + with tempfile.NamedTemporaryFile( + "w", encoding="utf-8", suffix=file_format, delete=False + ) as temp_file: + temp_file.write(cfg_str) + # on windows, previous implementation cause error + # see PR 1077 for details + cfg = Config.fromfile(temp_file.name) + os.remove(temp_file.name) + return cfg + + @staticmethod + def auto_argparser(description=None): + """Generate argparser from config file automatically (experimental)""" + partial_parser = ArgumentParser(description=description) + partial_parser.add_argument("config", help="config file path") + cfg_file = partial_parser.parse_known_args()[0].config + cfg = Config.fromfile(cfg_file) + parser = ArgumentParser(description=description) + parser.add_argument("config", help="config file path") + add_args(parser, cfg) + return parser, cfg + + def __init__(self, cfg_dict=None, cfg_text=None, filename=None): + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}") + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f"{key} is reserved for config file") + + super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict)) + super(Config, self).__setattr__("_filename", filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename, "r") as f: + text = f.read() + else: + text = "" + super(Config, self).__setattr__("_text", text) + + @property + def filename(self): + return self._filename + + @property + def text(self): + return self._text + + @property + def pretty_text(self): + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = f"'{v}'" + else: + v_str = str(v) + + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: {v_str}" + else: + attr_str = f"{str(k)}={v_str}" + attr_str = _indent(attr_str, indent) + + return attr_str + + def _format_list(k, v, use_mapping=False): + # check if all items in the list are dict + if all(isinstance(_, dict) for _ in v): + v_str = "[\n" + v_str += "\n".join( + f"dict({_indent(_format_dict(v_), indent)})," for v_ in v + ).rstrip(",") + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: {v_str}" + else: + attr_str = f"{str(k)}={v_str}" + attr_str = _indent(attr_str, indent) + "]" + else: + attr_str = _format_basic_types(k, v, use_mapping) + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= not str(key_name).isidentifier() + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = "" + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += "{" + for idx, (k, v) in enumerate(input_dict.items()): + is_last = idx >= len(input_dict) - 1 + end = "" if outest_level or is_last else "," + if isinstance(v, dict): + v_str = "\n" + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: dict({v_str}" + else: + attr_str = f"{str(k)}=dict({v_str}" + attr_str = _indent(attr_str, indent) + ")" + end + elif isinstance(v, list): + attr_str = _format_list(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += "\n".join(s) + if use_mapping: + r += "}" + return r + + cfg_dict = self._cfg_dict.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + # copied from setup.cfg + yapf_style = dict( + based_on_style="pep8", + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True, + ) + text, _ = FormatCode(text, style_config=yapf_style, verify=True) + + return text + + def __repr__(self): + return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}" + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name): + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def __getstate__(self): + return (self._cfg_dict, self._filename, self._text) + + def __setstate__(self, state): + _cfg_dict, _filename, _text = state + super(Config, self).__setattr__("_cfg_dict", _cfg_dict) + super(Config, self).__setattr__("_filename", _filename) + super(Config, self).__setattr__("_text", _text) + + def dump(self, file=None): + cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict() + if self.filename.endswith(".py"): + if file is None: + return self.pretty_text + else: + with open(file, "w", encoding="utf-8") as f: + f.write(self.pretty_text) + else: + import mmcv + + if file is None: + file_format = self.filename.split(".")[-1] + return mmcv.dump(cfg_dict, file_format=file_format) + else: + mmcv.dump(cfg_dict, file) + + def merge_from_dict(self, options, allow_list_keys=True): + """Merge list into cfg_dict. + + Merge the dict parsed by MultipleKVAction into this cfg. + + Examples: + >>> options = {'models.backbone.depth': 50, + ... 'models.backbone.with_cp':True} + >>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet')))) + >>> cfg.merge_from_dict(options) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... models=dict(backbone=dict(depth=50, with_cp=True))) + + # Merge list element + >>> cfg = Config(dict(pipeline=[ + ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) + >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) + >>> cfg.merge_from_dict(options, allow_list_keys=True) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict(pipeline=[ + ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) + + Args: + options (dict): dict of configs to merge from. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in ``options`` and will replace the element of the + corresponding index in the config if the config is a list. + Default: True. + """ + option_cfg_dict = {} + for full_key, v in options.items(): + d = option_cfg_dict + key_list = full_key.split(".") + for subkey in key_list[:-1]: + d.setdefault(subkey, ConfigDict()) + d = d[subkey] + subkey = key_list[-1] + d[subkey] = v + + cfg_dict = super(Config, self).__getattribute__("_cfg_dict") + super(Config, self).__setattr__( + "_cfg_dict", + Config._merge_a_into_b( + option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys + ), + ) + + +class DictAction(Action): + """ + argparse action to split an argument into KEY=VALUE form + on the first = and append to a dictionary. List options can + be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit + brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build + list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' + """ + + @staticmethod + def _parse_int_float_bool(val): + try: + return int(val) + except ValueError: + pass + try: + return float(val) + except ValueError: + pass + if val.lower() in ["true", "false"]: + return True if val.lower() == "true" else False + return val + + @staticmethod + def _parse_iterable(val): + """Parse iterable values in the string. + + All elements inside '()' or '[]' are treated as iterable values. + + Args: + val (str): Value string. + + Returns: + list | tuple: The expanded list or tuple from the string. + + Examples: + >>> DictAction._parse_iterable('1,2,3') + [1, 2, 3] + >>> DictAction._parse_iterable('[a, b, c]') + ['a', 'b', 'c'] + >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') + [(1, 2, 3), ['a', 'b'], 'c'] + """ + + def find_next_comma(string): + """Find the position of next comma in the string. + + If no ',' is found in the string, return the string length. All + chars inside '()' and '[]' are treated as one element and thus ',' + inside these brackets are ignored. + """ + assert (string.count("(") == string.count(")")) and ( + string.count("[") == string.count("]") + ), f"Imbalanced brackets exist in {string}" + end = len(string) + for idx, char in enumerate(string): + pre = string[:idx] + # The string before this ',' is balanced + if ( + (char == ",") + and (pre.count("(") == pre.count(")")) + and (pre.count("[") == pre.count("]")) + ): + end = idx + break + return end + + # Strip ' and " characters and replace whitespace. + val = val.strip("'\"").replace(" ", "") + is_tuple = False + if val.startswith("(") and val.endswith(")"): + is_tuple = True + val = val[1:-1] + elif val.startswith("[") and val.endswith("]"): + val = val[1:-1] + elif "," not in val: + # val is a single value + return DictAction._parse_int_float_bool(val) + + values = [] + while len(val) > 0: + comma_idx = find_next_comma(val) + element = DictAction._parse_iterable(val[:comma_idx]) + values.append(element) + val = val[comma_idx + 1 :] + if is_tuple: + values = tuple(values) + return values + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for kv in values: + key, val = kv.split("=", maxsplit=1) + options[key] = self._parse_iterable(val) + setattr(namespace, self.dest, options) diff --git a/utils/env.py b/utils/env.py new file mode 100644 index 0000000..802ed90 --- /dev/null +++ b/utils/env.py @@ -0,0 +1,33 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import os +import random +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +from datetime import datetime + + +def get_random_seed(): + seed = ( + os.getpid() + + int(datetime.now().strftime("%S%f")) + + int.from_bytes(os.urandom(2), "big") + ) + return seed + + +def set_seed(seed=None): + if seed is None: + seed = get_random_seed() + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + cudnn.benchmark = False + cudnn.deterministic = True + os.environ["PYTHONHASHSEED"] = str(seed) diff --git a/utils/events.py b/utils/events.py new file mode 100644 index 0000000..90412dd --- /dev/null +++ b/utils/events.py @@ -0,0 +1,585 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + + +import datetime +import json +import logging +import os +import time +import torch +import numpy as np + +from typing import List, Optional, Tuple +from collections import defaultdict +from contextlib import contextmanager + +__all__ = [ + "get_event_storage", + "JSONWriter", + "TensorboardXWriter", + "CommonMetricPrinter", + "EventStorage", +] + +_CURRENT_STORAGE_STACK = [] + + +def get_event_storage(): + """ + Returns: + The :class:`EventStorage` object that's currently being used. + Throws an error if no :class:`EventStorage` is currently enabled. + """ + assert len( + _CURRENT_STORAGE_STACK + ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!" + return _CURRENT_STORAGE_STACK[-1] + + +class EventWriter: + """ + Base class for writers that obtain events from :class:`EventStorage` and process them. + """ + + def write(self): + raise NotImplementedError + + def close(self): + pass + + +class JSONWriter(EventWriter): + """ + Write scalars to a json file. + It saves scalars as one json per line (instead of a big json) for easy parsing. + Examples parsing such a json file: + :: + $ cat metrics.json | jq -s '.[0:2]' + [ + { + "data_time": 0.008433341979980469, + "iteration": 19, + "loss": 1.9228371381759644, + "loss_box_reg": 0.050025828182697296, + "loss_classifier": 0.5316952466964722, + "loss_mask": 0.7236229181289673, + "loss_rpn_box": 0.0856662318110466, + "loss_rpn_cls": 0.48198649287223816, + "lr": 0.007173333333333333, + "time": 0.25401854515075684 + }, + { + "data_time": 0.007216215133666992, + "iteration": 39, + "loss": 1.282649278640747, + "loss_box_reg": 0.06222952902317047, + "loss_classifier": 0.30682939291000366, + "loss_mask": 0.6970193982124329, + "loss_rpn_box": 0.038663312792778015, + "loss_rpn_cls": 0.1471673548221588, + "lr": 0.007706666666666667, + "time": 0.2490077018737793 + } + ] + $ cat metrics.json | jq '.loss_mask' + 0.7126231789588928 + 0.689423680305481 + 0.6776131987571716 + ... + """ + + def __init__(self, json_file, window_size=20): + """ + Args: + json_file (str): path to the json file. New data will be appended if the file exists. + window_size (int): the window size of median smoothing for the scalars whose + `smoothing_hint` are True. + """ + self._file_handle = open(json_file, "a") + self._window_size = window_size + self._last_write = -1 + + def write(self): + storage = get_event_storage() + to_save = defaultdict(dict) + + for k, (v, iter) in storage.latest_with_smoothing_hint( + self._window_size + ).items(): + # keep scalars that have not been written + if iter <= self._last_write: + continue + to_save[iter][k] = v + if len(to_save): + all_iters = sorted(to_save.keys()) + self._last_write = max(all_iters) + + for itr, scalars_per_iter in to_save.items(): + scalars_per_iter["iteration"] = itr + self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n") + self._file_handle.flush() + try: + os.fsync(self._file_handle.fileno()) + except AttributeError: + pass + + def close(self): + self._file_handle.close() + + +class TensorboardXWriter(EventWriter): + """ + Write all scalars to a tensorboard file. + """ + + def __init__(self, log_dir: str, window_size: int = 20, **kwargs): + """ + Args: + log_dir (str): the directory to save the output events + window_size (int): the scalars will be median-smoothed by this window size + kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` + """ + self._window_size = window_size + from torch.utils.tensorboard import SummaryWriter + + self._writer = SummaryWriter(log_dir, **kwargs) + self._last_write = -1 + + def write(self): + storage = get_event_storage() + new_last_write = self._last_write + for k, (v, iter) in storage.latest_with_smoothing_hint( + self._window_size + ).items(): + if iter > self._last_write: + self._writer.add_scalar(k, v, iter) + new_last_write = max(new_last_write, iter) + self._last_write = new_last_write + + # storage.put_{image,histogram} is only meant to be used by + # tensorboard writer. So we access its internal fields directly from here. + if len(storage._vis_data) >= 1: + for img_name, img, step_num in storage._vis_data: + self._writer.add_image(img_name, img, step_num) + # Storage stores all image data and rely on this writer to clear them. + # As a result it assumes only one writer will use its image data. + # An alternative design is to let storage store limited recent + # data (e.g. only the most recent image) that all writers can access. + # In that case a writer may not see all image data if its period is long. + storage.clear_images() + + if len(storage._histograms) >= 1: + for params in storage._histograms: + self._writer.add_histogram_raw(**params) + storage.clear_histograms() + + def close(self): + if hasattr(self, "_writer"): # doesn't exist when the code fails at import + self._writer.close() + + +class CommonMetricPrinter(EventWriter): + """ + Print **common** metrics to the terminal, including + iteration time, ETA, memory, all losses, and the learning rate. + It also applies smoothing using a window of 20 elements. + It's meant to print common metrics in common ways. + To print something in more customized ways, please implement a similar printer by yourself. + """ + + def __init__(self, max_iter: Optional[int] = None, window_size: int = 20): + """ + Args: + max_iter: the maximum number of iterations to train. + Used to compute ETA. If not given, ETA will not be printed. + window_size (int): the losses will be median-smoothed by this window size + """ + self.logger = logging.getLogger(__name__) + self._max_iter = max_iter + self._window_size = window_size + self._last_write = ( + None # (step, time) of last call to write(). Used to compute ETA + ) + + def _get_eta(self, storage) -> Optional[str]: + if self._max_iter is None: + return "" + iteration = storage.iter + try: + eta_seconds = storage.history("time").median(1000) * ( + self._max_iter - iteration - 1 + ) + storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False) + return str(datetime.timedelta(seconds=int(eta_seconds))) + except KeyError: + # estimate eta on our own - more noisy + eta_string = None + if self._last_write is not None: + estimate_iter_time = (time.perf_counter() - self._last_write[1]) / ( + iteration - self._last_write[0] + ) + eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + self._last_write = (iteration, time.perf_counter()) + return eta_string + + def write(self): + storage = get_event_storage() + iteration = storage.iter + if iteration == self._max_iter: + # This hook only reports training progress (loss, ETA, etc) but not other data, + # therefore do not write anything after training succeeds, even if this method + # is called. + return + + try: + data_time = storage.history("data_time").avg(20) + except KeyError: + # they may not exist in the first few iterations (due to warmup) + # or when SimpleTrainer is not used + data_time = None + try: + iter_time = storage.history("time").global_avg() + except KeyError: + iter_time = None + try: + lr = "{:.5g}".format(storage.history("lr").latest()) + except KeyError: + lr = "N/A" + + eta_string = self._get_eta(storage) + + if torch.cuda.is_available(): + max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 + else: + max_mem_mb = None + + # NOTE: max_mem is parsed by grep in "dev/parse_results.sh" + self.logger.info( + " {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format( + eta=f"eta: {eta_string} " if eta_string else "", + iter=iteration, + losses=" ".join( + [ + "{}: {:.4g}".format(k, v.median(self._window_size)) + for k, v in storage.histories().items() + if "loss" in k + ] + ), + time="time: {:.4f} ".format(iter_time) + if iter_time is not None + else "", + data_time="data_time: {:.4f} ".format(data_time) + if data_time is not None + else "", + lr=lr, + memory="max_mem: {:.0f}M".format(max_mem_mb) + if max_mem_mb is not None + else "", + ) + ) + + +class EventStorage: + """ + The user-facing class that provides metric storage functionalities. + In the future we may add support for storing / logging other types of data if needed. + """ + + def __init__(self, start_iter=0): + """ + Args: + start_iter (int): the iteration number to start with + """ + self._history = defaultdict(AverageMeter) + self._smoothing_hints = {} + self._latest_scalars = {} + self._iter = start_iter + self._current_prefix = "" + self._vis_data = [] + self._histograms = [] + + # def put_image(self, img_name, img_tensor): + # """ + # Add an `img_tensor` associated with `img_name`, to be shown on + # tensorboard. + # Args: + # img_name (str): The name of the image to put into tensorboard. + # img_tensor (torch.Tensor or numpy.array): An `uint8` or `float` + # Tensor of shape `[channel, height, width]` where `channel` is + # 3. The image format should be RGB. The elements in img_tensor + # can either have values in [0, 1] (float32) or [0, 255] (uint8). + # The `img_tensor` will be visualized in tensorboard. + # """ + # self._vis_data.append((img_name, img_tensor, self._iter)) + + def put_scalar(self, name, value, n=1, smoothing_hint=False): + """ + Add a scalar `value` to the `HistoryBuffer` associated with `name`. + Args: + smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be + smoothed when logged. The hint will be accessible through + :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint + and apply custom smoothing rule. + It defaults to True because most scalars we save need to be smoothed to + provide any useful signal. + """ + name = self._current_prefix + name + history = self._history[name] + history.update(value, n) + self._latest_scalars[name] = (value, self._iter) + + existing_hint = self._smoothing_hints.get(name) + if existing_hint is not None: + assert ( + existing_hint == smoothing_hint + ), "Scalar {} was put with a different smoothing_hint!".format(name) + else: + self._smoothing_hints[name] = smoothing_hint + + # def put_scalars(self, *, smoothing_hint=True, **kwargs): + # """ + # Put multiple scalars from keyword arguments. + # Examples: + # storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True) + # """ + # for k, v in kwargs.items(): + # self.put_scalar(k, v, smoothing_hint=smoothing_hint) + # + # def put_histogram(self, hist_name, hist_tensor, bins=1000): + # """ + # Create a histogram from a tensor. + # Args: + # hist_name (str): The name of the histogram to put into tensorboard. + # hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted + # into a histogram. + # bins (int): Number of histogram bins. + # """ + # ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item() + # + # # Create a histogram with PyTorch + # hist_counts = torch.histc(hist_tensor, bins=bins) + # hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32) + # + # # Parameter for the add_histogram_raw function of SummaryWriter + # hist_params = dict( + # tag=hist_name, + # min=ht_min, + # max=ht_max, + # num=len(hist_tensor), + # sum=float(hist_tensor.sum()), + # sum_squares=float(torch.sum(hist_tensor**2)), + # bucket_limits=hist_edges[1:].tolist(), + # bucket_counts=hist_counts.tolist(), + # global_step=self._iter, + # ) + # self._histograms.append(hist_params) + + def history(self, name): + """ + Returns: + AverageMeter: the history for name + """ + ret = self._history.get(name, None) + if ret is None: + raise KeyError("No history metric available for {}!".format(name)) + return ret + + def histories(self): + """ + Returns: + dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars + """ + return self._history + + def latest(self): + """ + Returns: + dict[str -> (float, int)]: mapping from the name of each scalar to the most + recent value and the iteration number its added. + """ + return self._latest_scalars + + def latest_with_smoothing_hint(self, window_size=20): + """ + Similar to :meth:`latest`, but the returned values + are either the un-smoothed original latest value, + or a median of the given window_size, + depend on whether the smoothing_hint is True. + This provides a default behavior that other writers can use. + """ + result = {} + for k, (v, itr) in self._latest_scalars.items(): + result[k] = ( + self._history[k].median(window_size) if self._smoothing_hints[k] else v, + itr, + ) + return result + + def smoothing_hints(self): + """ + Returns: + dict[name -> bool]: the user-provided hint on whether the scalar + is noisy and needs smoothing. + """ + return self._smoothing_hints + + def step(self): + """ + User should either: (1) Call this function to increment storage.iter when needed. Or + (2) Set `storage.iter` to the correct iteration number before each iteration. + The storage will then be able to associate the new data with an iteration number. + """ + self._iter += 1 + + @property + def iter(self): + """ + Returns: + int: The current iteration number. When used together with a trainer, + this is ensured to be the same as trainer.iter. + """ + return self._iter + + @iter.setter + def iter(self, val): + self._iter = int(val) + + @property + def iteration(self): + # for backward compatibility + return self._iter + + def __enter__(self): + _CURRENT_STORAGE_STACK.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert _CURRENT_STORAGE_STACK[-1] == self + _CURRENT_STORAGE_STACK.pop() + + @contextmanager + def name_scope(self, name): + """ + Yields: + A context within which all the events added to this storage + will be prefixed by the name scope. + """ + old_prefix = self._current_prefix + self._current_prefix = name.rstrip("/") + "/" + yield + self._current_prefix = old_prefix + + def clear_images(self): + """ + Delete all the stored images for visualization. This should be called + after images are written to tensorboard. + """ + self._vis_data = [] + + def clear_histograms(self): + """ + Delete all the stored histograms for visualization. + This should be called after histograms are written to tensorboard. + """ + self._histograms = [] + + def reset_history(self, name): + ret = self._history.get(name, None) + if ret is None: + raise KeyError("No history metric available for {}!".format(name)) + ret.reset() + + def reset_histories(self): + for name in self._history.keys(): + self._history[name].reset() + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.total = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.total = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.total += val * n + self.count += n + self.avg = self.total / self.count + + +class HistoryBuffer: + """ + Track a series of scalar values and provide access to smoothed values over a + window or the global average of the series. + """ + + def __init__(self, max_length: int = 1000000) -> None: + """ + Args: + max_length: maximal number of values that can be stored in the + buffer. When the capacity of the buffer is exhausted, old + values will be removed. + """ + self._max_length: int = max_length + self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs + self._count: int = 0 + self._global_avg: float = 0 + + def update(self, value: float, iteration: Optional[float] = None) -> None: + """ + Add a new scalar value produced at certain iteration. If the length + of the buffer exceeds self._max_length, the oldest element will be + removed from the buffer. + """ + if iteration is None: + iteration = self._count + if len(self._data) == self._max_length: + self._data.pop(0) + self._data.append((value, iteration)) + + self._count += 1 + self._global_avg += (value - self._global_avg) / self._count + + def latest(self) -> float: + """ + Return the latest scalar value added to the buffer. + """ + return self._data[-1][0] + + def median(self, window_size: int) -> float: + """ + Return the median of the latest `window_size` values in the buffer. + """ + return np.median([x[0] for x in self._data[-window_size:]]) + + def avg(self, window_size: int) -> float: + """ + Return the mean of the latest `window_size` values in the buffer. + """ + return np.mean([x[0] for x in self._data[-window_size:]]) + + def global_avg(self) -> float: + """ + Return the mean of all the elements in the buffer. Note that this + includes those getting removed due to limited buffer storage. + """ + return self._global_avg + + def values(self) -> List[Tuple[float, float]]: + """ + Returns: + list[(number, iteration)]: content of the current buffer. + """ + return self._data diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..6e30c5d --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,167 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import logging +import torch +import torch.distributed as dist + +from termcolor import colored + +logger_initialized = {} +root_status = 0 + + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + file_mode (str): The file mode used in opening log file. + Defaults to 'a'. + color (bool): Colorful log output. Defaults to True + + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + logger.propagate = False + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + plain_formatter = logging.Formatter( + "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" + ) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + ) + else: + formatter = plain_formatter + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + + return logger + + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. + Some special loggers are: + - "silent": no message will be printed. + - other str: the logger obtained with `get_root_logger(logger)`. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == "silent": + pass + elif isinstance(logger, str): + _logger = get_logger(logger) + _logger.log(level, msg) + else: + raise TypeError( + "logger should be either a logging.Logger object, str, " + f'"silent" or None, but got {type(logger)}' + ) + + +def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name. + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + file_mode (str): File Mode of logger. (w or a) + + Returns: + logging.Logger: The root logger. + """ + logger = get_logger( + name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode + ) + return logger + + +def _log_api_usage(identifier: str): + """ + Internal function used to log the usage of different detectron2 components + inside facebook's infra. + """ + torch._C._log_api_usage_once("pointcept." + identifier) diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..dbd257e --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,156 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import os +import warnings +from collections import abc +import numpy as np +import torch +from importlib import import_module + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def intersection_and_union(output, target, K, ignore_index=-1): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert output.ndim in [1, 2, 3] + assert output.shape == target.shape + output = output.reshape(output.size).copy() + target = target.reshape(target.size) + output[np.where(target == ignore_index)[0]] = ignore_index + intersection = output[np.where(output == target)[0]] + area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) + area_output, _ = np.histogram(output, bins=np.arange(K + 1)) + area_target, _ = np.histogram(target, bins=np.arange(K + 1)) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + + +def intersection_and_union_gpu(output, target, k, ignore_index=-1): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert output.dim() in [1, 2, 3] + assert output.shape == target.shape + output = output.view(-1) + target = target.view(-1) + output[target == ignore_index] = ignore_index + intersection = output[output == target] + area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1) + area_output = torch.histc(output, bins=k, min=0, max=k - 1) + area_target = torch.histc(target, bins=k, min=0, max=k - 1) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + + +def make_dirs(dir_name): + if not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + +def find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_str(x): + """Whether the input is an string instance. + + Note: This method is deprecated since python 2 is no longer supported. + """ + return isinstance(x, str) + + +def import_modules_from_strings(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules_from_strings( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError(f"custom_imports must be a list but got type {type(imports)}") + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + warnings.warn(f"{imp} failed to import and is ignored.", UserWarning) + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported diff --git a/utils/optimizer.py b/utils/optimizer.py new file mode 100644 index 0000000..2eb70a3 --- /dev/null +++ b/utils/optimizer.py @@ -0,0 +1,52 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import torch +from utils.logger import get_root_logger +from utils.registry import Registry + +OPTIMIZERS = Registry("optimizers") + + +OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD") +OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam") +OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW") + + +def build_optimizer(cfg, model, param_dicts=None): + if param_dicts is None: + cfg.params = model.parameters() + else: + cfg.params = [dict(names=[], params=[], lr=cfg.lr)] + for i in range(len(param_dicts)): + param_group = dict(names=[], params=[]) + if "lr" in param_dicts[i].keys(): + param_group["lr"] = param_dicts[i].lr + if "momentum" in param_dicts[i].keys(): + param_group["momentum"] = param_dicts[i].momentum + if "weight_decay" in param_dicts[i].keys(): + param_group["weight_decay"] = param_dicts[i].weight_decay + cfg.params.append(param_group) + + for n, p in model.named_parameters(): + flag = False + for i in range(len(param_dicts)): + if param_dicts[i].keyword in n: + cfg.params[i + 1]["names"].append(n) + cfg.params[i + 1]["params"].append(p) + flag = True + break + if not flag: + cfg.params[0]["names"].append(n) + cfg.params[0]["params"].append(p) + + logger = get_root_logger() + for i in range(len(cfg.params)): + param_names = cfg.params[i].pop("names") + message = "" + for key in cfg.params[i].keys(): + if key != "params": + message += f" {key}: {cfg.params[i][key]};" + logger.info(f"Params Group {i+1} -{message} Params: {param_names}.") + return OPTIMIZERS.build(cfg=cfg) diff --git a/utils/path.py b/utils/path.py new file mode 100644 index 0000000..5d1da76 --- /dev/null +++ b/utils/path.py @@ -0,0 +1,105 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" +import os +import os.path as osp +from pathlib import Path + +from .misc import is_str + + +def is_filepath(x): + return is_str(x) or isinstance(x, Path) + + +def fopen(filepath, *args, **kwargs): + if is_str(filepath): + return open(filepath, *args, **kwargs) + elif isinstance(filepath, Path): + return filepath.open(*args, **kwargs) + raise ValueError("`filepath` should be a string or a Path") + + +def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == "": + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src, dst, overwrite=True, **kwargs): + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): + """Scan a directory to find the interested files. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + case_sensitive (bool, optional) : If set to False, ignore the case of + suffix. Default: True. + + Returns: + A generator for all the interested files with relative paths. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + if suffix is not None and not case_sensitive: + suffix = ( + suffix.lower() + if isinstance(suffix, str) + else tuple(item.lower() for item in suffix) + ) + + root = dir_path + + def _scandir(dir_path, suffix, recursive, case_sensitive): + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + _rel_path = rel_path if case_sensitive else rel_path.lower() + if suffix is None or _rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix, recursive, case_sensitive) + + return _scandir(dir_path, suffix, recursive, case_sensitive) + + +def find_vcs_root(path, markers=(".git",)): + """Finds the root directory (including itself) of specified markers. + + Args: + path (str): Path of directory or file. + markers (list[str], optional): List of file or directory names. + + Returns: + The directory contained one of the markers or None if not found. + """ + if osp.isfile(path): + path = osp.dirname(path) + + prev, cur = None, osp.abspath(osp.expanduser(path)) + while cur != prev: + if any(osp.exists(osp.join(cur, marker)) for marker in markers): + return cur + prev, cur = cur, osp.split(cur)[0] + return None diff --git a/utils/registry.py b/utils/registry.py new file mode 100644 index 0000000..bd0e55c --- /dev/null +++ b/utils/registry.py @@ -0,0 +1,318 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" +import inspect +import warnings +from functools import partial + +from .misc import is_seq_of + + +def build_from_cfg(cfg, registry, default_args=None): + """Build a module from configs dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f"cfg must be a dict, but got {type(cfg)}") + if "type" not in cfg: + if default_args is None or "type" not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f"but got {cfg}\n{default_args}" + ) + if not isinstance(registry, Registry): + raise TypeError( + "registry must be an mmcv.Registry object, " f"but got {type(registry)}" + ) + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError( + "default_args must be a dict or None, " f"but got {type(default_args)}" + ) + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop("type") + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError(f"{obj_type} is not in the {registry.name} registry") + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f"{obj_cls.__name__}: {e}") + + +class Registry: + """A registry to map strings to classes. + + Registered object could be built from registry. + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + + Please refer to + https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for + advanced usage. + + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = ( + self.__class__.__name__ + f"(name={self._name}, " + f"items={self._module_dict})" + ) + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + + + Returns: + scope (str): The inferred scope name. + """ + # inspect.stack() trace where this function is called, the index-2 + # indicates the frame where `infer_scope()` is called + filename = inspect.getmodule(inspect.stack()[2][0]).__name__ + split_filename = filename.split(".") + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + scope (str, None): The first scope. + key (str): The remaining key. + """ + split_index = key.find(".") + if split_index != -1: + return key[:split_index], key[split_index + 1 :] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key (str): The class name in string format. + + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert ( + registry.scope not in self.children + ), f"scope {registry.scope} exists in {self.name} registry" + self.children[registry.scope] = registry + + def _register_module(self, module_class, module_name=None, force=False): + if not inspect.isclass(module_class): + raise TypeError("module must be a class, " f"but got {type(module_class)}") + + if module_name is None: + module_name = module_class.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f"{name} is already registered " f"in {self.name}") + self._module_dict[name] = module_class + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + "The old API of register_module(module, force=False) " + "is deprecated and will be removed, please use the new API " + "register_module(name=None, force=False, module=None) instead." + ) + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + "name must be either of None, an instance of str or a sequence" + f" of str, but got {type(name)}" + ) + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module(module_class=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module(module_class=cls, module_name=name, force=force) + return cls + + return _register diff --git a/utils/scheduler.py b/utils/scheduler.py new file mode 100644 index 0000000..bb31459 --- /dev/null +++ b/utils/scheduler.py @@ -0,0 +1,144 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import torch.optim.lr_scheduler as lr_scheduler +from .registry import Registry + +SCHEDULERS = Registry("schedulers") + + +@SCHEDULERS.register_module() +class MultiStepLR(lr_scheduler.MultiStepLR): + def __init__( + self, + optimizer, + milestones, + total_steps, + gamma=0.1, + last_epoch=-1, + verbose=False, + ): + super().__init__( + optimizer=optimizer, + milestones=[rate * total_steps for rate in milestones], + gamma=gamma, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class MultiStepWithWarmupLR(lr_scheduler.LambdaLR): + def __init__( + self, + optimizer, + milestones, + total_steps, + gamma=0.1, + warmup_rate=0.05, + warmup_scale=1e-6, + last_epoch=-1, + verbose=False, + ): + milestones = [rate * total_steps for rate in milestones] + + def multi_step_with_warmup(s): + factor = 1.0 + for i in range(len(milestones)): + if s < milestones[i]: + break + factor *= gamma + + if s <= warmup_rate * total_steps: + warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * ( + 1 - warmup_scale + ) + else: + warmup_coefficient = 1.0 + return warmup_coefficient * factor + + super().__init__( + optimizer=optimizer, + lr_lambda=multi_step_with_warmup, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class PolyLR(lr_scheduler.LambdaLR): + def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class ExpLR(lr_scheduler.LambdaLR): + def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + lr_lambda=lambda s: gamma ** (s / total_steps), + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR): + def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + T_max=total_steps, + eta_min=eta_min, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class OneCycleLR(lr_scheduler.OneCycleLR): + r""" + torch.optim.lr_scheduler.OneCycleLR, Block total_steps + """ + + def __init__( + self, + optimizer, + max_lr, + total_steps=None, + pct_start=0.3, + anneal_strategy="cos", + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose=False, + ): + super().__init__( + optimizer=optimizer, + max_lr=max_lr, + total_steps=total_steps, + pct_start=pct_start, + anneal_strategy=anneal_strategy, + cycle_momentum=cycle_momentum, + base_momentum=base_momentum, + max_momentum=max_momentum, + div_factor=div_factor, + final_div_factor=final_div_factor, + three_phase=three_phase, + last_epoch=last_epoch, + verbose=verbose, + ) + + +def build_scheduler(cfg, optimizer): + cfg.optimizer = optimizer + return SCHEDULERS.build(cfg=cfg) diff --git a/utils/timer.py b/utils/timer.py new file mode 100644 index 0000000..7b7e9cb --- /dev/null +++ b/utils/timer.py @@ -0,0 +1,71 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +from time import perf_counter +from typing import Optional + + +class Timer: + """ + A timer which computes the time elapsed since the start/reset of the timer. + """ + + def __init__(self) -> None: + self.reset() + + def reset(self) -> None: + """ + Reset the timer. + """ + self._start = perf_counter() + self._paused: Optional[float] = None + self._total_paused = 0 + self._count_start = 1 + + def pause(self) -> None: + """ + Pause the timer. + """ + if self._paused is not None: + raise ValueError("Trying to pause a Timer that is already paused!") + self._paused = perf_counter() + + def is_paused(self) -> bool: + """ + Returns: + bool: whether the timer is currently paused + """ + return self._paused is not None + + def resume(self) -> None: + """ + Resume the timer. + """ + if self._paused is None: + raise ValueError("Trying to resume a Timer that is not paused!") + # pyre-fixme[58]: `-` is not supported for operand types `float` and + # `Optional[float]`. + self._total_paused += perf_counter() - self._paused + self._paused = None + self._count_start += 1 + + def seconds(self) -> float: + """ + Returns: + (float): the total number of seconds since the start/reset of the + timer, excluding the time when the timer is paused. + """ + if self._paused is not None: + end_time: float = self._paused # type: ignore + else: + end_time = perf_counter() + return end_time - self._start - self._total_paused + + def avg_seconds(self) -> float: + """ + Returns: + (float): the average number of seconds between every start/reset and + pause. + """ + return self.seconds() / self._count_start diff --git a/utils/visualization.py b/utils/visualization.py new file mode 100644 index 0000000..053cb64 --- /dev/null +++ b/utils/visualization.py @@ -0,0 +1,86 @@ +""" +The code is base on https://github.com/Pointcept/Pointcept +""" + +import os +import open3d as o3d +import numpy as np +import torch + + +def to_numpy(x): + if isinstance(x, torch.Tensor): + x = x.clone().detach().cpu().numpy() + assert isinstance(x, np.ndarray) + return x + + +def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + coord = to_numpy(coord) + if color is not None: + color = to_numpy(color) + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(coord) + pcd.colors = o3d.utility.Vector3dVector( + np.ones_like(coord) if color is None else color + ) + o3d.io.write_point_cloud(file_path, pcd) + if logger is not None: + logger.info(f"Save Point Cloud to: {file_path}") + + +def save_bounding_boxes( + bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None +): + bboxes_corners = to_numpy(bboxes_corners) + # point list + points = bboxes_corners.reshape(-1, 3) + # line list + box_lines = np.array( + [ + [0, 1], + [1, 2], + [2, 3], + [3, 0], + [4, 5], + [5, 6], + [6, 7], + [7, 0], + [0, 4], + [1, 5], + [2, 6], + [3, 7], + ] + ) + lines = [] + for i, _ in enumerate(bboxes_corners): + lines.append(box_lines + i * 8) + lines = np.concatenate(lines) + # color list + color = np.array([color for _ in range(len(lines))]) + # generate line set + line_set = o3d.geometry.LineSet() + line_set.points = o3d.utility.Vector3dVector(points) + line_set.lines = o3d.utility.Vector2iVector(lines) + line_set.colors = o3d.utility.Vector3dVector(color) + o3d.io.write_line_set(file_path, line_set) + + if logger is not None: + logger.info(f"Save Boxes to: {file_path}") + + +def save_lines( + points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None +): + points = to_numpy(points) + lines = to_numpy(lines) + colors = np.array([color for _ in range(len(lines))]) + line_set = o3d.geometry.LineSet() + line_set.points = o3d.utility.Vector3dVector(points) + line_set.lines = o3d.utility.Vector2iVector(lines) + line_set.colors = o3d.utility.Vector3dVector(colors) + o3d.io.write_line_set(file_path, line_set) + + if logger is not None: + logger.info(f"Save Lines to: {file_path}") diff --git a/wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl b/wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl new file mode 100644 index 0000000..afd5c61 Binary files /dev/null and b/wheels/gradio_gaussian_render-0.0.2-py3-none-any.whl differ