mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 09:29:24 +08:00
feat: Initial commit
This commit is contained in:
25
models/default.py
Normal file
25
models/default.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from models.losses import build_criteria
|
||||
from .builder import MODELS, build_model
|
||||
|
||||
@MODELS.register_module()
|
||||
class DefaultEstimator(nn.Module):
|
||||
def __init__(self, backbone=None, criteria=None):
|
||||
super().__init__()
|
||||
self.backbone = build_model(backbone)
|
||||
self.criteria = build_criteria(criteria)
|
||||
|
||||
def forward(self, input_dict):
|
||||
pred_exp = self.backbone(input_dict)
|
||||
# train
|
||||
if self.training:
|
||||
loss = self.criteria(pred_exp, input_dict["gt_exp"])
|
||||
return dict(loss=loss)
|
||||
# eval
|
||||
elif "gt_exp" in input_dict.keys():
|
||||
loss = self.criteria(pred_exp, input_dict["gt_exp"])
|
||||
return dict(loss=loss, pred_exp=pred_exp)
|
||||
# infer
|
||||
else:
|
||||
return dict(pred_exp=pred_exp)
|
||||
Reference in New Issue
Block a user