mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-05 01:49:23 +08:00
461 lines
17 KiB
Python
461 lines
17 KiB
Python
"""
|
|
The code is base on https://github.com/Pointcept/Pointcept
|
|
"""
|
|
|
|
import sys
|
|
import glob
|
|
import os
|
|
import shutil
|
|
import time
|
|
import torch
|
|
import torch.utils.data
|
|
from collections import OrderedDict
|
|
|
|
if sys.version_info >= (3, 10):
|
|
from collections.abc import Sequence
|
|
else:
|
|
from collections import Sequence
|
|
from utils.timer import Timer
|
|
from utils.comm import is_main_process, synchronize, get_world_size
|
|
from utils.cache import shared_dict
|
|
|
|
import utils.comm as comm
|
|
from engines.test import TESTERS
|
|
|
|
from .default import HookBase
|
|
from .builder import HOOKS
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class IterationTimer(HookBase):
|
|
def __init__(self, warmup_iter=1):
|
|
self._warmup_iter = warmup_iter
|
|
self._start_time = time.perf_counter()
|
|
self._iter_timer = Timer()
|
|
self._remain_iter = 0
|
|
|
|
def before_train(self):
|
|
self._start_time = time.perf_counter()
|
|
self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader)
|
|
|
|
def before_epoch(self):
|
|
self._iter_timer.reset()
|
|
|
|
def before_step(self):
|
|
data_time = self._iter_timer.seconds()
|
|
self.trainer.storage.put_scalar("data_time", data_time)
|
|
|
|
def after_step(self):
|
|
batch_time = self._iter_timer.seconds()
|
|
self._iter_timer.reset()
|
|
self.trainer.storage.put_scalar("batch_time", batch_time)
|
|
self._remain_iter -= 1
|
|
remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg
|
|
t_m, t_s = divmod(remain_time, 60)
|
|
t_h, t_m = divmod(t_m, 60)
|
|
remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s))
|
|
if "iter_info" in self.trainer.comm_info.keys():
|
|
info = (
|
|
"Data {data_time_val:.3f} ({data_time_avg:.3f}) "
|
|
"Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) "
|
|
"Remain {remain_time} ".format(
|
|
data_time_val=self.trainer.storage.history("data_time").val,
|
|
data_time_avg=self.trainer.storage.history("data_time").avg,
|
|
batch_time_val=self.trainer.storage.history("batch_time").val,
|
|
batch_time_avg=self.trainer.storage.history("batch_time").avg,
|
|
remain_time=remain_time,
|
|
)
|
|
)
|
|
self.trainer.comm_info["iter_info"] += info
|
|
if self.trainer.comm_info["iter"] <= self._warmup_iter:
|
|
self.trainer.storage.history("data_time").reset()
|
|
self.trainer.storage.history("batch_time").reset()
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class InformationWriter(HookBase):
|
|
def __init__(self):
|
|
self.curr_iter = 0
|
|
self.model_output_keys = []
|
|
|
|
def before_train(self):
|
|
self.trainer.comm_info["iter_info"] = ""
|
|
self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader)
|
|
|
|
def before_step(self):
|
|
self.curr_iter += 1
|
|
# MSC pretrain do not have offset information. Comment the code for support MSC
|
|
# info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \
|
|
# "Scan {batch_size} ({points_num}) ".format(
|
|
# epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch,
|
|
# iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader),
|
|
# batch_size=len(self.trainer.comm_info["input_dict"]["offset"]),
|
|
# points_num=self.trainer.comm_info["input_dict"]["offset"][-1]
|
|
# )
|
|
info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format(
|
|
epoch=self.trainer.epoch + 1,
|
|
max_epoch=self.trainer.max_epoch,
|
|
iter=self.trainer.comm_info["iter"] + 1,
|
|
max_iter=len(self.trainer.train_loader),
|
|
)
|
|
self.trainer.comm_info["iter_info"] += info
|
|
|
|
def after_step(self):
|
|
if "model_output_dict" in self.trainer.comm_info.keys():
|
|
model_output_dict = self.trainer.comm_info["model_output_dict"]
|
|
self.model_output_keys = model_output_dict.keys()
|
|
for key in self.model_output_keys:
|
|
self.trainer.storage.put_scalar(key, model_output_dict[key].item())
|
|
|
|
for key in self.model_output_keys:
|
|
self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format(
|
|
key=key, value=self.trainer.storage.history(key).val
|
|
)
|
|
lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"]
|
|
self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr)
|
|
self.trainer.logger.info(self.trainer.comm_info["iter_info"])
|
|
self.trainer.comm_info["iter_info"] = "" # reset iter info
|
|
if self.trainer.writer is not None:
|
|
self.trainer.writer.add_scalar("lr", lr, self.curr_iter)
|
|
for key in self.model_output_keys:
|
|
self.trainer.writer.add_scalar(
|
|
"train_batch/" + key,
|
|
self.trainer.storage.history(key).val,
|
|
self.curr_iter,
|
|
)
|
|
|
|
def after_epoch(self):
|
|
epoch_info = "Train result: "
|
|
for key in self.model_output_keys:
|
|
epoch_info += "{key}: {value:.4f} ".format(
|
|
key=key, value=self.trainer.storage.history(key).avg
|
|
)
|
|
self.trainer.logger.info(epoch_info)
|
|
if self.trainer.writer is not None:
|
|
for key in self.model_output_keys:
|
|
self.trainer.writer.add_scalar(
|
|
"train/" + key,
|
|
self.trainer.storage.history(key).avg,
|
|
self.trainer.epoch + 1,
|
|
)
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class CheckpointSaver(HookBase):
|
|
def __init__(self, save_freq=None):
|
|
self.save_freq = save_freq # None or int, None indicate only save model last
|
|
|
|
def after_epoch(self):
|
|
if is_main_process():
|
|
is_best = False
|
|
if self.trainer.cfg.evaluate:
|
|
current_metric_value = self.trainer.comm_info["current_metric_value"]
|
|
current_metric_name = self.trainer.comm_info["current_metric_name"]
|
|
if current_metric_value > self.trainer.best_metric_value:
|
|
self.trainer.best_metric_value = current_metric_value
|
|
is_best = True
|
|
self.trainer.logger.info(
|
|
"Best validation {} updated to: {:.4f}".format(
|
|
current_metric_name, current_metric_value
|
|
)
|
|
)
|
|
self.trainer.logger.info(
|
|
"Currently Best {}: {:.4f}".format(
|
|
current_metric_name, self.trainer.best_metric_value
|
|
)
|
|
)
|
|
|
|
filename = os.path.join(
|
|
self.trainer.cfg.save_path, "model", "model_last.pth"
|
|
)
|
|
self.trainer.logger.info("Saving checkpoint to: " + filename)
|
|
torch.save(
|
|
{
|
|
"epoch": self.trainer.epoch + 1,
|
|
"state_dict": self.trainer.model.state_dict(),
|
|
"optimizer": self.trainer.optimizer.state_dict(),
|
|
"scheduler": self.trainer.scheduler.state_dict(),
|
|
"scaler": self.trainer.scaler.state_dict()
|
|
if self.trainer.cfg.enable_amp
|
|
else None,
|
|
"best_metric_value": self.trainer.best_metric_value,
|
|
},
|
|
filename + ".tmp",
|
|
)
|
|
os.replace(filename + ".tmp", filename)
|
|
if is_best:
|
|
shutil.copyfile(
|
|
filename,
|
|
os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"),
|
|
)
|
|
if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0:
|
|
shutil.copyfile(
|
|
filename,
|
|
os.path.join(
|
|
self.trainer.cfg.save_path,
|
|
"model",
|
|
f"epoch_{self.trainer.epoch + 1}.pth",
|
|
),
|
|
)
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class CheckpointLoader(HookBase):
|
|
def __init__(self, keywords="", replacement=None, strict=False):
|
|
self.keywords = keywords
|
|
self.replacement = replacement if replacement is not None else keywords
|
|
self.strict = strict
|
|
|
|
def before_train(self):
|
|
self.trainer.logger.info("=> Loading checkpoint & weight ...")
|
|
if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight):
|
|
self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}")
|
|
checkpoint = torch.load(
|
|
self.trainer.cfg.weight,
|
|
map_location=lambda storage, loc: storage.cuda(),
|
|
)
|
|
self.trainer.logger.info(
|
|
f"Loading layer weights with keyword: {self.keywords}, "
|
|
f"replace keyword with: {self.replacement}"
|
|
)
|
|
weight = OrderedDict()
|
|
for key, value in checkpoint["state_dict"].items():
|
|
if not key.startswith("module."):
|
|
if comm.get_world_size() > 1:
|
|
key = "module." + key # xxx.xxx -> module.xxx.xxx
|
|
# Now all keys contain "module." no matter DDP or not.
|
|
if self.keywords in key:
|
|
key = key.replace(self.keywords, self.replacement)
|
|
if comm.get_world_size() == 1:
|
|
key = key[7:] # module.xxx.xxx -> xxx.xxx
|
|
weight[key] = value
|
|
load_state_info = self.trainer.model.load_state_dict(
|
|
weight, strict=self.strict
|
|
)
|
|
self.trainer.logger.info(f"Missing keys: {load_state_info[0]}")
|
|
if self.trainer.cfg.resume:
|
|
self.trainer.logger.info(
|
|
f"Resuming train at eval epoch: {checkpoint['epoch']}"
|
|
)
|
|
self.trainer.start_epoch = checkpoint["epoch"]
|
|
self.trainer.best_metric_value = checkpoint["best_metric_value"]
|
|
self.trainer.optimizer.load_state_dict(checkpoint["optimizer"])
|
|
self.trainer.scheduler.load_state_dict(checkpoint["scheduler"])
|
|
if self.trainer.cfg.enable_amp:
|
|
self.trainer.scaler.load_state_dict(checkpoint["scaler"])
|
|
else:
|
|
self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}")
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class PreciseEvaluator(HookBase):
|
|
def __init__(self, test_last=False):
|
|
self.test_last = test_last
|
|
|
|
def after_train(self):
|
|
self.trainer.logger.info(
|
|
">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>"
|
|
)
|
|
torch.cuda.empty_cache()
|
|
cfg = self.trainer.cfg
|
|
tester = TESTERS.build(
|
|
dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model)
|
|
)
|
|
if self.test_last:
|
|
self.trainer.logger.info("=> Testing on model_last ...")
|
|
else:
|
|
self.trainer.logger.info("=> Testing on model_best ...")
|
|
best_path = os.path.join(
|
|
self.trainer.cfg.save_path, "model", "model_best.pth"
|
|
)
|
|
checkpoint = torch.load(best_path)
|
|
state_dict = checkpoint["state_dict"]
|
|
tester.model.load_state_dict(state_dict, strict=True)
|
|
tester.test()
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class DataCacheOperator(HookBase):
|
|
def __init__(self, data_root, split):
|
|
self.data_root = data_root
|
|
self.split = split
|
|
self.data_list = self.get_data_list()
|
|
|
|
def get_data_list(self):
|
|
if isinstance(self.split, str):
|
|
data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth"))
|
|
elif isinstance(self.split, Sequence):
|
|
data_list = []
|
|
for split in self.split:
|
|
data_list += glob.glob(os.path.join(self.data_root, split, "*.pth"))
|
|
else:
|
|
raise NotImplementedError
|
|
return data_list
|
|
|
|
def get_cache_name(self, data_path):
|
|
data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0]
|
|
return "pointcept" + data_name.replace(os.path.sep, "-")
|
|
|
|
def before_train(self):
|
|
self.trainer.logger.info(
|
|
f"=> Caching dataset: {self.data_root}, split: {self.split} ..."
|
|
)
|
|
if is_main_process():
|
|
for data_path in self.data_list:
|
|
cache_name = self.get_cache_name(data_path)
|
|
data = torch.load(data_path)
|
|
shared_dict(cache_name, data)
|
|
synchronize()
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class RuntimeProfiler(HookBase):
|
|
def __init__(
|
|
self,
|
|
forward=True,
|
|
backward=True,
|
|
interrupt=False,
|
|
warm_up=2,
|
|
sort_by="cuda_time_total",
|
|
row_limit=30,
|
|
):
|
|
self.forward = forward
|
|
self.backward = backward
|
|
self.interrupt = interrupt
|
|
self.warm_up = warm_up
|
|
self.sort_by = sort_by
|
|
self.row_limit = row_limit
|
|
|
|
def before_train(self):
|
|
self.trainer.logger.info("Profiling runtime ...")
|
|
from torch.profiler import profile, record_function, ProfilerActivity
|
|
|
|
for i, input_dict in enumerate(self.trainer.train_loader):
|
|
if i == self.warm_up + 1:
|
|
break
|
|
for key in input_dict.keys():
|
|
if isinstance(input_dict[key], torch.Tensor):
|
|
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
|
if self.forward:
|
|
with profile(
|
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
|
record_shapes=True,
|
|
profile_memory=True,
|
|
with_stack=True,
|
|
) as forward_prof:
|
|
with record_function("model_inference"):
|
|
output_dict = self.trainer.model(input_dict)
|
|
else:
|
|
output_dict = self.trainer.model(input_dict)
|
|
loss = output_dict["loss"]
|
|
if self.backward:
|
|
with profile(
|
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
|
record_shapes=True,
|
|
profile_memory=True,
|
|
with_stack=True,
|
|
) as backward_prof:
|
|
with record_function("model_inference"):
|
|
loss.backward()
|
|
self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]")
|
|
if self.forward:
|
|
self.trainer.logger.info(
|
|
"Forward profile: \n"
|
|
+ str(
|
|
forward_prof.key_averages().table(
|
|
sort_by=self.sort_by, row_limit=self.row_limit
|
|
)
|
|
)
|
|
)
|
|
forward_prof.export_chrome_trace(
|
|
os.path.join(self.trainer.cfg.save_path, "forward_trace.json")
|
|
)
|
|
|
|
if self.backward:
|
|
self.trainer.logger.info(
|
|
"Backward profile: \n"
|
|
+ str(
|
|
backward_prof.key_averages().table(
|
|
sort_by=self.sort_by, row_limit=self.row_limit
|
|
)
|
|
)
|
|
)
|
|
backward_prof.export_chrome_trace(
|
|
os.path.join(self.trainer.cfg.save_path, "backward_trace.json")
|
|
)
|
|
if self.interrupt:
|
|
sys.exit(0)
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class RuntimeProfilerV2(HookBase):
|
|
def __init__(
|
|
self,
|
|
interrupt=False,
|
|
wait=1,
|
|
warmup=1,
|
|
active=10,
|
|
repeat=1,
|
|
sort_by="cuda_time_total",
|
|
row_limit=30,
|
|
):
|
|
self.interrupt = interrupt
|
|
self.wait = wait
|
|
self.warmup = warmup
|
|
self.active = active
|
|
self.repeat = repeat
|
|
self.sort_by = sort_by
|
|
self.row_limit = row_limit
|
|
|
|
def before_train(self):
|
|
self.trainer.logger.info("Profiling runtime ...")
|
|
from torch.profiler import (
|
|
profile,
|
|
record_function,
|
|
ProfilerActivity,
|
|
schedule,
|
|
tensorboard_trace_handler,
|
|
)
|
|
|
|
prof = profile(
|
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
|
schedule=schedule(
|
|
wait=self.wait,
|
|
warmup=self.warmup,
|
|
active=self.active,
|
|
repeat=self.repeat,
|
|
),
|
|
on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path),
|
|
record_shapes=True,
|
|
profile_memory=True,
|
|
with_stack=True,
|
|
)
|
|
prof.start()
|
|
for i, input_dict in enumerate(self.trainer.train_loader):
|
|
if i >= (self.wait + self.warmup + self.active) * self.repeat:
|
|
break
|
|
for key in input_dict.keys():
|
|
if isinstance(input_dict[key], torch.Tensor):
|
|
input_dict[key] = input_dict[key].cuda(non_blocking=True)
|
|
with record_function("model_forward"):
|
|
output_dict = self.trainer.model(input_dict)
|
|
loss = output_dict["loss"]
|
|
with record_function("model_backward"):
|
|
loss.backward()
|
|
prof.step()
|
|
self.trainer.logger.info(
|
|
f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]"
|
|
)
|
|
self.trainer.logger.info(
|
|
"Profile: \n"
|
|
+ str(
|
|
prof.key_averages().table(
|
|
sort_by=self.sort_by, row_limit=self.row_limit
|
|
)
|
|
)
|
|
)
|
|
prof.stop()
|
|
|
|
if self.interrupt:
|
|
sys.exit(0)
|