mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 01:49:23 +08:00
feat: Initial commit
This commit is contained in:
0
engines/__init__.py
Normal file
0
engines/__init__.py
Normal file
147
engines/defaults.py
Normal file
147
engines/defaults.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
|
||||
import utils.comm as comm
|
||||
from utils.env import get_random_seed, set_seed
|
||||
from utils.config import Config, DictAction
|
||||
|
||||
|
||||
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
|
||||
"""
|
||||
Create a DistributedDataParallel model if there are >1 processes.
|
||||
Args:
|
||||
model: a torch.nn.Module
|
||||
fp16_compression: add fp16 compression hooks to the ddp object.
|
||||
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
|
||||
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
|
||||
"""
|
||||
if comm.get_world_size() == 1:
|
||||
return model
|
||||
# kwargs['find_unused_parameters'] = True
|
||||
if "device_ids" not in kwargs:
|
||||
kwargs["device_ids"] = [comm.get_local_rank()]
|
||||
if "output_device" not in kwargs:
|
||||
kwargs["output_device"] = [comm.get_local_rank()]
|
||||
ddp = DistributedDataParallel(model, **kwargs)
|
||||
if fp16_compression:
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
||||
|
||||
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
|
||||
return ddp
|
||||
|
||||
|
||||
def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||
"""Worker init func for dataloader.
|
||||
|
||||
The seed of each worker equals to num_worker * rank + worker_id + user_seed
|
||||
|
||||
Args:
|
||||
worker_id (int): Worker id.
|
||||
num_workers (int): Number of workers.
|
||||
rank (int): The rank of current process.
|
||||
seed (int): The random seed to use.
|
||||
"""
|
||||
|
||||
worker_seed = num_workers * rank + worker_id + seed
|
||||
set_seed(worker_seed)
|
||||
|
||||
|
||||
def default_argument_parser(epilog=None):
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog=epilog
|
||||
or f"""
|
||||
Examples:
|
||||
Run on single machine:
|
||||
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
|
||||
Change some config options:
|
||||
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
|
||||
Run on multiple machines:
|
||||
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
|
||||
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
|
||||
""",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-file", default="", metavar="FILE", help="path to config file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-gpus", type=int, default=1, help="number of gpus *per machine*"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-machines", type=int, default=1, help="total number of machines"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--machine-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="the rank of this machine (unique per machine)",
|
||||
)
|
||||
# PyTorch still may leave orphan processes in multi-gpu training.
|
||||
# Therefore we use a deterministic way to obtain port,
|
||||
# so that users are aware of orphan processes by seeing the port occupied.
|
||||
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
|
||||
parser.add_argument(
|
||||
"--dist-url",
|
||||
# default="tcp://127.0.0.1:{}".format(port),
|
||||
default="auto",
|
||||
help="initialization URL for pytorch distributed backend. See "
|
||||
"https://pytorch.org/docs/stable/distributed.html for details.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--options", nargs="+", action=DictAction, help="custom options"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def default_config_parser(file_path, options):
|
||||
# config name protocol: dataset_name/model_name-exp_name
|
||||
if os.path.isfile(file_path):
|
||||
cfg = Config.fromfile(file_path)
|
||||
else:
|
||||
sep = file_path.find("-")
|
||||
cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :]))
|
||||
|
||||
if options is not None:
|
||||
cfg.merge_from_dict(options)
|
||||
|
||||
if cfg.seed is None:
|
||||
cfg.seed = get_random_seed()
|
||||
|
||||
cfg.data.train.loop = cfg.epoch // cfg.eval_epoch
|
||||
|
||||
os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True)
|
||||
if not cfg.resume:
|
||||
cfg.dump(os.path.join(cfg.save_path, "config.py"))
|
||||
return cfg
|
||||
|
||||
|
||||
def default_setup(cfg):
|
||||
# scalar by world size
|
||||
world_size = comm.get_world_size()
|
||||
cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count()
|
||||
cfg.num_worker_per_gpu = cfg.num_worker // world_size
|
||||
assert cfg.batch_size % world_size == 0
|
||||
assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0
|
||||
assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0
|
||||
cfg.batch_size_per_gpu = cfg.batch_size // world_size
|
||||
cfg.batch_size_val_per_gpu = (
|
||||
cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1
|
||||
)
|
||||
cfg.batch_size_test_per_gpu = (
|
||||
cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1
|
||||
)
|
||||
# update data loop
|
||||
assert cfg.epoch % cfg.eval_epoch == 0
|
||||
# settle random seed
|
||||
rank = comm.get_rank()
|
||||
seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank
|
||||
set_seed(seed)
|
||||
return cfg
|
||||
5
engines/hooks/__init__.py
Normal file
5
engines/hooks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .default import HookBase
|
||||
from .misc import *
|
||||
from .evaluator import *
|
||||
|
||||
from .builder import build_hooks
|
||||
15
engines/hooks/builder.py
Normal file
15
engines/hooks/builder.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
from utils.registry import Registry
|
||||
|
||||
|
||||
HOOKS = Registry("hooks")
|
||||
|
||||
|
||||
def build_hooks(cfg):
|
||||
hooks = []
|
||||
for hook_cfg in cfg:
|
||||
hooks.append(HOOKS.build(hook_cfg))
|
||||
return hooks
|
||||
29
engines/hooks/default.py
Normal file
29
engines/hooks/default.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
|
||||
class HookBase:
|
||||
"""
|
||||
Base class for hooks that can be registered with :class:`TrainerBase`.
|
||||
"""
|
||||
|
||||
trainer = None # A weak reference to the trainer object.
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def before_epoch(self):
|
||||
pass
|
||||
|
||||
def before_step(self):
|
||||
pass
|
||||
|
||||
def after_step(self):
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def after_train(self):
|
||||
pass
|
||||
577
engines/hooks/evaluator.py
Normal file
577
engines/hooks/evaluator.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from uuid import uuid4
|
||||
|
||||
import utils.comm as comm
|
||||
from utils.misc import intersection_and_union_gpu
|
||||
|
||||
from .default import HookBase
|
||||
from .builder import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ClsEvaluator(HookBase):
|
||||
def after_epoch(self):
|
||||
if self.trainer.cfg.evaluate:
|
||||
self.eval()
|
||||
|
||||
def eval(self):
|
||||
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
|
||||
self.trainer.model.eval()
|
||||
for i, input_dict in enumerate(self.trainer.val_loader):
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
output = output_dict["cls_logits"]
|
||||
loss = output_dict["loss"]
|
||||
pred = output.max(1)[1]
|
||||
label = input_dict["category"]
|
||||
intersection, union, target = intersection_and_union_gpu(
|
||||
pred,
|
||||
label,
|
||||
self.trainer.cfg.data.num_classes,
|
||||
self.trainer.cfg.data.ignore_index,
|
||||
)
|
||||
if comm.get_world_size() > 1:
|
||||
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
|
||||
target
|
||||
)
|
||||
intersection, union, target = (
|
||||
intersection.cpu().numpy(),
|
||||
union.cpu().numpy(),
|
||||
target.cpu().numpy(),
|
||||
)
|
||||
# Here there is no need to sync since sync happened in dist.all_reduce
|
||||
self.trainer.storage.put_scalar("val_intersection", intersection)
|
||||
self.trainer.storage.put_scalar("val_union", union)
|
||||
self.trainer.storage.put_scalar("val_target", target)
|
||||
self.trainer.storage.put_scalar("val_loss", loss.item())
|
||||
self.trainer.logger.info(
|
||||
"Test: [{iter}/{max_iter}] "
|
||||
"Loss {loss:.4f} ".format(
|
||||
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
|
||||
)
|
||||
)
|
||||
loss_avg = self.trainer.storage.history("val_loss").avg
|
||||
intersection = self.trainer.storage.history("val_intersection").total
|
||||
union = self.trainer.storage.history("val_union").total
|
||||
target = self.trainer.storage.history("val_target").total
|
||||
iou_class = intersection / (union + 1e-10)
|
||||
acc_class = intersection / (target + 1e-10)
|
||||
m_iou = np.mean(iou_class)
|
||||
m_acc = np.mean(acc_class)
|
||||
all_acc = sum(intersection) / (sum(target) + 1e-10)
|
||||
self.trainer.logger.info(
|
||||
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
|
||||
m_iou, m_acc, all_acc
|
||||
)
|
||||
)
|
||||
for i in range(self.trainer.cfg.data.num_classes):
|
||||
self.trainer.logger.info(
|
||||
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
|
||||
idx=i,
|
||||
name=self.trainer.cfg.data.names[i],
|
||||
iou=iou_class[i],
|
||||
accuracy=acc_class[i],
|
||||
)
|
||||
)
|
||||
current_epoch = self.trainer.epoch + 1
|
||||
if self.trainer.writer is not None:
|
||||
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
|
||||
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||
self.trainer.comm_info["current_metric_value"] = all_acc # save for saver
|
||||
self.trainer.comm_info["current_metric_name"] = "allAcc" # save for saver
|
||||
|
||||
def after_train(self):
|
||||
self.trainer.logger.info(
|
||||
"Best {}: {:.4f}".format("allAcc", self.trainer.best_metric_value)
|
||||
)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SemSegEvaluator(HookBase):
|
||||
def after_epoch(self):
|
||||
if self.trainer.cfg.evaluate:
|
||||
self.eval()
|
||||
|
||||
def eval(self):
|
||||
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
|
||||
self.trainer.model.eval()
|
||||
for i, input_dict in enumerate(self.trainer.val_loader):
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
output = output_dict["seg_logits"]
|
||||
loss = output_dict["loss"]
|
||||
pred = output.max(1)[1]
|
||||
segment = input_dict["segment"]
|
||||
if "origin_coord" in input_dict.keys():
|
||||
idx, _ = pointops.knn_query(
|
||||
1,
|
||||
input_dict["coord"].float(),
|
||||
input_dict["offset"].int(),
|
||||
input_dict["origin_coord"].float(),
|
||||
input_dict["origin_offset"].int(),
|
||||
)
|
||||
pred = pred[idx.flatten().long()]
|
||||
segment = input_dict["origin_segment"]
|
||||
intersection, union, target = intersection_and_union_gpu(
|
||||
pred,
|
||||
segment,
|
||||
self.trainer.cfg.data.num_classes,
|
||||
self.trainer.cfg.data.ignore_index,
|
||||
)
|
||||
if comm.get_world_size() > 1:
|
||||
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
|
||||
target
|
||||
)
|
||||
intersection, union, target = (
|
||||
intersection.cpu().numpy(),
|
||||
union.cpu().numpy(),
|
||||
target.cpu().numpy(),
|
||||
)
|
||||
# Here there is no need to sync since sync happened in dist.all_reduce
|
||||
self.trainer.storage.put_scalar("val_intersection", intersection)
|
||||
self.trainer.storage.put_scalar("val_union", union)
|
||||
self.trainer.storage.put_scalar("val_target", target)
|
||||
self.trainer.storage.put_scalar("val_loss", loss.item())
|
||||
info = "Test: [{iter}/{max_iter}] ".format(
|
||||
iter=i + 1, max_iter=len(self.trainer.val_loader)
|
||||
)
|
||||
if "origin_coord" in input_dict.keys():
|
||||
info = "Interp. " + info
|
||||
self.trainer.logger.info(
|
||||
info
|
||||
+ "Loss {loss:.4f} ".format(
|
||||
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
|
||||
)
|
||||
)
|
||||
loss_avg = self.trainer.storage.history("val_loss").avg
|
||||
intersection = self.trainer.storage.history("val_intersection").total
|
||||
union = self.trainer.storage.history("val_union").total
|
||||
target = self.trainer.storage.history("val_target").total
|
||||
iou_class = intersection / (union + 1e-10)
|
||||
acc_class = intersection / (target + 1e-10)
|
||||
m_iou = np.mean(iou_class)
|
||||
m_acc = np.mean(acc_class)
|
||||
all_acc = sum(intersection) / (sum(target) + 1e-10)
|
||||
self.trainer.logger.info(
|
||||
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
|
||||
m_iou, m_acc, all_acc
|
||||
)
|
||||
)
|
||||
for i in range(self.trainer.cfg.data.num_classes):
|
||||
self.trainer.logger.info(
|
||||
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
|
||||
idx=i,
|
||||
name=self.trainer.cfg.data.names[i],
|
||||
iou=iou_class[i],
|
||||
accuracy=acc_class[i],
|
||||
)
|
||||
)
|
||||
current_epoch = self.trainer.epoch + 1
|
||||
if self.trainer.writer is not None:
|
||||
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
|
||||
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||
self.trainer.comm_info["current_metric_value"] = m_iou # save for saver
|
||||
self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver
|
||||
|
||||
def after_train(self):
|
||||
self.trainer.logger.info(
|
||||
"Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value)
|
||||
)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class InsSegEvaluator(HookBase):
|
||||
def __init__(self, segment_ignore_index=(-1,), instance_ignore_index=-1):
|
||||
self.segment_ignore_index = segment_ignore_index
|
||||
self.instance_ignore_index = instance_ignore_index
|
||||
|
||||
self.valid_class_names = None # update in before train
|
||||
self.overlaps = np.append(np.arange(0.5, 0.95, 0.05), 0.25)
|
||||
self.min_region_sizes = 100
|
||||
self.distance_threshes = float("inf")
|
||||
self.distance_confs = -float("inf")
|
||||
|
||||
def before_train(self):
|
||||
self.valid_class_names = [
|
||||
self.trainer.cfg.data.names[i]
|
||||
for i in range(self.trainer.cfg.data.num_classes)
|
||||
if i not in self.segment_ignore_index
|
||||
]
|
||||
|
||||
def after_epoch(self):
|
||||
if self.trainer.cfg.evaluate:
|
||||
self.eval()
|
||||
|
||||
def associate_instances(self, pred, segment, instance):
|
||||
segment = segment.cpu().numpy()
|
||||
instance = instance.cpu().numpy()
|
||||
void_mask = np.in1d(segment, self.segment_ignore_index)
|
||||
|
||||
assert (
|
||||
pred["pred_classes"].shape[0]
|
||||
== pred["pred_scores"].shape[0]
|
||||
== pred["pred_masks"].shape[0]
|
||||
)
|
||||
assert pred["pred_masks"].shape[1] == segment.shape[0] == instance.shape[0]
|
||||
# get gt instances
|
||||
gt_instances = dict()
|
||||
for i in range(self.trainer.cfg.data.num_classes):
|
||||
if i not in self.segment_ignore_index:
|
||||
gt_instances[self.trainer.cfg.data.names[i]] = []
|
||||
instance_ids, idx, counts = np.unique(
|
||||
instance, return_index=True, return_counts=True
|
||||
)
|
||||
segment_ids = segment[idx]
|
||||
for i in range(len(instance_ids)):
|
||||
if instance_ids[i] == self.instance_ignore_index:
|
||||
continue
|
||||
if segment_ids[i] in self.segment_ignore_index:
|
||||
continue
|
||||
gt_inst = dict()
|
||||
gt_inst["instance_id"] = instance_ids[i]
|
||||
gt_inst["segment_id"] = segment_ids[i]
|
||||
gt_inst["dist_conf"] = 0.0
|
||||
gt_inst["med_dist"] = -1.0
|
||||
gt_inst["vert_count"] = counts[i]
|
||||
gt_inst["matched_pred"] = []
|
||||
gt_instances[self.trainer.cfg.data.names[segment_ids[i]]].append(gt_inst)
|
||||
|
||||
# get pred instances and associate with gt
|
||||
pred_instances = dict()
|
||||
for i in range(self.trainer.cfg.data.num_classes):
|
||||
if i not in self.segment_ignore_index:
|
||||
pred_instances[self.trainer.cfg.data.names[i]] = []
|
||||
instance_id = 0
|
||||
for i in range(len(pred["pred_classes"])):
|
||||
if pred["pred_classes"][i] in self.segment_ignore_index:
|
||||
continue
|
||||
pred_inst = dict()
|
||||
pred_inst["uuid"] = uuid4()
|
||||
pred_inst["instance_id"] = instance_id
|
||||
pred_inst["segment_id"] = pred["pred_classes"][i]
|
||||
pred_inst["confidence"] = pred["pred_scores"][i]
|
||||
pred_inst["mask"] = np.not_equal(pred["pred_masks"][i], 0)
|
||||
pred_inst["vert_count"] = np.count_nonzero(pred_inst["mask"])
|
||||
pred_inst["void_intersection"] = np.count_nonzero(
|
||||
np.logical_and(void_mask, pred_inst["mask"])
|
||||
)
|
||||
if pred_inst["vert_count"] < self.min_region_sizes:
|
||||
continue # skip if empty
|
||||
segment_name = self.trainer.cfg.data.names[pred_inst["segment_id"]]
|
||||
matched_gt = []
|
||||
for gt_idx, gt_inst in enumerate(gt_instances[segment_name]):
|
||||
intersection = np.count_nonzero(
|
||||
np.logical_and(
|
||||
instance == gt_inst["instance_id"], pred_inst["mask"]
|
||||
)
|
||||
)
|
||||
if intersection > 0:
|
||||
gt_inst_ = gt_inst.copy()
|
||||
pred_inst_ = pred_inst.copy()
|
||||
gt_inst_["intersection"] = intersection
|
||||
pred_inst_["intersection"] = intersection
|
||||
matched_gt.append(gt_inst_)
|
||||
gt_inst["matched_pred"].append(pred_inst_)
|
||||
pred_inst["matched_gt"] = matched_gt
|
||||
pred_instances[segment_name].append(pred_inst)
|
||||
instance_id += 1
|
||||
return gt_instances, pred_instances
|
||||
|
||||
def evaluate_matches(self, scenes):
|
||||
overlaps = self.overlaps
|
||||
min_region_sizes = [self.min_region_sizes]
|
||||
dist_threshes = [self.distance_threshes]
|
||||
dist_confs = [self.distance_confs]
|
||||
|
||||
# results: class x overlap
|
||||
ap_table = np.zeros(
|
||||
(len(dist_threshes), len(self.valid_class_names), len(overlaps)), float
|
||||
)
|
||||
for di, (min_region_size, distance_thresh, distance_conf) in enumerate(
|
||||
zip(min_region_sizes, dist_threshes, dist_confs)
|
||||
):
|
||||
for oi, overlap_th in enumerate(overlaps):
|
||||
pred_visited = {}
|
||||
for scene in scenes:
|
||||
for _ in scene["pred"]:
|
||||
for label_name in self.valid_class_names:
|
||||
for p in scene["pred"][label_name]:
|
||||
if "uuid" in p:
|
||||
pred_visited[p["uuid"]] = False
|
||||
for li, label_name in enumerate(self.valid_class_names):
|
||||
y_true = np.empty(0)
|
||||
y_score = np.empty(0)
|
||||
hard_false_negatives = 0
|
||||
has_gt = False
|
||||
has_pred = False
|
||||
for scene in scenes:
|
||||
pred_instances = scene["pred"][label_name]
|
||||
gt_instances = scene["gt"][label_name]
|
||||
# filter groups in ground truth
|
||||
gt_instances = [
|
||||
gt
|
||||
for gt in gt_instances
|
||||
if gt["vert_count"] >= min_region_size
|
||||
and gt["med_dist"] <= distance_thresh
|
||||
and gt["dist_conf"] >= distance_conf
|
||||
]
|
||||
if gt_instances:
|
||||
has_gt = True
|
||||
if pred_instances:
|
||||
has_pred = True
|
||||
|
||||
cur_true = np.ones(len(gt_instances))
|
||||
cur_score = np.ones(len(gt_instances)) * (-float("inf"))
|
||||
cur_match = np.zeros(len(gt_instances), dtype=bool)
|
||||
# collect matches
|
||||
for gti, gt in enumerate(gt_instances):
|
||||
found_match = False
|
||||
for pred in gt["matched_pred"]:
|
||||
# greedy assignments
|
||||
if pred_visited[pred["uuid"]]:
|
||||
continue
|
||||
overlap = float(pred["intersection"]) / (
|
||||
gt["vert_count"]
|
||||
+ pred["vert_count"]
|
||||
- pred["intersection"]
|
||||
)
|
||||
if overlap > overlap_th:
|
||||
confidence = pred["confidence"]
|
||||
# if already have a prediction for this gt,
|
||||
# the prediction with the lower score is automatically a false positive
|
||||
if cur_match[gti]:
|
||||
max_score = max(cur_score[gti], confidence)
|
||||
min_score = min(cur_score[gti], confidence)
|
||||
cur_score[gti] = max_score
|
||||
# append false positive
|
||||
cur_true = np.append(cur_true, 0)
|
||||
cur_score = np.append(cur_score, min_score)
|
||||
cur_match = np.append(cur_match, True)
|
||||
# otherwise set score
|
||||
else:
|
||||
found_match = True
|
||||
cur_match[gti] = True
|
||||
cur_score[gti] = confidence
|
||||
pred_visited[pred["uuid"]] = True
|
||||
if not found_match:
|
||||
hard_false_negatives += 1
|
||||
# remove non-matched ground truth instances
|
||||
cur_true = cur_true[cur_match]
|
||||
cur_score = cur_score[cur_match]
|
||||
|
||||
# collect non-matched predictions as false positive
|
||||
for pred in pred_instances:
|
||||
found_gt = False
|
||||
for gt in pred["matched_gt"]:
|
||||
overlap = float(gt["intersection"]) / (
|
||||
gt["vert_count"]
|
||||
+ pred["vert_count"]
|
||||
- gt["intersection"]
|
||||
)
|
||||
if overlap > overlap_th:
|
||||
found_gt = True
|
||||
break
|
||||
if not found_gt:
|
||||
num_ignore = pred["void_intersection"]
|
||||
for gt in pred["matched_gt"]:
|
||||
if gt["segment_id"] in self.segment_ignore_index:
|
||||
num_ignore += gt["intersection"]
|
||||
# small ground truth instances
|
||||
if (
|
||||
gt["vert_count"] < min_region_size
|
||||
or gt["med_dist"] > distance_thresh
|
||||
or gt["dist_conf"] < distance_conf
|
||||
):
|
||||
num_ignore += gt["intersection"]
|
||||
proportion_ignore = (
|
||||
float(num_ignore) / pred["vert_count"]
|
||||
)
|
||||
# if not ignored append false positive
|
||||
if proportion_ignore <= overlap_th:
|
||||
cur_true = np.append(cur_true, 0)
|
||||
confidence = pred["confidence"]
|
||||
cur_score = np.append(cur_score, confidence)
|
||||
|
||||
# append to overall results
|
||||
y_true = np.append(y_true, cur_true)
|
||||
y_score = np.append(y_score, cur_score)
|
||||
|
||||
# compute average precision
|
||||
if has_gt and has_pred:
|
||||
# compute precision recall curve first
|
||||
|
||||
# sorting and cumsum
|
||||
score_arg_sort = np.argsort(y_score)
|
||||
y_score_sorted = y_score[score_arg_sort]
|
||||
y_true_sorted = y_true[score_arg_sort]
|
||||
y_true_sorted_cumsum = np.cumsum(y_true_sorted)
|
||||
|
||||
# unique thresholds
|
||||
(thresholds, unique_indices) = np.unique(
|
||||
y_score_sorted, return_index=True
|
||||
)
|
||||
num_prec_recall = len(unique_indices) + 1
|
||||
|
||||
# prepare precision recall
|
||||
num_examples = len(y_score_sorted)
|
||||
# https://github.com/ScanNet/ScanNet/pull/26
|
||||
# all predictions are non-matched but also all of them are ignored and not counted as FP
|
||||
# y_true_sorted_cumsum is empty
|
||||
# num_true_examples = y_true_sorted_cumsum[-1]
|
||||
num_true_examples = (
|
||||
y_true_sorted_cumsum[-1]
|
||||
if len(y_true_sorted_cumsum) > 0
|
||||
else 0
|
||||
)
|
||||
precision = np.zeros(num_prec_recall)
|
||||
recall = np.zeros(num_prec_recall)
|
||||
|
||||
# deal with the first point
|
||||
y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
|
||||
# deal with remaining
|
||||
for idx_res, idx_scores in enumerate(unique_indices):
|
||||
cumsum = y_true_sorted_cumsum[idx_scores - 1]
|
||||
tp = num_true_examples - cumsum
|
||||
fp = num_examples - idx_scores - tp
|
||||
fn = cumsum + hard_false_negatives
|
||||
p = float(tp) / (tp + fp)
|
||||
r = float(tp) / (tp + fn)
|
||||
precision[idx_res] = p
|
||||
recall[idx_res] = r
|
||||
|
||||
# first point in curve is artificial
|
||||
precision[-1] = 1.0
|
||||
recall[-1] = 0.0
|
||||
|
||||
# compute average of precision-recall curve
|
||||
recall_for_conv = np.copy(recall)
|
||||
recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
|
||||
recall_for_conv = np.append(recall_for_conv, 0.0)
|
||||
|
||||
stepWidths = np.convolve(
|
||||
recall_for_conv, [-0.5, 0, 0.5], "valid"
|
||||
)
|
||||
# integrate is now simply a dot product
|
||||
ap_current = np.dot(precision, stepWidths)
|
||||
|
||||
elif has_gt:
|
||||
ap_current = 0.0
|
||||
else:
|
||||
ap_current = float("nan")
|
||||
ap_table[di, li, oi] = ap_current
|
||||
d_inf = 0
|
||||
o50 = np.where(np.isclose(self.overlaps, 0.5))
|
||||
o25 = np.where(np.isclose(self.overlaps, 0.25))
|
||||
oAllBut25 = np.where(np.logical_not(np.isclose(self.overlaps, 0.25)))
|
||||
ap_scores = dict()
|
||||
ap_scores["all_ap"] = np.nanmean(ap_table[d_inf, :, oAllBut25])
|
||||
ap_scores["all_ap_50%"] = np.nanmean(ap_table[d_inf, :, o50])
|
||||
ap_scores["all_ap_25%"] = np.nanmean(ap_table[d_inf, :, o25])
|
||||
ap_scores["classes"] = {}
|
||||
for li, label_name in enumerate(self.valid_class_names):
|
||||
ap_scores["classes"][label_name] = {}
|
||||
ap_scores["classes"][label_name]["ap"] = np.average(
|
||||
ap_table[d_inf, li, oAllBut25]
|
||||
)
|
||||
ap_scores["classes"][label_name]["ap50%"] = np.average(
|
||||
ap_table[d_inf, li, o50]
|
||||
)
|
||||
ap_scores["classes"][label_name]["ap25%"] = np.average(
|
||||
ap_table[d_inf, li, o25]
|
||||
)
|
||||
return ap_scores
|
||||
|
||||
def eval(self):
|
||||
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
|
||||
self.trainer.model.eval()
|
||||
scenes = []
|
||||
for i, input_dict in enumerate(self.trainer.val_loader):
|
||||
assert (
|
||||
len(input_dict["offset"]) == 1
|
||||
) # currently only support bs 1 for each GPU
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
|
||||
loss = output_dict["loss"]
|
||||
|
||||
segment = input_dict["segment"]
|
||||
instance = input_dict["instance"]
|
||||
# map to origin
|
||||
if "origin_coord" in input_dict.keys():
|
||||
idx, _ = pointops.knn_query(
|
||||
1,
|
||||
input_dict["coord"].float(),
|
||||
input_dict["offset"].int(),
|
||||
input_dict["origin_coord"].float(),
|
||||
input_dict["origin_offset"].int(),
|
||||
)
|
||||
idx = idx.cpu().flatten().long()
|
||||
output_dict["pred_masks"] = output_dict["pred_masks"][:, idx]
|
||||
segment = input_dict["origin_segment"]
|
||||
instance = input_dict["origin_instance"]
|
||||
|
||||
gt_instances, pred_instance = self.associate_instances(
|
||||
output_dict, segment, instance
|
||||
)
|
||||
scenes.append(dict(gt=gt_instances, pred=pred_instance))
|
||||
|
||||
self.trainer.storage.put_scalar("val_loss", loss.item())
|
||||
self.trainer.logger.info(
|
||||
"Test: [{iter}/{max_iter}] "
|
||||
"Loss {loss:.4f} ".format(
|
||||
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
|
||||
)
|
||||
)
|
||||
|
||||
loss_avg = self.trainer.storage.history("val_loss").avg
|
||||
comm.synchronize()
|
||||
scenes_sync = comm.gather(scenes, dst=0)
|
||||
scenes = [scene for scenes_ in scenes_sync for scene in scenes_]
|
||||
ap_scores = self.evaluate_matches(scenes)
|
||||
all_ap = ap_scores["all_ap"]
|
||||
all_ap_50 = ap_scores["all_ap_50%"]
|
||||
all_ap_25 = ap_scores["all_ap_25%"]
|
||||
self.trainer.logger.info(
|
||||
"Val result: mAP/AP50/AP25 {:.4f}/{:.4f}/{:.4f}.".format(
|
||||
all_ap, all_ap_50, all_ap_25
|
||||
)
|
||||
)
|
||||
for i, label_name in enumerate(self.valid_class_names):
|
||||
ap = ap_scores["classes"][label_name]["ap"]
|
||||
ap_50 = ap_scores["classes"][label_name]["ap50%"]
|
||||
ap_25 = ap_scores["classes"][label_name]["ap25%"]
|
||||
self.trainer.logger.info(
|
||||
"Class_{idx}-{name} Result: AP/AP50/AP25 {AP:.4f}/{AP50:.4f}/{AP25:.4f}".format(
|
||||
idx=i, name=label_name, AP=ap, AP50=ap_50, AP25=ap_25
|
||||
)
|
||||
)
|
||||
current_epoch = self.trainer.epoch + 1
|
||||
if self.trainer.writer is not None:
|
||||
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mAP", all_ap, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/AP50", all_ap_50, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/AP25", all_ap_25, current_epoch)
|
||||
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||
self.trainer.comm_info["current_metric_value"] = all_ap_50 # save for saver
|
||||
self.trainer.comm_info["current_metric_name"] = "AP50" # save for saver
|
||||
460
engines/hooks/misc.py
Normal file
460
engines/hooks/misc.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import sys
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from collections import OrderedDict
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from collections.abc import Sequence
|
||||
else:
|
||||
from collections import Sequence
|
||||
from utils.timer import Timer
|
||||
from utils.comm import is_main_process, synchronize, get_world_size
|
||||
from utils.cache import shared_dict
|
||||
|
||||
import utils.comm as comm
|
||||
from engines.test import TESTERS
|
||||
|
||||
from .default import HookBase
|
||||
from .builder import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class IterationTimer(HookBase):
|
||||
def __init__(self, warmup_iter=1):
|
||||
self._warmup_iter = warmup_iter
|
||||
self._start_time = time.perf_counter()
|
||||
self._iter_timer = Timer()
|
||||
self._remain_iter = 0
|
||||
|
||||
def before_train(self):
|
||||
self._start_time = time.perf_counter()
|
||||
self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader)
|
||||
|
||||
def before_epoch(self):
|
||||
self._iter_timer.reset()
|
||||
|
||||
def before_step(self):
|
||||
data_time = self._iter_timer.seconds()
|
||||
self.trainer.storage.put_scalar("data_time", data_time)
|
||||
|
||||
def after_step(self):
|
||||
batch_time = self._iter_timer.seconds()
|
||||
self._iter_timer.reset()
|
||||
self.trainer.storage.put_scalar("batch_time", batch_time)
|
||||
self._remain_iter -= 1
|
||||
remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg
|
||||
t_m, t_s = divmod(remain_time, 60)
|
||||
t_h, t_m = divmod(t_m, 60)
|
||||
remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s))
|
||||
if "iter_info" in self.trainer.comm_info.keys():
|
||||
info = (
|
||||
"Data {data_time_val:.3f} ({data_time_avg:.3f}) "
|
||||
"Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) "
|
||||
"Remain {remain_time} ".format(
|
||||
data_time_val=self.trainer.storage.history("data_time").val,
|
||||
data_time_avg=self.trainer.storage.history("data_time").avg,
|
||||
batch_time_val=self.trainer.storage.history("batch_time").val,
|
||||
batch_time_avg=self.trainer.storage.history("batch_time").avg,
|
||||
remain_time=remain_time,
|
||||
)
|
||||
)
|
||||
self.trainer.comm_info["iter_info"] += info
|
||||
if self.trainer.comm_info["iter"] <= self._warmup_iter:
|
||||
self.trainer.storage.history("data_time").reset()
|
||||
self.trainer.storage.history("batch_time").reset()
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class InformationWriter(HookBase):
|
||||
def __init__(self):
|
||||
self.curr_iter = 0
|
||||
self.model_output_keys = []
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.comm_info["iter_info"] = ""
|
||||
self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader)
|
||||
|
||||
def before_step(self):
|
||||
self.curr_iter += 1
|
||||
# MSC pretrain do not have offset information. Comment the code for support MSC
|
||||
# info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \
|
||||
# "Scan {batch_size} ({points_num}) ".format(
|
||||
# epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch,
|
||||
# iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader),
|
||||
# batch_size=len(self.trainer.comm_info["input_dict"]["offset"]),
|
||||
# points_num=self.trainer.comm_info["input_dict"]["offset"][-1]
|
||||
# )
|
||||
info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format(
|
||||
epoch=self.trainer.epoch + 1,
|
||||
max_epoch=self.trainer.max_epoch,
|
||||
iter=self.trainer.comm_info["iter"] + 1,
|
||||
max_iter=len(self.trainer.train_loader),
|
||||
)
|
||||
self.trainer.comm_info["iter_info"] += info
|
||||
|
||||
def after_step(self):
|
||||
if "model_output_dict" in self.trainer.comm_info.keys():
|
||||
model_output_dict = self.trainer.comm_info["model_output_dict"]
|
||||
self.model_output_keys = model_output_dict.keys()
|
||||
for key in self.model_output_keys:
|
||||
self.trainer.storage.put_scalar(key, model_output_dict[key].item())
|
||||
|
||||
for key in self.model_output_keys:
|
||||
self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format(
|
||||
key=key, value=self.trainer.storage.history(key).val
|
||||
)
|
||||
lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"]
|
||||
self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr)
|
||||
self.trainer.logger.info(self.trainer.comm_info["iter_info"])
|
||||
self.trainer.comm_info["iter_info"] = "" # reset iter info
|
||||
if self.trainer.writer is not None:
|
||||
self.trainer.writer.add_scalar("lr", lr, self.curr_iter)
|
||||
for key in self.model_output_keys:
|
||||
self.trainer.writer.add_scalar(
|
||||
"train_batch/" + key,
|
||||
self.trainer.storage.history(key).val,
|
||||
self.curr_iter,
|
||||
)
|
||||
|
||||
def after_epoch(self):
|
||||
epoch_info = "Train result: "
|
||||
for key in self.model_output_keys:
|
||||
epoch_info += "{key}: {value:.4f} ".format(
|
||||
key=key, value=self.trainer.storage.history(key).avg
|
||||
)
|
||||
self.trainer.logger.info(epoch_info)
|
||||
if self.trainer.writer is not None:
|
||||
for key in self.model_output_keys:
|
||||
self.trainer.writer.add_scalar(
|
||||
"train/" + key,
|
||||
self.trainer.storage.history(key).avg,
|
||||
self.trainer.epoch + 1,
|
||||
)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class CheckpointSaver(HookBase):
|
||||
def __init__(self, save_freq=None):
|
||||
self.save_freq = save_freq # None or int, None indicate only save model last
|
||||
|
||||
def after_epoch(self):
|
||||
if is_main_process():
|
||||
is_best = False
|
||||
if self.trainer.cfg.evaluate:
|
||||
current_metric_value = self.trainer.comm_info["current_metric_value"]
|
||||
current_metric_name = self.trainer.comm_info["current_metric_name"]
|
||||
if current_metric_value > self.trainer.best_metric_value:
|
||||
self.trainer.best_metric_value = current_metric_value
|
||||
is_best = True
|
||||
self.trainer.logger.info(
|
||||
"Best validation {} updated to: {:.4f}".format(
|
||||
current_metric_name, current_metric_value
|
||||
)
|
||||
)
|
||||
self.trainer.logger.info(
|
||||
"Currently Best {}: {:.4f}".format(
|
||||
current_metric_name, self.trainer.best_metric_value
|
||||
)
|
||||
)
|
||||
|
||||
filename = os.path.join(
|
||||
self.trainer.cfg.save_path, "model", "model_last.pth"
|
||||
)
|
||||
self.trainer.logger.info("Saving checkpoint to: " + filename)
|
||||
torch.save(
|
||||
{
|
||||
"epoch": self.trainer.epoch + 1,
|
||||
"state_dict": self.trainer.model.state_dict(),
|
||||
"optimizer": self.trainer.optimizer.state_dict(),
|
||||
"scheduler": self.trainer.scheduler.state_dict(),
|
||||
"scaler": self.trainer.scaler.state_dict()
|
||||
if self.trainer.cfg.enable_amp
|
||||
else None,
|
||||
"best_metric_value": self.trainer.best_metric_value,
|
||||
},
|
||||
filename + ".tmp",
|
||||
)
|
||||
os.replace(filename + ".tmp", filename)
|
||||
if is_best:
|
||||
shutil.copyfile(
|
||||
filename,
|
||||
os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"),
|
||||
)
|
||||
if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0:
|
||||
shutil.copyfile(
|
||||
filename,
|
||||
os.path.join(
|
||||
self.trainer.cfg.save_path,
|
||||
"model",
|
||||
f"epoch_{self.trainer.epoch + 1}.pth",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class CheckpointLoader(HookBase):
|
||||
def __init__(self, keywords="", replacement=None, strict=False):
|
||||
self.keywords = keywords
|
||||
self.replacement = replacement if replacement is not None else keywords
|
||||
self.strict = strict
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.logger.info("=> Loading checkpoint & weight ...")
|
||||
if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight):
|
||||
self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}")
|
||||
checkpoint = torch.load(
|
||||
self.trainer.cfg.weight,
|
||||
map_location=lambda storage, loc: storage.cuda(),
|
||||
)
|
||||
self.trainer.logger.info(
|
||||
f"Loading layer weights with keyword: {self.keywords}, "
|
||||
f"replace keyword with: {self.replacement}"
|
||||
)
|
||||
weight = OrderedDict()
|
||||
for key, value in checkpoint["state_dict"].items():
|
||||
if not key.startswith("module."):
|
||||
if comm.get_world_size() > 1:
|
||||
key = "module." + key # xxx.xxx -> module.xxx.xxx
|
||||
# Now all keys contain "module." no matter DDP or not.
|
||||
if self.keywords in key:
|
||||
key = key.replace(self.keywords, self.replacement)
|
||||
if comm.get_world_size() == 1:
|
||||
key = key[7:] # module.xxx.xxx -> xxx.xxx
|
||||
weight[key] = value
|
||||
load_state_info = self.trainer.model.load_state_dict(
|
||||
weight, strict=self.strict
|
||||
)
|
||||
self.trainer.logger.info(f"Missing keys: {load_state_info[0]}")
|
||||
if self.trainer.cfg.resume:
|
||||
self.trainer.logger.info(
|
||||
f"Resuming train at eval epoch: {checkpoint['epoch']}"
|
||||
)
|
||||
self.trainer.start_epoch = checkpoint["epoch"]
|
||||
self.trainer.best_metric_value = checkpoint["best_metric_value"]
|
||||
self.trainer.optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
self.trainer.scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
if self.trainer.cfg.enable_amp:
|
||||
self.trainer.scaler.load_state_dict(checkpoint["scaler"])
|
||||
else:
|
||||
self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}")
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class PreciseEvaluator(HookBase):
|
||||
def __init__(self, test_last=False):
|
||||
self.test_last = test_last
|
||||
|
||||
def after_train(self):
|
||||
self.trainer.logger.info(
|
||||
">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>"
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
cfg = self.trainer.cfg
|
||||
tester = TESTERS.build(
|
||||
dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model)
|
||||
)
|
||||
if self.test_last:
|
||||
self.trainer.logger.info("=> Testing on model_last ...")
|
||||
else:
|
||||
self.trainer.logger.info("=> Testing on model_best ...")
|
||||
best_path = os.path.join(
|
||||
self.trainer.cfg.save_path, "model", "model_best.pth"
|
||||
)
|
||||
checkpoint = torch.load(best_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
tester.model.load_state_dict(state_dict, strict=True)
|
||||
tester.test()
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class DataCacheOperator(HookBase):
|
||||
def __init__(self, data_root, split):
|
||||
self.data_root = data_root
|
||||
self.split = split
|
||||
self.data_list = self.get_data_list()
|
||||
|
||||
def get_data_list(self):
|
||||
if isinstance(self.split, str):
|
||||
data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth"))
|
||||
elif isinstance(self.split, Sequence):
|
||||
data_list = []
|
||||
for split in self.split:
|
||||
data_list += glob.glob(os.path.join(self.data_root, split, "*.pth"))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return data_list
|
||||
|
||||
def get_cache_name(self, data_path):
|
||||
data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0]
|
||||
return "pointcept" + data_name.replace(os.path.sep, "-")
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.logger.info(
|
||||
f"=> Caching dataset: {self.data_root}, split: {self.split} ..."
|
||||
)
|
||||
if is_main_process():
|
||||
for data_path in self.data_list:
|
||||
cache_name = self.get_cache_name(data_path)
|
||||
data = torch.load(data_path)
|
||||
shared_dict(cache_name, data)
|
||||
synchronize()
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class RuntimeProfiler(HookBase):
|
||||
def __init__(
|
||||
self,
|
||||
forward=True,
|
||||
backward=True,
|
||||
interrupt=False,
|
||||
warm_up=2,
|
||||
sort_by="cuda_time_total",
|
||||
row_limit=30,
|
||||
):
|
||||
self.forward = forward
|
||||
self.backward = backward
|
||||
self.interrupt = interrupt
|
||||
self.warm_up = warm_up
|
||||
self.sort_by = sort_by
|
||||
self.row_limit = row_limit
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.logger.info("Profiling runtime ...")
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
for i, input_dict in enumerate(self.trainer.train_loader):
|
||||
if i == self.warm_up + 1:
|
||||
break
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
if self.forward:
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
) as forward_prof:
|
||||
with record_function("model_inference"):
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
else:
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
loss = output_dict["loss"]
|
||||
if self.backward:
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
) as backward_prof:
|
||||
with record_function("model_inference"):
|
||||
loss.backward()
|
||||
self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]")
|
||||
if self.forward:
|
||||
self.trainer.logger.info(
|
||||
"Forward profile: \n"
|
||||
+ str(
|
||||
forward_prof.key_averages().table(
|
||||
sort_by=self.sort_by, row_limit=self.row_limit
|
||||
)
|
||||
)
|
||||
)
|
||||
forward_prof.export_chrome_trace(
|
||||
os.path.join(self.trainer.cfg.save_path, "forward_trace.json")
|
||||
)
|
||||
|
||||
if self.backward:
|
||||
self.trainer.logger.info(
|
||||
"Backward profile: \n"
|
||||
+ str(
|
||||
backward_prof.key_averages().table(
|
||||
sort_by=self.sort_by, row_limit=self.row_limit
|
||||
)
|
||||
)
|
||||
)
|
||||
backward_prof.export_chrome_trace(
|
||||
os.path.join(self.trainer.cfg.save_path, "backward_trace.json")
|
||||
)
|
||||
if self.interrupt:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class RuntimeProfilerV2(HookBase):
|
||||
def __init__(
|
||||
self,
|
||||
interrupt=False,
|
||||
wait=1,
|
||||
warmup=1,
|
||||
active=10,
|
||||
repeat=1,
|
||||
sort_by="cuda_time_total",
|
||||
row_limit=30,
|
||||
):
|
||||
self.interrupt = interrupt
|
||||
self.wait = wait
|
||||
self.warmup = warmup
|
||||
self.active = active
|
||||
self.repeat = repeat
|
||||
self.sort_by = sort_by
|
||||
self.row_limit = row_limit
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.logger.info("Profiling runtime ...")
|
||||
from torch.profiler import (
|
||||
profile,
|
||||
record_function,
|
||||
ProfilerActivity,
|
||||
schedule,
|
||||
tensorboard_trace_handler,
|
||||
)
|
||||
|
||||
prof = profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=schedule(
|
||||
wait=self.wait,
|
||||
warmup=self.warmup,
|
||||
active=self.active,
|
||||
repeat=self.repeat,
|
||||
),
|
||||
on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
prof.start()
|
||||
for i, input_dict in enumerate(self.trainer.train_loader):
|
||||
if i >= (self.wait + self.warmup + self.active) * self.repeat:
|
||||
break
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
with record_function("model_forward"):
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
loss = output_dict["loss"]
|
||||
with record_function("model_backward"):
|
||||
loss.backward()
|
||||
prof.step()
|
||||
self.trainer.logger.info(
|
||||
f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]"
|
||||
)
|
||||
self.trainer.logger.info(
|
||||
"Profile: \n"
|
||||
+ str(
|
||||
prof.key_averages().table(
|
||||
sort_by=self.sort_by, row_limit=self.row_limit
|
||||
)
|
||||
)
|
||||
)
|
||||
prof.stop()
|
||||
|
||||
if self.interrupt:
|
||||
sys.exit(0)
|
||||
285
engines/infer.py
Normal file
285
engines/infer.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Copyright 2024-2025 The Alibaba 3DAIGC Team Authors. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
import librosa
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .defaults import create_ddp_model
|
||||
import utils.comm as comm
|
||||
from models import build_model
|
||||
from utils.logger import get_root_logger
|
||||
from utils.registry import Registry
|
||||
from utils.misc import (
|
||||
AverageMeter,
|
||||
)
|
||||
|
||||
from models.utils import smooth_mouth_movements, apply_frame_blending, apply_savitzky_golay_smoothing, apply_random_brow_movement, \
|
||||
symmetrize_blendshapes, apply_random_eye_blinks, apply_random_eye_blinks_context, export_blendshape_animation, \
|
||||
RETURN_CODE, DEFAULT_CONTEXT, ARKitBlendShape
|
||||
|
||||
INFER = Registry("infer")
|
||||
|
||||
class InferBase:
|
||||
def __init__(self, cfg, model=None, verbose=False) -> None:
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
self.logger = get_root_logger(
|
||||
log_file=os.path.join(cfg.save_path, "infer.log"),
|
||||
file_mode="a" if cfg.resume else "w",
|
||||
)
|
||||
self.logger.info("=> Loading config ...")
|
||||
self.cfg = cfg
|
||||
self.verbose = verbose
|
||||
if self.verbose:
|
||||
self.logger.info(f"Save path: {cfg.save_path}")
|
||||
self.logger.info(f"Config:\n{cfg.pretty_text}")
|
||||
if model is None:
|
||||
self.logger.info("=> Building model ...")
|
||||
self.model = self.build_model()
|
||||
else:
|
||||
self.model = model
|
||||
|
||||
def build_model(self):
|
||||
model = build_model(self.cfg.model)
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Num params: {n_parameters}")
|
||||
model = create_ddp_model(
|
||||
model.cuda(),
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=self.cfg.find_unused_parameters,
|
||||
)
|
||||
if os.path.isfile(self.cfg.weight):
|
||||
self.logger.info(f"Loading weight at: {self.cfg.weight}")
|
||||
checkpoint = torch.load(self.cfg.weight)
|
||||
weight = OrderedDict()
|
||||
for key, value in checkpoint["state_dict"].items():
|
||||
if key.startswith("module."):
|
||||
if comm.get_world_size() == 1:
|
||||
key = key[7:] # module.xxx.xxx -> xxx.xxx
|
||||
else:
|
||||
if comm.get_world_size() > 1:
|
||||
key = "module." + key # xxx.xxx -> module.xxx.xxx
|
||||
weight[key] = value
|
||||
model.load_state_dict(weight, strict=True)
|
||||
self.logger.info(
|
||||
"=> Loaded weight '{}'".format(
|
||||
self.cfg.weight
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight))
|
||||
return model
|
||||
|
||||
|
||||
def infer(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@INFER.register_module()
|
||||
class Audio2ExpressionInfer(InferBase):
|
||||
def infer(self):
|
||||
logger = get_root_logger()
|
||||
logger.info(">>>>>>>>>>>>>>>> Start Inference >>>>>>>>>>>>>>>>")
|
||||
batch_time = AverageMeter()
|
||||
self.model.eval()
|
||||
|
||||
# process audio-input
|
||||
assert os.path.exists(self.cfg.audio_input)
|
||||
if(self.cfg.ex_vol):
|
||||
logger.info("Extract vocals ...")
|
||||
vocal_path = self.extract_vocal_track(self.cfg.audio_input)
|
||||
logger.info("=> Extract vocals at: {}".format(vocal_path if os.path.exists(vocal_path) else '... Failed'))
|
||||
if(os.path.exists(vocal_path)):
|
||||
self.cfg.audio_input = vocal_path
|
||||
|
||||
with torch.no_grad():
|
||||
input_dict = {}
|
||||
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
||||
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[None,...]
|
||||
speech_array, ssr = librosa.load(self.cfg.audio_input, sr=16000)
|
||||
input_dict['input_audio_array'] = torch.FloatTensor(speech_array).cuda(non_blocking=True)[None,...]
|
||||
|
||||
end = time.time()
|
||||
output_dict = self.model(input_dict)
|
||||
batch_time.update(time.time() - end)
|
||||
|
||||
logger.info(
|
||||
"Infer: [{}] "
|
||||
"Running Time: {batch_time.avg:.3f} ".format(
|
||||
self.cfg.audio_input,
|
||||
batch_time=batch_time,
|
||||
)
|
||||
)
|
||||
|
||||
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()
|
||||
|
||||
frame_length = math.ceil(speech_array.shape[0] / ssr * 30)
|
||||
volume = librosa.feature.rms(y=speech_array, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
|
||||
if (volume.shape[0] > frame_length):
|
||||
volume = volume[:frame_length]
|
||||
|
||||
if(self.cfg.movement_smooth):
|
||||
out_exp = smooth_mouth_movements(out_exp, 0, volume)
|
||||
|
||||
if (self.cfg.brow_movement):
|
||||
out_exp = apply_random_brow_movement(out_exp, volume)
|
||||
|
||||
pred_exp = self.blendshape_postprocess(out_exp)
|
||||
|
||||
if(self.cfg.save_json_path is not None):
|
||||
export_blendshape_animation(pred_exp,
|
||||
self.cfg.save_json_path,
|
||||
ARKitBlendShape,
|
||||
fps=self.cfg.fps)
|
||||
|
||||
logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||
|
||||
def infer_streaming_audio(self,
|
||||
audio: np.ndarray,
|
||||
ssr: float,
|
||||
context: dict):
|
||||
|
||||
if (context is None):
|
||||
context = DEFAULT_CONTEXT.copy()
|
||||
max_frame_length = 64
|
||||
|
||||
frame_length = math.ceil(audio.shape[0] / ssr * 30)
|
||||
output_context = DEFAULT_CONTEXT.copy()
|
||||
|
||||
volume = librosa.feature.rms(y=audio, frame_length=int(1 / 30 * ssr), hop_length=int(1 / 30 * ssr))[0]
|
||||
if (volume.shape[0] > frame_length):
|
||||
volume = volume[:frame_length]
|
||||
|
||||
# resample audio
|
||||
if (ssr != self.cfg.audio_sr):
|
||||
in_audio = librosa.resample(audio.astype(np.float32), orig_sr=ssr, target_sr=self.cfg.audio_sr)
|
||||
else:
|
||||
in_audio = audio.copy()
|
||||
|
||||
start_frame = int(max_frame_length - in_audio.shape[0] / self.cfg.audio_sr * 30)
|
||||
|
||||
if (context['is_initial_input'] or (context['previous_audio'] is None)):
|
||||
blank_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
|
||||
blank_audio = np.zeros(blank_audio_length, dtype=np.float32)
|
||||
|
||||
# pre-append
|
||||
input_audio = np.concatenate([blank_audio, in_audio])
|
||||
output_context['previous_audio'] = input_audio
|
||||
|
||||
else:
|
||||
clip_pre_audio_length = self.cfg.audio_sr * max_frame_length // 30 - in_audio.shape[0]
|
||||
clip_pre_audio = context['previous_audio'][-clip_pre_audio_length:]
|
||||
input_audio = np.concatenate([clip_pre_audio, in_audio])
|
||||
output_context['previous_audio'] = input_audio
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
input_dict = {}
|
||||
input_dict['id_idx'] = F.one_hot(torch.tensor(self.cfg.id_idx),
|
||||
self.cfg.model.backbone.num_identity_classes).cuda(non_blocking=True)[
|
||||
None, ...]
|
||||
input_dict['input_audio_array'] = torch.FloatTensor(input_audio).cuda(non_blocking=True)[None, ...]
|
||||
output_dict = self.model(input_dict)
|
||||
out_exp = output_dict['pred_exp'].squeeze().cpu().numpy()[start_frame:, :]
|
||||
except:
|
||||
self.logger.error('Error: faided to predict expression.')
|
||||
output_dict['pred_exp'] = torch.zeros((max_frame_length, 52)).float()
|
||||
return
|
||||
|
||||
|
||||
# post-process
|
||||
if (context['previous_expression'] is None):
|
||||
out_exp = self.apply_expression_postprocessing(out_exp, audio_volume=volume)
|
||||
else:
|
||||
previous_length = context['previous_expression'].shape[0]
|
||||
out_exp = self.apply_expression_postprocessing(expression_params = np.concatenate([context['previous_expression'], out_exp], axis=0),
|
||||
audio_volume=np.concatenate([context['previous_volume'], volume], axis=0),
|
||||
processed_frames=previous_length)[previous_length:, :]
|
||||
|
||||
if (context['previous_expression'] is not None):
|
||||
output_context['previous_expression'] = np.concatenate([context['previous_expression'], out_exp], axis=0)[
|
||||
-max_frame_length:, :]
|
||||
output_context['previous_volume'] = np.concatenate([context['previous_volume'], volume], axis=0)[-max_frame_length:]
|
||||
else:
|
||||
output_context['previous_expression'] = out_exp.copy()
|
||||
output_context['previous_volume'] = volume.copy()
|
||||
|
||||
output_context['first_input_flag'] = False
|
||||
|
||||
return {"code": RETURN_CODE['SUCCESS'],
|
||||
"expression": out_exp,
|
||||
"headpose": None}, output_context
|
||||
def apply_expression_postprocessing(
|
||||
self,
|
||||
expression_params: np.ndarray,
|
||||
processed_frames: int = 0,
|
||||
audio_volume: np.ndarray = None
|
||||
) -> np.ndarray:
|
||||
"""Applies full post-processing pipeline to facial expression parameters.
|
||||
|
||||
Args:
|
||||
expression_params: Raw output from animation model [num_frames, num_parameters]
|
||||
processed_frames: Number of frames already processed in previous batches
|
||||
audio_volume: Optional volume array for audio-visual synchronization
|
||||
|
||||
Returns:
|
||||
Processed expression parameters ready for animation synthesis
|
||||
"""
|
||||
# Pipeline execution order matters - maintain sequence
|
||||
expression_params = smooth_mouth_movements(expression_params, processed_frames, audio_volume)
|
||||
expression_params = apply_frame_blending(expression_params, processed_frames)
|
||||
expression_params, _ = apply_savitzky_golay_smoothing(expression_params, window_length=5)
|
||||
expression_params = symmetrize_blendshapes(expression_params)
|
||||
expression_params = apply_random_eye_blinks_context(expression_params, processed_frames=processed_frames)
|
||||
|
||||
return expression_params
|
||||
|
||||
def extract_vocal_track(
|
||||
self,
|
||||
input_audio_path: str
|
||||
) -> str:
|
||||
"""Isolates vocal track from audio file using source separation.
|
||||
|
||||
Args:
|
||||
input_audio_path: Path to input audio file containing vocals+accompaniment
|
||||
|
||||
Returns:
|
||||
Path to isolated vocal track in WAV format
|
||||
"""
|
||||
separation_command = f'spleeter separate -p spleeter:2stems -o {self.cfg.save_path} {input_audio_path}'
|
||||
os.system(separation_command)
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(input_audio_path))[0]
|
||||
return os.path.join(self.cfg.save_path, base_name, 'vocals.wav')
|
||||
|
||||
def blendshape_postprocess(self,
|
||||
bs_array: np.ndarray
|
||||
)->np.array:
|
||||
|
||||
bs_array, _ = apply_savitzky_golay_smoothing(bs_array, window_length=5)
|
||||
bs_array = symmetrize_blendshapes(bs_array)
|
||||
bs_array = apply_random_eye_blinks(bs_array)
|
||||
|
||||
return bs_array
|
||||
135
engines/launch.py
Normal file
135
engines/launch.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Launcher
|
||||
|
||||
modified from detectron2(https://github.com/facebookresearch/detectron2)
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from utils import comm
|
||||
|
||||
__all__ = ["DEFAULT_TIMEOUT", "launch"]
|
||||
|
||||
DEFAULT_TIMEOUT = timedelta(minutes=30)
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
import socket
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# Binding to port 0 will cause the OS to find an available port for us
|
||||
sock.bind(("", 0))
|
||||
port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
# NOTE: there is still a chance the port could be taken by other processes.
|
||||
return port
|
||||
|
||||
|
||||
def launch(
|
||||
main_func,
|
||||
num_gpus_per_machine,
|
||||
num_machines=1,
|
||||
machine_rank=0,
|
||||
dist_url=None,
|
||||
cfg=(),
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
):
|
||||
"""
|
||||
Launch multi-gpu or distributed training.
|
||||
This function must be called on all machines involved in the training.
|
||||
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
|
||||
Args:
|
||||
main_func: a function that will be called by `main_func(*args)`
|
||||
num_gpus_per_machine (int): number of GPUs per machine
|
||||
num_machines (int): the total number of machines
|
||||
machine_rank (int): the rank of this machine
|
||||
dist_url (str): url to connect to for distributed jobs, including protocol
|
||||
e.g. "tcp://127.0.0.1:8686".
|
||||
Can be set to "auto" to automatically select a free port on localhost
|
||||
timeout (timedelta): timeout of the distributed workers
|
||||
args (tuple): arguments passed to main_func
|
||||
"""
|
||||
world_size = num_machines * num_gpus_per_machine
|
||||
if world_size > 1:
|
||||
if dist_url == "auto":
|
||||
assert (
|
||||
num_machines == 1
|
||||
), "dist_url=auto not supported in multi-machine jobs."
|
||||
port = _find_free_port()
|
||||
dist_url = f"tcp://127.0.0.1:{port}"
|
||||
if num_machines > 1 and dist_url.startswith("file://"):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
|
||||
)
|
||||
|
||||
mp.spawn(
|
||||
_distributed_worker,
|
||||
nprocs=num_gpus_per_machine,
|
||||
args=(
|
||||
main_func,
|
||||
world_size,
|
||||
num_gpus_per_machine,
|
||||
machine_rank,
|
||||
dist_url,
|
||||
cfg,
|
||||
timeout,
|
||||
),
|
||||
daemon=False,
|
||||
)
|
||||
else:
|
||||
main_func(*cfg)
|
||||
|
||||
|
||||
def _distributed_worker(
|
||||
local_rank,
|
||||
main_func,
|
||||
world_size,
|
||||
num_gpus_per_machine,
|
||||
machine_rank,
|
||||
dist_url,
|
||||
cfg,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
):
|
||||
assert (
|
||||
torch.cuda.is_available()
|
||||
), "cuda is not available. Please check your installation."
|
||||
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
||||
try:
|
||||
dist.init_process_group(
|
||||
backend="NCCL",
|
||||
init_method=dist_url,
|
||||
world_size=world_size,
|
||||
rank=global_rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error("Process group URL: {}".format(dist_url))
|
||||
raise e
|
||||
|
||||
# Setup the local process group (which contains ranks within the same machine)
|
||||
assert comm._LOCAL_PROCESS_GROUP is None
|
||||
num_machines = world_size // num_gpus_per_machine
|
||||
for i in range(num_machines):
|
||||
ranks_on_i = list(
|
||||
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
|
||||
)
|
||||
pg = dist.new_group(ranks_on_i)
|
||||
if i == machine_rank:
|
||||
comm._LOCAL_PROCESS_GROUP = pg
|
||||
|
||||
assert num_gpus_per_machine <= torch.cuda.device_count()
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# synchronize is needed here to prevent a possible timeout after calling init_process_group
|
||||
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
|
||||
comm.synchronize()
|
||||
|
||||
main_func(*cfg)
|
||||
299
engines/train.py
Normal file
299
engines/train.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import weakref
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
from functools import partial
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from collections.abc import Iterator
|
||||
else:
|
||||
from collections import Iterator
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from .defaults import create_ddp_model, worker_init_fn
|
||||
from .hooks import HookBase, build_hooks
|
||||
import utils.comm as comm
|
||||
from datasets import build_dataset, point_collate_fn, collate_fn
|
||||
from models import build_model
|
||||
from utils.logger import get_root_logger
|
||||
from utils.optimizer import build_optimizer
|
||||
from utils.scheduler import build_scheduler
|
||||
from utils.events import EventStorage
|
||||
from utils.registry import Registry
|
||||
|
||||
|
||||
TRAINERS = Registry("trainers")
|
||||
|
||||
|
||||
class TrainerBase:
|
||||
def __init__(self) -> None:
|
||||
self.hooks = []
|
||||
self.epoch = 0
|
||||
self.start_epoch = 0
|
||||
self.max_epoch = 0
|
||||
self.max_iter = 0
|
||||
self.comm_info = dict()
|
||||
self.data_iterator: Iterator = enumerate([])
|
||||
self.storage: EventStorage
|
||||
self.writer: SummaryWriter
|
||||
|
||||
def register_hooks(self, hooks) -> None:
|
||||
hooks = build_hooks(hooks)
|
||||
for h in hooks:
|
||||
assert isinstance(h, HookBase)
|
||||
# To avoid circular reference, hooks and trainer cannot own each other.
|
||||
# This normally does not matter, but will cause memory leak if the
|
||||
# involved objects contain __del__:
|
||||
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
|
||||
h.trainer = weakref.proxy(self)
|
||||
self.hooks.extend(hooks)
|
||||
|
||||
def train(self):
|
||||
with EventStorage() as self.storage:
|
||||
# => before train
|
||||
self.before_train()
|
||||
for self.epoch in range(self.start_epoch, self.max_epoch):
|
||||
# => before epoch
|
||||
self.before_epoch()
|
||||
# => run_epoch
|
||||
for (
|
||||
self.comm_info["iter"],
|
||||
self.comm_info["input_dict"],
|
||||
) in self.data_iterator:
|
||||
# => before_step
|
||||
self.before_step()
|
||||
# => run_step
|
||||
self.run_step()
|
||||
# => after_step
|
||||
self.after_step()
|
||||
# => after epoch
|
||||
self.after_epoch()
|
||||
# => after train
|
||||
self.after_train()
|
||||
|
||||
def before_train(self):
|
||||
for h in self.hooks:
|
||||
h.before_train()
|
||||
|
||||
def before_epoch(self):
|
||||
for h in self.hooks:
|
||||
h.before_epoch()
|
||||
|
||||
def before_step(self):
|
||||
for h in self.hooks:
|
||||
h.before_step()
|
||||
|
||||
def run_step(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def after_step(self):
|
||||
for h in self.hooks:
|
||||
h.after_step()
|
||||
|
||||
def after_epoch(self):
|
||||
for h in self.hooks:
|
||||
h.after_epoch()
|
||||
self.storage.reset_histories()
|
||||
|
||||
def after_train(self):
|
||||
# Sync GPU before running train hooks
|
||||
comm.synchronize()
|
||||
for h in self.hooks:
|
||||
h.after_train()
|
||||
if comm.is_main_process():
|
||||
self.writer.close()
|
||||
|
||||
|
||||
@TRAINERS.register_module("DefaultTrainer")
|
||||
class Trainer(TrainerBase):
|
||||
def __init__(self, cfg):
|
||||
super(Trainer, self).__init__()
|
||||
self.epoch = 0
|
||||
self.start_epoch = 0
|
||||
self.max_epoch = cfg.eval_epoch
|
||||
self.best_metric_value = -torch.inf
|
||||
self.logger = get_root_logger(
|
||||
log_file=os.path.join(cfg.save_path, "train.log"),
|
||||
file_mode="a" if cfg.resume else "w",
|
||||
)
|
||||
self.logger.info("=> Loading config ...")
|
||||
self.cfg = cfg
|
||||
self.logger.info(f"Save path: {cfg.save_path}")
|
||||
self.logger.info(f"Config:\n{cfg.pretty_text}")
|
||||
self.logger.info("=> Building model ...")
|
||||
self.model = self.build_model()
|
||||
self.logger.info("=> Building writer ...")
|
||||
self.writer = self.build_writer()
|
||||
self.logger.info("=> Building train dataset & dataloader ...")
|
||||
self.train_loader = self.build_train_loader()
|
||||
self.logger.info("=> Building val dataset & dataloader ...")
|
||||
self.val_loader = self.build_val_loader()
|
||||
self.logger.info("=> Building optimize, scheduler, scaler(amp) ...")
|
||||
self.optimizer = self.build_optimizer()
|
||||
self.scheduler = self.build_scheduler()
|
||||
self.scaler = self.build_scaler()
|
||||
self.logger.info("=> Building hooks ...")
|
||||
self.register_hooks(self.cfg.hooks)
|
||||
|
||||
def train(self):
|
||||
with EventStorage() as self.storage:
|
||||
# => before train
|
||||
self.before_train()
|
||||
self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
|
||||
for self.epoch in range(self.start_epoch, self.max_epoch):
|
||||
# => before epoch
|
||||
# TODO: optimize to iteration based
|
||||
if comm.get_world_size() > 1:
|
||||
self.train_loader.sampler.set_epoch(self.epoch)
|
||||
self.model.train()
|
||||
self.data_iterator = enumerate(self.train_loader)
|
||||
self.before_epoch()
|
||||
# => run_epoch
|
||||
for (
|
||||
self.comm_info["iter"],
|
||||
self.comm_info["input_dict"],
|
||||
) in self.data_iterator:
|
||||
# => before_step
|
||||
self.before_step()
|
||||
# => run_step
|
||||
self.run_step()
|
||||
# => after_step
|
||||
self.after_step()
|
||||
# => after epoch
|
||||
self.after_epoch()
|
||||
# => after train
|
||||
self.after_train()
|
||||
|
||||
def run_step(self):
|
||||
input_dict = self.comm_info["input_dict"]
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp):
|
||||
output_dict = self.model(input_dict)
|
||||
loss = output_dict["loss"]
|
||||
self.optimizer.zero_grad()
|
||||
if self.cfg.enable_amp:
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
|
||||
# When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
|
||||
# Fix torch warning scheduler step before optimizer step.
|
||||
scaler = self.scaler.get_scale()
|
||||
self.scaler.update()
|
||||
if scaler <= self.scaler.get_scale():
|
||||
self.scheduler.step()
|
||||
else:
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
if self.cfg.empty_cache:
|
||||
torch.cuda.empty_cache()
|
||||
self.comm_info["model_output_dict"] = output_dict
|
||||
|
||||
def build_model(self):
|
||||
model = build_model(self.cfg.model)
|
||||
if self.cfg.sync_bn:
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
# logger.info(f"Model: \n{self.model}")
|
||||
self.logger.info(f"Num params: {n_parameters}")
|
||||
model = create_ddp_model(
|
||||
model.cuda(),
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=self.cfg.find_unused_parameters,
|
||||
)
|
||||
return model
|
||||
|
||||
def build_writer(self):
|
||||
writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None
|
||||
self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}")
|
||||
return writer
|
||||
|
||||
def build_train_loader(self):
|
||||
train_data = build_dataset(self.cfg.data.train)
|
||||
|
||||
if comm.get_world_size() > 1:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
init_fn = (
|
||||
partial(
|
||||
worker_init_fn,
|
||||
num_workers=self.cfg.num_worker_per_gpu,
|
||||
rank=comm.get_rank(),
|
||||
seed=self.cfg.seed,
|
||||
)
|
||||
if self.cfg.seed is not None
|
||||
else None
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data,
|
||||
batch_size=self.cfg.batch_size_per_gpu,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=0,
|
||||
sampler=train_sampler,
|
||||
collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob),
|
||||
pin_memory=True,
|
||||
worker_init_fn=init_fn,
|
||||
drop_last=True,
|
||||
# persistent_workers=True,
|
||||
)
|
||||
return train_loader
|
||||
|
||||
def build_val_loader(self):
|
||||
val_loader = None
|
||||
if self.cfg.evaluate:
|
||||
val_data = build_dataset(self.cfg.data.val)
|
||||
if comm.get_world_size() > 1:
|
||||
val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
|
||||
else:
|
||||
val_sampler = None
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_data,
|
||||
batch_size=self.cfg.batch_size_val_per_gpu,
|
||||
shuffle=False,
|
||||
num_workers=self.cfg.num_worker_per_gpu,
|
||||
pin_memory=True,
|
||||
sampler=val_sampler,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
return val_loader
|
||||
|
||||
def build_optimizer(self):
|
||||
return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts)
|
||||
|
||||
def build_scheduler(self):
|
||||
assert hasattr(self, "optimizer")
|
||||
assert hasattr(self, "train_loader")
|
||||
self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch
|
||||
return build_scheduler(self.cfg.scheduler, self.optimizer)
|
||||
|
||||
def build_scaler(self):
|
||||
scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None
|
||||
return scaler
|
||||
|
||||
|
||||
@TRAINERS.register_module("MultiDatasetTrainer")
|
||||
class MultiDatasetTrainer(Trainer):
|
||||
def build_train_loader(self):
|
||||
from datasets import MultiDatasetDataloader
|
||||
|
||||
train_data = build_dataset(self.cfg.data.train)
|
||||
train_loader = MultiDatasetDataloader(
|
||||
train_data,
|
||||
self.cfg.batch_size_per_gpu,
|
||||
self.cfg.num_worker_per_gpu,
|
||||
self.cfg.mix_prob,
|
||||
self.cfg.seed,
|
||||
)
|
||||
self.comm_info["iter_per_epoch"] = len(train_loader)
|
||||
return train_loader
|
||||
Reference in New Issue
Block a user