Files
LAM_Audio2Expression/engines/train.py
2025-04-17 23:14:24 +08:00

300 lines
10 KiB
Python

"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import os
import sys
import weakref
import torch
import torch.nn as nn
import torch.utils.data
from functools import partial
if sys.version_info >= (3, 10):
from collections.abc import Iterator
else:
from collections import Iterator
from tensorboardX import SummaryWriter
from .defaults import create_ddp_model, worker_init_fn
from .hooks import HookBase, build_hooks
import utils.comm as comm
from datasets import build_dataset, point_collate_fn, collate_fn
from models import build_model
from utils.logger import get_root_logger
from utils.optimizer import build_optimizer
from utils.scheduler import build_scheduler
from utils.events import EventStorage
from utils.registry import Registry
TRAINERS = Registry("trainers")
class TrainerBase:
def __init__(self) -> None:
self.hooks = []
self.epoch = 0
self.start_epoch = 0
self.max_epoch = 0
self.max_iter = 0
self.comm_info = dict()
self.data_iterator: Iterator = enumerate([])
self.storage: EventStorage
self.writer: SummaryWriter
def register_hooks(self, hooks) -> None:
hooks = build_hooks(hooks)
for h in hooks:
assert isinstance(h, HookBase)
# To avoid circular reference, hooks and trainer cannot own each other.
# This normally does not matter, but will cause memory leak if the
# involved objects contain __del__:
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
h.trainer = weakref.proxy(self)
self.hooks.extend(hooks)
def train(self):
with EventStorage() as self.storage:
# => before train
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
# => before epoch
self.before_epoch()
# => run_epoch
for (
self.comm_info["iter"],
self.comm_info["input_dict"],
) in self.data_iterator:
# => before_step
self.before_step()
# => run_step
self.run_step()
# => after_step
self.after_step()
# => after epoch
self.after_epoch()
# => after train
self.after_train()
def before_train(self):
for h in self.hooks:
h.before_train()
def before_epoch(self):
for h in self.hooks:
h.before_epoch()
def before_step(self):
for h in self.hooks:
h.before_step()
def run_step(self):
raise NotImplementedError
def after_step(self):
for h in self.hooks:
h.after_step()
def after_epoch(self):
for h in self.hooks:
h.after_epoch()
self.storage.reset_histories()
def after_train(self):
# Sync GPU before running train hooks
comm.synchronize()
for h in self.hooks:
h.after_train()
if comm.is_main_process():
self.writer.close()
@TRAINERS.register_module("DefaultTrainer")
class Trainer(TrainerBase):
def __init__(self, cfg):
super(Trainer, self).__init__()
self.epoch = 0
self.start_epoch = 0
self.max_epoch = cfg.eval_epoch
self.best_metric_value = -torch.inf
self.logger = get_root_logger(
log_file=os.path.join(cfg.save_path, "train.log"),
file_mode="a" if cfg.resume else "w",
)
self.logger.info("=> Loading config ...")
self.cfg = cfg
self.logger.info(f"Save path: {cfg.save_path}")
self.logger.info(f"Config:\n{cfg.pretty_text}")
self.logger.info("=> Building model ...")
self.model = self.build_model()
self.logger.info("=> Building writer ...")
self.writer = self.build_writer()
self.logger.info("=> Building train dataset & dataloader ...")
self.train_loader = self.build_train_loader()
self.logger.info("=> Building val dataset & dataloader ...")
self.val_loader = self.build_val_loader()
self.logger.info("=> Building optimize, scheduler, scaler(amp) ...")
self.optimizer = self.build_optimizer()
self.scheduler = self.build_scheduler()
self.scaler = self.build_scaler()
self.logger.info("=> Building hooks ...")
self.register_hooks(self.cfg.hooks)
def train(self):
with EventStorage() as self.storage:
# => before train
self.before_train()
self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
for self.epoch in range(self.start_epoch, self.max_epoch):
# => before epoch
# TODO: optimize to iteration based
if comm.get_world_size() > 1:
self.train_loader.sampler.set_epoch(self.epoch)
self.model.train()
self.data_iterator = enumerate(self.train_loader)
self.before_epoch()
# => run_epoch
for (
self.comm_info["iter"],
self.comm_info["input_dict"],
) in self.data_iterator:
# => before_step
self.before_step()
# => run_step
self.run_step()
# => after_step
self.after_step()
# => after epoch
self.after_epoch()
# => after train
self.after_train()
def run_step(self):
input_dict = self.comm_info["input_dict"]
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp):
output_dict = self.model(input_dict)
loss = output_dict["loss"]
self.optimizer.zero_grad()
if self.cfg.enable_amp:
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
# When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
# Fix torch warning scheduler step before optimizer step.
scaler = self.scaler.get_scale()
self.scaler.update()
if scaler <= self.scaler.get_scale():
self.scheduler.step()
else:
loss.backward()
self.optimizer.step()
self.scheduler.step()
if self.cfg.empty_cache:
torch.cuda.empty_cache()
self.comm_info["model_output_dict"] = output_dict
def build_model(self):
model = build_model(self.cfg.model)
if self.cfg.sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logger.info(f"Model: \n{self.model}")
self.logger.info(f"Num params: {n_parameters}")
model = create_ddp_model(
model.cuda(),
broadcast_buffers=False,
find_unused_parameters=self.cfg.find_unused_parameters,
)
return model
def build_writer(self):
writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None
self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}")
return writer
def build_train_loader(self):
train_data = build_dataset(self.cfg.data.train)
if comm.get_world_size() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
else:
train_sampler = None
init_fn = (
partial(
worker_init_fn,
num_workers=self.cfg.num_worker_per_gpu,
rank=comm.get_rank(),
seed=self.cfg.seed,
)
if self.cfg.seed is not None
else None
)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=self.cfg.batch_size_per_gpu,
shuffle=(train_sampler is None),
num_workers=0,
sampler=train_sampler,
collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob),
pin_memory=True,
worker_init_fn=init_fn,
drop_last=True,
# persistent_workers=True,
)
return train_loader
def build_val_loader(self):
val_loader = None
if self.cfg.evaluate:
val_data = build_dataset(self.cfg.data.val)
if comm.get_world_size() > 1:
val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
else:
val_sampler = None
val_loader = torch.utils.data.DataLoader(
val_data,
batch_size=self.cfg.batch_size_val_per_gpu,
shuffle=False,
num_workers=self.cfg.num_worker_per_gpu,
pin_memory=True,
sampler=val_sampler,
collate_fn=collate_fn,
)
return val_loader
def build_optimizer(self):
return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts)
def build_scheduler(self):
assert hasattr(self, "optimizer")
assert hasattr(self, "train_loader")
self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch
return build_scheduler(self.cfg.scheduler, self.optimizer)
def build_scaler(self):
scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None
return scaler
@TRAINERS.register_module("MultiDatasetTrainer")
class MultiDatasetTrainer(Trainer):
def build_train_loader(self):
from datasets import MultiDatasetDataloader
train_data = build_dataset(self.cfg.data.train)
train_loader = MultiDatasetDataloader(
train_data,
self.cfg.batch_size_per_gpu,
self.cfg.num_worker_per_gpu,
self.cfg.mix_prob,
self.cfg.seed,
)
self.comm_info["iter_per_epoch"] = len(train_loader)
return train_loader