feat: Initial commit

This commit is contained in:
fdyuandong
2025-04-17 23:14:24 +08:00
commit ca93dd0572
51 changed files with 7904 additions and 0 deletions

25
models/default.py Normal file
View File

@@ -0,0 +1,25 @@
import torch.nn as nn
from models.losses import build_criteria
from .builder import MODELS, build_model
@MODELS.register_module()
class DefaultEstimator(nn.Module):
def __init__(self, backbone=None, criteria=None):
super().__init__()
self.backbone = build_model(backbone)
self.criteria = build_criteria(criteria)
def forward(self, input_dict):
pred_exp = self.backbone(input_dict)
# train
if self.training:
loss = self.criteria(pred_exp, input_dict["gt_exp"])
return dict(loss=loss)
# eval
elif "gt_exp" in input_dict.keys():
loss = self.criteria(pred_exp, input_dict["gt_exp"])
return dict(loss=loss, pred_exp=pred_exp)
# infer
else:
return dict(pred_exp=pred_exp)