add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

View File

View File

@@ -0,0 +1,55 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from abc import ABC
from abc import abstractmethod
from typing import Dict
from typing import Tuple
import torch
class AbsESPnetModel(torch.nn.Module, ABC):
"""The common abstract class among each tasks
"ESPnetModel" is referred to a class which inherits torch.nn.Module,
and makes the dnn-models forward as its member field,
a.k.a delegate pattern,
and defines "loss", "stats", and "weight" for the task.
If you intend to implement new task in ESPNet,
the model must inherit this class.
In other words, the "mediator" objects between
our training system and the your task class are
just only these three values, loss, stats, and weight.
Example:
>>> from funasr_local.tasks.abs_task import AbsTask
>>> class YourESPnetModel(AbsESPnetModel):
... def forward(self, input, input_lengths):
... ...
... return loss, stats, weight
>>> class YourTask(AbsTask):
... @classmethod
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
"""
def __init__(self):
super().__init__()
self.num_updates = 0
@abstractmethod
def forward(
self, **batch: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
raise NotImplementedError
@abstractmethod
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def set_num_updates(self, num_updates):
self.num_updates = num_updates
def get_num_updates(self):
return self.num_updates

View File

@@ -0,0 +1,192 @@
from abc import ABC
from abc import abstractmethod
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.torch_utils.device_funcs import force_gatherable
from funasr_local.train.abs_espnet_model import AbsESPnetModel
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract class
To share the loss calculation way among different models,
We uses delegate pattern here:
The instance of this class should be passed to "LanguageModel"
This "model" is one of mediator objects for "Task" class.
"""
@abstractmethod
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def with_vad(self) -> bool:
raise NotImplementedError
class PunctuationModel(AbsESPnetModel):
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
assert check_argument_types()
super().__init__()
self.punc_model = punc_model
self.punc_weight = torch.Tensor(punc_weight)
self.sos = 1
self.eos = 2
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id
# if self.punc_model.with_vad():
# print("This is a vad puncuation model.")
def nll(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
max_length: Optional[int] = None,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
punc: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, :text_lengths.max()]
punc = punc[:, :text_lengths.max()]
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
if self.punc_model.with_vad():
# Should be VadRealtimeTransformer
assert vad_indexes is not None
y, _ = self.punc_model(text, text_lengths, vad_indexes)
else:
# Should be TargetDelayTransformer,
y, _ = self.punc_model(text, text_lengths)
# Calc negative log likelihood
# nll: (BxL,)
if self.training == False:
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
from sklearn.metrics import f1_score
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
indices.squeeze(-1).detach().cpu().numpy(),
average='micro')
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
return nll, text_lengths
else:
self.punc_weight = self.punc_weight.to(punc.device)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
ignore_index=self.ignore_id)
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
else:
nll.masked_fill_(
make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
0.0,
)
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, text_lengths
def batchify_nll(self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll) from transformer language model
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
text: (Batch, Length)
punc: (Batch, Length)
text_lengths: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
"""
total_num = text.size(0)
if total_num <= batch_size:
nll, x_lengths = self.nll(text, punc, text_lengths)
else:
nlls = []
x_lengths = []
max_length = text_lengths.max()
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_text = text[start_idx:end_idx, :]
batch_punc = punc[start_idx:end_idx, :]
batch_text_lengths = text_lengths[start_idx:end_idx]
# batch_nll: [B * T]
batch_nll, batch_x_lengths = self.nll(batch_text, batch_punc, batch_text_lengths, max_length=max_length)
nlls.append(batch_nll)
x_lengths.append(batch_x_lengths)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nlls)
x_lengths = torch.cat(x_lengths)
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
def forward(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
return {}
def inference(self,
text: torch.Tensor,
text_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
if self.punc_model.with_vad():
assert vad_indexes is not None
return self.punc_model(text, text_lengths, vad_indexes)
else:
return self.punc_model(text, text_lengths)

View File

@@ -0,0 +1,95 @@
from typing import Mapping
from typing import Optional
from typing import Tuple
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.utils.nested_dict_action import NestedDictAction
from funasr_local.utils.types import str_or_none
class ClassChoices:
"""Helper class to manage the options for variable objects and its configuration.
Example:
>>> class A:
... def __init__(self, foo=3): pass
>>> class B:
... def __init__(self, bar="aaaa"): pass
>>> choices = ClassChoices("var", dict(a=A, b=B), default="a")
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> choices.add_arguments(parser)
>>> args = parser.parse_args(["--var", "a", "--var_conf", "foo=4")
>>> args.var
a
>>> args.var_conf
{"foo": 4}
>>> class_obj = choices.get_class(args.var)
>>> a_object = class_obj(**args.var_conf)
"""
def __init__(
self,
name: str,
classes: Mapping[str, type],
type_check: type = None,
default: str = None,
optional: bool = False,
):
assert check_argument_types()
self.name = name
self.base_type = type_check
self.classes = {k.lower(): v for k, v in classes.items()}
if "none" in self.classes or "nil" in self.classes or "null" in self.classes:
raise ValueError('"none", "nil", and "null" are reserved.')
if type_check is not None:
for v in self.classes.values():
if not issubclass(v, type_check):
raise ValueError(f"must be {type_check.__name__}, but got {v}")
self.optional = optional
self.default = default
if default is None:
self.optional = True
def choices(self) -> Tuple[Optional[str], ...]:
retval = tuple(self.classes)
if self.optional:
return retval + (None,)
else:
return retval
def get_class(self, name: Optional[str]) -> Optional[type]:
assert check_argument_types()
if name is None or (self.optional and name.lower() == ("none", "null", "nil")):
retval = None
elif name.lower() in self.classes:
class_obj = self.classes[name]
assert check_return_type(class_obj)
retval = class_obj
else:
raise ValueError(
f"--{self.name} must be one of {self.choices()}: "
f"--{self.name} {name.lower()}"
)
return retval
def add_arguments(self, parser):
parser.add_argument(
f"--{self.name}",
type=lambda x: str_or_none(x.lower()),
default=self.default,
choices=self.choices(),
help=f"The {self.name} type",
)
parser.add_argument(
f"--{self.name}_conf",
action=NestedDictAction,
default=dict(),
help=f"The keyword arguments for {self.name}",
)

View File

@@ -0,0 +1,380 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import dataclasses
import logging
import os
import socket
from typing import Optional
import torch
import torch.distributed
@dataclasses.dataclass
class DistributedOption:
# Enable distributed Training
distributed: bool = False
# torch.distributed.Backend: "nccl", "mpi", "gloo", or "tcp"
dist_backend: str = "nccl"
# if init_method="env://",
# env values of "MASTER_PORT", "MASTER_ADDR", "WORLD_SIZE", and "RANK" are referred.
dist_init_method: str = "env://"
dist_world_size: Optional[int] = None
dist_rank: Optional[int] = None
local_rank: Optional[int] = None
ngpu: int = 0
dist_master_addr: Optional[str] = None
dist_master_port: Optional[int] = None
dist_launcher: Optional[str] = None
multiprocessing_distributed: bool = True
def init_options(self):
if self.distributed:
if self.dist_init_method == "env://":
if get_master_addr(self.dist_master_addr, self.dist_launcher) is None:
raise RuntimeError(
"--dist_master_addr or MASTER_ADDR must be set "
"if --dist_init_method == 'env://'"
)
if get_master_port(self.dist_master_port) is None:
raise RuntimeError(
"--dist_master_port or MASTER_PORT must be set "
"if --dist_init_port == 'env://'"
)
def init_torch_distributed(self, args):
if self.distributed:
# See:
# https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html
os.environ.setdefault("NCCL_DEBUG", "INFO")
# See:
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
os.environ.setdefault("NCCL_BLOCKING_WAIT", "1")
torch.distributed.init_process_group(backend=self.dist_backend,
init_method=self.dist_init_method,
world_size=args.dist_world_size,
rank=args.dist_rank)
self.dist_rank = torch.distributed.get_rank()
self.dist_world_size = torch.distributed.get_world_size()
self.local_rank = args.local_rank
def init_options_pai(self):
if self.distributed:
if self.dist_init_method == "env://":
if get_master_addr(self.dist_master_addr, self.dist_launcher) is None:
raise RuntimeError(
"--dist_master_addr or MASTER_ADDR must be set "
"if --dist_init_method == 'env://'"
)
if get_master_port(self.dist_master_port) is None:
raise RuntimeError(
"--dist_master_port or MASTER_PORT must be set "
"if --dist_init_port == 'env://'"
)
self.dist_rank = get_rank(self.dist_rank, self.dist_launcher)
self.dist_world_size = get_world_size(
self.dist_world_size, self.dist_launcher
)
self.local_rank = get_local_rank(self.local_rank, self.dist_launcher)
if (
self.dist_rank is not None
and self.dist_world_size is not None
and self.dist_rank >= self.dist_world_size
):
raise RuntimeError(
f"RANK >= WORLD_SIZE: {self.dist_rank} >= {self.dist_world_size}"
)
if self.dist_init_method == "env://":
self.dist_master_addr = get_master_addr(
self.dist_master_addr, self.dist_launcher
)
self.dist_master_port = get_master_port(self.dist_master_port)
if (
self.dist_master_addr is not None
and self.dist_master_port is not None
):
self.dist_init_method = (
f"tcp://{self.dist_master_addr}:{self.dist_master_port}"
)
def init_torch_distributed_pai(self, args):
if self.distributed:
# See:
# https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html
os.environ.setdefault("NCCL_DEBUG", "INFO")
# See:
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
os.environ.setdefault("NCCL_BLOCKING_WAIT", "1")
torch.distributed.init_process_group(backend=self.dist_backend, init_method='env://')
self.dist_rank = torch.distributed.get_rank()
self.dist_world_size = torch.distributed.get_world_size()
self.local_rank = args.local_rank
def resolve_distributed_mode(args):
# Note that args.distributed is set by only this function.
# and ArgumentParser doesn't have such option
if args.multiprocessing_distributed:
num_nodes = get_num_nodes(args.dist_world_size, args.dist_launcher)
# a. multi-node
if num_nodes > 1:
args.distributed = True
# b. single-node and multi-gpu with multiprocessing_distributed mode
elif args.ngpu > 1:
args.distributed = True
# c. single-node and single-gpu
else:
args.distributed = False
if args.ngpu <= 1:
# Disable multiprocessing_distributed mode if 1process per node or cpu mode
args.multiprocessing_distributed = False
if args.ngpu == 1:
# If the number of GPUs equals to 1 with multiprocessing_distributed mode,
# LOCAL_RANK is always 0
args.local_rank = 0
if num_nodes > 1 and get_node_rank(args.dist_rank, args.dist_launcher) is None:
raise RuntimeError(
"--dist_rank or RANK must be set "
"if --multiprocessing_distributed == true"
)
# Note that RANK, LOCAL_RANK, and WORLD_SIZE is automatically set,
# so we don't need to check here
else:
# d. multiprocess and multi-gpu with external launcher
# e.g. torch.distributed.launch
if get_world_size(args.dist_world_size, args.dist_launcher) > 1:
args.distributed = True
# e. single-process
else:
args.distributed = False
if args.distributed and args.ngpu > 0:
if get_local_rank(args.local_rank, args.dist_launcher) is None:
raise RuntimeError(
"--local_rank or LOCAL_RANK must be set "
"if --multiprocessing_distributed == false"
)
if args.distributed:
if get_node_rank(args.dist_rank, args.dist_launcher) is None:
raise RuntimeError(
"--dist_rank or RANK must be set "
"if --multiprocessing_distributed == false"
)
if args.distributed and args.dist_launcher == "slurm" and not is_in_slurm_step():
raise RuntimeError("Launch by 'srun' command if --dist_launcher='slurm'")
def is_in_slurm_job() -> bool:
return "SLURM_PROCID" in os.environ and "SLURM_NTASKS" in os.environ
def is_in_slurm_step() -> bool:
return (
is_in_slurm_job()
and "SLURM_STEP_NUM_NODES" in os.environ
and "SLURM_STEP_NODELIST" in os.environ
)
def _int_or_none(x: Optional[str]) -> Optional[int]:
if x is None:
return x
return int(x)
def free_port():
"""Find free port using bind().
There are some interval between finding this port and using it
and the other process might catch the port by that time.
Thus it is not guaranteed that the port is really empty.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
def get_rank(prior=None, launcher: str = None) -> Optional[int]:
if prior is None:
if launcher == "slurm":
if not is_in_slurm_step():
raise RuntimeError("This process seems not to be launched by 'srun'")
prior = os.environ["SLURM_PROCID"]
elif launcher == "mpi":
raise RuntimeError(
"launcher=mpi is used for 'multiprocessing-distributed' mode"
)
elif launcher is not None:
raise RuntimeError(f"launcher='{launcher}' is not supported")
if prior is not None:
return int(prior)
else:
# prior is None and RANK is None -> RANK = None
return _int_or_none(os.environ.get("RANK"))
def get_world_size(prior=None, launcher: str = None) -> int:
if prior is None:
if launcher == "slurm":
if not is_in_slurm_step():
raise RuntimeError("This process seems not to be launched by 'srun'")
prior = int(os.environ["SLURM_NTASKS"])
elif launcher == "mpi":
raise RuntimeError(
"launcher=mpi is used for 'multiprocessing-distributed' mode"
)
elif launcher is not None:
raise RuntimeError(f"launcher='{launcher}' is not supported")
if prior is not None:
return int(prior)
else:
# prior is None and WORLD_SIZE is None -> WORLD_SIZE = 1
return int(os.environ.get("WORLD_SIZE", "1"))
def get_local_rank(prior=None, launcher: str = None) -> Optional[int]:
# LOCAL_RANK is same as GPU device id
if prior is None:
if launcher == "slurm":
if not is_in_slurm_step():
raise RuntimeError("This process seems not to be launched by 'srun'")
prior = int(os.environ["SLURM_LOCALID"])
elif launcher == "mpi":
raise RuntimeError(
"launcher=mpi is used for 'multiprocessing-distributed' mode"
)
elif launcher is not None:
raise RuntimeError(f"launcher='{launcher}' is not supported")
if prior is not None:
return int(prior)
elif "LOCAL_RANK" in os.environ:
return int(os.environ["LOCAL_RANK"])
elif "CUDA_VISIBLE_DEVICES" in os.environ:
# There are two possibility:
# - "CUDA_VISIBLE_DEVICES" is set to multiple GPU ids. e.g. "0.1,2"
# => This intends to specify multiple devices to to be used exactly
# and local_rank information is possibly insufficient.
# - "CUDA_VISIBLE_DEVICES" is set to an id. e.g. "1"
# => This could be used for LOCAL_RANK
cvd = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
if len(cvd) == 1 and "LOCAL_RANK" not in os.environ:
# If CUDA_VISIBLE_DEVICES is set and LOCAL_RANK is not set,
# then use it as LOCAL_RANK.
# Unset CUDA_VISIBLE_DEVICES
# because the other device must be visible to communicate
return int(os.environ.pop("CUDA_VISIBLE_DEVICES"))
else:
return None
else:
return None
def get_master_addr(prior=None, launcher: str = None) -> Optional[str]:
if prior is None:
if launcher == "slurm":
if not is_in_slurm_step():
raise RuntimeError("This process seems not to be launched by 'srun'")
# e.g nodelist = foo[1-10],bar[3-8] or foo4,bar[2-10]
nodelist = os.environ["SLURM_STEP_NODELIST"]
prior = nodelist.split(",")[0].split("-")[0].replace("[", "")
if prior is not None:
return str(prior)
else:
return os.environ.get("MASTER_ADDR")
def get_master_port(prior=None) -> Optional[int]:
if prior is not None:
return prior
else:
return _int_or_none(os.environ.get("MASTER_PORT"))
def get_node_rank(prior=None, launcher: str = None) -> Optional[int]:
"""Get Node Rank.
Use for "multiprocessing distributed" mode.
The initial RANK equals to the Node id in this case and
the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed.
"""
if prior is not None:
return prior
elif launcher == "slurm":
if not is_in_slurm_step():
raise RuntimeError("This process seems not to be launched by 'srun'")
# Assume ntasks_per_node == 1
if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]:
raise RuntimeError(
"Run with --ntasks_per_node=1 if mutliprocessing_distributed=true"
)
return int(os.environ["SLURM_NODEID"])
elif launcher == "mpi":
# Use mpi4py only for initialization and not using for communication
from mpi4py import MPI
comm = MPI.COMM_WORLD
# Assume ntasks_per_node == 1 (We can't check whether it is or not)
return comm.Get_rank()
elif launcher is not None:
raise RuntimeError(f"launcher='{launcher}' is not supported")
else:
return _int_or_none(os.environ.get("RANK"))
def get_num_nodes(prior=None, launcher: str = None) -> Optional[int]:
"""Get the number of nodes.
Use for "multiprocessing distributed" mode.
RANK equals to the Node id in this case and
the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed.
"""
if prior is not None:
return prior
elif launcher == "slurm":
if not is_in_slurm_step():
raise RuntimeError("This process seems not to be launched by 'srun'")
# Assume ntasks_per_node == 1
if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]:
raise RuntimeError(
"Run with --ntasks_per_node=1 if mutliprocessing_distributed=true"
)
return int(os.environ["SLURM_STEP_NUM_NODES"])
elif launcher == "mpi":
# Use mpi4py only for initialization and not using for communication
from mpi4py import MPI
comm = MPI.COMM_WORLD
# Assume ntasks_per_node == 1 (We can't check whether it is or not)
return comm.Get_size()
elif launcher is not None:
raise RuntimeError(f"launcher='{launcher}' is not supported")
else:
# prior is None -> NUM_NODES = 1
return int(os.environ.get("WORLD_SIZE", 1))

View File

@@ -0,0 +1,540 @@
"""Reporter module."""
import dataclasses
import datetime
import logging
import time
import warnings
from collections import defaultdict
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import ContextManager
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import humanfriendly
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
Num = Union[float, int, complex, torch.Tensor, np.ndarray]
_reserved = {"time", "total_count"}
def to_reported_value(v: Num, weight: Num = None) -> "ReportedValue":
assert check_argument_types()
if isinstance(v, (torch.Tensor, np.ndarray)):
if np.prod(v.shape) != 1:
raise ValueError(f"v must be 0 or 1 dimension: {len(v.shape)}")
v = v.item()
if isinstance(weight, (torch.Tensor, np.ndarray)):
if np.prod(weight.shape) != 1:
raise ValueError(f"weight must be 0 or 1 dimension: {len(weight.shape)}")
weight = weight.item()
if weight is not None:
retval = WeightedAverage(v, weight)
else:
retval = Average(v)
assert check_return_type(retval)
return retval
def aggregate(values: Sequence["ReportedValue"]) -> Num:
assert check_argument_types()
for v in values:
if not isinstance(v, type(values[0])):
raise ValueError(
f"Can't use different Reported type together: "
f"{type(v)} != {type(values[0])}"
)
if len(values) == 0:
warnings.warn("No stats found")
retval = np.nan
elif isinstance(values[0], Average):
retval = np.nanmean([v.value for v in values])
elif isinstance(values[0], WeightedAverage):
# Excludes non finite values
invalid_indices = set()
for i, v in enumerate(values):
if not np.isfinite(v.value) or not np.isfinite(v.weight):
invalid_indices.add(i)
values = [v for i, v in enumerate(values) if i not in invalid_indices]
if len(values) != 0:
# Calc weighed average. Weights are changed to sum-to-1.
sum_weights = sum(v.weight for i, v in enumerate(values))
sum_value = sum(v.value * v.weight for i, v in enumerate(values))
if sum_weights == 0:
warnings.warn("weight is zero")
retval = np.nan
else:
retval = sum_value / sum_weights
else:
warnings.warn("No valid stats found")
retval = np.nan
else:
raise NotImplementedError(f"type={type(values[0])}")
assert check_return_type(retval)
return retval
def wandb_get_prefix(key: str):
if key.startswith("valid"):
return "valid/"
if key.startswith("train"):
return "train/"
if key.startswith("attn"):
return "attn/"
return "metrics/"
class ReportedValue:
pass
@dataclasses.dataclass(frozen=True)
class Average(ReportedValue):
value: Num
@dataclasses.dataclass(frozen=True)
class WeightedAverage(ReportedValue):
value: Tuple[Num, Num]
weight: Num
class SubReporter:
"""This class is used in Reporter.
See the docstring of Reporter for the usage.
"""
def __init__(self, key: str, epoch: int, total_count: int):
assert check_argument_types()
self.key = key
self.epoch = epoch
self.start_time = time.perf_counter()
self.stats = defaultdict(list)
self._finished = False
self.total_count = total_count
self.count = 0
self._seen_keys_in_the_step = set()
def get_total_count(self) -> int:
"""Returns the number of iterations over all epochs."""
return self.total_count
def get_epoch(self) -> int:
return self.epoch
def next(self):
"""Close up this step and reset state for the next step"""
for key, stats_list in self.stats.items():
if key not in self._seen_keys_in_the_step:
# Fill nan value if the key is not registered in this step
if isinstance(stats_list[0], WeightedAverage):
stats_list.append(to_reported_value(np.nan, 0))
elif isinstance(stats_list[0], Average):
stats_list.append(to_reported_value(np.nan))
else:
raise NotImplementedError(f"type={type(stats_list[0])}")
assert len(stats_list) == self.count, (len(stats_list), self.count)
self._seen_keys_in_the_step = set()
def register(
self,
stats: Dict[str, Optional[Union[Num, Dict[str, Num]]]],
weight: Num = None,
) -> None:
assert check_argument_types()
if self._finished:
raise RuntimeError("Already finished")
if len(self._seen_keys_in_the_step) == 0:
# Increment count as the first register in this step
self.total_count += 1
self.count += 1
for key2, v in stats.items():
if key2 in _reserved:
raise RuntimeError(f"{key2} is reserved.")
if key2 in self._seen_keys_in_the_step:
raise RuntimeError(f"{key2} is registered twice.")
if v is None:
v = np.nan
r = to_reported_value(v, weight)
if key2 not in self.stats:
# If it's the first time to register the key,
# append nan values in front of the the value
# to make it same length to the other stats
# e.g.
# stat A: [0.4, 0.3, 0.5]
# stat B: [nan, nan, 0.2]
nan = to_reported_value(np.nan, None if weight is None else 0)
self.stats[key2].extend(
r if i == self.count - 1 else nan for i in range(self.count)
)
else:
self.stats[key2].append(r)
self._seen_keys_in_the_step.add(key2)
def log_message(self, start: int = None, end: int = None, num_updates: int = None) -> str:
if self._finished:
raise RuntimeError("Already finished")
if start is None:
start = 0
if start < 0:
start = self.count + start
if end is None:
end = self.count
if self.count == 0 or start == end:
return ""
message = f"{self.epoch}epoch:{self.key}:" f"{start + 1}-{end}batch:"
if num_updates is not None:
message += f"{num_updates}num_updates: "
for idx, (key2, stats_list) in enumerate(self.stats.items()):
assert len(stats_list) == self.count, (len(stats_list), self.count)
# values: List[ReportValue]
values = stats_list[start:end]
if idx != 0 and idx != len(stats_list):
message += ", "
v = aggregate(values)
if abs(v) > 1.0e3:
message += f"{key2}={v:.3e}"
elif abs(v) > 1.0e-3:
message += f"{key2}={v:.3f}"
else:
message += f"{key2}={v:.3e}"
return message
def tensorboard_add_scalar(self, summary_writer, start: int = None):
if start is None:
start = 0
if start < 0:
start = self.count + start
for key2, stats_list in self.stats.items():
assert len(stats_list) == self.count, (len(stats_list), self.count)
# values: List[ReportValue]
values = stats_list[start:]
v = aggregate(values)
summary_writer.add_scalar(f"{key2}", v, self.total_count)
def wandb_log(self, start: int = None):
import wandb
if start is None:
start = 0
if start < 0:
start = self.count + start
d = {}
for key2, stats_list in self.stats.items():
assert len(stats_list) == self.count, (len(stats_list), self.count)
# values: List[ReportValue]
values = stats_list[start:]
v = aggregate(values)
d[wandb_get_prefix(key2) + key2] = v
d["iteration"] = self.total_count
wandb.log(d)
def finished(self) -> None:
self._finished = True
@contextmanager
def measure_time(self, name: str):
start = time.perf_counter()
yield start
t = time.perf_counter() - start
self.register({name: t})
def measure_iter_time(self, iterable, name: str):
iterator = iter(iterable)
while True:
try:
start = time.perf_counter()
retval = next(iterator)
t = time.perf_counter() - start
self.register({name: t})
yield retval
except StopIteration:
break
class Reporter:
"""Reporter class.
Examples:
>>> reporter = Reporter()
>>> with reporter.observe('train') as sub_reporter:
... for batch in iterator:
... stats = dict(loss=0.2)
... sub_reporter.register(stats)
"""
def __init__(self, epoch: int = 0):
assert check_argument_types()
if epoch < 0:
raise ValueError(f"epoch must be 0 or more: {epoch}")
self.epoch = epoch
# stats: Dict[int, Dict[str, Dict[str, float]]]
# e.g. self.stats[epoch]['train']['loss']
self.stats = {}
def get_epoch(self) -> int:
return self.epoch
def set_epoch(self, epoch: int) -> None:
if epoch < 0:
raise ValueError(f"epoch must be 0 or more: {epoch}")
self.epoch = epoch
@contextmanager
def observe(self, key: str, epoch: int = None) -> ContextManager[SubReporter]:
sub_reporter = self.start_epoch(key, epoch)
yield sub_reporter
# Receive the stats from sub_reporter
self.finish_epoch(sub_reporter)
def start_epoch(self, key: str, epoch: int = None) -> SubReporter:
if epoch is not None:
if epoch < 0:
raise ValueError(f"epoch must be 0 or more: {epoch}")
self.epoch = epoch
if self.epoch - 1 not in self.stats or key not in self.stats[self.epoch - 1]:
# If the previous epoch doesn't exist for some reason,
# maybe due to bug, this case also indicates 0-count.
if self.epoch - 1 != 0:
warnings.warn(
f"The stats of the previous epoch={self.epoch - 1}"
f"doesn't exist."
)
total_count = 0
else:
total_count = self.stats[self.epoch - 1][key]["total_count"]
sub_reporter = SubReporter(key, self.epoch, total_count)
# Clear the stats for the next epoch if it exists
self.stats.pop(epoch, None)
return sub_reporter
def finish_epoch(self, sub_reporter: SubReporter) -> None:
if self.epoch != sub_reporter.epoch:
raise RuntimeError(
f"Don't change epoch during observation: "
f"{self.epoch} != {sub_reporter.epoch}"
)
# Calc mean of current stats and set it as previous epochs stats
stats = {}
for key2, values in sub_reporter.stats.items():
v = aggregate(values)
stats[key2] = v
stats["time"] = datetime.timedelta(
seconds=time.perf_counter() - sub_reporter.start_time
)
stats["total_count"] = sub_reporter.total_count
if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
if torch.cuda.is_initialized():
stats["gpu_max_cached_mem_GB"] = (
torch.cuda.max_memory_reserved() / 2 ** 30
)
else:
if torch.cuda.is_available() and torch.cuda.max_memory_cached() > 0:
stats["gpu_cached_mem_GB"] = torch.cuda.max_memory_cached() / 2 ** 30
self.stats.setdefault(self.epoch, {})[sub_reporter.key] = stats
sub_reporter.finished()
def sort_epochs_and_values(
self, key: str, key2: str, mode: str
) -> List[Tuple[int, float]]:
"""Return the epoch which resulted the best value.
Example:
>>> val = reporter.sort_epochs_and_values('eval', 'loss', 'min')
>>> e_1best, v_1best = val[0]
>>> e_2best, v_2best = val[1]
"""
if mode not in ("min", "max"):
raise ValueError(f"mode must min or max: {mode}")
if not self.has(key, key2):
raise KeyError(f"{key}.{key2} is not found: {self.get_all_keys()}")
# iterate from the last epoch
values = [(e, self.stats[e][key][key2]) for e in self.stats]
if mode == "min":
values = sorted(values, key=lambda x: x[1])
else:
values = sorted(values, key=lambda x: -x[1])
return values
def sort_epochs(self, key: str, key2: str, mode: str) -> List[int]:
return [e for e, v in self.sort_epochs_and_values(key, key2, mode)]
def sort_values(self, key: str, key2: str, mode: str) -> List[float]:
return [v for e, v in self.sort_epochs_and_values(key, key2, mode)]
def get_best_epoch(self, key: str, key2: str, mode: str, nbest: int = 0) -> int:
return self.sort_epochs(key, key2, mode)[nbest]
def check_early_stopping(
self,
patience: int,
key1: str,
key2: str,
mode: str,
epoch: int = None,
logger=None,
) -> bool:
if logger is None:
logger = logging
if epoch is None:
epoch = self.get_epoch()
best_epoch = self.get_best_epoch(key1, key2, mode)
if epoch - best_epoch > patience:
logger.info(
f"[Early stopping] {key1}.{key2} has not been "
f"improved {epoch - best_epoch} epochs continuously. "
f"The training was stopped at {epoch}epoch"
)
return True
else:
return False
def has(self, key: str, key2: str, epoch: int = None) -> bool:
if epoch is None:
epoch = self.get_epoch()
return (
epoch in self.stats
and key in self.stats[epoch]
and key2 in self.stats[epoch][key]
)
def log_message(self, epoch: int = None) -> str:
if epoch is None:
epoch = self.get_epoch()
message = ""
for key, d in self.stats[epoch].items():
_message = ""
for key2, v in d.items():
if v is not None:
if len(_message) != 0:
_message += ", "
if isinstance(v, float):
if abs(v) > 1.0e3:
_message += f"{key2}={v:.3e}"
elif abs(v) > 1.0e-3:
_message += f"{key2}={v:.3f}"
else:
_message += f"{key2}={v:.3e}"
elif isinstance(v, datetime.timedelta):
_v = humanfriendly.format_timespan(v)
_message += f"{key2}={_v}"
else:
_message += f"{key2}={v}"
if len(_message) != 0:
if len(message) == 0:
message += f"{epoch}epoch results: "
else:
message += ", "
message += f"[{key}] {_message}"
return message
def get_value(self, key: str, key2: str, epoch: int = None):
if not self.has(key, key2):
raise KeyError(f"{key}.{key2} is not found in stats: {self.get_all_keys()}")
if epoch is None:
epoch = self.get_epoch()
return self.stats[epoch][key][key2]
def get_keys(self, epoch: int = None) -> Tuple[str, ...]:
"""Returns keys1 e.g. train,eval."""
if epoch is None:
epoch = self.get_epoch()
return tuple(self.stats[epoch])
def get_keys2(self, key: str, epoch: int = None) -> Tuple[str, ...]:
"""Returns keys2 e.g. loss,acc."""
if epoch is None:
epoch = self.get_epoch()
d = self.stats[epoch][key]
keys2 = tuple(k for k in d if k not in ("time", "total_count"))
return keys2
def get_all_keys(self, epoch: int = None) -> Tuple[Tuple[str, str], ...]:
if epoch is None:
epoch = self.get_epoch()
all_keys = []
for key in self.stats[epoch]:
for key2 in self.stats[epoch][key]:
all_keys.append((key, key2))
return tuple(all_keys)
def tensorboard_add_scalar(
self, summary_writer, epoch: int = None, key1: str = None
):
if epoch is None:
epoch = self.get_epoch()
total_count = self.stats[epoch]["train"]["total_count"]
if key1 == "train":
summary_writer.add_scalar("iter_epoch", epoch, total_count)
if key1 is not None:
key1_iterator = tuple([key1])
else:
key1_iterator = self.get_keys(epoch)
for key1 in key1_iterator:
for key2 in self.get_keys2(key1):
summary_writer.add_scalar(
f"{key2}", self.stats[epoch][key1][key2], total_count
)
def wandb_log(self, epoch: int = None):
import wandb
if epoch is None:
epoch = self.get_epoch()
d = {}
for key1 in self.get_keys(epoch):
for key2 in self.stats[epoch][key1]:
if key2 in ("time", "total_count"):
continue
key = f"{key1}_{key2}_epoch"
d[wandb_get_prefix(key) + key] = self.stats[epoch][key1][key2]
d["epoch"] = epoch
wandb.log(d)
def state_dict(self):
return {"stats": self.stats, "epoch": self.epoch}
def load_state_dict(self, state_dict: dict):
self.epoch = state_dict["epoch"]
self.stats = state_dict["stats"]

View File

@@ -0,0 +1,840 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Trainer module."""
import argparse
from contextlib import contextmanager
import dataclasses
from dataclasses import is_dataclass
from distutils.version import LooseVersion
import logging
from pathlib import Path
import time
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import humanfriendly
import oss2
from io import BytesIO
import os
import numpy as np
import torch
import torch.nn
import torch.optim
from typeguard import check_argument_types
from funasr_local.iterators.abs_iter_factory import AbsIterFactory
from funasr_local.main_funcs.average_nbest_models import average_nbest_models
from funasr_local.main_funcs.calculate_all_attentions import calculate_all_attentions
from funasr_local.schedulers.abs_scheduler import AbsBatchStepScheduler
from funasr_local.schedulers.abs_scheduler import AbsEpochStepScheduler
from funasr_local.schedulers.abs_scheduler import AbsScheduler
from funasr_local.schedulers.abs_scheduler import AbsValEpochStepScheduler
from funasr_local.torch_utils.add_gradient_noise import add_gradient_noise
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.recursive_op import recursive_average
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.train.abs_espnet_model import AbsESPnetModel
from funasr_local.train.distributed_utils import DistributedOption
from funasr_local.train.reporter import Reporter
from funasr_local.train.reporter import SubReporter
from funasr_local.utils.build_dataclass import build_dataclass
if torch.distributed.is_available():
from torch.distributed import ReduceOp
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
GradScaler = None
try:
import fairscale
except ImportError:
fairscale = None
@dataclasses.dataclass
class TrainerOptions:
ngpu: int
resume: bool
use_amp: bool
train_dtype: str
grad_noise: bool
accum_grad: int
grad_clip: float
grad_clip_type: float
log_interval: Optional[int]
no_forward_run: bool
use_tensorboard: bool
use_wandb: bool
output_dir: Union[Path, str]
max_epoch: int
max_update: int
seed: int
sharded_ddp: bool
patience: Optional[int]
keep_nbest_models: Union[int, List[int]]
nbest_averaging_interval: int
early_stopping_criterion: Sequence[str]
best_model_criterion: Sequence[Sequence[str]]
val_scheduler_criterion: Sequence[str]
unused_parameters: bool
wandb_model_log_interval: int
use_pai: bool
oss_bucket: Union[oss2.Bucket, None]
batch_interval: int
class Trainer:
"""Trainer having a optimizer.
If you'd like to use multiple optimizers, then inherit this class
and override the methods if necessary - at least "train_one_epoch()"
>>> class TwoOptimizerTrainer(Trainer):
... @classmethod
... def add_arguments(cls, parser):
... ...
...
... @classmethod
... def train_one_epoch(cls, model, optimizers, ...):
... loss1 = model.model1(...)
... loss1.backward()
... optimizers[0].step()
...
... loss2 = model.model2(...)
... loss2.backward()
... optimizers[1].step()
"""
def __init__(self):
raise RuntimeError("This class can't be instantiated.")
@classmethod
def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
"""Build options consumed by train(), eval()"""
assert check_argument_types()
return build_dataclass(TrainerOptions, args)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
"""Reserved for future development of another Trainer"""
pass
@staticmethod
def resume(
checkpoint: Union[str, Path],
model: torch.nn.Module,
reporter: Reporter,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
scaler: Optional[GradScaler],
ngpu: int = 0,
):
states = torch.load(
checkpoint,
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
)
model.load_state_dict(states["model"])
reporter.load_state_dict(states["reporter"])
for optimizer, state in zip(optimizers, states["optimizers"]):
optimizer.load_state_dict(state)
for scheduler, state in zip(schedulers, states["schedulers"]):
if scheduler is not None:
scheduler.load_state_dict(state)
if scaler is not None:
if states["scaler"] is None:
logging.warning("scaler state is not found")
else:
scaler.load_state_dict(states["scaler"])
logging.info(f"The training was resumed using {checkpoint}")
@classmethod
def run(
cls,
model: AbsESPnetModel,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
train_iter_factory: AbsIterFactory,
valid_iter_factory: AbsIterFactory,
trainer_options,
distributed_option: DistributedOption,
) -> None:
"""Perform training. This method performs the main process of training."""
assert check_argument_types()
# NOTE(kamo): Don't check the type more strictly as far trainer_options
assert is_dataclass(trainer_options), type(trainer_options)
assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
if isinstance(trainer_options.keep_nbest_models, int):
keep_nbest_models = [trainer_options.keep_nbest_models]
else:
if len(trainer_options.keep_nbest_models) == 0:
logging.warning("No keep_nbest_models is given. Change to [1]")
trainer_options.keep_nbest_models = [1]
keep_nbest_models = trainer_options.keep_nbest_models
output_dir = Path(trainer_options.output_dir)
reporter = Reporter()
if trainer_options.use_amp:
if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
raise RuntimeError(
"Require torch>=1.6.0 for Automatic Mixed Precision"
)
if trainer_options.sharded_ddp:
if fairscale is None:
raise RuntimeError(
"Requiring fairscale. Do 'pip install fairscale'"
)
scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
else:
scaler = GradScaler()
else:
scaler = None
if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
cls.resume(
checkpoint=output_dir / "checkpoint.pb",
model=model,
optimizers=optimizers,
schedulers=schedulers,
reporter=reporter,
scaler=scaler,
ngpu=trainer_options.ngpu,
)
start_epoch = reporter.get_epoch() + 1
if start_epoch == trainer_options.max_epoch + 1:
logging.warning(
f"The training has already reached at max_epoch: {start_epoch}"
)
if distributed_option.distributed:
if trainer_options.sharded_ddp:
dp_model = fairscale.nn.data_parallel.ShardedDataParallel(
module=model,
sharded_optimizer=optimizers,
)
else:
dp_model = torch.nn.parallel.DistributedDataParallel(
model, find_unused_parameters=trainer_options.unused_parameters)
elif distributed_option.ngpu > 1:
dp_model = torch.nn.parallel.DataParallel(
model,
device_ids=list(range(distributed_option.ngpu)),
)
else:
# NOTE(kamo): DataParallel also should work with ngpu=1,
# but for debuggability it's better to keep this block.
dp_model = model
if trainer_options.use_tensorboard and (
not distributed_option.distributed or distributed_option.dist_rank == 0
):
from torch.utils.tensorboard import SummaryWriter
if trainer_options.use_pai:
train_summary_writer = SummaryWriter(
os.path.join(trainer_options.output_dir, "tensorboard/train")
)
valid_summary_writer = SummaryWriter(
os.path.join(trainer_options.output_dir, "tensorboard/valid")
)
else:
train_summary_writer = SummaryWriter(
str(output_dir / "tensorboard" / "train")
)
valid_summary_writer = SummaryWriter(
str(output_dir / "tensorboard" / "valid")
)
else:
train_summary_writer = None
start_time = time.perf_counter()
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch:
logging.info(
"{}/{}epoch started. Estimated time to finish: {}".format(
iepoch,
trainer_options.max_epoch,
humanfriendly.format_timespan(
(time.perf_counter() - start_time)
/ (iepoch - start_epoch)
* (trainer_options.max_epoch - iepoch + 1)
),
)
)
else:
logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
set_all_random_seed(trainer_options.seed + iepoch)
reporter.set_epoch(iepoch)
# 1. Train and validation for one-epoch
with reporter.observe("train") as sub_reporter:
all_steps_are_invalid, max_update_stop = cls.train_one_epoch(
model=dp_model,
optimizers=optimizers,
schedulers=schedulers,
iterator=train_iter_factory.build_iter(iepoch),
reporter=sub_reporter,
scaler=scaler,
summary_writer=train_summary_writer,
options=trainer_options,
distributed_option=distributed_option,
)
with reporter.observe("valid") as sub_reporter:
cls.validate_one_epoch(
model=dp_model,
iterator=valid_iter_factory.build_iter(iepoch),
reporter=sub_reporter,
options=trainer_options,
distributed_option=distributed_option,
)
# 2. LR Scheduler step
for scheduler in schedulers:
if isinstance(scheduler, AbsValEpochStepScheduler):
scheduler.step(
reporter.get_value(*trainer_options.val_scheduler_criterion)
)
elif isinstance(scheduler, AbsEpochStepScheduler):
scheduler.step()
if trainer_options.sharded_ddp:
for optimizer in optimizers:
if isinstance(optimizer, fairscale.optim.oss.OSS):
optimizer.consolidate_state_dict()
if not distributed_option.distributed or distributed_option.dist_rank == 0:
# 3. Report the results
logging.info(reporter.log_message())
if train_summary_writer is not None:
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
if trainer_options.use_wandb:
reporter.wandb_log()
# save tensorboard on oss
if trainer_options.use_pai and train_summary_writer is not None:
def write_tensorboard_summary(summary_writer_path, oss_bucket):
file_list = []
for root, dirs, files in os.walk(summary_writer_path, topdown=False):
for name in files:
file_full_path = os.path.join(root, name)
file_list.append(file_full_path)
for file_full_path in file_list:
with open(file_full_path, "rb") as f:
oss_bucket.put_object(file_full_path, f)
write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"), trainer_options.oss_bucket)
write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"), trainer_options.oss_bucket)
# 4. Save/Update the checkpoint
if trainer_options.use_pai:
buffer = BytesIO()
torch.save(
{
"model": model.state_dict(),
"reporter": reporter.state_dict(),
"optimizers": [o.state_dict() for o in optimizers],
"schedulers": [
s.state_dict() if s is not None else None
for s in schedulers
],
"scaler": scaler.state_dict() if scaler is not None else None,
"ema_model": model.encoder.ema.model.state_dict()
if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
},
buffer,
)
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), buffer.getvalue())
else:
torch.save(
{
"model": model.state_dict(),
"reporter": reporter.state_dict(),
"optimizers": [o.state_dict() for o in optimizers],
"schedulers": [
s.state_dict() if s is not None else None
for s in schedulers
],
"scaler": scaler.state_dict() if scaler is not None else None,
},
output_dir / "checkpoint.pb",
)
# 5. Save and log the model and update the link to the best model
if trainer_options.use_pai:
buffer = BytesIO()
torch.save(model.state_dict(), buffer)
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
f"{iepoch}epoch.pb"),buffer.getvalue())
else:
torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
# Creates a sym link latest.pb -> {iepoch}epoch.pb
if trainer_options.use_pai:
p = os.path.join(trainer_options.output_dir, "latest.pb")
if trainer_options.oss_bucket.object_exists(p):
trainer_options.oss_bucket.delete_object(p)
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"), p)
else:
p = output_dir / "latest.pb"
if p.is_symlink() or p.exists():
p.unlink()
p.symlink_to(f"{iepoch}epoch.pb")
_improved = []
for _phase, k, _mode in trainer_options.best_model_criterion:
# e.g. _phase, k, _mode = "train", "loss", "min"
if reporter.has(_phase, k):
best_epoch = reporter.get_best_epoch(_phase, k, _mode)
# Creates sym links if it's the best result
if best_epoch == iepoch:
if trainer_options.use_pai:
p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
if trainer_options.oss_bucket.object_exists(p):
trainer_options.oss_bucket.delete_object(p)
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"),p)
else:
p = output_dir / f"{_phase}.{k}.best.pb"
if p.is_symlink() or p.exists():
p.unlink()
p.symlink_to(f"{iepoch}epoch.pb")
_improved.append(f"{_phase}.{k}")
if len(_improved) == 0:
logging.info("There are no improvements in this epoch")
else:
logging.info(
"The best model has been updated: " + ", ".join(_improved)
)
log_model = (
trainer_options.wandb_model_log_interval > 0
and iepoch % trainer_options.wandb_model_log_interval == 0
)
if log_model and trainer_options.use_wandb:
import wandb
logging.info("Logging Model on this epoch :::::")
artifact = wandb.Artifact(
name=f"model_{wandb.run.id}",
type="model",
metadata={"improved": _improved},
)
artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
aliases = [
f"epoch-{iepoch}",
"best" if best_epoch == iepoch else "",
]
wandb.log_artifact(artifact, aliases=aliases)
# 6. Remove the model files excluding n-best epoch and latest epoch
_removed = []
# Get the union set of the n-best among multiple criterion
nbests = set().union(
*[
set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
for ph, k, m in trainer_options.best_model_criterion
if reporter.has(ph, k)
]
)
# Generated n-best averaged model
if (
trainer_options.nbest_averaging_interval > 0
and iepoch % trainer_options.nbest_averaging_interval == 0
):
average_nbest_models(
reporter=reporter,
output_dir=output_dir,
best_model_criterion=trainer_options.best_model_criterion,
nbest=keep_nbest_models,
suffix=f"till{iepoch}epoch",
oss_bucket=trainer_options.oss_bucket,
pai_output_dir=trainer_options.output_dir,
)
for e in range(1, iepoch):
if trainer_options.use_pai:
p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
trainer_options.oss_bucket.delete_object(p)
_removed.append(str(p))
else:
p = output_dir / f"{e}epoch.pb"
if p.exists() and e not in nbests:
p.unlink()
_removed.append(str(p))
if len(_removed) != 0:
logging.info("The model files were removed: " + ", ".join(_removed))
# 7. If any updating haven't happened, stops the training
if all_steps_are_invalid:
logging.warning(
f"The gradients at all steps are invalid in this epoch. "
f"Something seems wrong. This training was stopped at {iepoch}epoch"
)
break
if max_update_stop:
logging.info(
f"Stopping training due to "
f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
)
break
# 8. Check early stopping
if trainer_options.patience is not None:
if reporter.check_early_stopping(
trainer_options.patience, *trainer_options.early_stopping_criterion
):
break
else:
logging.info(
f"The training was finished at {trainer_options.max_epoch} epochs "
)
# Generated n-best averaged model
if not distributed_option.distributed or distributed_option.dist_rank == 0:
average_nbest_models(
reporter=reporter,
output_dir=output_dir,
best_model_criterion=trainer_options.best_model_criterion,
nbest=keep_nbest_models,
oss_bucket=trainer_options.oss_bucket,
pai_output_dir=trainer_options.output_dir,
)
@classmethod
def train_one_epoch(
cls,
model: torch.nn.Module,
iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
scaler: Optional[GradScaler],
reporter: SubReporter,
summary_writer,
options: TrainerOptions,
distributed_option: DistributedOption,
) -> Tuple[bool, bool]:
assert check_argument_types()
grad_noise = options.grad_noise
accum_grad = options.accum_grad
grad_clip = options.grad_clip
grad_clip_type = options.grad_clip_type
log_interval = options.log_interval
no_forward_run = options.no_forward_run
ngpu = options.ngpu
use_wandb = options.use_wandb
distributed = distributed_option.distributed
if log_interval is None:
try:
log_interval = max(len(iterator) // 20, 10)
except TypeError:
log_interval = 100
model.train()
all_steps_are_invalid = True
max_update_stop = False
# [For distributed] Because iteration counts are not always equals between
# processes, send stop-flag to the other processes if iterator is finished
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
#get the rank
rank = distributed_option.dist_rank
#get the num batch updates
num_batch_updates = 0
#ouput dir
output_dir = Path(options.output_dir)
#batch interval
batch_interval = options.batch_interval
start_time = time.perf_counter()
for iiter, (_, batch) in enumerate(
reporter.measure_iter_time(iterator, "iter_time"), 1
):
assert isinstance(batch, dict), type(batch)
if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
if num_batch_updates % batch_interval == 0:
if options.use_pai and options.oss_bucket is not None:
buffer = BytesIO()
if hasattr(model, "module"):
torch.save(model.module.state_dict(), buffer)
else:
torch.save(model.state_dict(), buffer)
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
else:
if hasattr(model, "module"):
torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
else:
torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
if distributed:
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
if iterator_stop > 0:
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
if no_forward_run:
all_steps_are_invalid = False
continue
with autocast(scaler is not None):
with reporter.measure_time("forward_time"):
retval = model(**batch)
# Note(kamo):
# Supporting two patterns for the returned value from the model
# a. dict type
if isinstance(retval, dict):
loss = retval["loss"]
stats = retval["stats"]
weight = retval["weight"]
optim_idx = retval.get("optim_idx")
if optim_idx is not None and not isinstance(optim_idx, int):
if not isinstance(optim_idx, torch.Tensor):
raise RuntimeError(
"optim_idx must be int or 1dim torch.Tensor, "
f"but got {type(optim_idx)}"
)
if optim_idx.dim() >= 2:
raise RuntimeError(
"optim_idx must be int or 1dim torch.Tensor, "
f"but got {optim_idx.dim()}dim tensor"
)
if optim_idx.dim() == 1:
for v in optim_idx:
if v != optim_idx[0]:
raise RuntimeError(
"optim_idx must be 1dim tensor "
"having same values for all entries"
)
optim_idx = optim_idx[0].item()
else:
optim_idx = optim_idx.item()
# b. tuple or list type
else:
loss, stats, weight = retval
optim_idx = None
stats = {k: v for k, v in stats.items() if v is not None}
if ngpu > 1 or distributed:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
stats, weight = recursive_average(stats, weight, distributed)
# Now weight is summation over all workers
loss /= weight
if distributed:
# NOTE(kamo): Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= torch.distributed.get_world_size()
loss /= accum_grad
reporter.register(stats, weight)
with reporter.measure_time("backward_time"):
if scaler is not None:
# Scales loss. Calls backward() on scaled loss
# to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose
# for corresponding forward ops.
scaler.scale(loss).backward()
else:
loss.backward()
if iiter % accum_grad == 0:
if scaler is not None:
# Unscales the gradients of optimizer's assigned params in-place
for iopt, optimizer in enumerate(optimizers):
if optim_idx is not None and iopt != optim_idx:
continue
scaler.unscale_(optimizer)
# gradient noise injection
if grad_noise:
add_gradient_noise(
model,
reporter.get_total_count(),
duration=100,
eta=1.0,
scale_factor=0.55,
)
# compute the gradient norm to check if it is normal or not
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=grad_clip,
norm_type=grad_clip_type,
)
# PyTorch<=1.4, clip_grad_norm_ returns float value
if not isinstance(grad_norm, torch.Tensor):
grad_norm = torch.tensor(grad_norm)
if not torch.isfinite(grad_norm):
logging.warning(
f"The grad norm is {grad_norm}. Skipping updating the model."
)
# Must invoke scaler.update() if unscale_() is used in the iteration
# to avoid the following error:
# RuntimeError: unscale_() has already been called
# on this optimizer since the last update().
# Note that if the gradient has inf/nan values,
# scaler.step skips optimizer.step().
if scaler is not None:
for iopt, optimizer in enumerate(optimizers):
if optim_idx is not None and iopt != optim_idx:
continue
scaler.step(optimizer)
scaler.update()
else:
all_steps_are_invalid = False
with reporter.measure_time("optim_step_time"):
for iopt, (optimizer, scheduler) in enumerate(
zip(optimizers, schedulers)
):
if optim_idx is not None and iopt != optim_idx:
continue
if scaler is not None:
# scaler.step() first unscales the gradients of
# the optimizer's assigned params.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
else:
optimizer.step()
if isinstance(scheduler, AbsBatchStepScheduler):
scheduler.step()
for iopt, optimizer in enumerate(optimizers):
if optim_idx is not None and iopt != optim_idx:
continue
optimizer.zero_grad()
# Register lr and train/load time[sec/step],
# where step refers to accum_grad * mini-batch
reporter.register(
dict(
{
f"optim{i}_lr{j}": pg["lr"]
for i, optimizer in enumerate(optimizers)
for j, pg in enumerate(optimizer.param_groups)
if "lr" in pg
},
train_time=time.perf_counter() - start_time,
),
)
start_time = time.perf_counter()
# update num_updates
if distributed:
if hasattr(model.module, "num_updates"):
model.module.set_num_updates(model.module.get_num_updates() + 1)
options.num_updates = model.module.get_num_updates()
if model.module.get_num_updates() >= options.max_update:
max_update_stop = True
else:
if hasattr(model, "num_updates"):
model.set_num_updates(model.get_num_updates() + 1)
options.num_updates = model.get_num_updates()
if model.get_num_updates() >= options.max_update:
max_update_stop = True
# NOTE(kamo): Call log_message() after next()
reporter.next()
if iiter % log_interval == 0:
num_updates = options.num_updates if hasattr(options, "num_updates") else None
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
if summary_writer is not None:
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
if use_wandb:
reporter.wandb_log()
if max_update_stop:
break
else:
if distributed:
iterator_stop.fill_(1)
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
return all_steps_are_invalid, max_update_stop
@classmethod
@torch.no_grad()
def validate_one_epoch(
cls,
model: torch.nn.Module,
iterator: Iterable[Dict[str, torch.Tensor]],
reporter: SubReporter,
options: TrainerOptions,
distributed_option: DistributedOption,
) -> None:
assert check_argument_types()
ngpu = options.ngpu
no_forward_run = options.no_forward_run
distributed = distributed_option.distributed
model.eval()
# [For distributed] Because iteration counts are not always equals between
# processes, send stop-flag to the other processes if iterator is finished
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
for (_, batch) in iterator:
assert isinstance(batch, dict), type(batch)
if distributed:
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
if iterator_stop > 0:
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
if no_forward_run:
continue
retval = model(**batch)
if isinstance(retval, dict):
stats = retval["stats"]
weight = retval["weight"]
else:
_, stats, weight = retval
if ngpu > 1 or distributed:
# Apply weighted averaging for stats.
# if distributed, this method can also apply all_reduce()
stats, weight = recursive_average(stats, weight, distributed)
reporter.register(stats, weight)
reporter.next()
else:
if distributed:
iterator_stop.fill_(1)
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)