mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
127
funasr_local/main_funcs/average_nbest_models.py
Normal file
127
funasr_local/main_funcs/average_nbest_models.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Union
|
||||
import warnings
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Collection
|
||||
|
||||
from funasr_local.train.reporter import Reporter
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def average_nbest_models(
|
||||
output_dir: Path,
|
||||
reporter: Reporter,
|
||||
best_model_criterion: Sequence[Sequence[str]],
|
||||
nbest: Union[Collection[int], int],
|
||||
suffix: Optional[str] = None,
|
||||
oss_bucket=None,
|
||||
pai_output_dir=None,
|
||||
) -> None:
|
||||
"""Generate averaged model from n-best models
|
||||
|
||||
Args:
|
||||
output_dir: The directory contains the model file for each epoch
|
||||
reporter: Reporter instance
|
||||
best_model_criterion: Give criterions to decide the best model.
|
||||
e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
|
||||
nbest: Number of best model files to be averaged
|
||||
suffix: A suffix added to the averaged model file name
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if isinstance(nbest, int):
|
||||
nbests = [nbest]
|
||||
else:
|
||||
nbests = list(nbest)
|
||||
if len(nbests) == 0:
|
||||
warnings.warn("At least 1 nbest values are required")
|
||||
nbests = [1]
|
||||
if suffix is not None:
|
||||
suffix = suffix + "."
|
||||
else:
|
||||
suffix = ""
|
||||
|
||||
# 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
|
||||
nbest_epochs = [
|
||||
(ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
|
||||
for ph, k, m in best_model_criterion
|
||||
if reporter.has(ph, k)
|
||||
]
|
||||
|
||||
_loaded = {}
|
||||
for ph, cr, epoch_and_values in nbest_epochs:
|
||||
_nbests = [i for i in nbests if i <= len(epoch_and_values)]
|
||||
if len(_nbests) == 0:
|
||||
_nbests = [1]
|
||||
|
||||
for n in _nbests:
|
||||
if n == 0:
|
||||
continue
|
||||
elif n == 1:
|
||||
# The averaged model is same as the best model
|
||||
e, _ = epoch_and_values[0]
|
||||
op = output_dir / f"{e}epoch.pb"
|
||||
sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
|
||||
if sym_op.is_symlink() or sym_op.exists():
|
||||
sym_op.unlink()
|
||||
sym_op.symlink_to(op.name)
|
||||
else:
|
||||
op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
|
||||
logging.info(
|
||||
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
|
||||
)
|
||||
|
||||
avg = None
|
||||
# 2.a. Averaging model
|
||||
for e, _ in epoch_and_values[:n]:
|
||||
if e not in _loaded:
|
||||
if oss_bucket is None:
|
||||
_loaded[e] = torch.load(
|
||||
output_dir / f"{e}epoch.pb",
|
||||
map_location="cpu",
|
||||
)
|
||||
else:
|
||||
buffer = BytesIO(
|
||||
oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
|
||||
_loaded[e] = torch.load(buffer)
|
||||
states = _loaded[e]
|
||||
|
||||
if avg is None:
|
||||
avg = states
|
||||
else:
|
||||
# Accumulated
|
||||
for k in avg:
|
||||
avg[k] = avg[k] + states[k]
|
||||
for k in avg:
|
||||
if str(avg[k].dtype).startswith("torch.int"):
|
||||
# For int type, not averaged, but only accumulated.
|
||||
# e.g. BatchNorm.num_batches_tracked
|
||||
# (If there are any cases that requires averaging
|
||||
# or the other reducing method, e.g. max/min, for integer type,
|
||||
# please report.)
|
||||
pass
|
||||
else:
|
||||
avg[k] = avg[k] / n
|
||||
|
||||
# 2.b. Save the ave model and create a symlink
|
||||
if oss_bucket is None:
|
||||
torch.save(avg, op)
|
||||
else:
|
||||
buffer = BytesIO()
|
||||
torch.save(avg, buffer)
|
||||
oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
|
||||
buffer.getvalue())
|
||||
|
||||
# 3. *.*.ave.pb is a symlink to the max ave model
|
||||
if oss_bucket is None:
|
||||
op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
|
||||
sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
|
||||
if sym_op.is_symlink() or sym_op.exists():
|
||||
sym_op.unlink()
|
||||
sym_op.symlink_to(op.name)
|
||||
Reference in New Issue
Block a user