mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/train/__init__.py
Normal file
0
funasr_local/train/__init__.py
Normal file
55
funasr_local/train/abs_espnet_model.py
Normal file
55
funasr_local/train/abs_espnet_model.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Dict
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsESPnetModel(torch.nn.Module, ABC):
|
||||
"""The common abstract class among each tasks
|
||||
|
||||
"ESPnetModel" is referred to a class which inherits torch.nn.Module,
|
||||
and makes the dnn-models forward as its member field,
|
||||
a.k.a delegate pattern,
|
||||
and defines "loss", "stats", and "weight" for the task.
|
||||
|
||||
If you intend to implement new task in ESPNet,
|
||||
the model must inherit this class.
|
||||
In other words, the "mediator" objects between
|
||||
our training system and the your task class are
|
||||
just only these three values, loss, stats, and weight.
|
||||
|
||||
Example:
|
||||
>>> from funasr_local.tasks.abs_task import AbsTask
|
||||
>>> class YourESPnetModel(AbsESPnetModel):
|
||||
... def forward(self, input, input_lengths):
|
||||
... ...
|
||||
... return loss, stats, weight
|
||||
>>> class YourTask(AbsTask):
|
||||
... @classmethod
|
||||
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.num_updates = 0
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, **batch: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
self.num_updates = num_updates
|
||||
|
||||
def get_num_updates(self):
|
||||
return self.num_updates
|
||||
192
funasr_local/train/abs_model.py
Normal file
192
funasr_local/train/abs_model.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.torch_utils.device_funcs import force_gatherable
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
|
||||
|
||||
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
|
||||
"""The abstract class
|
||||
|
||||
To share the loss calculation way among different models,
|
||||
We uses delegate pattern here:
|
||||
The instance of this class should be passed to "LanguageModel"
|
||||
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def with_vad(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PunctuationModel(AbsESPnetModel):
|
||||
|
||||
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.punc_model = punc_model
|
||||
self.punc_weight = torch.Tensor(punc_weight)
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
|
||||
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
|
||||
self.ignore_id = ignore_id
|
||||
# if self.punc_model.with_vad():
|
||||
# print("This is a vad puncuation model.")
|
||||
|
||||
def nll(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
punc: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
punc_lengths: torch.Tensor,
|
||||
max_length: Optional[int] = None,
|
||||
vad_indexes: Optional[torch.Tensor] = None,
|
||||
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll)
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
punc: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
max_lengths: int
|
||||
"""
|
||||
batch_size = text.size(0)
|
||||
# For data parallel
|
||||
if max_length is None:
|
||||
text = text[:, :text_lengths.max()]
|
||||
punc = punc[:, :text_lengths.max()]
|
||||
else:
|
||||
text = text[:, :max_length]
|
||||
punc = punc[:, :max_length]
|
||||
|
||||
if self.punc_model.with_vad():
|
||||
# Should be VadRealtimeTransformer
|
||||
assert vad_indexes is not None
|
||||
y, _ = self.punc_model(text, text_lengths, vad_indexes)
|
||||
else:
|
||||
# Should be TargetDelayTransformer,
|
||||
y, _ = self.punc_model(text, text_lengths)
|
||||
|
||||
# Calc negative log likelihood
|
||||
# nll: (BxL,)
|
||||
if self.training == False:
|
||||
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
||||
from sklearn.metrics import f1_score
|
||||
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
|
||||
indices.squeeze(-1).detach().cpu().numpy(),
|
||||
average='micro')
|
||||
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
|
||||
return nll, text_lengths
|
||||
else:
|
||||
self.punc_weight = self.punc_weight.to(punc.device)
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
|
||||
ignore_index=self.ignore_id)
|
||||
# nll: (BxL,) -> (BxL,)
|
||||
if max_length is None:
|
||||
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
|
||||
else:
|
||||
nll.masked_fill_(
|
||||
make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
|
||||
0.0,
|
||||
)
|
||||
# nll: (BxL,) -> (B, L)
|
||||
nll = nll.view(batch_size, -1)
|
||||
return nll, text_lengths
|
||||
|
||||
def batchify_nll(self,
|
||||
text: torch.Tensor,
|
||||
punc: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
punc_lengths: torch.Tensor,
|
||||
batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll) from transformer language model
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
punc: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
batch_size: int, samples each batch contain when computing nll,
|
||||
you may change this to avoid OOM or increase
|
||||
|
||||
"""
|
||||
total_num = text.size(0)
|
||||
if total_num <= batch_size:
|
||||
nll, x_lengths = self.nll(text, punc, text_lengths)
|
||||
else:
|
||||
nlls = []
|
||||
x_lengths = []
|
||||
max_length = text_lengths.max()
|
||||
|
||||
start_idx = 0
|
||||
while True:
|
||||
end_idx = min(start_idx + batch_size, total_num)
|
||||
batch_text = text[start_idx:end_idx, :]
|
||||
batch_punc = punc[start_idx:end_idx, :]
|
||||
batch_text_lengths = text_lengths[start_idx:end_idx]
|
||||
# batch_nll: [B * T]
|
||||
batch_nll, batch_x_lengths = self.nll(batch_text, batch_punc, batch_text_lengths, max_length=max_length)
|
||||
nlls.append(batch_nll)
|
||||
x_lengths.append(batch_x_lengths)
|
||||
start_idx = end_idx
|
||||
if start_idx == total_num:
|
||||
break
|
||||
nll = torch.cat(nlls)
|
||||
x_lengths = torch.cat(x_lengths)
|
||||
assert nll.size(0) == total_num
|
||||
assert x_lengths.size(0) == total_num
|
||||
return nll, x_lengths
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
punc: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
punc_lengths: torch.Tensor,
|
||||
vad_indexes: Optional[torch.Tensor] = None,
|
||||
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
|
||||
ntokens = y_lengths.sum()
|
||||
loss = nll.sum() / ntokens
|
||||
stats = dict(loss=loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
|
||||
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def inference(self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
|
||||
if self.punc_model.with_vad():
|
||||
assert vad_indexes is not None
|
||||
return self.punc_model(text, text_lengths, vad_indexes)
|
||||
else:
|
||||
return self.punc_model(text, text_lengths)
|
||||
95
funasr_local/train/class_choices.py
Normal file
95
funasr_local/train/class_choices.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr_local.utils.nested_dict_action import NestedDictAction
|
||||
from funasr_local.utils.types import str_or_none
|
||||
|
||||
|
||||
class ClassChoices:
|
||||
"""Helper class to manage the options for variable objects and its configuration.
|
||||
|
||||
Example:
|
||||
|
||||
>>> class A:
|
||||
... def __init__(self, foo=3): pass
|
||||
>>> class B:
|
||||
... def __init__(self, bar="aaaa"): pass
|
||||
>>> choices = ClassChoices("var", dict(a=A, b=B), default="a")
|
||||
>>> import argparse
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> choices.add_arguments(parser)
|
||||
>>> args = parser.parse_args(["--var", "a", "--var_conf", "foo=4")
|
||||
>>> args.var
|
||||
a
|
||||
>>> args.var_conf
|
||||
{"foo": 4}
|
||||
>>> class_obj = choices.get_class(args.var)
|
||||
>>> a_object = class_obj(**args.var_conf)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
classes: Mapping[str, type],
|
||||
type_check: type = None,
|
||||
default: str = None,
|
||||
optional: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.name = name
|
||||
self.base_type = type_check
|
||||
self.classes = {k.lower(): v for k, v in classes.items()}
|
||||
if "none" in self.classes or "nil" in self.classes or "null" in self.classes:
|
||||
raise ValueError('"none", "nil", and "null" are reserved.')
|
||||
if type_check is not None:
|
||||
for v in self.classes.values():
|
||||
if not issubclass(v, type_check):
|
||||
raise ValueError(f"must be {type_check.__name__}, but got {v}")
|
||||
|
||||
self.optional = optional
|
||||
self.default = default
|
||||
if default is None:
|
||||
self.optional = True
|
||||
|
||||
def choices(self) -> Tuple[Optional[str], ...]:
|
||||
retval = tuple(self.classes)
|
||||
if self.optional:
|
||||
return retval + (None,)
|
||||
else:
|
||||
return retval
|
||||
|
||||
def get_class(self, name: Optional[str]) -> Optional[type]:
|
||||
assert check_argument_types()
|
||||
if name is None or (self.optional and name.lower() == ("none", "null", "nil")):
|
||||
retval = None
|
||||
elif name.lower() in self.classes:
|
||||
class_obj = self.classes[name]
|
||||
assert check_return_type(class_obj)
|
||||
retval = class_obj
|
||||
else:
|
||||
raise ValueError(
|
||||
f"--{self.name} must be one of {self.choices()}: "
|
||||
f"--{self.name} {name.lower()}"
|
||||
)
|
||||
|
||||
return retval
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
f"--{self.name}",
|
||||
type=lambda x: str_or_none(x.lower()),
|
||||
default=self.default,
|
||||
choices=self.choices(),
|
||||
help=f"The {self.name} type",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{self.name}_conf",
|
||||
action=NestedDictAction,
|
||||
default=dict(),
|
||||
help=f"The keyword arguments for {self.name}",
|
||||
)
|
||||
380
funasr_local/train/distributed_utils.py
Normal file
380
funasr_local/train/distributed_utils.py
Normal file
@@ -0,0 +1,380 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DistributedOption:
|
||||
# Enable distributed Training
|
||||
distributed: bool = False
|
||||
# torch.distributed.Backend: "nccl", "mpi", "gloo", or "tcp"
|
||||
dist_backend: str = "nccl"
|
||||
# if init_method="env://",
|
||||
# env values of "MASTER_PORT", "MASTER_ADDR", "WORLD_SIZE", and "RANK" are referred.
|
||||
dist_init_method: str = "env://"
|
||||
dist_world_size: Optional[int] = None
|
||||
dist_rank: Optional[int] = None
|
||||
local_rank: Optional[int] = None
|
||||
ngpu: int = 0
|
||||
dist_master_addr: Optional[str] = None
|
||||
dist_master_port: Optional[int] = None
|
||||
dist_launcher: Optional[str] = None
|
||||
multiprocessing_distributed: bool = True
|
||||
|
||||
def init_options(self):
|
||||
if self.distributed:
|
||||
if self.dist_init_method == "env://":
|
||||
if get_master_addr(self.dist_master_addr, self.dist_launcher) is None:
|
||||
raise RuntimeError(
|
||||
"--dist_master_addr or MASTER_ADDR must be set "
|
||||
"if --dist_init_method == 'env://'"
|
||||
)
|
||||
if get_master_port(self.dist_master_port) is None:
|
||||
raise RuntimeError(
|
||||
"--dist_master_port or MASTER_PORT must be set "
|
||||
"if --dist_init_port == 'env://'"
|
||||
)
|
||||
|
||||
def init_torch_distributed(self, args):
|
||||
if self.distributed:
|
||||
# See:
|
||||
# https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html
|
||||
os.environ.setdefault("NCCL_DEBUG", "INFO")
|
||||
|
||||
# See:
|
||||
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
|
||||
os.environ.setdefault("NCCL_BLOCKING_WAIT", "1")
|
||||
|
||||
torch.distributed.init_process_group(backend=self.dist_backend,
|
||||
init_method=self.dist_init_method,
|
||||
world_size=args.dist_world_size,
|
||||
rank=args.dist_rank)
|
||||
self.dist_rank = torch.distributed.get_rank()
|
||||
self.dist_world_size = torch.distributed.get_world_size()
|
||||
self.local_rank = args.local_rank
|
||||
|
||||
def init_options_pai(self):
|
||||
if self.distributed:
|
||||
if self.dist_init_method == "env://":
|
||||
if get_master_addr(self.dist_master_addr, self.dist_launcher) is None:
|
||||
raise RuntimeError(
|
||||
"--dist_master_addr or MASTER_ADDR must be set "
|
||||
"if --dist_init_method == 'env://'"
|
||||
)
|
||||
if get_master_port(self.dist_master_port) is None:
|
||||
raise RuntimeError(
|
||||
"--dist_master_port or MASTER_PORT must be set "
|
||||
"if --dist_init_port == 'env://'"
|
||||
)
|
||||
|
||||
self.dist_rank = get_rank(self.dist_rank, self.dist_launcher)
|
||||
self.dist_world_size = get_world_size(
|
||||
self.dist_world_size, self.dist_launcher
|
||||
)
|
||||
self.local_rank = get_local_rank(self.local_rank, self.dist_launcher)
|
||||
|
||||
if (
|
||||
self.dist_rank is not None
|
||||
and self.dist_world_size is not None
|
||||
and self.dist_rank >= self.dist_world_size
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"RANK >= WORLD_SIZE: {self.dist_rank} >= {self.dist_world_size}"
|
||||
)
|
||||
|
||||
if self.dist_init_method == "env://":
|
||||
self.dist_master_addr = get_master_addr(
|
||||
self.dist_master_addr, self.dist_launcher
|
||||
)
|
||||
self.dist_master_port = get_master_port(self.dist_master_port)
|
||||
if (
|
||||
self.dist_master_addr is not None
|
||||
and self.dist_master_port is not None
|
||||
):
|
||||
self.dist_init_method = (
|
||||
f"tcp://{self.dist_master_addr}:{self.dist_master_port}"
|
||||
)
|
||||
|
||||
def init_torch_distributed_pai(self, args):
|
||||
if self.distributed:
|
||||
# See:
|
||||
# https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html
|
||||
os.environ.setdefault("NCCL_DEBUG", "INFO")
|
||||
|
||||
# See:
|
||||
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
|
||||
os.environ.setdefault("NCCL_BLOCKING_WAIT", "1")
|
||||
|
||||
torch.distributed.init_process_group(backend=self.dist_backend, init_method='env://')
|
||||
self.dist_rank = torch.distributed.get_rank()
|
||||
self.dist_world_size = torch.distributed.get_world_size()
|
||||
self.local_rank = args.local_rank
|
||||
|
||||
|
||||
def resolve_distributed_mode(args):
|
||||
# Note that args.distributed is set by only this function.
|
||||
# and ArgumentParser doesn't have such option
|
||||
|
||||
if args.multiprocessing_distributed:
|
||||
num_nodes = get_num_nodes(args.dist_world_size, args.dist_launcher)
|
||||
# a. multi-node
|
||||
if num_nodes > 1:
|
||||
args.distributed = True
|
||||
# b. single-node and multi-gpu with multiprocessing_distributed mode
|
||||
elif args.ngpu > 1:
|
||||
args.distributed = True
|
||||
# c. single-node and single-gpu
|
||||
else:
|
||||
args.distributed = False
|
||||
|
||||
if args.ngpu <= 1:
|
||||
# Disable multiprocessing_distributed mode if 1process per node or cpu mode
|
||||
args.multiprocessing_distributed = False
|
||||
if args.ngpu == 1:
|
||||
# If the number of GPUs equals to 1 with multiprocessing_distributed mode,
|
||||
# LOCAL_RANK is always 0
|
||||
args.local_rank = 0
|
||||
|
||||
if num_nodes > 1 and get_node_rank(args.dist_rank, args.dist_launcher) is None:
|
||||
raise RuntimeError(
|
||||
"--dist_rank or RANK must be set "
|
||||
"if --multiprocessing_distributed == true"
|
||||
)
|
||||
|
||||
# Note that RANK, LOCAL_RANK, and WORLD_SIZE is automatically set,
|
||||
# so we don't need to check here
|
||||
else:
|
||||
# d. multiprocess and multi-gpu with external launcher
|
||||
# e.g. torch.distributed.launch
|
||||
if get_world_size(args.dist_world_size, args.dist_launcher) > 1:
|
||||
args.distributed = True
|
||||
# e. single-process
|
||||
else:
|
||||
args.distributed = False
|
||||
|
||||
if args.distributed and args.ngpu > 0:
|
||||
if get_local_rank(args.local_rank, args.dist_launcher) is None:
|
||||
raise RuntimeError(
|
||||
"--local_rank or LOCAL_RANK must be set "
|
||||
"if --multiprocessing_distributed == false"
|
||||
)
|
||||
if args.distributed:
|
||||
if get_node_rank(args.dist_rank, args.dist_launcher) is None:
|
||||
raise RuntimeError(
|
||||
"--dist_rank or RANK must be set "
|
||||
"if --multiprocessing_distributed == false"
|
||||
)
|
||||
if args.distributed and args.dist_launcher == "slurm" and not is_in_slurm_step():
|
||||
raise RuntimeError("Launch by 'srun' command if --dist_launcher='slurm'")
|
||||
|
||||
|
||||
def is_in_slurm_job() -> bool:
|
||||
return "SLURM_PROCID" in os.environ and "SLURM_NTASKS" in os.environ
|
||||
|
||||
|
||||
def is_in_slurm_step() -> bool:
|
||||
return (
|
||||
is_in_slurm_job()
|
||||
and "SLURM_STEP_NUM_NODES" in os.environ
|
||||
and "SLURM_STEP_NODELIST" in os.environ
|
||||
)
|
||||
|
||||
|
||||
def _int_or_none(x: Optional[str]) -> Optional[int]:
|
||||
if x is None:
|
||||
return x
|
||||
return int(x)
|
||||
|
||||
|
||||
def free_port():
|
||||
"""Find free port using bind().
|
||||
|
||||
There are some interval between finding this port and using it
|
||||
and the other process might catch the port by that time.
|
||||
Thus it is not guaranteed that the port is really empty.
|
||||
|
||||
"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def get_rank(prior=None, launcher: str = None) -> Optional[int]:
|
||||
if prior is None:
|
||||
if launcher == "slurm":
|
||||
if not is_in_slurm_step():
|
||||
raise RuntimeError("This process seems not to be launched by 'srun'")
|
||||
prior = os.environ["SLURM_PROCID"]
|
||||
elif launcher == "mpi":
|
||||
raise RuntimeError(
|
||||
"launcher=mpi is used for 'multiprocessing-distributed' mode"
|
||||
)
|
||||
elif launcher is not None:
|
||||
raise RuntimeError(f"launcher='{launcher}' is not supported")
|
||||
|
||||
if prior is not None:
|
||||
return int(prior)
|
||||
else:
|
||||
# prior is None and RANK is None -> RANK = None
|
||||
return _int_or_none(os.environ.get("RANK"))
|
||||
|
||||
|
||||
def get_world_size(prior=None, launcher: str = None) -> int:
|
||||
if prior is None:
|
||||
if launcher == "slurm":
|
||||
if not is_in_slurm_step():
|
||||
raise RuntimeError("This process seems not to be launched by 'srun'")
|
||||
prior = int(os.environ["SLURM_NTASKS"])
|
||||
elif launcher == "mpi":
|
||||
raise RuntimeError(
|
||||
"launcher=mpi is used for 'multiprocessing-distributed' mode"
|
||||
)
|
||||
elif launcher is not None:
|
||||
raise RuntimeError(f"launcher='{launcher}' is not supported")
|
||||
|
||||
if prior is not None:
|
||||
return int(prior)
|
||||
else:
|
||||
# prior is None and WORLD_SIZE is None -> WORLD_SIZE = 1
|
||||
return int(os.environ.get("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
def get_local_rank(prior=None, launcher: str = None) -> Optional[int]:
|
||||
# LOCAL_RANK is same as GPU device id
|
||||
|
||||
if prior is None:
|
||||
if launcher == "slurm":
|
||||
if not is_in_slurm_step():
|
||||
raise RuntimeError("This process seems not to be launched by 'srun'")
|
||||
|
||||
prior = int(os.environ["SLURM_LOCALID"])
|
||||
elif launcher == "mpi":
|
||||
raise RuntimeError(
|
||||
"launcher=mpi is used for 'multiprocessing-distributed' mode"
|
||||
)
|
||||
elif launcher is not None:
|
||||
raise RuntimeError(f"launcher='{launcher}' is not supported")
|
||||
|
||||
if prior is not None:
|
||||
return int(prior)
|
||||
|
||||
elif "LOCAL_RANK" in os.environ:
|
||||
return int(os.environ["LOCAL_RANK"])
|
||||
|
||||
elif "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
# There are two possibility:
|
||||
# - "CUDA_VISIBLE_DEVICES" is set to multiple GPU ids. e.g. "0.1,2"
|
||||
# => This intends to specify multiple devices to to be used exactly
|
||||
# and local_rank information is possibly insufficient.
|
||||
# - "CUDA_VISIBLE_DEVICES" is set to an id. e.g. "1"
|
||||
# => This could be used for LOCAL_RANK
|
||||
cvd = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
||||
if len(cvd) == 1 and "LOCAL_RANK" not in os.environ:
|
||||
# If CUDA_VISIBLE_DEVICES is set and LOCAL_RANK is not set,
|
||||
# then use it as LOCAL_RANK.
|
||||
|
||||
# Unset CUDA_VISIBLE_DEVICES
|
||||
# because the other device must be visible to communicate
|
||||
return int(os.environ.pop("CUDA_VISIBLE_DEVICES"))
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_master_addr(prior=None, launcher: str = None) -> Optional[str]:
|
||||
if prior is None:
|
||||
if launcher == "slurm":
|
||||
if not is_in_slurm_step():
|
||||
raise RuntimeError("This process seems not to be launched by 'srun'")
|
||||
|
||||
# e.g nodelist = foo[1-10],bar[3-8] or foo4,bar[2-10]
|
||||
nodelist = os.environ["SLURM_STEP_NODELIST"]
|
||||
prior = nodelist.split(",")[0].split("-")[0].replace("[", "")
|
||||
|
||||
if prior is not None:
|
||||
return str(prior)
|
||||
else:
|
||||
return os.environ.get("MASTER_ADDR")
|
||||
|
||||
|
||||
def get_master_port(prior=None) -> Optional[int]:
|
||||
if prior is not None:
|
||||
return prior
|
||||
else:
|
||||
return _int_or_none(os.environ.get("MASTER_PORT"))
|
||||
|
||||
|
||||
def get_node_rank(prior=None, launcher: str = None) -> Optional[int]:
|
||||
"""Get Node Rank.
|
||||
|
||||
Use for "multiprocessing distributed" mode.
|
||||
The initial RANK equals to the Node id in this case and
|
||||
the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed.
|
||||
|
||||
"""
|
||||
if prior is not None:
|
||||
return prior
|
||||
elif launcher == "slurm":
|
||||
if not is_in_slurm_step():
|
||||
raise RuntimeError("This process seems not to be launched by 'srun'")
|
||||
|
||||
# Assume ntasks_per_node == 1
|
||||
if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]:
|
||||
raise RuntimeError(
|
||||
"Run with --ntasks_per_node=1 if mutliprocessing_distributed=true"
|
||||
)
|
||||
return int(os.environ["SLURM_NODEID"])
|
||||
elif launcher == "mpi":
|
||||
# Use mpi4py only for initialization and not using for communication
|
||||
from mpi4py import MPI
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
# Assume ntasks_per_node == 1 (We can't check whether it is or not)
|
||||
return comm.Get_rank()
|
||||
elif launcher is not None:
|
||||
raise RuntimeError(f"launcher='{launcher}' is not supported")
|
||||
else:
|
||||
return _int_or_none(os.environ.get("RANK"))
|
||||
|
||||
|
||||
def get_num_nodes(prior=None, launcher: str = None) -> Optional[int]:
|
||||
"""Get the number of nodes.
|
||||
|
||||
Use for "multiprocessing distributed" mode.
|
||||
RANK equals to the Node id in this case and
|
||||
the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed.
|
||||
|
||||
"""
|
||||
if prior is not None:
|
||||
return prior
|
||||
elif launcher == "slurm":
|
||||
if not is_in_slurm_step():
|
||||
raise RuntimeError("This process seems not to be launched by 'srun'")
|
||||
|
||||
# Assume ntasks_per_node == 1
|
||||
if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]:
|
||||
raise RuntimeError(
|
||||
"Run with --ntasks_per_node=1 if mutliprocessing_distributed=true"
|
||||
)
|
||||
return int(os.environ["SLURM_STEP_NUM_NODES"])
|
||||
elif launcher == "mpi":
|
||||
# Use mpi4py only for initialization and not using for communication
|
||||
from mpi4py import MPI
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
# Assume ntasks_per_node == 1 (We can't check whether it is or not)
|
||||
return comm.Get_size()
|
||||
elif launcher is not None:
|
||||
raise RuntimeError(f"launcher='{launcher}' is not supported")
|
||||
else:
|
||||
# prior is None -> NUM_NODES = 1
|
||||
return int(os.environ.get("WORLD_SIZE", 1))
|
||||
540
funasr_local/train/reporter.py
Normal file
540
funasr_local/train/reporter.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""Reporter module."""
|
||||
import dataclasses
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import ContextManager
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
Num = Union[float, int, complex, torch.Tensor, np.ndarray]
|
||||
|
||||
_reserved = {"time", "total_count"}
|
||||
|
||||
|
||||
def to_reported_value(v: Num, weight: Num = None) -> "ReportedValue":
|
||||
assert check_argument_types()
|
||||
if isinstance(v, (torch.Tensor, np.ndarray)):
|
||||
if np.prod(v.shape) != 1:
|
||||
raise ValueError(f"v must be 0 or 1 dimension: {len(v.shape)}")
|
||||
v = v.item()
|
||||
|
||||
if isinstance(weight, (torch.Tensor, np.ndarray)):
|
||||
if np.prod(weight.shape) != 1:
|
||||
raise ValueError(f"weight must be 0 or 1 dimension: {len(weight.shape)}")
|
||||
weight = weight.item()
|
||||
|
||||
if weight is not None:
|
||||
retval = WeightedAverage(v, weight)
|
||||
else:
|
||||
retval = Average(v)
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
|
||||
def aggregate(values: Sequence["ReportedValue"]) -> Num:
|
||||
assert check_argument_types()
|
||||
|
||||
for v in values:
|
||||
if not isinstance(v, type(values[0])):
|
||||
raise ValueError(
|
||||
f"Can't use different Reported type together: "
|
||||
f"{type(v)} != {type(values[0])}"
|
||||
)
|
||||
|
||||
if len(values) == 0:
|
||||
warnings.warn("No stats found")
|
||||
retval = np.nan
|
||||
|
||||
elif isinstance(values[0], Average):
|
||||
retval = np.nanmean([v.value for v in values])
|
||||
|
||||
elif isinstance(values[0], WeightedAverage):
|
||||
# Excludes non finite values
|
||||
invalid_indices = set()
|
||||
for i, v in enumerate(values):
|
||||
if not np.isfinite(v.value) or not np.isfinite(v.weight):
|
||||
invalid_indices.add(i)
|
||||
values = [v for i, v in enumerate(values) if i not in invalid_indices]
|
||||
|
||||
if len(values) != 0:
|
||||
# Calc weighed average. Weights are changed to sum-to-1.
|
||||
sum_weights = sum(v.weight for i, v in enumerate(values))
|
||||
sum_value = sum(v.value * v.weight for i, v in enumerate(values))
|
||||
if sum_weights == 0:
|
||||
warnings.warn("weight is zero")
|
||||
retval = np.nan
|
||||
else:
|
||||
retval = sum_value / sum_weights
|
||||
else:
|
||||
warnings.warn("No valid stats found")
|
||||
retval = np.nan
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"type={type(values[0])}")
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
|
||||
def wandb_get_prefix(key: str):
|
||||
if key.startswith("valid"):
|
||||
return "valid/"
|
||||
if key.startswith("train"):
|
||||
return "train/"
|
||||
if key.startswith("attn"):
|
||||
return "attn/"
|
||||
return "metrics/"
|
||||
|
||||
|
||||
class ReportedValue:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Average(ReportedValue):
|
||||
value: Num
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WeightedAverage(ReportedValue):
|
||||
value: Tuple[Num, Num]
|
||||
weight: Num
|
||||
|
||||
|
||||
class SubReporter:
|
||||
"""This class is used in Reporter.
|
||||
|
||||
See the docstring of Reporter for the usage.
|
||||
"""
|
||||
|
||||
def __init__(self, key: str, epoch: int, total_count: int):
|
||||
assert check_argument_types()
|
||||
self.key = key
|
||||
self.epoch = epoch
|
||||
self.start_time = time.perf_counter()
|
||||
self.stats = defaultdict(list)
|
||||
self._finished = False
|
||||
self.total_count = total_count
|
||||
self.count = 0
|
||||
self._seen_keys_in_the_step = set()
|
||||
|
||||
def get_total_count(self) -> int:
|
||||
"""Returns the number of iterations over all epochs."""
|
||||
return self.total_count
|
||||
|
||||
def get_epoch(self) -> int:
|
||||
return self.epoch
|
||||
|
||||
def next(self):
|
||||
"""Close up this step and reset state for the next step"""
|
||||
for key, stats_list in self.stats.items():
|
||||
if key not in self._seen_keys_in_the_step:
|
||||
# Fill nan value if the key is not registered in this step
|
||||
if isinstance(stats_list[0], WeightedAverage):
|
||||
stats_list.append(to_reported_value(np.nan, 0))
|
||||
elif isinstance(stats_list[0], Average):
|
||||
stats_list.append(to_reported_value(np.nan))
|
||||
else:
|
||||
raise NotImplementedError(f"type={type(stats_list[0])}")
|
||||
|
||||
assert len(stats_list) == self.count, (len(stats_list), self.count)
|
||||
|
||||
self._seen_keys_in_the_step = set()
|
||||
|
||||
def register(
|
||||
self,
|
||||
stats: Dict[str, Optional[Union[Num, Dict[str, Num]]]],
|
||||
weight: Num = None,
|
||||
) -> None:
|
||||
assert check_argument_types()
|
||||
if self._finished:
|
||||
raise RuntimeError("Already finished")
|
||||
if len(self._seen_keys_in_the_step) == 0:
|
||||
# Increment count as the first register in this step
|
||||
self.total_count += 1
|
||||
self.count += 1
|
||||
|
||||
for key2, v in stats.items():
|
||||
if key2 in _reserved:
|
||||
raise RuntimeError(f"{key2} is reserved.")
|
||||
if key2 in self._seen_keys_in_the_step:
|
||||
raise RuntimeError(f"{key2} is registered twice.")
|
||||
if v is None:
|
||||
v = np.nan
|
||||
r = to_reported_value(v, weight)
|
||||
|
||||
if key2 not in self.stats:
|
||||
# If it's the first time to register the key,
|
||||
# append nan values in front of the the value
|
||||
# to make it same length to the other stats
|
||||
# e.g.
|
||||
# stat A: [0.4, 0.3, 0.5]
|
||||
# stat B: [nan, nan, 0.2]
|
||||
nan = to_reported_value(np.nan, None if weight is None else 0)
|
||||
self.stats[key2].extend(
|
||||
r if i == self.count - 1 else nan for i in range(self.count)
|
||||
)
|
||||
else:
|
||||
self.stats[key2].append(r)
|
||||
self._seen_keys_in_the_step.add(key2)
|
||||
|
||||
def log_message(self, start: int = None, end: int = None, num_updates: int = None) -> str:
|
||||
if self._finished:
|
||||
raise RuntimeError("Already finished")
|
||||
if start is None:
|
||||
start = 0
|
||||
if start < 0:
|
||||
start = self.count + start
|
||||
if end is None:
|
||||
end = self.count
|
||||
|
||||
if self.count == 0 or start == end:
|
||||
return ""
|
||||
|
||||
message = f"{self.epoch}epoch:{self.key}:" f"{start + 1}-{end}batch:"
|
||||
if num_updates is not None:
|
||||
message += f"{num_updates}num_updates: "
|
||||
|
||||
for idx, (key2, stats_list) in enumerate(self.stats.items()):
|
||||
assert len(stats_list) == self.count, (len(stats_list), self.count)
|
||||
# values: List[ReportValue]
|
||||
values = stats_list[start:end]
|
||||
if idx != 0 and idx != len(stats_list):
|
||||
message += ", "
|
||||
|
||||
v = aggregate(values)
|
||||
if abs(v) > 1.0e3:
|
||||
message += f"{key2}={v:.3e}"
|
||||
elif abs(v) > 1.0e-3:
|
||||
message += f"{key2}={v:.3f}"
|
||||
else:
|
||||
message += f"{key2}={v:.3e}"
|
||||
return message
|
||||
|
||||
def tensorboard_add_scalar(self, summary_writer, start: int = None):
|
||||
if start is None:
|
||||
start = 0
|
||||
if start < 0:
|
||||
start = self.count + start
|
||||
|
||||
for key2, stats_list in self.stats.items():
|
||||
assert len(stats_list) == self.count, (len(stats_list), self.count)
|
||||
# values: List[ReportValue]
|
||||
values = stats_list[start:]
|
||||
v = aggregate(values)
|
||||
summary_writer.add_scalar(f"{key2}", v, self.total_count)
|
||||
|
||||
def wandb_log(self, start: int = None):
|
||||
import wandb
|
||||
|
||||
if start is None:
|
||||
start = 0
|
||||
if start < 0:
|
||||
start = self.count + start
|
||||
|
||||
d = {}
|
||||
for key2, stats_list in self.stats.items():
|
||||
assert len(stats_list) == self.count, (len(stats_list), self.count)
|
||||
# values: List[ReportValue]
|
||||
values = stats_list[start:]
|
||||
v = aggregate(values)
|
||||
d[wandb_get_prefix(key2) + key2] = v
|
||||
d["iteration"] = self.total_count
|
||||
wandb.log(d)
|
||||
|
||||
def finished(self) -> None:
|
||||
self._finished = True
|
||||
|
||||
@contextmanager
|
||||
def measure_time(self, name: str):
|
||||
start = time.perf_counter()
|
||||
yield start
|
||||
t = time.perf_counter() - start
|
||||
self.register({name: t})
|
||||
|
||||
def measure_iter_time(self, iterable, name: str):
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
start = time.perf_counter()
|
||||
retval = next(iterator)
|
||||
t = time.perf_counter() - start
|
||||
self.register({name: t})
|
||||
yield retval
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
|
||||
class Reporter:
|
||||
"""Reporter class.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> reporter = Reporter()
|
||||
>>> with reporter.observe('train') as sub_reporter:
|
||||
... for batch in iterator:
|
||||
... stats = dict(loss=0.2)
|
||||
... sub_reporter.register(stats)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, epoch: int = 0):
|
||||
assert check_argument_types()
|
||||
if epoch < 0:
|
||||
raise ValueError(f"epoch must be 0 or more: {epoch}")
|
||||
self.epoch = epoch
|
||||
# stats: Dict[int, Dict[str, Dict[str, float]]]
|
||||
# e.g. self.stats[epoch]['train']['loss']
|
||||
self.stats = {}
|
||||
|
||||
def get_epoch(self) -> int:
|
||||
return self.epoch
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if epoch < 0:
|
||||
raise ValueError(f"epoch must be 0 or more: {epoch}")
|
||||
self.epoch = epoch
|
||||
|
||||
@contextmanager
|
||||
def observe(self, key: str, epoch: int = None) -> ContextManager[SubReporter]:
|
||||
sub_reporter = self.start_epoch(key, epoch)
|
||||
yield sub_reporter
|
||||
# Receive the stats from sub_reporter
|
||||
self.finish_epoch(sub_reporter)
|
||||
|
||||
def start_epoch(self, key: str, epoch: int = None) -> SubReporter:
|
||||
if epoch is not None:
|
||||
if epoch < 0:
|
||||
raise ValueError(f"epoch must be 0 or more: {epoch}")
|
||||
self.epoch = epoch
|
||||
|
||||
if self.epoch - 1 not in self.stats or key not in self.stats[self.epoch - 1]:
|
||||
# If the previous epoch doesn't exist for some reason,
|
||||
# maybe due to bug, this case also indicates 0-count.
|
||||
if self.epoch - 1 != 0:
|
||||
warnings.warn(
|
||||
f"The stats of the previous epoch={self.epoch - 1}"
|
||||
f"doesn't exist."
|
||||
)
|
||||
total_count = 0
|
||||
else:
|
||||
total_count = self.stats[self.epoch - 1][key]["total_count"]
|
||||
|
||||
sub_reporter = SubReporter(key, self.epoch, total_count)
|
||||
# Clear the stats for the next epoch if it exists
|
||||
self.stats.pop(epoch, None)
|
||||
return sub_reporter
|
||||
|
||||
def finish_epoch(self, sub_reporter: SubReporter) -> None:
|
||||
if self.epoch != sub_reporter.epoch:
|
||||
raise RuntimeError(
|
||||
f"Don't change epoch during observation: "
|
||||
f"{self.epoch} != {sub_reporter.epoch}"
|
||||
)
|
||||
|
||||
# Calc mean of current stats and set it as previous epochs stats
|
||||
stats = {}
|
||||
for key2, values in sub_reporter.stats.items():
|
||||
v = aggregate(values)
|
||||
stats[key2] = v
|
||||
|
||||
stats["time"] = datetime.timedelta(
|
||||
seconds=time.perf_counter() - sub_reporter.start_time
|
||||
)
|
||||
stats["total_count"] = sub_reporter.total_count
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
|
||||
if torch.cuda.is_initialized():
|
||||
stats["gpu_max_cached_mem_GB"] = (
|
||||
torch.cuda.max_memory_reserved() / 2 ** 30
|
||||
)
|
||||
else:
|
||||
if torch.cuda.is_available() and torch.cuda.max_memory_cached() > 0:
|
||||
stats["gpu_cached_mem_GB"] = torch.cuda.max_memory_cached() / 2 ** 30
|
||||
|
||||
self.stats.setdefault(self.epoch, {})[sub_reporter.key] = stats
|
||||
sub_reporter.finished()
|
||||
|
||||
def sort_epochs_and_values(
|
||||
self, key: str, key2: str, mode: str
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""Return the epoch which resulted the best value.
|
||||
|
||||
Example:
|
||||
>>> val = reporter.sort_epochs_and_values('eval', 'loss', 'min')
|
||||
>>> e_1best, v_1best = val[0]
|
||||
>>> e_2best, v_2best = val[1]
|
||||
"""
|
||||
if mode not in ("min", "max"):
|
||||
raise ValueError(f"mode must min or max: {mode}")
|
||||
if not self.has(key, key2):
|
||||
raise KeyError(f"{key}.{key2} is not found: {self.get_all_keys()}")
|
||||
|
||||
# iterate from the last epoch
|
||||
values = [(e, self.stats[e][key][key2]) for e in self.stats]
|
||||
|
||||
if mode == "min":
|
||||
values = sorted(values, key=lambda x: x[1])
|
||||
else:
|
||||
values = sorted(values, key=lambda x: -x[1])
|
||||
return values
|
||||
|
||||
def sort_epochs(self, key: str, key2: str, mode: str) -> List[int]:
|
||||
return [e for e, v in self.sort_epochs_and_values(key, key2, mode)]
|
||||
|
||||
def sort_values(self, key: str, key2: str, mode: str) -> List[float]:
|
||||
return [v for e, v in self.sort_epochs_and_values(key, key2, mode)]
|
||||
|
||||
def get_best_epoch(self, key: str, key2: str, mode: str, nbest: int = 0) -> int:
|
||||
return self.sort_epochs(key, key2, mode)[nbest]
|
||||
|
||||
def check_early_stopping(
|
||||
self,
|
||||
patience: int,
|
||||
key1: str,
|
||||
key2: str,
|
||||
mode: str,
|
||||
epoch: int = None,
|
||||
logger=None,
|
||||
) -> bool:
|
||||
if logger is None:
|
||||
logger = logging
|
||||
if epoch is None:
|
||||
epoch = self.get_epoch()
|
||||
|
||||
best_epoch = self.get_best_epoch(key1, key2, mode)
|
||||
if epoch - best_epoch > patience:
|
||||
logger.info(
|
||||
f"[Early stopping] {key1}.{key2} has not been "
|
||||
f"improved {epoch - best_epoch} epochs continuously. "
|
||||
f"The training was stopped at {epoch}epoch"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def has(self, key: str, key2: str, epoch: int = None) -> bool:
|
||||
if epoch is None:
|
||||
epoch = self.get_epoch()
|
||||
return (
|
||||
epoch in self.stats
|
||||
and key in self.stats[epoch]
|
||||
and key2 in self.stats[epoch][key]
|
||||
)
|
||||
|
||||
def log_message(self, epoch: int = None) -> str:
|
||||
if epoch is None:
|
||||
epoch = self.get_epoch()
|
||||
|
||||
message = ""
|
||||
for key, d in self.stats[epoch].items():
|
||||
_message = ""
|
||||
for key2, v in d.items():
|
||||
if v is not None:
|
||||
if len(_message) != 0:
|
||||
_message += ", "
|
||||
if isinstance(v, float):
|
||||
if abs(v) > 1.0e3:
|
||||
_message += f"{key2}={v:.3e}"
|
||||
elif abs(v) > 1.0e-3:
|
||||
_message += f"{key2}={v:.3f}"
|
||||
else:
|
||||
_message += f"{key2}={v:.3e}"
|
||||
elif isinstance(v, datetime.timedelta):
|
||||
_v = humanfriendly.format_timespan(v)
|
||||
_message += f"{key2}={_v}"
|
||||
else:
|
||||
_message += f"{key2}={v}"
|
||||
if len(_message) != 0:
|
||||
if len(message) == 0:
|
||||
message += f"{epoch}epoch results: "
|
||||
else:
|
||||
message += ", "
|
||||
message += f"[{key}] {_message}"
|
||||
return message
|
||||
|
||||
def get_value(self, key: str, key2: str, epoch: int = None):
|
||||
if not self.has(key, key2):
|
||||
raise KeyError(f"{key}.{key2} is not found in stats: {self.get_all_keys()}")
|
||||
if epoch is None:
|
||||
epoch = self.get_epoch()
|
||||
return self.stats[epoch][key][key2]
|
||||
|
||||
def get_keys(self, epoch: int = None) -> Tuple[str, ...]:
|
||||
"""Returns keys1 e.g. train,eval."""
|
||||
if epoch is None:
|
||||
epoch = self.get_epoch()
|
||||
return tuple(self.stats[epoch])
|
||||
|
||||
def get_keys2(self, key: str, epoch: int = None) -> Tuple[str, ...]:
|
||||
"""Returns keys2 e.g. loss,acc."""
|
||||
if epoch is None:
|
||||
epoch = self.get_epoch()
|
||||
d = self.stats[epoch][key]
|
||||
keys2 = tuple(k for k in d if k not in ("time", "total_count"))
|
||||
return keys2
|
||||
|
||||
def get_all_keys(self, epoch: int = None) -> Tuple[Tuple[str, str], ...]:
|
||||
if epoch is None:
|
||||
epoch = self.get_epoch()
|
||||
all_keys = []
|
||||
for key in self.stats[epoch]:
|
||||
for key2 in self.stats[epoch][key]:
|
||||
all_keys.append((key, key2))
|
||||
return tuple(all_keys)
|
||||
|
||||
def tensorboard_add_scalar(
|
||||
self, summary_writer, epoch: int = None, key1: str = None
|
||||
):
|
||||
if epoch is None:
|
||||
epoch = self.get_epoch()
|
||||
total_count = self.stats[epoch]["train"]["total_count"]
|
||||
if key1 == "train":
|
||||
summary_writer.add_scalar("iter_epoch", epoch, total_count)
|
||||
|
||||
if key1 is not None:
|
||||
key1_iterator = tuple([key1])
|
||||
else:
|
||||
key1_iterator = self.get_keys(epoch)
|
||||
|
||||
for key1 in key1_iterator:
|
||||
for key2 in self.get_keys2(key1):
|
||||
summary_writer.add_scalar(
|
||||
f"{key2}", self.stats[epoch][key1][key2], total_count
|
||||
)
|
||||
|
||||
def wandb_log(self, epoch: int = None):
|
||||
import wandb
|
||||
|
||||
if epoch is None:
|
||||
epoch = self.get_epoch()
|
||||
|
||||
d = {}
|
||||
for key1 in self.get_keys(epoch):
|
||||
for key2 in self.stats[epoch][key1]:
|
||||
if key2 in ("time", "total_count"):
|
||||
continue
|
||||
key = f"{key1}_{key2}_epoch"
|
||||
d[wandb_get_prefix(key) + key] = self.stats[epoch][key1][key2]
|
||||
d["epoch"] = epoch
|
||||
wandb.log(d)
|
||||
|
||||
def state_dict(self):
|
||||
return {"stats": self.stats, "epoch": self.epoch}
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
self.epoch = state_dict["epoch"]
|
||||
self.stats = state_dict["stats"]
|
||||
840
funasr_local/train/trainer.py
Normal file
840
funasr_local/train/trainer.py
Normal file
@@ -0,0 +1,840 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Trainer module."""
|
||||
import argparse
|
||||
from contextlib import contextmanager
|
||||
import dataclasses
|
||||
from dataclasses import is_dataclass
|
||||
from distutils.version import LooseVersion
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
import oss2
|
||||
from io import BytesIO
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.iterators.abs_iter_factory import AbsIterFactory
|
||||
from funasr_local.main_funcs.average_nbest_models import average_nbest_models
|
||||
from funasr_local.main_funcs.calculate_all_attentions import calculate_all_attentions
|
||||
from funasr_local.schedulers.abs_scheduler import AbsBatchStepScheduler
|
||||
from funasr_local.schedulers.abs_scheduler import AbsEpochStepScheduler
|
||||
from funasr_local.schedulers.abs_scheduler import AbsScheduler
|
||||
from funasr_local.schedulers.abs_scheduler import AbsValEpochStepScheduler
|
||||
from funasr_local.torch_utils.add_gradient_noise import add_gradient_noise
|
||||
from funasr_local.torch_utils.device_funcs import to_device
|
||||
from funasr_local.torch_utils.recursive_op import recursive_average
|
||||
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
from funasr_local.train.distributed_utils import DistributedOption
|
||||
from funasr_local.train.reporter import Reporter
|
||||
from funasr_local.train.reporter import SubReporter
|
||||
from funasr_local.utils.build_dataclass import build_dataclass
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
GradScaler = None
|
||||
|
||||
try:
|
||||
import fairscale
|
||||
except ImportError:
|
||||
fairscale = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainerOptions:
|
||||
ngpu: int
|
||||
resume: bool
|
||||
use_amp: bool
|
||||
train_dtype: str
|
||||
grad_noise: bool
|
||||
accum_grad: int
|
||||
grad_clip: float
|
||||
grad_clip_type: float
|
||||
log_interval: Optional[int]
|
||||
no_forward_run: bool
|
||||
use_tensorboard: bool
|
||||
use_wandb: bool
|
||||
output_dir: Union[Path, str]
|
||||
max_epoch: int
|
||||
max_update: int
|
||||
seed: int
|
||||
sharded_ddp: bool
|
||||
patience: Optional[int]
|
||||
keep_nbest_models: Union[int, List[int]]
|
||||
nbest_averaging_interval: int
|
||||
early_stopping_criterion: Sequence[str]
|
||||
best_model_criterion: Sequence[Sequence[str]]
|
||||
val_scheduler_criterion: Sequence[str]
|
||||
unused_parameters: bool
|
||||
wandb_model_log_interval: int
|
||||
use_pai: bool
|
||||
oss_bucket: Union[oss2.Bucket, None]
|
||||
batch_interval: int
|
||||
|
||||
class Trainer:
|
||||
"""Trainer having a optimizer.
|
||||
|
||||
If you'd like to use multiple optimizers, then inherit this class
|
||||
and override the methods if necessary - at least "train_one_epoch()"
|
||||
|
||||
>>> class TwoOptimizerTrainer(Trainer):
|
||||
... @classmethod
|
||||
... def add_arguments(cls, parser):
|
||||
... ...
|
||||
...
|
||||
... @classmethod
|
||||
... def train_one_epoch(cls, model, optimizers, ...):
|
||||
... loss1 = model.model1(...)
|
||||
... loss1.backward()
|
||||
... optimizers[0].step()
|
||||
...
|
||||
... loss2 = model.model2(...)
|
||||
... loss2.backward()
|
||||
... optimizers[1].step()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("This class can't be instantiated.")
|
||||
|
||||
@classmethod
|
||||
def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
|
||||
"""Build options consumed by train(), eval()"""
|
||||
assert check_argument_types()
|
||||
return build_dataclass(TrainerOptions, args)
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
"""Reserved for future development of another Trainer"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def resume(
|
||||
checkpoint: Union[str, Path],
|
||||
model: torch.nn.Module,
|
||||
reporter: Reporter,
|
||||
optimizers: Sequence[torch.optim.Optimizer],
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
scaler: Optional[GradScaler],
|
||||
ngpu: int = 0,
|
||||
):
|
||||
states = torch.load(
|
||||
checkpoint,
|
||||
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
|
||||
)
|
||||
model.load_state_dict(states["model"])
|
||||
reporter.load_state_dict(states["reporter"])
|
||||
for optimizer, state in zip(optimizers, states["optimizers"]):
|
||||
optimizer.load_state_dict(state)
|
||||
for scheduler, state in zip(schedulers, states["schedulers"]):
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(state)
|
||||
if scaler is not None:
|
||||
if states["scaler"] is None:
|
||||
logging.warning("scaler state is not found")
|
||||
else:
|
||||
scaler.load_state_dict(states["scaler"])
|
||||
|
||||
logging.info(f"The training was resumed using {checkpoint}")
|
||||
|
||||
@classmethod
|
||||
def run(
|
||||
cls,
|
||||
model: AbsESPnetModel,
|
||||
optimizers: Sequence[torch.optim.Optimizer],
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
train_iter_factory: AbsIterFactory,
|
||||
valid_iter_factory: AbsIterFactory,
|
||||
trainer_options,
|
||||
distributed_option: DistributedOption,
|
||||
) -> None:
|
||||
"""Perform training. This method performs the main process of training."""
|
||||
assert check_argument_types()
|
||||
# NOTE(kamo): Don't check the type more strictly as far trainer_options
|
||||
assert is_dataclass(trainer_options), type(trainer_options)
|
||||
assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
|
||||
|
||||
if isinstance(trainer_options.keep_nbest_models, int):
|
||||
keep_nbest_models = [trainer_options.keep_nbest_models]
|
||||
else:
|
||||
if len(trainer_options.keep_nbest_models) == 0:
|
||||
logging.warning("No keep_nbest_models is given. Change to [1]")
|
||||
trainer_options.keep_nbest_models = [1]
|
||||
keep_nbest_models = trainer_options.keep_nbest_models
|
||||
|
||||
output_dir = Path(trainer_options.output_dir)
|
||||
reporter = Reporter()
|
||||
if trainer_options.use_amp:
|
||||
if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
|
||||
raise RuntimeError(
|
||||
"Require torch>=1.6.0 for Automatic Mixed Precision"
|
||||
)
|
||||
if trainer_options.sharded_ddp:
|
||||
if fairscale is None:
|
||||
raise RuntimeError(
|
||||
"Requiring fairscale. Do 'pip install fairscale'"
|
||||
)
|
||||
scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
|
||||
else:
|
||||
scaler = GradScaler()
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
|
||||
cls.resume(
|
||||
checkpoint=output_dir / "checkpoint.pb",
|
||||
model=model,
|
||||
optimizers=optimizers,
|
||||
schedulers=schedulers,
|
||||
reporter=reporter,
|
||||
scaler=scaler,
|
||||
ngpu=trainer_options.ngpu,
|
||||
)
|
||||
|
||||
start_epoch = reporter.get_epoch() + 1
|
||||
if start_epoch == trainer_options.max_epoch + 1:
|
||||
logging.warning(
|
||||
f"The training has already reached at max_epoch: {start_epoch}"
|
||||
)
|
||||
|
||||
if distributed_option.distributed:
|
||||
if trainer_options.sharded_ddp:
|
||||
dp_model = fairscale.nn.data_parallel.ShardedDataParallel(
|
||||
module=model,
|
||||
sharded_optimizer=optimizers,
|
||||
)
|
||||
else:
|
||||
dp_model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, find_unused_parameters=trainer_options.unused_parameters)
|
||||
elif distributed_option.ngpu > 1:
|
||||
dp_model = torch.nn.parallel.DataParallel(
|
||||
model,
|
||||
device_ids=list(range(distributed_option.ngpu)),
|
||||
)
|
||||
else:
|
||||
# NOTE(kamo): DataParallel also should work with ngpu=1,
|
||||
# but for debuggability it's better to keep this block.
|
||||
dp_model = model
|
||||
|
||||
if trainer_options.use_tensorboard and (
|
||||
not distributed_option.distributed or distributed_option.dist_rank == 0
|
||||
):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
if trainer_options.use_pai:
|
||||
train_summary_writer = SummaryWriter(
|
||||
os.path.join(trainer_options.output_dir, "tensorboard/train")
|
||||
)
|
||||
valid_summary_writer = SummaryWriter(
|
||||
os.path.join(trainer_options.output_dir, "tensorboard/valid")
|
||||
)
|
||||
else:
|
||||
train_summary_writer = SummaryWriter(
|
||||
str(output_dir / "tensorboard" / "train")
|
||||
)
|
||||
valid_summary_writer = SummaryWriter(
|
||||
str(output_dir / "tensorboard" / "valid")
|
||||
)
|
||||
else:
|
||||
train_summary_writer = None
|
||||
|
||||
start_time = time.perf_counter()
|
||||
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
|
||||
if iepoch != start_epoch:
|
||||
logging.info(
|
||||
"{}/{}epoch started. Estimated time to finish: {}".format(
|
||||
iepoch,
|
||||
trainer_options.max_epoch,
|
||||
humanfriendly.format_timespan(
|
||||
(time.perf_counter() - start_time)
|
||||
/ (iepoch - start_epoch)
|
||||
* (trainer_options.max_epoch - iepoch + 1)
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
|
||||
set_all_random_seed(trainer_options.seed + iepoch)
|
||||
|
||||
reporter.set_epoch(iepoch)
|
||||
# 1. Train and validation for one-epoch
|
||||
with reporter.observe("train") as sub_reporter:
|
||||
all_steps_are_invalid, max_update_stop = cls.train_one_epoch(
|
||||
model=dp_model,
|
||||
optimizers=optimizers,
|
||||
schedulers=schedulers,
|
||||
iterator=train_iter_factory.build_iter(iepoch),
|
||||
reporter=sub_reporter,
|
||||
scaler=scaler,
|
||||
summary_writer=train_summary_writer,
|
||||
options=trainer_options,
|
||||
distributed_option=distributed_option,
|
||||
)
|
||||
|
||||
with reporter.observe("valid") as sub_reporter:
|
||||
cls.validate_one_epoch(
|
||||
model=dp_model,
|
||||
iterator=valid_iter_factory.build_iter(iepoch),
|
||||
reporter=sub_reporter,
|
||||
options=trainer_options,
|
||||
distributed_option=distributed_option,
|
||||
)
|
||||
|
||||
# 2. LR Scheduler step
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, AbsValEpochStepScheduler):
|
||||
scheduler.step(
|
||||
reporter.get_value(*trainer_options.val_scheduler_criterion)
|
||||
)
|
||||
elif isinstance(scheduler, AbsEpochStepScheduler):
|
||||
scheduler.step()
|
||||
if trainer_options.sharded_ddp:
|
||||
for optimizer in optimizers:
|
||||
if isinstance(optimizer, fairscale.optim.oss.OSS):
|
||||
optimizer.consolidate_state_dict()
|
||||
|
||||
if not distributed_option.distributed or distributed_option.dist_rank == 0:
|
||||
# 3. Report the results
|
||||
logging.info(reporter.log_message())
|
||||
if train_summary_writer is not None:
|
||||
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
|
||||
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
|
||||
if trainer_options.use_wandb:
|
||||
reporter.wandb_log()
|
||||
|
||||
# save tensorboard on oss
|
||||
if trainer_options.use_pai and train_summary_writer is not None:
|
||||
def write_tensorboard_summary(summary_writer_path, oss_bucket):
|
||||
file_list = []
|
||||
for root, dirs, files in os.walk(summary_writer_path, topdown=False):
|
||||
for name in files:
|
||||
file_full_path = os.path.join(root, name)
|
||||
file_list.append(file_full_path)
|
||||
|
||||
for file_full_path in file_list:
|
||||
with open(file_full_path, "rb") as f:
|
||||
oss_bucket.put_object(file_full_path, f)
|
||||
|
||||
write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"), trainer_options.oss_bucket)
|
||||
write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"), trainer_options.oss_bucket)
|
||||
|
||||
|
||||
# 4. Save/Update the checkpoint
|
||||
if trainer_options.use_pai:
|
||||
buffer = BytesIO()
|
||||
torch.save(
|
||||
{
|
||||
"model": model.state_dict(),
|
||||
"reporter": reporter.state_dict(),
|
||||
"optimizers": [o.state_dict() for o in optimizers],
|
||||
"schedulers": [
|
||||
s.state_dict() if s is not None else None
|
||||
for s in schedulers
|
||||
],
|
||||
"scaler": scaler.state_dict() if scaler is not None else None,
|
||||
"ema_model": model.encoder.ema.model.state_dict()
|
||||
if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
|
||||
},
|
||||
buffer,
|
||||
)
|
||||
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), buffer.getvalue())
|
||||
else:
|
||||
torch.save(
|
||||
{
|
||||
"model": model.state_dict(),
|
||||
"reporter": reporter.state_dict(),
|
||||
"optimizers": [o.state_dict() for o in optimizers],
|
||||
"schedulers": [
|
||||
s.state_dict() if s is not None else None
|
||||
for s in schedulers
|
||||
],
|
||||
"scaler": scaler.state_dict() if scaler is not None else None,
|
||||
},
|
||||
output_dir / "checkpoint.pb",
|
||||
)
|
||||
|
||||
# 5. Save and log the model and update the link to the best model
|
||||
if trainer_options.use_pai:
|
||||
buffer = BytesIO()
|
||||
torch.save(model.state_dict(), buffer)
|
||||
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
|
||||
f"{iepoch}epoch.pb"),buffer.getvalue())
|
||||
else:
|
||||
torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
|
||||
|
||||
# Creates a sym link latest.pb -> {iepoch}epoch.pb
|
||||
if trainer_options.use_pai:
|
||||
p = os.path.join(trainer_options.output_dir, "latest.pb")
|
||||
if trainer_options.oss_bucket.object_exists(p):
|
||||
trainer_options.oss_bucket.delete_object(p)
|
||||
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
|
||||
os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"), p)
|
||||
else:
|
||||
p = output_dir / "latest.pb"
|
||||
if p.is_symlink() or p.exists():
|
||||
p.unlink()
|
||||
p.symlink_to(f"{iepoch}epoch.pb")
|
||||
|
||||
_improved = []
|
||||
for _phase, k, _mode in trainer_options.best_model_criterion:
|
||||
# e.g. _phase, k, _mode = "train", "loss", "min"
|
||||
if reporter.has(_phase, k):
|
||||
best_epoch = reporter.get_best_epoch(_phase, k, _mode)
|
||||
# Creates sym links if it's the best result
|
||||
if best_epoch == iepoch:
|
||||
if trainer_options.use_pai:
|
||||
p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
|
||||
if trainer_options.oss_bucket.object_exists(p):
|
||||
trainer_options.oss_bucket.delete_object(p)
|
||||
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
|
||||
os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"),p)
|
||||
else:
|
||||
p = output_dir / f"{_phase}.{k}.best.pb"
|
||||
if p.is_symlink() or p.exists():
|
||||
p.unlink()
|
||||
p.symlink_to(f"{iepoch}epoch.pb")
|
||||
_improved.append(f"{_phase}.{k}")
|
||||
if len(_improved) == 0:
|
||||
logging.info("There are no improvements in this epoch")
|
||||
else:
|
||||
logging.info(
|
||||
"The best model has been updated: " + ", ".join(_improved)
|
||||
)
|
||||
|
||||
log_model = (
|
||||
trainer_options.wandb_model_log_interval > 0
|
||||
and iepoch % trainer_options.wandb_model_log_interval == 0
|
||||
)
|
||||
if log_model and trainer_options.use_wandb:
|
||||
import wandb
|
||||
|
||||
logging.info("Logging Model on this epoch :::::")
|
||||
artifact = wandb.Artifact(
|
||||
name=f"model_{wandb.run.id}",
|
||||
type="model",
|
||||
metadata={"improved": _improved},
|
||||
)
|
||||
artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
|
||||
aliases = [
|
||||
f"epoch-{iepoch}",
|
||||
"best" if best_epoch == iepoch else "",
|
||||
]
|
||||
wandb.log_artifact(artifact, aliases=aliases)
|
||||
|
||||
# 6. Remove the model files excluding n-best epoch and latest epoch
|
||||
_removed = []
|
||||
# Get the union set of the n-best among multiple criterion
|
||||
nbests = set().union(
|
||||
*[
|
||||
set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
|
||||
for ph, k, m in trainer_options.best_model_criterion
|
||||
if reporter.has(ph, k)
|
||||
]
|
||||
)
|
||||
|
||||
# Generated n-best averaged model
|
||||
if (
|
||||
trainer_options.nbest_averaging_interval > 0
|
||||
and iepoch % trainer_options.nbest_averaging_interval == 0
|
||||
):
|
||||
average_nbest_models(
|
||||
reporter=reporter,
|
||||
output_dir=output_dir,
|
||||
best_model_criterion=trainer_options.best_model_criterion,
|
||||
nbest=keep_nbest_models,
|
||||
suffix=f"till{iepoch}epoch",
|
||||
oss_bucket=trainer_options.oss_bucket,
|
||||
pai_output_dir=trainer_options.output_dir,
|
||||
)
|
||||
|
||||
for e in range(1, iepoch):
|
||||
if trainer_options.use_pai:
|
||||
p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
|
||||
if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
|
||||
trainer_options.oss_bucket.delete_object(p)
|
||||
_removed.append(str(p))
|
||||
else:
|
||||
p = output_dir / f"{e}epoch.pb"
|
||||
if p.exists() and e not in nbests:
|
||||
p.unlink()
|
||||
_removed.append(str(p))
|
||||
if len(_removed) != 0:
|
||||
logging.info("The model files were removed: " + ", ".join(_removed))
|
||||
|
||||
# 7. If any updating haven't happened, stops the training
|
||||
if all_steps_are_invalid:
|
||||
logging.warning(
|
||||
f"The gradients at all steps are invalid in this epoch. "
|
||||
f"Something seems wrong. This training was stopped at {iepoch}epoch"
|
||||
)
|
||||
break
|
||||
|
||||
if max_update_stop:
|
||||
logging.info(
|
||||
f"Stopping training due to "
|
||||
f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
|
||||
)
|
||||
break
|
||||
|
||||
# 8. Check early stopping
|
||||
if trainer_options.patience is not None:
|
||||
if reporter.check_early_stopping(
|
||||
trainer_options.patience, *trainer_options.early_stopping_criterion
|
||||
):
|
||||
break
|
||||
|
||||
else:
|
||||
logging.info(
|
||||
f"The training was finished at {trainer_options.max_epoch} epochs "
|
||||
)
|
||||
|
||||
# Generated n-best averaged model
|
||||
if not distributed_option.distributed or distributed_option.dist_rank == 0:
|
||||
average_nbest_models(
|
||||
reporter=reporter,
|
||||
output_dir=output_dir,
|
||||
best_model_criterion=trainer_options.best_model_criterion,
|
||||
nbest=keep_nbest_models,
|
||||
oss_bucket=trainer_options.oss_bucket,
|
||||
pai_output_dir=trainer_options.output_dir,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def train_one_epoch(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
|
||||
optimizers: Sequence[torch.optim.Optimizer],
|
||||
schedulers: Sequence[Optional[AbsScheduler]],
|
||||
scaler: Optional[GradScaler],
|
||||
reporter: SubReporter,
|
||||
summary_writer,
|
||||
options: TrainerOptions,
|
||||
distributed_option: DistributedOption,
|
||||
) -> Tuple[bool, bool]:
|
||||
assert check_argument_types()
|
||||
|
||||
grad_noise = options.grad_noise
|
||||
accum_grad = options.accum_grad
|
||||
grad_clip = options.grad_clip
|
||||
grad_clip_type = options.grad_clip_type
|
||||
log_interval = options.log_interval
|
||||
no_forward_run = options.no_forward_run
|
||||
ngpu = options.ngpu
|
||||
use_wandb = options.use_wandb
|
||||
distributed = distributed_option.distributed
|
||||
|
||||
if log_interval is None:
|
||||
try:
|
||||
log_interval = max(len(iterator) // 20, 10)
|
||||
except TypeError:
|
||||
log_interval = 100
|
||||
|
||||
model.train()
|
||||
all_steps_are_invalid = True
|
||||
max_update_stop = False
|
||||
# [For distributed] Because iteration counts are not always equals between
|
||||
# processes, send stop-flag to the other processes if iterator is finished
|
||||
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
|
||||
|
||||
#get the rank
|
||||
rank = distributed_option.dist_rank
|
||||
#get the num batch updates
|
||||
num_batch_updates = 0
|
||||
#ouput dir
|
||||
output_dir = Path(options.output_dir)
|
||||
#batch interval
|
||||
batch_interval = options.batch_interval
|
||||
|
||||
start_time = time.perf_counter()
|
||||
for iiter, (_, batch) in enumerate(
|
||||
reporter.measure_iter_time(iterator, "iter_time"), 1
|
||||
):
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
|
||||
if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
|
||||
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
|
||||
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
|
||||
if num_batch_updates % batch_interval == 0:
|
||||
if options.use_pai and options.oss_bucket is not None:
|
||||
buffer = BytesIO()
|
||||
if hasattr(model, "module"):
|
||||
torch.save(model.module.state_dict(), buffer)
|
||||
else:
|
||||
torch.save(model.state_dict(), buffer)
|
||||
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
|
||||
else:
|
||||
if hasattr(model, "module"):
|
||||
torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
||||
else:
|
||||
torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
||||
|
||||
if distributed:
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
if iterator_stop > 0:
|
||||
break
|
||||
|
||||
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
||||
if no_forward_run:
|
||||
all_steps_are_invalid = False
|
||||
continue
|
||||
|
||||
with autocast(scaler is not None):
|
||||
with reporter.measure_time("forward_time"):
|
||||
retval = model(**batch)
|
||||
|
||||
# Note(kamo):
|
||||
# Supporting two patterns for the returned value from the model
|
||||
# a. dict type
|
||||
if isinstance(retval, dict):
|
||||
loss = retval["loss"]
|
||||
stats = retval["stats"]
|
||||
weight = retval["weight"]
|
||||
optim_idx = retval.get("optim_idx")
|
||||
if optim_idx is not None and not isinstance(optim_idx, int):
|
||||
if not isinstance(optim_idx, torch.Tensor):
|
||||
raise RuntimeError(
|
||||
"optim_idx must be int or 1dim torch.Tensor, "
|
||||
f"but got {type(optim_idx)}"
|
||||
)
|
||||
if optim_idx.dim() >= 2:
|
||||
raise RuntimeError(
|
||||
"optim_idx must be int or 1dim torch.Tensor, "
|
||||
f"but got {optim_idx.dim()}dim tensor"
|
||||
)
|
||||
if optim_idx.dim() == 1:
|
||||
for v in optim_idx:
|
||||
if v != optim_idx[0]:
|
||||
raise RuntimeError(
|
||||
"optim_idx must be 1dim tensor "
|
||||
"having same values for all entries"
|
||||
)
|
||||
optim_idx = optim_idx[0].item()
|
||||
else:
|
||||
optim_idx = optim_idx.item()
|
||||
|
||||
# b. tuple or list type
|
||||
else:
|
||||
loss, stats, weight = retval
|
||||
optim_idx = None
|
||||
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
if ngpu > 1 or distributed:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
stats, weight = recursive_average(stats, weight, distributed)
|
||||
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight
|
||||
if distributed:
|
||||
# NOTE(kamo): Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= torch.distributed.get_world_size()
|
||||
|
||||
loss /= accum_grad
|
||||
|
||||
reporter.register(stats, weight)
|
||||
|
||||
with reporter.measure_time("backward_time"):
|
||||
if scaler is not None:
|
||||
# Scales loss. Calls backward() on scaled loss
|
||||
# to create scaled gradients.
|
||||
# Backward passes under autocast are not recommended.
|
||||
# Backward ops run in the same dtype autocast chose
|
||||
# for corresponding forward ops.
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if iiter % accum_grad == 0:
|
||||
if scaler is not None:
|
||||
# Unscales the gradients of optimizer's assigned params in-place
|
||||
for iopt, optimizer in enumerate(optimizers):
|
||||
if optim_idx is not None and iopt != optim_idx:
|
||||
continue
|
||||
scaler.unscale_(optimizer)
|
||||
|
||||
# gradient noise injection
|
||||
if grad_noise:
|
||||
add_gradient_noise(
|
||||
model,
|
||||
reporter.get_total_count(),
|
||||
duration=100,
|
||||
eta=1.0,
|
||||
scale_factor=0.55,
|
||||
)
|
||||
|
||||
# compute the gradient norm to check if it is normal or not
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
max_norm=grad_clip,
|
||||
norm_type=grad_clip_type,
|
||||
)
|
||||
# PyTorch<=1.4, clip_grad_norm_ returns float value
|
||||
if not isinstance(grad_norm, torch.Tensor):
|
||||
grad_norm = torch.tensor(grad_norm)
|
||||
|
||||
if not torch.isfinite(grad_norm):
|
||||
logging.warning(
|
||||
f"The grad norm is {grad_norm}. Skipping updating the model."
|
||||
)
|
||||
|
||||
# Must invoke scaler.update() if unscale_() is used in the iteration
|
||||
# to avoid the following error:
|
||||
# RuntimeError: unscale_() has already been called
|
||||
# on this optimizer since the last update().
|
||||
# Note that if the gradient has inf/nan values,
|
||||
# scaler.step skips optimizer.step().
|
||||
if scaler is not None:
|
||||
for iopt, optimizer in enumerate(optimizers):
|
||||
if optim_idx is not None and iopt != optim_idx:
|
||||
continue
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
else:
|
||||
all_steps_are_invalid = False
|
||||
with reporter.measure_time("optim_step_time"):
|
||||
for iopt, (optimizer, scheduler) in enumerate(
|
||||
zip(optimizers, schedulers)
|
||||
):
|
||||
if optim_idx is not None and iopt != optim_idx:
|
||||
continue
|
||||
if scaler is not None:
|
||||
# scaler.step() first unscales the gradients of
|
||||
# the optimizer's assigned params.
|
||||
scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
if isinstance(scheduler, AbsBatchStepScheduler):
|
||||
scheduler.step()
|
||||
for iopt, optimizer in enumerate(optimizers):
|
||||
if optim_idx is not None and iopt != optim_idx:
|
||||
continue
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Register lr and train/load time[sec/step],
|
||||
# where step refers to accum_grad * mini-batch
|
||||
reporter.register(
|
||||
dict(
|
||||
{
|
||||
f"optim{i}_lr{j}": pg["lr"]
|
||||
for i, optimizer in enumerate(optimizers)
|
||||
for j, pg in enumerate(optimizer.param_groups)
|
||||
if "lr" in pg
|
||||
},
|
||||
train_time=time.perf_counter() - start_time,
|
||||
),
|
||||
)
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# update num_updates
|
||||
if distributed:
|
||||
if hasattr(model.module, "num_updates"):
|
||||
model.module.set_num_updates(model.module.get_num_updates() + 1)
|
||||
options.num_updates = model.module.get_num_updates()
|
||||
if model.module.get_num_updates() >= options.max_update:
|
||||
max_update_stop = True
|
||||
else:
|
||||
if hasattr(model, "num_updates"):
|
||||
model.set_num_updates(model.get_num_updates() + 1)
|
||||
options.num_updates = model.get_num_updates()
|
||||
if model.get_num_updates() >= options.max_update:
|
||||
max_update_stop = True
|
||||
|
||||
# NOTE(kamo): Call log_message() after next()
|
||||
reporter.next()
|
||||
if iiter % log_interval == 0:
|
||||
num_updates = options.num_updates if hasattr(options, "num_updates") else None
|
||||
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
|
||||
if summary_writer is not None:
|
||||
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
|
||||
if use_wandb:
|
||||
reporter.wandb_log()
|
||||
|
||||
if max_update_stop:
|
||||
break
|
||||
|
||||
else:
|
||||
if distributed:
|
||||
iterator_stop.fill_(1)
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
return all_steps_are_invalid, max_update_stop
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def validate_one_epoch(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
iterator: Iterable[Dict[str, torch.Tensor]],
|
||||
reporter: SubReporter,
|
||||
options: TrainerOptions,
|
||||
distributed_option: DistributedOption,
|
||||
) -> None:
|
||||
assert check_argument_types()
|
||||
ngpu = options.ngpu
|
||||
no_forward_run = options.no_forward_run
|
||||
distributed = distributed_option.distributed
|
||||
|
||||
model.eval()
|
||||
|
||||
# [For distributed] Because iteration counts are not always equals between
|
||||
# processes, send stop-flag to the other processes if iterator is finished
|
||||
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
|
||||
for (_, batch) in iterator:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
if distributed:
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
if iterator_stop > 0:
|
||||
break
|
||||
|
||||
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
||||
if no_forward_run:
|
||||
continue
|
||||
|
||||
retval = model(**batch)
|
||||
if isinstance(retval, dict):
|
||||
stats = retval["stats"]
|
||||
weight = retval["weight"]
|
||||
else:
|
||||
_, stats, weight = retval
|
||||
if ngpu > 1 or distributed:
|
||||
# Apply weighted averaging for stats.
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
stats, weight = recursive_average(stats, weight, distributed)
|
||||
|
||||
reporter.register(stats, weight)
|
||||
reporter.next()
|
||||
|
||||
else:
|
||||
if distributed:
|
||||
iterator_stop.fill_(1)
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
Reference in New Issue
Block a user