mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
108
funasr_local/bin/aggregate_stats_dirs.py
Normal file
108
funasr_local/bin/aggregate_stats_dirs.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from funasr_local.utils.cli_utils import get_commandline_args
|
||||
|
||||
|
||||
def aggregate_stats_dirs(
|
||||
input_dir: Iterable[Union[str, Path]],
|
||||
output_dir: Union[str, Path],
|
||||
log_level: str,
|
||||
skip_sum_stats: bool,
|
||||
):
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) (levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
input_dirs = [Path(p) for p in input_dir]
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
for mode in ["train", "valid"]:
|
||||
with (input_dirs[0] / mode / "batch_keys").open("r", encoding="utf-8") as f:
|
||||
batch_keys = [line.strip() for line in f if line.strip() != ""]
|
||||
with (input_dirs[0] / mode / "stats_keys").open("r", encoding="utf-8") as f:
|
||||
stats_keys = [line.strip() for line in f if line.strip() != ""]
|
||||
(output_dir / mode).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for key in batch_keys:
|
||||
with (output_dir / mode / f"{key}_shape").open(
|
||||
"w", encoding="utf-8"
|
||||
) as fout:
|
||||
for idir in input_dirs:
|
||||
with (idir / mode / f"{key}_shape").open(
|
||||
"r", encoding="utf-8"
|
||||
) as fin:
|
||||
# Read to the last in order to sort keys
|
||||
# because the order can be changed if num_workers>=1
|
||||
lines = fin.readlines()
|
||||
lines = sorted(lines, key=lambda x: x.split()[0])
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
for key in stats_keys:
|
||||
if not skip_sum_stats:
|
||||
sum_stats = None
|
||||
for idir in input_dirs:
|
||||
stats = np.load(idir / mode / f"{key}_stats.npz")
|
||||
if sum_stats is None:
|
||||
sum_stats = dict(**stats)
|
||||
else:
|
||||
for k in stats:
|
||||
sum_stats[k] += stats[k]
|
||||
|
||||
np.savez(output_dir / mode / f"{key}_stats.npz", **sum_stats)
|
||||
|
||||
# if --write_collected_feats=true
|
||||
p = Path(mode) / "collect_feats" / f"{key}.scp"
|
||||
scp = input_dirs[0] / p
|
||||
if scp.exists():
|
||||
(output_dir / p).parent.mkdir(parents=True, exist_ok=True)
|
||||
with (output_dir / p).open("w", encoding="utf-8") as fout:
|
||||
for idir in input_dirs:
|
||||
with (idir / p).open("r", encoding="utf-8") as fin:
|
||||
for line in fin:
|
||||
fout.write(line)
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Aggregate statistics directories into one directory",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_sum_stats",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Skip computing the sum of statistics.",
|
||||
)
|
||||
|
||||
parser.add_argument("--input_dir", action="append", help="Input directories")
|
||||
parser.add_argument("--output_dir", required=True, help="Output directory")
|
||||
return parser
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
aggregate_stats_dirs(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user