mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 09:59:21 +08:00
feat: Initial commit
This commit is contained in:
5
engines/hooks/__init__.py
Normal file
5
engines/hooks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .default import HookBase
|
||||
from .misc import *
|
||||
from .evaluator import *
|
||||
|
||||
from .builder import build_hooks
|
||||
15
engines/hooks/builder.py
Normal file
15
engines/hooks/builder.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
from utils.registry import Registry
|
||||
|
||||
|
||||
HOOKS = Registry("hooks")
|
||||
|
||||
|
||||
def build_hooks(cfg):
|
||||
hooks = []
|
||||
for hook_cfg in cfg:
|
||||
hooks.append(HOOKS.build(hook_cfg))
|
||||
return hooks
|
||||
29
engines/hooks/default.py
Normal file
29
engines/hooks/default.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
|
||||
class HookBase:
|
||||
"""
|
||||
Base class for hooks that can be registered with :class:`TrainerBase`.
|
||||
"""
|
||||
|
||||
trainer = None # A weak reference to the trainer object.
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def before_epoch(self):
|
||||
pass
|
||||
|
||||
def before_step(self):
|
||||
pass
|
||||
|
||||
def after_step(self):
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def after_train(self):
|
||||
pass
|
||||
577
engines/hooks/evaluator.py
Normal file
577
engines/hooks/evaluator.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from uuid import uuid4
|
||||
|
||||
import utils.comm as comm
|
||||
from utils.misc import intersection_and_union_gpu
|
||||
|
||||
from .default import HookBase
|
||||
from .builder import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ClsEvaluator(HookBase):
|
||||
def after_epoch(self):
|
||||
if self.trainer.cfg.evaluate:
|
||||
self.eval()
|
||||
|
||||
def eval(self):
|
||||
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
|
||||
self.trainer.model.eval()
|
||||
for i, input_dict in enumerate(self.trainer.val_loader):
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
output = output_dict["cls_logits"]
|
||||
loss = output_dict["loss"]
|
||||
pred = output.max(1)[1]
|
||||
label = input_dict["category"]
|
||||
intersection, union, target = intersection_and_union_gpu(
|
||||
pred,
|
||||
label,
|
||||
self.trainer.cfg.data.num_classes,
|
||||
self.trainer.cfg.data.ignore_index,
|
||||
)
|
||||
if comm.get_world_size() > 1:
|
||||
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
|
||||
target
|
||||
)
|
||||
intersection, union, target = (
|
||||
intersection.cpu().numpy(),
|
||||
union.cpu().numpy(),
|
||||
target.cpu().numpy(),
|
||||
)
|
||||
# Here there is no need to sync since sync happened in dist.all_reduce
|
||||
self.trainer.storage.put_scalar("val_intersection", intersection)
|
||||
self.trainer.storage.put_scalar("val_union", union)
|
||||
self.trainer.storage.put_scalar("val_target", target)
|
||||
self.trainer.storage.put_scalar("val_loss", loss.item())
|
||||
self.trainer.logger.info(
|
||||
"Test: [{iter}/{max_iter}] "
|
||||
"Loss {loss:.4f} ".format(
|
||||
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
|
||||
)
|
||||
)
|
||||
loss_avg = self.trainer.storage.history("val_loss").avg
|
||||
intersection = self.trainer.storage.history("val_intersection").total
|
||||
union = self.trainer.storage.history("val_union").total
|
||||
target = self.trainer.storage.history("val_target").total
|
||||
iou_class = intersection / (union + 1e-10)
|
||||
acc_class = intersection / (target + 1e-10)
|
||||
m_iou = np.mean(iou_class)
|
||||
m_acc = np.mean(acc_class)
|
||||
all_acc = sum(intersection) / (sum(target) + 1e-10)
|
||||
self.trainer.logger.info(
|
||||
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
|
||||
m_iou, m_acc, all_acc
|
||||
)
|
||||
)
|
||||
for i in range(self.trainer.cfg.data.num_classes):
|
||||
self.trainer.logger.info(
|
||||
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
|
||||
idx=i,
|
||||
name=self.trainer.cfg.data.names[i],
|
||||
iou=iou_class[i],
|
||||
accuracy=acc_class[i],
|
||||
)
|
||||
)
|
||||
current_epoch = self.trainer.epoch + 1
|
||||
if self.trainer.writer is not None:
|
||||
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
|
||||
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||
self.trainer.comm_info["current_metric_value"] = all_acc # save for saver
|
||||
self.trainer.comm_info["current_metric_name"] = "allAcc" # save for saver
|
||||
|
||||
def after_train(self):
|
||||
self.trainer.logger.info(
|
||||
"Best {}: {:.4f}".format("allAcc", self.trainer.best_metric_value)
|
||||
)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SemSegEvaluator(HookBase):
|
||||
def after_epoch(self):
|
||||
if self.trainer.cfg.evaluate:
|
||||
self.eval()
|
||||
|
||||
def eval(self):
|
||||
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
|
||||
self.trainer.model.eval()
|
||||
for i, input_dict in enumerate(self.trainer.val_loader):
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
output = output_dict["seg_logits"]
|
||||
loss = output_dict["loss"]
|
||||
pred = output.max(1)[1]
|
||||
segment = input_dict["segment"]
|
||||
if "origin_coord" in input_dict.keys():
|
||||
idx, _ = pointops.knn_query(
|
||||
1,
|
||||
input_dict["coord"].float(),
|
||||
input_dict["offset"].int(),
|
||||
input_dict["origin_coord"].float(),
|
||||
input_dict["origin_offset"].int(),
|
||||
)
|
||||
pred = pred[idx.flatten().long()]
|
||||
segment = input_dict["origin_segment"]
|
||||
intersection, union, target = intersection_and_union_gpu(
|
||||
pred,
|
||||
segment,
|
||||
self.trainer.cfg.data.num_classes,
|
||||
self.trainer.cfg.data.ignore_index,
|
||||
)
|
||||
if comm.get_world_size() > 1:
|
||||
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(
|
||||
target
|
||||
)
|
||||
intersection, union, target = (
|
||||
intersection.cpu().numpy(),
|
||||
union.cpu().numpy(),
|
||||
target.cpu().numpy(),
|
||||
)
|
||||
# Here there is no need to sync since sync happened in dist.all_reduce
|
||||
self.trainer.storage.put_scalar("val_intersection", intersection)
|
||||
self.trainer.storage.put_scalar("val_union", union)
|
||||
self.trainer.storage.put_scalar("val_target", target)
|
||||
self.trainer.storage.put_scalar("val_loss", loss.item())
|
||||
info = "Test: [{iter}/{max_iter}] ".format(
|
||||
iter=i + 1, max_iter=len(self.trainer.val_loader)
|
||||
)
|
||||
if "origin_coord" in input_dict.keys():
|
||||
info = "Interp. " + info
|
||||
self.trainer.logger.info(
|
||||
info
|
||||
+ "Loss {loss:.4f} ".format(
|
||||
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
|
||||
)
|
||||
)
|
||||
loss_avg = self.trainer.storage.history("val_loss").avg
|
||||
intersection = self.trainer.storage.history("val_intersection").total
|
||||
union = self.trainer.storage.history("val_union").total
|
||||
target = self.trainer.storage.history("val_target").total
|
||||
iou_class = intersection / (union + 1e-10)
|
||||
acc_class = intersection / (target + 1e-10)
|
||||
m_iou = np.mean(iou_class)
|
||||
m_acc = np.mean(acc_class)
|
||||
all_acc = sum(intersection) / (sum(target) + 1e-10)
|
||||
self.trainer.logger.info(
|
||||
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(
|
||||
m_iou, m_acc, all_acc
|
||||
)
|
||||
)
|
||||
for i in range(self.trainer.cfg.data.num_classes):
|
||||
self.trainer.logger.info(
|
||||
"Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
|
||||
idx=i,
|
||||
name=self.trainer.cfg.data.names[i],
|
||||
iou=iou_class[i],
|
||||
accuracy=acc_class[i],
|
||||
)
|
||||
)
|
||||
current_epoch = self.trainer.epoch + 1
|
||||
if self.trainer.writer is not None:
|
||||
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch)
|
||||
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||
self.trainer.comm_info["current_metric_value"] = m_iou # save for saver
|
||||
self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver
|
||||
|
||||
def after_train(self):
|
||||
self.trainer.logger.info(
|
||||
"Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value)
|
||||
)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class InsSegEvaluator(HookBase):
|
||||
def __init__(self, segment_ignore_index=(-1,), instance_ignore_index=-1):
|
||||
self.segment_ignore_index = segment_ignore_index
|
||||
self.instance_ignore_index = instance_ignore_index
|
||||
|
||||
self.valid_class_names = None # update in before train
|
||||
self.overlaps = np.append(np.arange(0.5, 0.95, 0.05), 0.25)
|
||||
self.min_region_sizes = 100
|
||||
self.distance_threshes = float("inf")
|
||||
self.distance_confs = -float("inf")
|
||||
|
||||
def before_train(self):
|
||||
self.valid_class_names = [
|
||||
self.trainer.cfg.data.names[i]
|
||||
for i in range(self.trainer.cfg.data.num_classes)
|
||||
if i not in self.segment_ignore_index
|
||||
]
|
||||
|
||||
def after_epoch(self):
|
||||
if self.trainer.cfg.evaluate:
|
||||
self.eval()
|
||||
|
||||
def associate_instances(self, pred, segment, instance):
|
||||
segment = segment.cpu().numpy()
|
||||
instance = instance.cpu().numpy()
|
||||
void_mask = np.in1d(segment, self.segment_ignore_index)
|
||||
|
||||
assert (
|
||||
pred["pred_classes"].shape[0]
|
||||
== pred["pred_scores"].shape[0]
|
||||
== pred["pred_masks"].shape[0]
|
||||
)
|
||||
assert pred["pred_masks"].shape[1] == segment.shape[0] == instance.shape[0]
|
||||
# get gt instances
|
||||
gt_instances = dict()
|
||||
for i in range(self.trainer.cfg.data.num_classes):
|
||||
if i not in self.segment_ignore_index:
|
||||
gt_instances[self.trainer.cfg.data.names[i]] = []
|
||||
instance_ids, idx, counts = np.unique(
|
||||
instance, return_index=True, return_counts=True
|
||||
)
|
||||
segment_ids = segment[idx]
|
||||
for i in range(len(instance_ids)):
|
||||
if instance_ids[i] == self.instance_ignore_index:
|
||||
continue
|
||||
if segment_ids[i] in self.segment_ignore_index:
|
||||
continue
|
||||
gt_inst = dict()
|
||||
gt_inst["instance_id"] = instance_ids[i]
|
||||
gt_inst["segment_id"] = segment_ids[i]
|
||||
gt_inst["dist_conf"] = 0.0
|
||||
gt_inst["med_dist"] = -1.0
|
||||
gt_inst["vert_count"] = counts[i]
|
||||
gt_inst["matched_pred"] = []
|
||||
gt_instances[self.trainer.cfg.data.names[segment_ids[i]]].append(gt_inst)
|
||||
|
||||
# get pred instances and associate with gt
|
||||
pred_instances = dict()
|
||||
for i in range(self.trainer.cfg.data.num_classes):
|
||||
if i not in self.segment_ignore_index:
|
||||
pred_instances[self.trainer.cfg.data.names[i]] = []
|
||||
instance_id = 0
|
||||
for i in range(len(pred["pred_classes"])):
|
||||
if pred["pred_classes"][i] in self.segment_ignore_index:
|
||||
continue
|
||||
pred_inst = dict()
|
||||
pred_inst["uuid"] = uuid4()
|
||||
pred_inst["instance_id"] = instance_id
|
||||
pred_inst["segment_id"] = pred["pred_classes"][i]
|
||||
pred_inst["confidence"] = pred["pred_scores"][i]
|
||||
pred_inst["mask"] = np.not_equal(pred["pred_masks"][i], 0)
|
||||
pred_inst["vert_count"] = np.count_nonzero(pred_inst["mask"])
|
||||
pred_inst["void_intersection"] = np.count_nonzero(
|
||||
np.logical_and(void_mask, pred_inst["mask"])
|
||||
)
|
||||
if pred_inst["vert_count"] < self.min_region_sizes:
|
||||
continue # skip if empty
|
||||
segment_name = self.trainer.cfg.data.names[pred_inst["segment_id"]]
|
||||
matched_gt = []
|
||||
for gt_idx, gt_inst in enumerate(gt_instances[segment_name]):
|
||||
intersection = np.count_nonzero(
|
||||
np.logical_and(
|
||||
instance == gt_inst["instance_id"], pred_inst["mask"]
|
||||
)
|
||||
)
|
||||
if intersection > 0:
|
||||
gt_inst_ = gt_inst.copy()
|
||||
pred_inst_ = pred_inst.copy()
|
||||
gt_inst_["intersection"] = intersection
|
||||
pred_inst_["intersection"] = intersection
|
||||
matched_gt.append(gt_inst_)
|
||||
gt_inst["matched_pred"].append(pred_inst_)
|
||||
pred_inst["matched_gt"] = matched_gt
|
||||
pred_instances[segment_name].append(pred_inst)
|
||||
instance_id += 1
|
||||
return gt_instances, pred_instances
|
||||
|
||||
def evaluate_matches(self, scenes):
|
||||
overlaps = self.overlaps
|
||||
min_region_sizes = [self.min_region_sizes]
|
||||
dist_threshes = [self.distance_threshes]
|
||||
dist_confs = [self.distance_confs]
|
||||
|
||||
# results: class x overlap
|
||||
ap_table = np.zeros(
|
||||
(len(dist_threshes), len(self.valid_class_names), len(overlaps)), float
|
||||
)
|
||||
for di, (min_region_size, distance_thresh, distance_conf) in enumerate(
|
||||
zip(min_region_sizes, dist_threshes, dist_confs)
|
||||
):
|
||||
for oi, overlap_th in enumerate(overlaps):
|
||||
pred_visited = {}
|
||||
for scene in scenes:
|
||||
for _ in scene["pred"]:
|
||||
for label_name in self.valid_class_names:
|
||||
for p in scene["pred"][label_name]:
|
||||
if "uuid" in p:
|
||||
pred_visited[p["uuid"]] = False
|
||||
for li, label_name in enumerate(self.valid_class_names):
|
||||
y_true = np.empty(0)
|
||||
y_score = np.empty(0)
|
||||
hard_false_negatives = 0
|
||||
has_gt = False
|
||||
has_pred = False
|
||||
for scene in scenes:
|
||||
pred_instances = scene["pred"][label_name]
|
||||
gt_instances = scene["gt"][label_name]
|
||||
# filter groups in ground truth
|
||||
gt_instances = [
|
||||
gt
|
||||
for gt in gt_instances
|
||||
if gt["vert_count"] >= min_region_size
|
||||
and gt["med_dist"] <= distance_thresh
|
||||
and gt["dist_conf"] >= distance_conf
|
||||
]
|
||||
if gt_instances:
|
||||
has_gt = True
|
||||
if pred_instances:
|
||||
has_pred = True
|
||||
|
||||
cur_true = np.ones(len(gt_instances))
|
||||
cur_score = np.ones(len(gt_instances)) * (-float("inf"))
|
||||
cur_match = np.zeros(len(gt_instances), dtype=bool)
|
||||
# collect matches
|
||||
for gti, gt in enumerate(gt_instances):
|
||||
found_match = False
|
||||
for pred in gt["matched_pred"]:
|
||||
# greedy assignments
|
||||
if pred_visited[pred["uuid"]]:
|
||||
continue
|
||||
overlap = float(pred["intersection"]) / (
|
||||
gt["vert_count"]
|
||||
+ pred["vert_count"]
|
||||
- pred["intersection"]
|
||||
)
|
||||
if overlap > overlap_th:
|
||||
confidence = pred["confidence"]
|
||||
# if already have a prediction for this gt,
|
||||
# the prediction with the lower score is automatically a false positive
|
||||
if cur_match[gti]:
|
||||
max_score = max(cur_score[gti], confidence)
|
||||
min_score = min(cur_score[gti], confidence)
|
||||
cur_score[gti] = max_score
|
||||
# append false positive
|
||||
cur_true = np.append(cur_true, 0)
|
||||
cur_score = np.append(cur_score, min_score)
|
||||
cur_match = np.append(cur_match, True)
|
||||
# otherwise set score
|
||||
else:
|
||||
found_match = True
|
||||
cur_match[gti] = True
|
||||
cur_score[gti] = confidence
|
||||
pred_visited[pred["uuid"]] = True
|
||||
if not found_match:
|
||||
hard_false_negatives += 1
|
||||
# remove non-matched ground truth instances
|
||||
cur_true = cur_true[cur_match]
|
||||
cur_score = cur_score[cur_match]
|
||||
|
||||
# collect non-matched predictions as false positive
|
||||
for pred in pred_instances:
|
||||
found_gt = False
|
||||
for gt in pred["matched_gt"]:
|
||||
overlap = float(gt["intersection"]) / (
|
||||
gt["vert_count"]
|
||||
+ pred["vert_count"]
|
||||
- gt["intersection"]
|
||||
)
|
||||
if overlap > overlap_th:
|
||||
found_gt = True
|
||||
break
|
||||
if not found_gt:
|
||||
num_ignore = pred["void_intersection"]
|
||||
for gt in pred["matched_gt"]:
|
||||
if gt["segment_id"] in self.segment_ignore_index:
|
||||
num_ignore += gt["intersection"]
|
||||
# small ground truth instances
|
||||
if (
|
||||
gt["vert_count"] < min_region_size
|
||||
or gt["med_dist"] > distance_thresh
|
||||
or gt["dist_conf"] < distance_conf
|
||||
):
|
||||
num_ignore += gt["intersection"]
|
||||
proportion_ignore = (
|
||||
float(num_ignore) / pred["vert_count"]
|
||||
)
|
||||
# if not ignored append false positive
|
||||
if proportion_ignore <= overlap_th:
|
||||
cur_true = np.append(cur_true, 0)
|
||||
confidence = pred["confidence"]
|
||||
cur_score = np.append(cur_score, confidence)
|
||||
|
||||
# append to overall results
|
||||
y_true = np.append(y_true, cur_true)
|
||||
y_score = np.append(y_score, cur_score)
|
||||
|
||||
# compute average precision
|
||||
if has_gt and has_pred:
|
||||
# compute precision recall curve first
|
||||
|
||||
# sorting and cumsum
|
||||
score_arg_sort = np.argsort(y_score)
|
||||
y_score_sorted = y_score[score_arg_sort]
|
||||
y_true_sorted = y_true[score_arg_sort]
|
||||
y_true_sorted_cumsum = np.cumsum(y_true_sorted)
|
||||
|
||||
# unique thresholds
|
||||
(thresholds, unique_indices) = np.unique(
|
||||
y_score_sorted, return_index=True
|
||||
)
|
||||
num_prec_recall = len(unique_indices) + 1
|
||||
|
||||
# prepare precision recall
|
||||
num_examples = len(y_score_sorted)
|
||||
# https://github.com/ScanNet/ScanNet/pull/26
|
||||
# all predictions are non-matched but also all of them are ignored and not counted as FP
|
||||
# y_true_sorted_cumsum is empty
|
||||
# num_true_examples = y_true_sorted_cumsum[-1]
|
||||
num_true_examples = (
|
||||
y_true_sorted_cumsum[-1]
|
||||
if len(y_true_sorted_cumsum) > 0
|
||||
else 0
|
||||
)
|
||||
precision = np.zeros(num_prec_recall)
|
||||
recall = np.zeros(num_prec_recall)
|
||||
|
||||
# deal with the first point
|
||||
y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
|
||||
# deal with remaining
|
||||
for idx_res, idx_scores in enumerate(unique_indices):
|
||||
cumsum = y_true_sorted_cumsum[idx_scores - 1]
|
||||
tp = num_true_examples - cumsum
|
||||
fp = num_examples - idx_scores - tp
|
||||
fn = cumsum + hard_false_negatives
|
||||
p = float(tp) / (tp + fp)
|
||||
r = float(tp) / (tp + fn)
|
||||
precision[idx_res] = p
|
||||
recall[idx_res] = r
|
||||
|
||||
# first point in curve is artificial
|
||||
precision[-1] = 1.0
|
||||
recall[-1] = 0.0
|
||||
|
||||
# compute average of precision-recall curve
|
||||
recall_for_conv = np.copy(recall)
|
||||
recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
|
||||
recall_for_conv = np.append(recall_for_conv, 0.0)
|
||||
|
||||
stepWidths = np.convolve(
|
||||
recall_for_conv, [-0.5, 0, 0.5], "valid"
|
||||
)
|
||||
# integrate is now simply a dot product
|
||||
ap_current = np.dot(precision, stepWidths)
|
||||
|
||||
elif has_gt:
|
||||
ap_current = 0.0
|
||||
else:
|
||||
ap_current = float("nan")
|
||||
ap_table[di, li, oi] = ap_current
|
||||
d_inf = 0
|
||||
o50 = np.where(np.isclose(self.overlaps, 0.5))
|
||||
o25 = np.where(np.isclose(self.overlaps, 0.25))
|
||||
oAllBut25 = np.where(np.logical_not(np.isclose(self.overlaps, 0.25)))
|
||||
ap_scores = dict()
|
||||
ap_scores["all_ap"] = np.nanmean(ap_table[d_inf, :, oAllBut25])
|
||||
ap_scores["all_ap_50%"] = np.nanmean(ap_table[d_inf, :, o50])
|
||||
ap_scores["all_ap_25%"] = np.nanmean(ap_table[d_inf, :, o25])
|
||||
ap_scores["classes"] = {}
|
||||
for li, label_name in enumerate(self.valid_class_names):
|
||||
ap_scores["classes"][label_name] = {}
|
||||
ap_scores["classes"][label_name]["ap"] = np.average(
|
||||
ap_table[d_inf, li, oAllBut25]
|
||||
)
|
||||
ap_scores["classes"][label_name]["ap50%"] = np.average(
|
||||
ap_table[d_inf, li, o50]
|
||||
)
|
||||
ap_scores["classes"][label_name]["ap25%"] = np.average(
|
||||
ap_table[d_inf, li, o25]
|
||||
)
|
||||
return ap_scores
|
||||
|
||||
def eval(self):
|
||||
self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
|
||||
self.trainer.model.eval()
|
||||
scenes = []
|
||||
for i, input_dict in enumerate(self.trainer.val_loader):
|
||||
assert (
|
||||
len(input_dict["offset"]) == 1
|
||||
) # currently only support bs 1 for each GPU
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
|
||||
loss = output_dict["loss"]
|
||||
|
||||
segment = input_dict["segment"]
|
||||
instance = input_dict["instance"]
|
||||
# map to origin
|
||||
if "origin_coord" in input_dict.keys():
|
||||
idx, _ = pointops.knn_query(
|
||||
1,
|
||||
input_dict["coord"].float(),
|
||||
input_dict["offset"].int(),
|
||||
input_dict["origin_coord"].float(),
|
||||
input_dict["origin_offset"].int(),
|
||||
)
|
||||
idx = idx.cpu().flatten().long()
|
||||
output_dict["pred_masks"] = output_dict["pred_masks"][:, idx]
|
||||
segment = input_dict["origin_segment"]
|
||||
instance = input_dict["origin_instance"]
|
||||
|
||||
gt_instances, pred_instance = self.associate_instances(
|
||||
output_dict, segment, instance
|
||||
)
|
||||
scenes.append(dict(gt=gt_instances, pred=pred_instance))
|
||||
|
||||
self.trainer.storage.put_scalar("val_loss", loss.item())
|
||||
self.trainer.logger.info(
|
||||
"Test: [{iter}/{max_iter}] "
|
||||
"Loss {loss:.4f} ".format(
|
||||
iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()
|
||||
)
|
||||
)
|
||||
|
||||
loss_avg = self.trainer.storage.history("val_loss").avg
|
||||
comm.synchronize()
|
||||
scenes_sync = comm.gather(scenes, dst=0)
|
||||
scenes = [scene for scenes_ in scenes_sync for scene in scenes_]
|
||||
ap_scores = self.evaluate_matches(scenes)
|
||||
all_ap = ap_scores["all_ap"]
|
||||
all_ap_50 = ap_scores["all_ap_50%"]
|
||||
all_ap_25 = ap_scores["all_ap_25%"]
|
||||
self.trainer.logger.info(
|
||||
"Val result: mAP/AP50/AP25 {:.4f}/{:.4f}/{:.4f}.".format(
|
||||
all_ap, all_ap_50, all_ap_25
|
||||
)
|
||||
)
|
||||
for i, label_name in enumerate(self.valid_class_names):
|
||||
ap = ap_scores["classes"][label_name]["ap"]
|
||||
ap_50 = ap_scores["classes"][label_name]["ap50%"]
|
||||
ap_25 = ap_scores["classes"][label_name]["ap25%"]
|
||||
self.trainer.logger.info(
|
||||
"Class_{idx}-{name} Result: AP/AP50/AP25 {AP:.4f}/{AP50:.4f}/{AP25:.4f}".format(
|
||||
idx=i, name=label_name, AP=ap, AP50=ap_50, AP25=ap_25
|
||||
)
|
||||
)
|
||||
current_epoch = self.trainer.epoch + 1
|
||||
if self.trainer.writer is not None:
|
||||
self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/mAP", all_ap, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/AP50", all_ap_50, current_epoch)
|
||||
self.trainer.writer.add_scalar("val/AP25", all_ap_25, current_epoch)
|
||||
self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
|
||||
self.trainer.comm_info["current_metric_value"] = all_ap_50 # save for saver
|
||||
self.trainer.comm_info["current_metric_name"] = "AP50" # save for saver
|
||||
460
engines/hooks/misc.py
Normal file
460
engines/hooks/misc.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import sys
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from collections import OrderedDict
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from collections.abc import Sequence
|
||||
else:
|
||||
from collections import Sequence
|
||||
from utils.timer import Timer
|
||||
from utils.comm import is_main_process, synchronize, get_world_size
|
||||
from utils.cache import shared_dict
|
||||
|
||||
import utils.comm as comm
|
||||
from engines.test import TESTERS
|
||||
|
||||
from .default import HookBase
|
||||
from .builder import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class IterationTimer(HookBase):
|
||||
def __init__(self, warmup_iter=1):
|
||||
self._warmup_iter = warmup_iter
|
||||
self._start_time = time.perf_counter()
|
||||
self._iter_timer = Timer()
|
||||
self._remain_iter = 0
|
||||
|
||||
def before_train(self):
|
||||
self._start_time = time.perf_counter()
|
||||
self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader)
|
||||
|
||||
def before_epoch(self):
|
||||
self._iter_timer.reset()
|
||||
|
||||
def before_step(self):
|
||||
data_time = self._iter_timer.seconds()
|
||||
self.trainer.storage.put_scalar("data_time", data_time)
|
||||
|
||||
def after_step(self):
|
||||
batch_time = self._iter_timer.seconds()
|
||||
self._iter_timer.reset()
|
||||
self.trainer.storage.put_scalar("batch_time", batch_time)
|
||||
self._remain_iter -= 1
|
||||
remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg
|
||||
t_m, t_s = divmod(remain_time, 60)
|
||||
t_h, t_m = divmod(t_m, 60)
|
||||
remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s))
|
||||
if "iter_info" in self.trainer.comm_info.keys():
|
||||
info = (
|
||||
"Data {data_time_val:.3f} ({data_time_avg:.3f}) "
|
||||
"Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) "
|
||||
"Remain {remain_time} ".format(
|
||||
data_time_val=self.trainer.storage.history("data_time").val,
|
||||
data_time_avg=self.trainer.storage.history("data_time").avg,
|
||||
batch_time_val=self.trainer.storage.history("batch_time").val,
|
||||
batch_time_avg=self.trainer.storage.history("batch_time").avg,
|
||||
remain_time=remain_time,
|
||||
)
|
||||
)
|
||||
self.trainer.comm_info["iter_info"] += info
|
||||
if self.trainer.comm_info["iter"] <= self._warmup_iter:
|
||||
self.trainer.storage.history("data_time").reset()
|
||||
self.trainer.storage.history("batch_time").reset()
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class InformationWriter(HookBase):
|
||||
def __init__(self):
|
||||
self.curr_iter = 0
|
||||
self.model_output_keys = []
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.comm_info["iter_info"] = ""
|
||||
self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader)
|
||||
|
||||
def before_step(self):
|
||||
self.curr_iter += 1
|
||||
# MSC pretrain do not have offset information. Comment the code for support MSC
|
||||
# info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \
|
||||
# "Scan {batch_size} ({points_num}) ".format(
|
||||
# epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch,
|
||||
# iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader),
|
||||
# batch_size=len(self.trainer.comm_info["input_dict"]["offset"]),
|
||||
# points_num=self.trainer.comm_info["input_dict"]["offset"][-1]
|
||||
# )
|
||||
info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format(
|
||||
epoch=self.trainer.epoch + 1,
|
||||
max_epoch=self.trainer.max_epoch,
|
||||
iter=self.trainer.comm_info["iter"] + 1,
|
||||
max_iter=len(self.trainer.train_loader),
|
||||
)
|
||||
self.trainer.comm_info["iter_info"] += info
|
||||
|
||||
def after_step(self):
|
||||
if "model_output_dict" in self.trainer.comm_info.keys():
|
||||
model_output_dict = self.trainer.comm_info["model_output_dict"]
|
||||
self.model_output_keys = model_output_dict.keys()
|
||||
for key in self.model_output_keys:
|
||||
self.trainer.storage.put_scalar(key, model_output_dict[key].item())
|
||||
|
||||
for key in self.model_output_keys:
|
||||
self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format(
|
||||
key=key, value=self.trainer.storage.history(key).val
|
||||
)
|
||||
lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"]
|
||||
self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr)
|
||||
self.trainer.logger.info(self.trainer.comm_info["iter_info"])
|
||||
self.trainer.comm_info["iter_info"] = "" # reset iter info
|
||||
if self.trainer.writer is not None:
|
||||
self.trainer.writer.add_scalar("lr", lr, self.curr_iter)
|
||||
for key in self.model_output_keys:
|
||||
self.trainer.writer.add_scalar(
|
||||
"train_batch/" + key,
|
||||
self.trainer.storage.history(key).val,
|
||||
self.curr_iter,
|
||||
)
|
||||
|
||||
def after_epoch(self):
|
||||
epoch_info = "Train result: "
|
||||
for key in self.model_output_keys:
|
||||
epoch_info += "{key}: {value:.4f} ".format(
|
||||
key=key, value=self.trainer.storage.history(key).avg
|
||||
)
|
||||
self.trainer.logger.info(epoch_info)
|
||||
if self.trainer.writer is not None:
|
||||
for key in self.model_output_keys:
|
||||
self.trainer.writer.add_scalar(
|
||||
"train/" + key,
|
||||
self.trainer.storage.history(key).avg,
|
||||
self.trainer.epoch + 1,
|
||||
)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class CheckpointSaver(HookBase):
|
||||
def __init__(self, save_freq=None):
|
||||
self.save_freq = save_freq # None or int, None indicate only save model last
|
||||
|
||||
def after_epoch(self):
|
||||
if is_main_process():
|
||||
is_best = False
|
||||
if self.trainer.cfg.evaluate:
|
||||
current_metric_value = self.trainer.comm_info["current_metric_value"]
|
||||
current_metric_name = self.trainer.comm_info["current_metric_name"]
|
||||
if current_metric_value > self.trainer.best_metric_value:
|
||||
self.trainer.best_metric_value = current_metric_value
|
||||
is_best = True
|
||||
self.trainer.logger.info(
|
||||
"Best validation {} updated to: {:.4f}".format(
|
||||
current_metric_name, current_metric_value
|
||||
)
|
||||
)
|
||||
self.trainer.logger.info(
|
||||
"Currently Best {}: {:.4f}".format(
|
||||
current_metric_name, self.trainer.best_metric_value
|
||||
)
|
||||
)
|
||||
|
||||
filename = os.path.join(
|
||||
self.trainer.cfg.save_path, "model", "model_last.pth"
|
||||
)
|
||||
self.trainer.logger.info("Saving checkpoint to: " + filename)
|
||||
torch.save(
|
||||
{
|
||||
"epoch": self.trainer.epoch + 1,
|
||||
"state_dict": self.trainer.model.state_dict(),
|
||||
"optimizer": self.trainer.optimizer.state_dict(),
|
||||
"scheduler": self.trainer.scheduler.state_dict(),
|
||||
"scaler": self.trainer.scaler.state_dict()
|
||||
if self.trainer.cfg.enable_amp
|
||||
else None,
|
||||
"best_metric_value": self.trainer.best_metric_value,
|
||||
},
|
||||
filename + ".tmp",
|
||||
)
|
||||
os.replace(filename + ".tmp", filename)
|
||||
if is_best:
|
||||
shutil.copyfile(
|
||||
filename,
|
||||
os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"),
|
||||
)
|
||||
if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0:
|
||||
shutil.copyfile(
|
||||
filename,
|
||||
os.path.join(
|
||||
self.trainer.cfg.save_path,
|
||||
"model",
|
||||
f"epoch_{self.trainer.epoch + 1}.pth",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class CheckpointLoader(HookBase):
|
||||
def __init__(self, keywords="", replacement=None, strict=False):
|
||||
self.keywords = keywords
|
||||
self.replacement = replacement if replacement is not None else keywords
|
||||
self.strict = strict
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.logger.info("=> Loading checkpoint & weight ...")
|
||||
if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight):
|
||||
self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}")
|
||||
checkpoint = torch.load(
|
||||
self.trainer.cfg.weight,
|
||||
map_location=lambda storage, loc: storage.cuda(),
|
||||
)
|
||||
self.trainer.logger.info(
|
||||
f"Loading layer weights with keyword: {self.keywords}, "
|
||||
f"replace keyword with: {self.replacement}"
|
||||
)
|
||||
weight = OrderedDict()
|
||||
for key, value in checkpoint["state_dict"].items():
|
||||
if not key.startswith("module."):
|
||||
if comm.get_world_size() > 1:
|
||||
key = "module." + key # xxx.xxx -> module.xxx.xxx
|
||||
# Now all keys contain "module." no matter DDP or not.
|
||||
if self.keywords in key:
|
||||
key = key.replace(self.keywords, self.replacement)
|
||||
if comm.get_world_size() == 1:
|
||||
key = key[7:] # module.xxx.xxx -> xxx.xxx
|
||||
weight[key] = value
|
||||
load_state_info = self.trainer.model.load_state_dict(
|
||||
weight, strict=self.strict
|
||||
)
|
||||
self.trainer.logger.info(f"Missing keys: {load_state_info[0]}")
|
||||
if self.trainer.cfg.resume:
|
||||
self.trainer.logger.info(
|
||||
f"Resuming train at eval epoch: {checkpoint['epoch']}"
|
||||
)
|
||||
self.trainer.start_epoch = checkpoint["epoch"]
|
||||
self.trainer.best_metric_value = checkpoint["best_metric_value"]
|
||||
self.trainer.optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
self.trainer.scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
if self.trainer.cfg.enable_amp:
|
||||
self.trainer.scaler.load_state_dict(checkpoint["scaler"])
|
||||
else:
|
||||
self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}")
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class PreciseEvaluator(HookBase):
|
||||
def __init__(self, test_last=False):
|
||||
self.test_last = test_last
|
||||
|
||||
def after_train(self):
|
||||
self.trainer.logger.info(
|
||||
">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>"
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
cfg = self.trainer.cfg
|
||||
tester = TESTERS.build(
|
||||
dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model)
|
||||
)
|
||||
if self.test_last:
|
||||
self.trainer.logger.info("=> Testing on model_last ...")
|
||||
else:
|
||||
self.trainer.logger.info("=> Testing on model_best ...")
|
||||
best_path = os.path.join(
|
||||
self.trainer.cfg.save_path, "model", "model_best.pth"
|
||||
)
|
||||
checkpoint = torch.load(best_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
tester.model.load_state_dict(state_dict, strict=True)
|
||||
tester.test()
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class DataCacheOperator(HookBase):
|
||||
def __init__(self, data_root, split):
|
||||
self.data_root = data_root
|
||||
self.split = split
|
||||
self.data_list = self.get_data_list()
|
||||
|
||||
def get_data_list(self):
|
||||
if isinstance(self.split, str):
|
||||
data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth"))
|
||||
elif isinstance(self.split, Sequence):
|
||||
data_list = []
|
||||
for split in self.split:
|
||||
data_list += glob.glob(os.path.join(self.data_root, split, "*.pth"))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return data_list
|
||||
|
||||
def get_cache_name(self, data_path):
|
||||
data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0]
|
||||
return "pointcept" + data_name.replace(os.path.sep, "-")
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.logger.info(
|
||||
f"=> Caching dataset: {self.data_root}, split: {self.split} ..."
|
||||
)
|
||||
if is_main_process():
|
||||
for data_path in self.data_list:
|
||||
cache_name = self.get_cache_name(data_path)
|
||||
data = torch.load(data_path)
|
||||
shared_dict(cache_name, data)
|
||||
synchronize()
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class RuntimeProfiler(HookBase):
|
||||
def __init__(
|
||||
self,
|
||||
forward=True,
|
||||
backward=True,
|
||||
interrupt=False,
|
||||
warm_up=2,
|
||||
sort_by="cuda_time_total",
|
||||
row_limit=30,
|
||||
):
|
||||
self.forward = forward
|
||||
self.backward = backward
|
||||
self.interrupt = interrupt
|
||||
self.warm_up = warm_up
|
||||
self.sort_by = sort_by
|
||||
self.row_limit = row_limit
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.logger.info("Profiling runtime ...")
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
for i, input_dict in enumerate(self.trainer.train_loader):
|
||||
if i == self.warm_up + 1:
|
||||
break
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
if self.forward:
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
) as forward_prof:
|
||||
with record_function("model_inference"):
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
else:
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
loss = output_dict["loss"]
|
||||
if self.backward:
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
) as backward_prof:
|
||||
with record_function("model_inference"):
|
||||
loss.backward()
|
||||
self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]")
|
||||
if self.forward:
|
||||
self.trainer.logger.info(
|
||||
"Forward profile: \n"
|
||||
+ str(
|
||||
forward_prof.key_averages().table(
|
||||
sort_by=self.sort_by, row_limit=self.row_limit
|
||||
)
|
||||
)
|
||||
)
|
||||
forward_prof.export_chrome_trace(
|
||||
os.path.join(self.trainer.cfg.save_path, "forward_trace.json")
|
||||
)
|
||||
|
||||
if self.backward:
|
||||
self.trainer.logger.info(
|
||||
"Backward profile: \n"
|
||||
+ str(
|
||||
backward_prof.key_averages().table(
|
||||
sort_by=self.sort_by, row_limit=self.row_limit
|
||||
)
|
||||
)
|
||||
)
|
||||
backward_prof.export_chrome_trace(
|
||||
os.path.join(self.trainer.cfg.save_path, "backward_trace.json")
|
||||
)
|
||||
if self.interrupt:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class RuntimeProfilerV2(HookBase):
|
||||
def __init__(
|
||||
self,
|
||||
interrupt=False,
|
||||
wait=1,
|
||||
warmup=1,
|
||||
active=10,
|
||||
repeat=1,
|
||||
sort_by="cuda_time_total",
|
||||
row_limit=30,
|
||||
):
|
||||
self.interrupt = interrupt
|
||||
self.wait = wait
|
||||
self.warmup = warmup
|
||||
self.active = active
|
||||
self.repeat = repeat
|
||||
self.sort_by = sort_by
|
||||
self.row_limit = row_limit
|
||||
|
||||
def before_train(self):
|
||||
self.trainer.logger.info("Profiling runtime ...")
|
||||
from torch.profiler import (
|
||||
profile,
|
||||
record_function,
|
||||
ProfilerActivity,
|
||||
schedule,
|
||||
tensorboard_trace_handler,
|
||||
)
|
||||
|
||||
prof = profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=schedule(
|
||||
wait=self.wait,
|
||||
warmup=self.warmup,
|
||||
active=self.active,
|
||||
repeat=self.repeat,
|
||||
),
|
||||
on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
prof.start()
|
||||
for i, input_dict in enumerate(self.trainer.train_loader):
|
||||
if i >= (self.wait + self.warmup + self.active) * self.repeat:
|
||||
break
|
||||
for key in input_dict.keys():
|
||||
if isinstance(input_dict[key], torch.Tensor):
|
||||
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
||||
with record_function("model_forward"):
|
||||
output_dict = self.trainer.model(input_dict)
|
||||
loss = output_dict["loss"]
|
||||
with record_function("model_backward"):
|
||||
loss.backward()
|
||||
prof.step()
|
||||
self.trainer.logger.info(
|
||||
f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]"
|
||||
)
|
||||
self.trainer.logger.info(
|
||||
"Profile: \n"
|
||||
+ str(
|
||||
prof.key_averages().table(
|
||||
sort_by=self.sort_by, row_limit=self.row_limit
|
||||
)
|
||||
)
|
||||
)
|
||||
prof.stop()
|
||||
|
||||
if self.interrupt:
|
||||
sys.exit(0)
|
||||
Reference in New Issue
Block a user