feat: Initial commit

This commit is contained in:
fdyuandong
2025-04-17 23:14:24 +08:00
commit ca93dd0572
51 changed files with 7904 additions and 0 deletions

0
engines/__init__.py Normal file
View File

147
engines/defaults.py Normal file
View File

@@ -0,0 +1,147 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import sys
import argparse
import multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import utils.comm as comm
from utils.env import get_random_seed, set_seed
from utils.config import Config, DictAction
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
"""
Create a DistributedDataParallel model if there are >1 processes.
Args:
model: a torch.nn.Module
fp16_compression: add fp16 compression hooks to the ddp object.
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
"""
if comm.get_world_size() == 1:
return model
# kwargs['find_unused_parameters'] = True
if "device_ids" not in kwargs:
kwargs["device_ids"] = [comm.get_local_rank()]
if "output_device" not in kwargs:
kwargs["output_device"] = [comm.get_local_rank()]
ddp = DistributedDataParallel(model, **kwargs)
if fp16_compression:
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
return ddp
def worker_init_fn(worker_id, num_workers, rank, seed):
"""Worker init func for dataloader.
The seed of each worker equals to num_worker * rank + worker_id + user_seed
Args:
worker_id (int): Worker id.
num_workers (int): Number of workers.
rank (int): The rank of current process.
seed (int): The random seed to use.
"""
worker_seed = num_workers * rank + worker_id + seed
set_seed(worker_seed)
def default_argument_parser(epilog=None):
parser = argparse.ArgumentParser(
epilog=epilog
or f"""
Examples:
Run on single machine:
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
Change some config options:
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
Run on multiple machines:
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--config-file", default="", metavar="FILE", help="path to config file"
)
parser.add_argument(
"--num-gpus", type=int, default=1, help="number of gpus *per machine*"
)
parser.add_argument(
"--num-machines", type=int, default=1, help="total number of machines"
)
parser.add_argument(
"--machine-rank",
type=int,
default=0,
help="the rank of this machine (unique per machine)",
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument(
"--dist-url",
# default="tcp://127.0.0.1:{}".format(port),
default="auto",
help="initialization URL for pytorch distributed backend. See "
"https://pytorch.org/docs/stable/distributed.html for details.",
)
parser.add_argument(
"--options", nargs="+", action=DictAction, help="custom options"
)
return parser
def default_config_parser(file_path, options):
# config name protocol: dataset_name/model_name-exp_name
if os.path.isfile(file_path):
cfg = Config.fromfile(file_path)
else:
sep = file_path.find("-")
cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :]))
if options is not None:
cfg.merge_from_dict(options)
if cfg.seed is None:
cfg.seed = get_random_seed()
cfg.data.train.loop = cfg.epoch // cfg.eval_epoch
os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True)
if not cfg.resume:
cfg.dump(os.path.join(cfg.save_path, "config.py"))
return cfg
def default_setup(cfg):
# scalar by world size
world_size = comm.get_world_size()
cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count()
cfg.num_worker_per_gpu = cfg.num_worker // world_size
assert cfg.batch_size % world_size == 0
assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0
assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0
cfg.batch_size_per_gpu = cfg.batch_size // world_size
cfg.batch_size_val_per_gpu = (
cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1
)
cfg.batch_size_test_per_gpu = (
cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1
)
# update data loop
assert cfg.epoch % cfg.eval_epoch == 0
# settle random seed
rank = comm.get_rank()
seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank
set_seed(seed)
return cfg

View File

@@ -0,0 +1,5 @@
from .default import HookBase
from .misc import *
from .evaluator import *
from .builder import build_hooks

15
engines/hooks/builder.py Normal file
View File

@@ -0,0 +1,15 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
from utils.registry import Registry
HOOKS = Registry("hooks")
def build_hooks(cfg):
hooks = []
for hook_cfg in cfg:
hooks.append(HOOKS.build(hook_cfg))
return hooks

29
engines/hooks/default.py Normal file
View File

@@ -0,0 +1,29 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
class HookBase:
"""
Base class for hooks that can be registered with :class:`TrainerBase`.
"""
trainer = None # A weak reference to the trainer object.
def before_train(self):
pass
def before_epoch(self):
pass
def before_step(self):
pass
def after_step(self):
pass
def after_epoch(self):
pass
def after_train(self):
pass

577
engines/hooks/evaluator.py Normal file
View File

@@ -0,0 +1,577 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import numpy as np
import torch
import torch.distributed as dist
from uuid import uuid4
import utils.comm as comm
from utils.misc import intersection_and_union_gpu
from .default import HookBase
from .builder import HOOKS
@HOOKS.register_module()
class ClsEvaluator(HookBase):
def after_epoch(self):
if self.trainer.cfg.evaluate:
self.eval()
def eval(self):
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
self.trainer.model.eval()
for i, input_dict in enumerate(self.trainer.val_loader):
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.no_grad():
output_dict = self.trainer.model(input_dict)
output = output_dict["cls_logits"]
loss = output_dict["loss"]
pred = output.max(1)[1]
label = input_dict["category"]
intersection, union, target = intersection_and_union_gpu(
pred,
label,
self.trainer.cfg.data.num_classes,
self.trainer.cfg.data.ignore_index,
)
if comm.get_world_size() > 1:
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
target
)
intersection, union, target = (
intersection.cpu().numpy(),
union.cpu().numpy(),
target.cpu().numpy(),
)
# Here there is no need to sync since sync happened in dist.all_reduce
self.trainer.storage.put_scalar("val_intersection", intersection)
self.trainer.storage.put_scalar("val_union", union)
self.trainer.storage.put_scalar("val_target", target)
self.trainer.storage.put_scalar("val_loss", loss.item())
self.trainer.logger.info(
"Test: [{iter}/{max_iter}] "
"Loss {loss:.4f} ".format(
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
)
)
loss_avg = self.trainer.storage.history("val_loss").avg
intersection = self.trainer.storage.history("val_intersection").total
union = self.trainer.storage.history("val_union").total
target = self.trainer.storage.history("val_target").total
iou_class = intersection / (union + 1e-10)
acc_class = intersection / (target + 1e-10)
m_iou = np.mean(iou_class)
m_acc = np.mean(acc_class)
all_acc = sum(intersection) / (sum(target) + 1e-10)
self.trainer.logger.info(
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
m_iou, m_acc, all_acc
)
)
for i in range(self.trainer.cfg.data.num_classes):
self.trainer.logger.info(
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
idx=i,
name=self.trainer.cfg.data.names[i],
iou=iou_class[i],
accuracy=acc_class[i],
)
)
current_epoch = self.trainer.epoch + 1
if self.trainer.writer is not None:
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
self.trainer.comm_info["current_metric_value"] = all_acc # save for saver
self.trainer.comm_info["current_metric_name"] = "allAcc" # save for saver
def after_train(self):
self.trainer.logger.info(
"Best {}: {:.4f}".format("allAcc", self.trainer.best_metric_value)
)
@HOOKS.register_module()
class SemSegEvaluator(HookBase):
def after_epoch(self):
if self.trainer.cfg.evaluate:
self.eval()
def eval(self):
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
self.trainer.model.eval()
for i, input_dict in enumerate(self.trainer.val_loader):
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.no_grad():
output_dict = self.trainer.model(input_dict)
output = output_dict["seg_logits"]
loss = output_dict["loss"]
pred = output.max(1)[1]
segment = input_dict["segment"]
if "origin_coord" in input_dict.keys():
idx, _ = pointops.knn_query(
1,
input_dict["coord"].float(),
input_dict["offset"].int(),
input_dict["origin_coord"].float(),
input_dict["origin_offset"].int(),
)
pred = pred[idx.flatten().long()]
segment = input_dict["origin_segment"]
intersection, union, target = intersection_and_union_gpu(
pred,
segment,
self.trainer.cfg.data.num_classes,
self.trainer.cfg.data.ignore_index,
)
if comm.get_world_size() > 1:
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
target
)
intersection, union, target = (
intersection.cpu().numpy(),
union.cpu().numpy(),
target.cpu().numpy(),
)
# Here there is no need to sync since sync happened in dist.all_reduce
self.trainer.storage.put_scalar("val_intersection", intersection)
self.trainer.storage.put_scalar("val_union", union)
self.trainer.storage.put_scalar("val_target", target)
self.trainer.storage.put_scalar("val_loss", loss.item())
info = "Test: [{iter}/{max_iter}] ".format(
iter=i + 1, max_iter=len(self.trainer.val_loader)
)
if "origin_coord" in input_dict.keys():
info = "Interp. " + info
self.trainer.logger.info(
info
+ "Loss {loss:.4f} ".format(
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
)
)
loss_avg = self.trainer.storage.history("val_loss").avg
intersection = self.trainer.storage.history("val_intersection").total
union = self.trainer.storage.history("val_union").total
target = self.trainer.storage.history("val_target").total
iou_class = intersection / (union + 1e-10)
acc_class = intersection / (target + 1e-10)
m_iou = np.mean(iou_class)
m_acc = np.mean(acc_class)
all_acc = sum(intersection) / (sum(target) + 1e-10)
self.trainer.logger.info(
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
m_iou, m_acc, all_acc
)
)
for i in range(self.trainer.cfg.data.num_classes):
self.trainer.logger.info(
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
idx=i,
name=self.trainer.cfg.data.names[i],
iou=iou_class[i],
accuracy=acc_class[i],
)
)
current_epoch = self.trainer.epoch + 1
if self.trainer.writer is not None:
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
self.trainer.comm_info["current_metric_value"] = m_iou # save for saver
self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver
def after_train(self):
self.trainer.logger.info(
"Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value)
)
@HOOKS.register_module()
class InsSegEvaluator(HookBase):
def __init__(self, segment_ignore_index=(-1,), instance_ignore_index=-1):
self.segment_ignore_index = segment_ignore_index
self.instance_ignore_index = instance_ignore_index
self.valid_class_names = None # update in before train
self.overlaps = np.append(np.arange(0.5, 0.95, 0.05), 0.25)
self.min_region_sizes = 100
self.distance_threshes = float("inf")
self.distance_confs = -float("inf")
def before_train(self):
self.valid_class_names = [
self.trainer.cfg.data.names[i]
for i in range(self.trainer.cfg.data.num_classes)
if i not in self.segment_ignore_index
]
def after_epoch(self):
if self.trainer.cfg.evaluate:
self.eval()
def associate_instances(self, pred, segment, instance):
segment = segment.cpu().numpy()
instance = instance.cpu().numpy()
void_mask = np.in1d(segment, self.segment_ignore_index)
assert (
pred["pred_classes"].shape[0]
== pred["pred_scores"].shape[0]
== pred["pred_masks"].shape[0]
)
assert pred["pred_masks"].shape[1] == segment.shape[0] == instance.shape[0]
# get gt instances
gt_instances = dict()
for i in range(self.trainer.cfg.data.num_classes):
if i not in self.segment_ignore_index:
gt_instances[self.trainer.cfg.data.names[i]] = []
instance_ids, idx, counts = np.unique(
instance, return_index=True, return_counts=True
)
segment_ids = segment[idx]
for i in range(len(instance_ids)):
if instance_ids[i] == self.instance_ignore_index:
continue
if segment_ids[i] in self.segment_ignore_index:
continue
gt_inst = dict()
gt_inst["instance_id"] = instance_ids[i]
gt_inst["segment_id"] = segment_ids[i]
gt_inst["dist_conf"] = 0.0
gt_inst["med_dist"] = -1.0
gt_inst["vert_count"] = counts[i]
gt_inst["matched_pred"] = []
gt_instances[self.trainer.cfg.data.names[segment_ids[i]]].append(gt_inst)
# get pred instances and associate with gt
pred_instances = dict()
for i in range(self.trainer.cfg.data.num_classes):
if i not in self.segment_ignore_index:
pred_instances[self.trainer.cfg.data.names[i]] = []
instance_id = 0
for i in range(len(pred["pred_classes"])):
if pred["pred_classes"][i] in self.segment_ignore_index:
continue
pred_inst = dict()
pred_inst["uuid"] = uuid4()
pred_inst["instance_id"] = instance_id
pred_inst["segment_id"] = pred["pred_classes"][i]
pred_inst["confidence"] = pred["pred_scores"][i]
pred_inst["mask"] = np.not_equal(pred["pred_masks"][i], 0)
pred_inst["vert_count"] = np.count_nonzero(pred_inst["mask"])
pred_inst["void_intersection"] = np.count_nonzero(
np.logical_and(void_mask, pred_inst["mask"])
)
if pred_inst["vert_count"] < self.min_region_sizes:
continue # skip if empty
segment_name = self.trainer.cfg.data.names[pred_inst["segment_id"]]
matched_gt = []
for gt_idx, gt_inst in enumerate(gt_instances[segment_name]):
intersection = np.count_nonzero(
np.logical_and(
instance == gt_inst["instance_id"], pred_inst["mask"]
)
)
if intersection > 0:
gt_inst_ = gt_inst.copy()
pred_inst_ = pred_inst.copy()
gt_inst_["intersection"] = intersection
pred_inst_["intersection"] = intersection
matched_gt.append(gt_inst_)
gt_inst["matched_pred"].append(pred_inst_)
pred_inst["matched_gt"] = matched_gt
pred_instances[segment_name].append(pred_inst)
instance_id += 1
return gt_instances, pred_instances
def evaluate_matches(self, scenes):
overlaps = self.overlaps
min_region_sizes = [self.min_region_sizes]
dist_threshes = [self.distance_threshes]
dist_confs = [self.distance_confs]
# results: class x overlap
ap_table = np.zeros(
(len(dist_threshes), len(self.valid_class_names), len(overlaps)), float
)
for di, (min_region_size, distance_thresh, distance_conf) in enumerate(
zip(min_region_sizes, dist_threshes, dist_confs)
):
for oi, overlap_th in enumerate(overlaps):
pred_visited = {}
for scene in scenes:
for _ in scene["pred"]:
for label_name in self.valid_class_names:
for p in scene["pred"][label_name]:
if "uuid" in p:
pred_visited[p["uuid"]] = False
for li, label_name in enumerate(self.valid_class_names):
y_true = np.empty(0)
y_score = np.empty(0)
hard_false_negatives = 0
has_gt = False
has_pred = False
for scene in scenes:
pred_instances = scene["pred"][label_name]
gt_instances = scene["gt"][label_name]
# filter groups in ground truth
gt_instances = [
gt
for gt in gt_instances
if gt["vert_count"] >= min_region_size
and gt["med_dist"] <= distance_thresh
and gt["dist_conf"] >= distance_conf
]
if gt_instances:
has_gt = True
if pred_instances:
has_pred = True
cur_true = np.ones(len(gt_instances))
cur_score = np.ones(len(gt_instances)) * (-float("inf"))
cur_match = np.zeros(len(gt_instances), dtype=bool)
# collect matches
for gti, gt in enumerate(gt_instances):
found_match = False
for pred in gt["matched_pred"]:
# greedy assignments
if pred_visited[pred["uuid"]]:
continue
overlap = float(pred["intersection"]) / (
gt["vert_count"]
+ pred["vert_count"]
- pred["intersection"]
)
if overlap > overlap_th:
confidence = pred["confidence"]
# if already have a prediction for this gt,
# the prediction with the lower score is automatically a false positive
if cur_match[gti]:
max_score = max(cur_score[gti], confidence)
min_score = min(cur_score[gti], confidence)
cur_score[gti] = max_score
# append false positive
cur_true = np.append(cur_true, 0)
cur_score = np.append(cur_score, min_score)
cur_match = np.append(cur_match, True)
# otherwise set score
else:
found_match = True
cur_match[gti] = True
cur_score[gti] = confidence
pred_visited[pred["uuid"]] = True
if not found_match:
hard_false_negatives += 1
# remove non-matched ground truth instances
cur_true = cur_true[cur_match]
cur_score = cur_score[cur_match]
# collect non-matched predictions as false positive
for pred in pred_instances:
found_gt = False
for gt in pred["matched_gt"]:
overlap = float(gt["intersection"]) / (
gt["vert_count"]
+ pred["vert_count"]
- gt["intersection"]
)
if overlap > overlap_th:
found_gt = True
break
if not found_gt:
num_ignore = pred["void_intersection"]
for gt in pred["matched_gt"]:
if gt["segment_id"] in self.segment_ignore_index:
num_ignore += gt["intersection"]
# small ground truth instances
if (
gt["vert_count"] < min_region_size
or gt["med_dist"] > distance_thresh
or gt["dist_conf"] < distance_conf
):
num_ignore += gt["intersection"]
proportion_ignore = (
float(num_ignore) / pred["vert_count"]
)
# if not ignored append false positive
if proportion_ignore <= overlap_th:
cur_true = np.append(cur_true, 0)
confidence = pred["confidence"]
cur_score = np.append(cur_score, confidence)
# append to overall results
y_true = np.append(y_true, cur_true)
y_score = np.append(y_score, cur_score)
# compute average precision
if has_gt and has_pred:
# compute precision recall curve first
# sorting and cumsum
score_arg_sort = np.argsort(y_score)
y_score_sorted = y_score[score_arg_sort]
y_true_sorted = y_true[score_arg_sort]
y_true_sorted_cumsum = np.cumsum(y_true_sorted)
# unique thresholds
(thresholds, unique_indices) = np.unique(
y_score_sorted, return_index=True
)
num_prec_recall = len(unique_indices) + 1
# prepare precision recall
num_examples = len(y_score_sorted)
# https://github.com/ScanNet/ScanNet/pull/26
# all predictions are non-matched but also all of them are ignored and not counted as FP
# y_true_sorted_cumsum is empty
# num_true_examples = y_true_sorted_cumsum[-1]
num_true_examples = (
y_true_sorted_cumsum[-1]
if len(y_true_sorted_cumsum) > 0
else 0
)
precision = np.zeros(num_prec_recall)
recall = np.zeros(num_prec_recall)
# deal with the first point
y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
# deal with remaining
for idx_res, idx_scores in enumerate(unique_indices):
cumsum = y_true_sorted_cumsum[idx_scores - 1]
tp = num_true_examples - cumsum
fp = num_examples - idx_scores - tp
fn = cumsum + hard_false_negatives
p = float(tp) / (tp + fp)
r = float(tp) / (tp + fn)
precision[idx_res] = p
recall[idx_res] = r
# first point in curve is artificial
precision[-1] = 1.0
recall[-1] = 0.0
# compute average of precision-recall curve
recall_for_conv = np.copy(recall)
recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
recall_for_conv = np.append(recall_for_conv, 0.0)
stepWidths = np.convolve(
recall_for_conv, [-0.5, 0, 0.5], "valid"
)
# integrate is now simply a dot product
ap_current = np.dot(precision, stepWidths)
elif has_gt:
ap_current = 0.0
else:
ap_current = float("nan")
ap_table[di, li, oi] = ap_current
d_inf = 0
o50 = np.where(np.isclose(self.overlaps, 0.5))
o25 = np.where(np.isclose(self.overlaps, 0.25))
oAllBut25 = np.where(np.logical_not(np.isclose(self.overlaps, 0.25)))
ap_scores = dict()
ap_scores["all_ap"] = np.nanmean(ap_table[d_inf, :, oAllBut25])
ap_scores["all_ap_50%"] = np.nanmean(ap_table[d_inf, :, o50])
ap_scores["all_ap_25%"] = np.nanmean(ap_table[d_inf, :, o25])
ap_scores["classes"] = {}
for li, label_name in enumerate(self.valid_class_names):
ap_scores["classes"][label_name] = {}
ap_scores["classes"][label_name]["ap"] = np.average(
ap_table[d_inf, li, oAllBut25]
)
ap_scores["classes"][label_name]["ap50%"] = np.average(
ap_table[d_inf, li, o50]
)
ap_scores["classes"][label_name]["ap25%"] = np.average(
ap_table[d_inf, li, o25]
)
return ap_scores
def eval(self):
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
self.trainer.model.eval()
scenes = []
for i, input_dict in enumerate(self.trainer.val_loader):
assert (
len(input_dict["offset"]) == 1
) # currently only support bs 1 for each GPU
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.no_grad():
output_dict = self.trainer.model(input_dict)
loss = output_dict["loss"]
segment = input_dict["segment"]
instance = input_dict["instance"]
# map to origin
if "origin_coord" in input_dict.keys():
idx, _ = pointops.knn_query(
1,
input_dict["coord"].float(),
input_dict["offset"].int(),
input_dict["origin_coord"].float(),
input_dict["origin_offset"].int(),
)
idx = idx.cpu().flatten().long()
output_dict["pred_masks"] = output_dict["pred_masks"][:, idx]
segment = input_dict["origin_segment"]
instance = input_dict["origin_instance"]
gt_instances, pred_instance = self.associate_instances(
output_dict, segment, instance
)
scenes.append(dict(gt=gt_instances, pred=pred_instance))
self.trainer.storage.put_scalar("val_loss", loss.item())
self.trainer.logger.info(
"Test: [{iter}/{max_iter}] "
"Loss {loss:.4f} ".format(
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
)
)
loss_avg = self.trainer.storage.history("val_loss").avg
comm.synchronize()
scenes_sync = comm.gather(scenes, dst=0)
scenes = [scene for scenes_ in scenes_sync for scene in scenes_]
ap_scores = self.evaluate_matches(scenes)
all_ap = ap_scores["all_ap"]
all_ap_50 = ap_scores["all_ap_50%"]
all_ap_25 = ap_scores["all_ap_25%"]
self.trainer.logger.info(
"Val result: mAP/AP50/AP25 {:.4f}/{:.4f}/{:.4f}.".format(
all_ap, all_ap_50, all_ap_25
)
)
for i, label_name in enumerate(self.valid_class_names):
ap = ap_scores["classes"][label_name]["ap"]
ap_50 = ap_scores["classes"][label_name]["ap50%"]
ap_25 = ap_scores["classes"][label_name]["ap25%"]
self.trainer.logger.info(
"Class_{idx}-{name} Result: AP/AP50/AP25 {AP:.4f}/{AP50:.4f}/{AP25:.4f}".format(
idx=i, name=label_name, AP=ap, AP50=ap_50, AP25=ap_25
)
)
current_epoch = self.trainer.epoch + 1
if self.trainer.writer is not None:
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
self.trainer.writer.add_scalar("val/mAP", all_ap, current_epoch)
self.trainer.writer.add_scalar("val/AP50", all_ap_50, current_epoch)
self.trainer.writer.add_scalar("val/AP25", all_ap_25, current_epoch)
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
self.trainer.comm_info["current_metric_value"] = all_ap_50 # save for saver
self.trainer.comm_info["current_metric_name"] = "AP50" # save for saver

460
engines/hooks/misc.py Normal file
View File

@@ -0,0 +1,460 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import sys
import glob
import os
import shutil
import time
import torch
import torch.utils.data
from collections import OrderedDict
if sys.version_info >= (3, 10):
from collections.abc import Sequence
else:
from collections import Sequence
from utils.timer import Timer
from utils.comm import is_main_process, synchronize, get_world_size
from utils.cache import shared_dict
import utils.comm as comm
from engines.test import TESTERS
from .default import HookBase
from .builder import HOOKS
@HOOKS.register_module()
class IterationTimer(HookBase):
def __init__(self, warmup_iter=1):
self._warmup_iter = warmup_iter
self._start_time = time.perf_counter()
self._iter_timer = Timer()
self._remain_iter = 0
def before_train(self):
self._start_time = time.perf_counter()
self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader)
def before_epoch(self):
self._iter_timer.reset()
def before_step(self):
data_time = self._iter_timer.seconds()
self.trainer.storage.put_scalar("data_time", data_time)
def after_step(self):
batch_time = self._iter_timer.seconds()
self._iter_timer.reset()
self.trainer.storage.put_scalar("batch_time", batch_time)
self._remain_iter -= 1
remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg
t_m, t_s = divmod(remain_time, 60)
t_h, t_m = divmod(t_m, 60)
remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s))
if "iter_info" in self.trainer.comm_info.keys():
info = (
"Data {data_time_val:.3f} ({data_time_avg:.3f}) "
"Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) "
"Remain {remain_time} ".format(
data_time_val=self.trainer.storage.history("data_time").val,
data_time_avg=self.trainer.storage.history("data_time").avg,
batch_time_val=self.trainer.storage.history("batch_time").val,
batch_time_avg=self.trainer.storage.history("batch_time").avg,
remain_time=remain_time,
)
)
self.trainer.comm_info["iter_info"] += info
if self.trainer.comm_info["iter"] <= self._warmup_iter:
self.trainer.storage.history("data_time").reset()
self.trainer.storage.history("batch_time").reset()
@HOOKS.register_module()
class InformationWriter(HookBase):
def __init__(self):
self.curr_iter = 0
self.model_output_keys = []
def before_train(self):
self.trainer.comm_info["iter_info"] = ""
self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader)
def before_step(self):
self.curr_iter += 1
# MSC pretrain do not have offset information. Comment the code for support MSC
# info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \
# "Scan {batch_size} ({points_num}) ".format(
# epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch,
# iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader),
# batch_size=len(self.trainer.comm_info["input_dict"]["offset"]),
# points_num=self.trainer.comm_info["input_dict"]["offset"][-1]
# )
info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format(
epoch=self.trainer.epoch + 1,
max_epoch=self.trainer.max_epoch,
iter=self.trainer.comm_info["iter"] + 1,
max_iter=len(self.trainer.train_loader),
)
self.trainer.comm_info["iter_info"] += info
def after_step(self):
if "model_output_dict" in self.trainer.comm_info.keys():
model_output_dict = self.trainer.comm_info["model_output_dict"]
self.model_output_keys = model_output_dict.keys()
for key in self.model_output_keys:
self.trainer.storage.put_scalar(key, model_output_dict[key].item())
for key in self.model_output_keys:
self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format(
key=key, value=self.trainer.storage.history(key).val
)
lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"]
self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr)
self.trainer.logger.info(self.trainer.comm_info["iter_info"])
self.trainer.comm_info["iter_info"] = "" # reset iter info
if self.trainer.writer is not None:
self.trainer.writer.add_scalar("lr", lr, self.curr_iter)
for key in self.model_output_keys:
self.trainer.writer.add_scalar(
"train_batch/" + key,
self.trainer.storage.history(key).val,
self.curr_iter,
)
def after_epoch(self):
epoch_info = "Train result: "
for key in self.model_output_keys:
epoch_info += "{key}: {value:.4f} ".format(
key=key, value=self.trainer.storage.history(key).avg
)
self.trainer.logger.info(epoch_info)
if self.trainer.writer is not None:
for key in self.model_output_keys:
self.trainer.writer.add_scalar(
"train/" + key,
self.trainer.storage.history(key).avg,
self.trainer.epoch + 1,
)
@HOOKS.register_module()
class CheckpointSaver(HookBase):
def __init__(self, save_freq=None):
self.save_freq = save_freq # None or int, None indicate only save model last
def after_epoch(self):
if is_main_process():
is_best = False
if self.trainer.cfg.evaluate:
current_metric_value = self.trainer.comm_info["current_metric_value"]
current_metric_name = self.trainer.comm_info["current_metric_name"]
if current_metric_value > self.trainer.best_metric_value:
self.trainer.best_metric_value = current_metric_value
is_best = True
self.trainer.logger.info(
"Best validation {} updated to: {:.4f}".format(
current_metric_name, current_metric_value
)
)
self.trainer.logger.info(
"Currently Best {}: {:.4f}".format(
current_metric_name, self.trainer.best_metric_value
)
)
filename = os.path.join(
self.trainer.cfg.save_path, "model", "model_last.pth"
)
self.trainer.logger.info("Saving checkpoint to: " + filename)
torch.save(
{
"epoch": self.trainer.epoch + 1,
"state_dict": self.trainer.model.state_dict(),
"optimizer": self.trainer.optimizer.state_dict(),
"scheduler": self.trainer.scheduler.state_dict(),
"scaler": self.trainer.scaler.state_dict()
if self.trainer.cfg.enable_amp
else None,
"best_metric_value": self.trainer.best_metric_value,
},
filename + ".tmp",
)
os.replace(filename + ".tmp", filename)
if is_best:
shutil.copyfile(
filename,
os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"),
)
if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0:
shutil.copyfile(
filename,
os.path.join(
self.trainer.cfg.save_path,
"model",
f"epoch_{self.trainer.epoch + 1}.pth",
),
)
@HOOKS.register_module()
class CheckpointLoader(HookBase):
def __init__(self, keywords="", replacement=None, strict=False):
self.keywords = keywords
self.replacement = replacement if replacement is not None else keywords
self.strict = strict
def before_train(self):
self.trainer.logger.info("=> Loading checkpoint & weight ...")
if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight):
self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}")
checkpoint = torch.load(
self.trainer.cfg.weight,
map_location=lambda storage, loc: storage.cuda(),
)
self.trainer.logger.info(
f"Loading layer weights with keyword: {self.keywords}, "
f"replace keyword with: {self.replacement}"
)
weight = OrderedDict()
for key, value in checkpoint["state_dict"].items():
if not key.startswith("module."):
if comm.get_world_size() > 1:
key = "module." + key # xxx.xxx -> module.xxx.xxx
# Now all keys contain "module." no matter DDP or not.
if self.keywords in key:
key = key.replace(self.keywords, self.replacement)
if comm.get_world_size() == 1:
key = key[7:] # module.xxx.xxx -> xxx.xxx
weight[key] = value
load_state_info = self.trainer.model.load_state_dict(
weight, strict=self.strict
)
self.trainer.logger.info(f"Missing keys: {load_state_info[0]}")
if self.trainer.cfg.resume:
self.trainer.logger.info(
f"Resuming train at eval epoch: {checkpoint['epoch']}"
)
self.trainer.start_epoch = checkpoint["epoch"]
self.trainer.best_metric_value = checkpoint["best_metric_value"]
self.trainer.optimizer.load_state_dict(checkpoint["optimizer"])
self.trainer.scheduler.load_state_dict(checkpoint["scheduler"])
if self.trainer.cfg.enable_amp:
self.trainer.scaler.load_state_dict(checkpoint["scaler"])
else:
self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}")
@HOOKS.register_module()
class PreciseEvaluator(HookBase):
def __init__(self, test_last=False):
self.test_last = test_last
def after_train(self):
self.trainer.logger.info(
">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>"
)
torch.cuda.empty_cache()
cfg = self.trainer.cfg
tester = TESTERS.build(
dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model)
)
if self.test_last:
self.trainer.logger.info("=> Testing on model_last ...")
else:
self.trainer.logger.info("=> Testing on model_best ...")
best_path = os.path.join(
self.trainer.cfg.save_path, "model", "model_best.pth"
)
checkpoint = torch.load(best_path)
state_dict = checkpoint["state_dict"]
tester.model.load_state_dict(state_dict, strict=True)
tester.test()
@HOOKS.register_module()
class DataCacheOperator(HookBase):
def __init__(self, data_root, split):
self.data_root = data_root
self.split = split
self.data_list = self.get_data_list()
def get_data_list(self):
if isinstance(self.split, str):
data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth"))
elif isinstance(self.split, Sequence):
data_list = []
for split in self.split:
data_list += glob.glob(os.path.join(self.data_root, split, "*.pth"))
else:
raise NotImplementedError
return data_list
def get_cache_name(self, data_path):
data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0]
return "pointcept" + data_name.replace(os.path.sep, "-")
def before_train(self):
self.trainer.logger.info(
f"=> Caching dataset: {self.data_root}, split: {self.split} ..."
)
if is_main_process():
for data_path in self.data_list:
cache_name = self.get_cache_name(data_path)
data = torch.load(data_path)
shared_dict(cache_name, data)
synchronize()
@HOOKS.register_module()
class RuntimeProfiler(HookBase):
def __init__(
self,
forward=True,
backward=True,
interrupt=False,
warm_up=2,
sort_by="cuda_time_total",
row_limit=30,
):
self.forward = forward
self.backward = backward
self.interrupt = interrupt
self.warm_up = warm_up
self.sort_by = sort_by
self.row_limit = row_limit
def before_train(self):
self.trainer.logger.info("Profiling runtime ...")
from torch.profiler import profile, record_function, ProfilerActivity
for i, input_dict in enumerate(self.trainer.train_loader):
if i == self.warm_up + 1:
break
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
if self.forward:
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) as forward_prof:
with record_function("model_inference"):
output_dict = self.trainer.model(input_dict)
else:
output_dict = self.trainer.model(input_dict)
loss = output_dict["loss"]
if self.backward:
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) as backward_prof:
with record_function("model_inference"):
loss.backward()
self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]")
if self.forward:
self.trainer.logger.info(
"Forward profile: \n"
+ str(
forward_prof.key_averages().table(
sort_by=self.sort_by, row_limit=self.row_limit
)
)
)
forward_prof.export_chrome_trace(
os.path.join(self.trainer.cfg.save_path, "forward_trace.json")
)
if self.backward:
self.trainer.logger.info(
"Backward profile: \n"
+ str(
backward_prof.key_averages().table(
sort_by=self.sort_by, row_limit=self.row_limit
)
)
)
backward_prof.export_chrome_trace(
os.path.join(self.trainer.cfg.save_path, "backward_trace.json")
)
if self.interrupt:
sys.exit(0)
@HOOKS.register_module()
class RuntimeProfilerV2(HookBase):
def __init__(
self,
interrupt=False,
wait=1,
warmup=1,
active=10,
repeat=1,
sort_by="cuda_time_total",
row_limit=30,
):
self.interrupt = interrupt
self.wait = wait
self.warmup = warmup
self.active = active
self.repeat = repeat
self.sort_by = sort_by
self.row_limit = row_limit
def before_train(self):
self.trainer.logger.info("Profiling runtime ...")
from torch.profiler import (
profile,
record_function,
ProfilerActivity,
schedule,
tensorboard_trace_handler,
)
prof = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(
wait=self.wait,
warmup=self.warmup,
active=self.active,
repeat=self.repeat,
),
on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()
for i, input_dict in enumerate(self.trainer.train_loader):
if i >= (self.wait + self.warmup + self.active) * self.repeat:
break
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with record_function("model_forward"):
output_dict = self.trainer.model(input_dict)
loss = output_dict["loss"]
with record_function("model_backward"):
loss.backward()
prof.step()
self.trainer.logger.info(
f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]"
)
self.trainer.logger.info(
"Profile: \n"
+ str(
prof.key_averages().table(
sort_by=self.sort_by, row_limit=self.row_limit
)
)
)
prof.stop()
if self.interrupt:
sys.exit(0)

285
engines/infer.py Normal file
View File

@@ -0,0 +1,285 @@
"""
Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import math
import time
import librosa
import numpy as np
from collections import OrderedDict
import torch
import torch.utils.data
import torch.nn.functional as F
from .defaults import create_ddp_model
import utils.comm as comm
from models import build_model
from utils.logger import get_root_logger
from utils.registry import Registry
from utils.misc import (
AverageMeter,
)
from models.utils import smooth_mouth_movements, apply_frame_blending, apply_savitzky_golay_smoothing, apply_random_brow_movement, \
symmetrize_blendshapes, apply_random_eye_blinks, apply_random_eye_blinks_context, export_blendshape_animation, \
RETURN_CODE, DEFAULT_CONTEXT, ARKitBlendShape
INFER = Registry("infer")
class InferBase:
def __init__(self, cfg, model=None, verbose=False) -> None:
torch.multiprocessing.set_sharing_strategy("file_system")
self.logger = get_root_logger(
log_file=os.path.join(cfg.save_path, "infer.log"),
file_mode="a" if cfg.resume else "w",
)
self.logger.info("=> Loading config ...")
self.cfg = cfg
self.verbose = verbose
if self.verbose:
self.logger.info(f"Save path: {cfg.save_path}")
self.logger.info(f"Config:\n{cfg.pretty_text}")
if model is None:
self.logger.info("=> Building model ...")
self.model = self.build_model()
else:
self.model = model
def build_model(self):
model = build_model(self.cfg.model)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.logger.info(f"Num params: {n_parameters}")
model = create_ddp_model(
model.cuda(),
broadcast_buffers=False,
find_unused_parameters=self.cfg.find_unused_parameters,
)
if os.path.isfile(self.cfg.weight):
self.logger.info(f"Loading weight at: {self.cfg.weight}")
checkpoint = torch.load(self.cfg.weight)
weight = OrderedDict()
for key, value in checkpoint["state_dict"].items():
if key.startswith("module."):
if comm.get_world_size() == 1:
key = key[7:] # module.xxx.xxx -> xxx.xxx
else:
if comm.get_world_size() > 1:
key = "module." + key # xxx.xxx -> module.xxx.xxx
weight[key] = value
model.load_state_dict(weight, strict=True)
self.logger.info(
"=> Loaded weight '{}'".format(
self.cfg.weight
)
)
else:
raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight))
return model
def infer(self):
raise NotImplementedError
@INFER.register_module()
class Audio2ExpressionInfer(InferBase):
def infer(self):
logger = get_root_logger()
logger.info(">>>>>>>>>>>>>>>> Start Inference >>>>>>>>>>>>>>>>")
batch_time = AverageMeter()
self.model.eval()
# process audio-input
assert os.path.exists(self.cfg.audio_input)
if(self.cfg.ex_vol):
logger.info("Extract vocals ...")
vocal_path = self.extract_vocal_track(self.cfg.audio_input)
logger.info("=> Extract vocals at: {}".format(vocal_path if os.path.exists(vocal_path) else '... Failed'))
if(os.path.exists(vocal_path)):
self.cfg.audio_input = vocal_path
with torch.no_grad():
input_dict = {}
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[None,...]
speech_array, ssr = librosa.load(self.cfg.audio_input, sr=16000)
input_dict['input_audio_array'] = torch.FloatTensor(speech_array).cuda(non_blocking=True)[None,...]
end = time.time()
output_dict = self.model(input_dict)
batch_time.update(time.time() - end)
logger.info(
"Infer: [{}] "
"Running Time: {batch_time.avg:.3f} ".format(
self.cfg.audio_input,
batch_time=batch_time,
)
)
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()
frame_length = math.ceil(speech_array.shape[0] / ssr * 30)
volume = librosa.feature.rms(y=speech_array, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
if (volume.shape[0] > frame_length):
volume = volume[:frame_length]
if(self.cfg.movement_smooth):
out_exp = smooth_mouth_movements(out_exp, 0, volume)
if (self.cfg.brow_movement):
out_exp = apply_random_brow_movement(out_exp, volume)
pred_exp = self.blendshape_postprocess(out_exp)
if(self.cfg.save_json_path is not None):
export_blendshape_animation(pred_exp,
self.cfg.save_json_path,
ARKitBlendShape,
fps=self.cfg.fps)
logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
def infer_streaming_audio(self,
audio: np.ndarray,
ssr: float,
context: dict):
if (context is None):
context = DEFAULT_CONTEXT.copy()
max_frame_length = 64
frame_length = math.ceil(audio.shape[0] / ssr * 30)
output_context = DEFAULT_CONTEXT.copy()
volume = librosa.feature.rms(y=audio, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
if (volume.shape[0] > frame_length):
volume = volume[:frame_length]
# resample audio
if (ssr != self.cfg.audio_sr):
in_audio = librosa.resample(audio.astype(np.float32), orig_sr=ssr, target_sr=self.cfg.audio_sr)
else:
in_audio = audio.copy()
start_frame = int(max_frame_length - in_audio.shape[0] / self.cfg.audio_sr * 30)
if (context['is_initial_input'] or (context['previous_audio'] is None)):
blank_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
blank_audio = np.zeros(blank_audio_length, dtype=np.float32)
# pre-append
input_audio = np.concatenate([blank_audio, in_audio])
output_context['previous_audio'] = input_audio
else:
clip_pre_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
clip_pre_audio = context['previous_audio'][-clip_pre_audio_length:]
input_audio = np.concatenate([clip_pre_audio, in_audio])
output_context['previous_audio'] = input_audio
with torch.no_grad():
try:
input_dict = {}
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[
None, ...]
input_dict['input_audio_array'] = torch.FloatTensor(input_audio).cuda(non_blocking=True)[None, ...]
output_dict = self.model(input_dict)
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()[start_frame:, :]
except:
self.logger.error('Error: faided to predict expression.')
output_dict['pred_exp'] = torch.zeros((max_frame_length, 52)).float()
return
# post-process
if (context['previous_expression'] is None):
out_exp = self.apply_expression_postprocessing(out_exp, audio_volume=volume)
else:
previous_length = context['previous_expression'].shape[0]
out_exp = self.apply_expression_postprocessing(expression_params = np.concatenate([context['previous_expression'], out_exp], axis=0),
audio_volume=np.concatenate([context['previous_volume'], volume], axis=0),
processed_frames=previous_length)[previous_length:, :]
if (context['previous_expression'] is not None):
output_context['previous_expression'] = np.concatenate([context['previous_expression'], out_exp], axis=0)[
-max_frame_length:, :]
output_context['previous_volume'] = np.concatenate([context['previous_volume'], volume], axis=0)[-max_frame_length:]
else:
output_context['previous_expression'] = out_exp.copy()
output_context['previous_volume'] = volume.copy()
output_context['first_input_flag'] = False
return {"code": RETURN_CODE['SUCCESS'],
"expression": out_exp,
"headpose": None}, output_context
def apply_expression_postprocessing(
self,
expression_params: np.ndarray,
processed_frames: int = 0,
audio_volume: np.ndarray = None
) -> np.ndarray:
"""Applies full post-processing pipeline to facial expression parameters.
Args:
expression_params: Raw output from animation model [num_frames, num_parameters]
processed_frames: Number of frames already processed in previous batches
audio_volume: Optional volume array for audio-visual synchronization
Returns:
Processed expression parameters ready for animation synthesis
"""
# Pipeline execution order matters - maintain sequence
expression_params = smooth_mouth_movements(expression_params, processed_frames, audio_volume)
expression_params = apply_frame_blending(expression_params, processed_frames)
expression_params, _ = apply_savitzky_golay_smoothing(expression_params, window_length=5)
expression_params = symmetrize_blendshapes(expression_params)
expression_params = apply_random_eye_blinks_context(expression_params, processed_frames=processed_frames)
return expression_params
def extract_vocal_track(
self,
input_audio_path: str
) -> str:
"""Isolates vocal track from audio file using source separation.
Args:
input_audio_path: Path to input audio file containing vocals+accompaniment
Returns:
Path to isolated vocal track in WAV format
"""
separation_command = f'spleeter separate -p spleeter:2stems -o {self.cfg.save_path} {input_audio_path}'
os.system(separation_command)
base_name = os.path.splitext(os.path.basename(input_audio_path))[0]
return os.path.join(self.cfg.save_path, base_name, 'vocals.wav')
def blendshape_postprocess(self,
bs_array: np.ndarray
)->np.array:
bs_array, _ = apply_savitzky_golay_smoothing(bs_array, window_length=5)
bs_array = symmetrize_blendshapes(bs_array)
bs_array = apply_random_eye_blinks(bs_array)
return bs_array

135
engines/launch.py Normal file
View File

@@ -0,0 +1,135 @@
"""
Launcher
modified from detectron2(https://github.com/facebookresearch/detectron2)
"""
import os
import logging
from datetime import timedelta
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from utils import comm
__all__ = ["DEFAULT_TIMEOUT", "launch"]
DEFAULT_TIMEOUT = timedelta(minutes=30)
def _find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def launch(
main_func,
num_gpus_per_machine,
num_machines=1,
machine_rank=0,
dist_url=None,
cfg=(),
timeout=DEFAULT_TIMEOUT,
):
"""
Launch multi-gpu or distributed training.
This function must be called on all machines involved in the training.
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
Args:
main_func: a function that will be called by `main_func(*args)`
num_gpus_per_machine (int): number of GPUs per machine
num_machines (int): the total number of machines
machine_rank (int): the rank of this machine
dist_url (str): url to connect to for distributed jobs, including protocol
e.g. "tcp://127.0.0.1:8686".
Can be set to "auto" to automatically select a free port on localhost
timeout (timedelta): timeout of the distributed workers
args (tuple): arguments passed to main_func
"""
world_size = num_machines * num_gpus_per_machine
if world_size > 1:
if dist_url == "auto":
assert (
num_machines == 1
), "dist_url=auto not supported in multi-machine jobs."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if num_machines > 1 and dist_url.startswith("file://"):
logger = logging.getLogger(__name__)
logger.warning(
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
)
mp.spawn(
_distributed_worker,
nprocs=num_gpus_per_machine,
args=(
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
cfg,
timeout,
),
daemon=False,
)
else:
main_func(*cfg)
def _distributed_worker(
local_rank,
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
cfg,
timeout=DEFAULT_TIMEOUT,
):
assert (
torch.cuda.is_available()
), "cuda is not available. Please check your installation."
global_rank = machine_rank * num_gpus_per_machine + local_rank
try:
dist.init_process_group(
backend="NCCL",
init_method=dist_url,
world_size=world_size,
rank=global_rank,
timeout=timeout,
)
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Process group URL: {}".format(dist_url))
raise e
# Setup the local process group (which contains ranks within the same machine)
assert comm._LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
)
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comm._LOCAL_PROCESS_GROUP = pg
assert num_gpus_per_machine <= torch.cuda.device_count()
torch.cuda.set_device(local_rank)
# synchronize is needed here to prevent a possible timeout after calling init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
main_func(*cfg)

299
engines/train.py Normal file
View File

@@ -0,0 +1,299 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import sys
import weakref
import torch
import torch.nn as nn
import torch.utils.data
from functools import partial
if sys.version_info >= (3, 10):
from collections.abc import Iterator
else:
from collections import Iterator
from tensorboardX import SummaryWriter
from .defaults import create_ddp_model, worker_init_fn
from .hooks import HookBase, build_hooks
import utils.comm as comm
from datasets import build_dataset, point_collate_fn, collate_fn
from models import build_model
from utils.logger import get_root_logger
from utils.optimizer import build_optimizer
from utils.scheduler import build_scheduler
from utils.events import EventStorage
from utils.registry import Registry
TRAINERS = Registry("trainers")
class TrainerBase:
def __init__(self) -> None:
self.hooks = []
self.epoch = 0
self.start_epoch = 0
self.max_epoch = 0
self.max_iter = 0
self.comm_info = dict()
self.data_iterator: Iterator = enumerate([])
self.storage: EventStorage
self.writer: SummaryWriter
def register_hooks(self, hooks) -> None:
hooks = build_hooks(hooks)
for h in hooks:
assert isinstance(h, HookBase)
# To avoid circular reference, hooks and trainer cannot own each other.
# This normally does not matter, but will cause memory leak if the
# involved objects contain __del__:
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
h.trainer = weakref.proxy(self)
self.hooks.extend(hooks)
def train(self):
with EventStorage() as self.storage:
# => before train
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
# => before epoch
self.before_epoch()
# => run_epoch
for (
self.comm_info["iter"],
self.comm_info["input_dict"],
) in self.data_iterator:
# => before_step
self.before_step()
# => run_step
self.run_step()
# => after_step
self.after_step()
# => after epoch
self.after_epoch()
# => after train
self.after_train()
def before_train(self):
for h in self.hooks:
h.before_train()
def before_epoch(self):
for h in self.hooks:
h.before_epoch()
def before_step(self):
for h in self.hooks:
h.before_step()
def run_step(self):
raise NotImplementedError
def after_step(self):
for h in self.hooks:
h.after_step()
def after_epoch(self):
for h in self.hooks:
h.after_epoch()
self.storage.reset_histories()
def after_train(self):
# Sync GPU before running train hooks
comm.synchronize()
for h in self.hooks:
h.after_train()
if comm.is_main_process():
self.writer.close()
@TRAINERS.register_module("DefaultTrainer")
class Trainer(TrainerBase):
def __init__(self, cfg):
super(Trainer, self).__init__()
self.epoch = 0
self.start_epoch = 0
self.max_epoch = cfg.eval_epoch
self.best_metric_value = -torch.inf
self.logger = get_root_logger(
log_file=os.path.join(cfg.save_path, "train.log"),
file_mode="a" if cfg.resume else "w",
)
self.logger.info("=> Loading config ...")
self.cfg = cfg
self.logger.info(f"Save path: {cfg.save_path}")
self.logger.info(f"Config:\n{cfg.pretty_text}")
self.logger.info("=> Building model ...")
self.model = self.build_model()
self.logger.info("=> Building writer ...")
self.writer = self.build_writer()
self.logger.info("=> Building train dataset & dataloader ...")
self.train_loader = self.build_train_loader()
self.logger.info("=> Building val dataset & dataloader ...")
self.val_loader = self.build_val_loader()
self.logger.info("=> Building optimize, scheduler, scaler(amp) ...")
self.optimizer = self.build_optimizer()
self.scheduler = self.build_scheduler()
self.scaler = self.build_scaler()
self.logger.info("=> Building hooks ...")
self.register_hooks(self.cfg.hooks)
def train(self):
with EventStorage() as self.storage:
# => before train
self.before_train()
self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
for self.epoch in range(self.start_epoch, self.max_epoch):
# => before epoch
# TODO: optimize to iteration based
if comm.get_world_size() > 1:
self.train_loader.sampler.set_epoch(self.epoch)
self.model.train()
self.data_iterator = enumerate(self.train_loader)
self.before_epoch()
# => run_epoch
for (
self.comm_info["iter"],
self.comm_info["input_dict"],
) in self.data_iterator:
# => before_step
self.before_step()
# => run_step
self.run_step()
# => after_step
self.after_step()
# => after epoch
self.after_epoch()
# => after train
self.after_train()
def run_step(self):
input_dict = self.comm_info["input_dict"]
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp):
output_dict = self.model(input_dict)
loss = output_dict["loss"]
self.optimizer.zero_grad()
if self.cfg.enable_amp:
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
# When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
# Fix torch warning scheduler step before optimizer step.
scaler = self.scaler.get_scale()
self.scaler.update()
if scaler <= self.scaler.get_scale():
self.scheduler.step()
else:
loss.backward()
self.optimizer.step()
self.scheduler.step()
if self.cfg.empty_cache:
torch.cuda.empty_cache()
self.comm_info["model_output_dict"] = output_dict
def build_model(self):
model = build_model(self.cfg.model)
if self.cfg.sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logger.info(f"Model: \n{self.model}")
self.logger.info(f"Num params: {n_parameters}")
model = create_ddp_model(
model.cuda(),
broadcast_buffers=False,
find_unused_parameters=self.cfg.find_unused_parameters,
)
return model
def build_writer(self):
writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None
self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}")
return writer
def build_train_loader(self):
train_data = build_dataset(self.cfg.data.train)
if comm.get_world_size() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
else:
train_sampler = None
init_fn = (
partial(
worker_init_fn,
num_workers=self.cfg.num_worker_per_gpu,
rank=comm.get_rank(),
seed=self.cfg.seed,
)
if self.cfg.seed is not None
else None
)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=self.cfg.batch_size_per_gpu,
shuffle=(train_sampler is None),
num_workers=0,
sampler=train_sampler,
collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob),
pin_memory=True,
worker_init_fn=init_fn,
drop_last=True,
# persistent_workers=True,
)
return train_loader
def build_val_loader(self):
val_loader = None
if self.cfg.evaluate:
val_data = build_dataset(self.cfg.data.val)
if comm.get_world_size() > 1:
val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
else:
val_sampler = None
val_loader = torch.utils.data.DataLoader(
val_data,
batch_size=self.cfg.batch_size_val_per_gpu,
shuffle=False,
num_workers=self.cfg.num_worker_per_gpu,
pin_memory=True,
sampler=val_sampler,
collate_fn=collate_fn,
)
return val_loader
def build_optimizer(self):
return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts)
def build_scheduler(self):
assert hasattr(self, "optimizer")
assert hasattr(self, "train_loader")
self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch
return build_scheduler(self.cfg.scheduler, self.optimizer)
def build_scaler(self):
scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None
return scaler
@TRAINERS.register_module("MultiDatasetTrainer")
class MultiDatasetTrainer(Trainer):
def build_train_loader(self):
from datasets import MultiDatasetDataloader
train_data = build_dataset(self.cfg.data.train)
train_loader = MultiDatasetDataloader(
train_data,
self.cfg.batch_size_per_gpu,
self.cfg.num_worker_per_gpu,
self.cfg.mix_prob,
self.cfg.seed,
)
self.comm_info["iter_per_epoch"] = len(train_loader)
return train_loader