mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 01:49:23 +08:00
feat: Initial commit
This commit is contained in:
92
configs/lam_audio2exp_config.py
Normal file
92
configs/lam_audio2exp_config.py
Normal file
@@ -0,0 +1,92 @@
|
||||
weight = 'pretrained_models/lam_audio2exp.tar' # path to model weight
|
||||
ex_vol = True # Isolates vocal track from audio file
|
||||
audio_input = './assets/sample_audio/BarackObama.wav'
|
||||
save_json_path = 'bsData.json'
|
||||
|
||||
audio_sr = 16000
|
||||
fps = 30.0
|
||||
|
||||
movement_smooth = True
|
||||
brow_movement = True
|
||||
id_idx = 153
|
||||
|
||||
resume = False # whether to resume training process
|
||||
evaluate = True # evaluate after each epoch training process
|
||||
test_only = False # test process
|
||||
|
||||
seed = None # train process will init a random seed and record
|
||||
save_path = "exp/audio2exp"
|
||||
num_worker = 16 # total worker in all gpu
|
||||
batch_size = 16 # total batch size in all gpu
|
||||
batch_size_val = None # auto adapt to bs 1 for each gpu
|
||||
batch_size_test = None # auto adapt to bs 1 for each gpu
|
||||
epoch = 100 # total epoch, data loop = epoch // eval_epoch
|
||||
eval_epoch = 100 # sche total eval & checkpoint epoch
|
||||
|
||||
sync_bn = False
|
||||
enable_amp = False
|
||||
empty_cache = False
|
||||
find_unused_parameters = False
|
||||
|
||||
mix_prob = 0
|
||||
param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type="DefaultEstimator",
|
||||
backbone=dict(
|
||||
type="Audio2Expression",
|
||||
pretrained_encoder_type='wav2vec',
|
||||
pretrained_encoder_path='facebook/wav2vec2-base-960h',
|
||||
wav2vec2_config_path = 'configs/wav2vec2_config.json',
|
||||
num_identity_classes=5016,
|
||||
identity_feat_dim=64,
|
||||
hidden_dim=512,
|
||||
expression_dim=52,
|
||||
norm_type='ln',
|
||||
use_transformer=True,
|
||||
num_attention_heads=8,
|
||||
num_transformer_layers=6,
|
||||
),
|
||||
criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)],
|
||||
)
|
||||
|
||||
dataset_type = 'audio2exp'
|
||||
data_root = './'
|
||||
data = dict(
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
split="train",
|
||||
data_root=data_root,
|
||||
test_mode=False,
|
||||
),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
split="val",
|
||||
data_root=data_root,
|
||||
test_mode=False,
|
||||
),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
split="val",
|
||||
data_root=data_root,
|
||||
test_mode=True
|
||||
),
|
||||
)
|
||||
|
||||
# hook
|
||||
hooks = [
|
||||
dict(type="CheckpointLoader"),
|
||||
dict(type="IterationTimer", warmup_iter=2),
|
||||
dict(type="InformationWriter"),
|
||||
dict(type="SemSegEvaluator"),
|
||||
dict(type="CheckpointSaver", save_freq=None),
|
||||
dict(type="PreciseEvaluator", test_last=False),
|
||||
]
|
||||
|
||||
# Trainer
|
||||
train = dict(type="DefaultTrainer")
|
||||
|
||||
# Tester
|
||||
infer = dict(type="Audio2ExpressionInfer",
|
||||
verbose=True)
|
||||
92
configs/lam_audio2exp_config_streaming.py
Normal file
92
configs/lam_audio2exp_config_streaming.py
Normal file
@@ -0,0 +1,92 @@
|
||||
weight = 'pretrained_models/lam_audio2exp_streaming.tar' # path to model weight
|
||||
ex_vol = True # extract
|
||||
audio_input = './assets/sample_audio/BarackObama.wav'
|
||||
save_json_path = 'bsData.json'
|
||||
|
||||
audio_sr = 16000
|
||||
fps = 30.0
|
||||
|
||||
movement_smooth = False
|
||||
brow_movement = False
|
||||
id_idx = 0
|
||||
|
||||
resume = False # whether to resume training process
|
||||
evaluate = True # evaluate after each epoch training process
|
||||
test_only = False # test process
|
||||
|
||||
seed = None # train process will init a random seed and record
|
||||
save_path = "exp/audio2exp"
|
||||
num_worker = 16 # total worker in all gpu
|
||||
batch_size = 16 # total batch size in all gpu
|
||||
batch_size_val = None # auto adapt to bs 1 for each gpu
|
||||
batch_size_test = None # auto adapt to bs 1 for each gpu
|
||||
epoch = 100 # total epoch, data loop = epoch // eval_epoch
|
||||
eval_epoch = 100 # sche total eval & checkpoint epoch
|
||||
|
||||
sync_bn = False
|
||||
enable_amp = False
|
||||
empty_cache = False
|
||||
find_unused_parameters = False
|
||||
|
||||
mix_prob = 0
|
||||
param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type="DefaultEstimator",
|
||||
backbone=dict(
|
||||
type="Audio2Expression",
|
||||
pretrained_encoder_type='wav2vec',
|
||||
pretrained_encoder_path='facebook/wav2vec2-base-960h',
|
||||
wav2vec2_config_path = 'configs/wav2vec2_config.json',
|
||||
num_identity_classes=12,
|
||||
identity_feat_dim=64,
|
||||
hidden_dim=512,
|
||||
expression_dim=52,
|
||||
norm_type='ln',
|
||||
use_transformer=False,
|
||||
num_attention_heads=8,
|
||||
num_transformer_layers=6,
|
||||
),
|
||||
criteria=[dict(type="L1Loss", loss_weight=1.0, ignore_index=-1)],
|
||||
)
|
||||
|
||||
dataset_type = 'audio2exp'
|
||||
data_root = './'
|
||||
data = dict(
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
split="train",
|
||||
data_root=data_root,
|
||||
test_mode=False,
|
||||
),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
split="val",
|
||||
data_root=data_root,
|
||||
test_mode=False,
|
||||
),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
split="val",
|
||||
data_root=data_root,
|
||||
test_mode=True
|
||||
),
|
||||
)
|
||||
|
||||
# hook
|
||||
hooks = [
|
||||
dict(type="CheckpointLoader"),
|
||||
dict(type="IterationTimer", warmup_iter=2),
|
||||
dict(type="InformationWriter"),
|
||||
dict(type="SemSegEvaluator"),
|
||||
dict(type="CheckpointSaver", save_freq=None),
|
||||
dict(type="PreciseEvaluator", test_last=False),
|
||||
]
|
||||
|
||||
# Trainer
|
||||
train = dict(type="DefaultTrainer")
|
||||
|
||||
# Tester
|
||||
infer = dict(type="Audio2ExpressionInfer",
|
||||
verbose=True)
|
||||
77
configs/wav2vec2_config.json
Normal file
77
configs/wav2vec2_config.json
Normal file
@@ -0,0 +1,77 @@
|
||||
{
|
||||
"_name_or_path": "facebook/wav2vec2-base-960h",
|
||||
"activation_dropout": 0.1,
|
||||
"apply_spec_augment": true,
|
||||
"architectures": [
|
||||
"Wav2Vec2ForCTC"
|
||||
],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 1,
|
||||
"codevector_dim": 256,
|
||||
"contrastive_logits_temperature": 0.1,
|
||||
"conv_bias": false,
|
||||
"conv_dim": [
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"conv_kernel": [
|
||||
10,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"conv_stride": [
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"ctc_loss_reduction": "sum",
|
||||
"ctc_zero_infinity": false,
|
||||
"diversity_loss_weight": 0.1,
|
||||
"do_stable_layer_norm": false,
|
||||
"eos_token_id": 2,
|
||||
"feat_extract_activation": "gelu",
|
||||
"feat_extract_dropout": 0.0,
|
||||
"feat_extract_norm": "group",
|
||||
"feat_proj_dropout": 0.1,
|
||||
"feat_quantizer_dropout": 0.0,
|
||||
"final_dropout": 0.1,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout": 0.1,
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"layerdrop": 0.1,
|
||||
"mask_feature_length": 10,
|
||||
"mask_feature_prob": 0.0,
|
||||
"mask_time_length": 10,
|
||||
"mask_time_prob": 0.05,
|
||||
"model_type": "wav2vec2",
|
||||
"num_attention_heads": 12,
|
||||
"num_codevector_groups": 2,
|
||||
"num_codevectors_per_group": 320,
|
||||
"num_conv_pos_embedding_groups": 16,
|
||||
"num_conv_pos_embeddings": 128,
|
||||
"num_feat_extract_layers": 7,
|
||||
"num_hidden_layers": 12,
|
||||
"num_negatives": 100,
|
||||
"pad_token_id": 0,
|
||||
"proj_codevector_dim": 256,
|
||||
"transformers_version": "4.7.0.dev0",
|
||||
"vocab_size": 32
|
||||
}
|
||||
Reference in New Issue
Block a user