mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/main_funcs/__init__.py
Normal file
0
funasr_local/main_funcs/__init__.py
Normal file
127
funasr_local/main_funcs/average_nbest_models.py
Normal file
127
funasr_local/main_funcs/average_nbest_models.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Union
|
||||
import warnings
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Collection
|
||||
|
||||
from funasr_local.train.reporter import Reporter
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def average_nbest_models(
|
||||
output_dir: Path,
|
||||
reporter: Reporter,
|
||||
best_model_criterion: Sequence[Sequence[str]],
|
||||
nbest: Union[Collection[int], int],
|
||||
suffix: Optional[str] = None,
|
||||
oss_bucket=None,
|
||||
pai_output_dir=None,
|
||||
) -> None:
|
||||
"""Generate averaged model from n-best models
|
||||
|
||||
Args:
|
||||
output_dir: The directory contains the model file for each epoch
|
||||
reporter: Reporter instance
|
||||
best_model_criterion: Give criterions to decide the best model.
|
||||
e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
|
||||
nbest: Number of best model files to be averaged
|
||||
suffix: A suffix added to the averaged model file name
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if isinstance(nbest, int):
|
||||
nbests = [nbest]
|
||||
else:
|
||||
nbests = list(nbest)
|
||||
if len(nbests) == 0:
|
||||
warnings.warn("At least 1 nbest values are required")
|
||||
nbests = [1]
|
||||
if suffix is not None:
|
||||
suffix = suffix + "."
|
||||
else:
|
||||
suffix = ""
|
||||
|
||||
# 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
|
||||
nbest_epochs = [
|
||||
(ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
|
||||
for ph, k, m in best_model_criterion
|
||||
if reporter.has(ph, k)
|
||||
]
|
||||
|
||||
_loaded = {}
|
||||
for ph, cr, epoch_and_values in nbest_epochs:
|
||||
_nbests = [i for i in nbests if i <= len(epoch_and_values)]
|
||||
if len(_nbests) == 0:
|
||||
_nbests = [1]
|
||||
|
||||
for n in _nbests:
|
||||
if n == 0:
|
||||
continue
|
||||
elif n == 1:
|
||||
# The averaged model is same as the best model
|
||||
e, _ = epoch_and_values[0]
|
||||
op = output_dir / f"{e}epoch.pb"
|
||||
sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
|
||||
if sym_op.is_symlink() or sym_op.exists():
|
||||
sym_op.unlink()
|
||||
sym_op.symlink_to(op.name)
|
||||
else:
|
||||
op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
|
||||
logging.info(
|
||||
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
|
||||
)
|
||||
|
||||
avg = None
|
||||
# 2.a. Averaging model
|
||||
for e, _ in epoch_and_values[:n]:
|
||||
if e not in _loaded:
|
||||
if oss_bucket is None:
|
||||
_loaded[e] = torch.load(
|
||||
output_dir / f"{e}epoch.pb",
|
||||
map_location="cpu",
|
||||
)
|
||||
else:
|
||||
buffer = BytesIO(
|
||||
oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
|
||||
_loaded[e] = torch.load(buffer)
|
||||
states = _loaded[e]
|
||||
|
||||
if avg is None:
|
||||
avg = states
|
||||
else:
|
||||
# Accumulated
|
||||
for k in avg:
|
||||
avg[k] = avg[k] + states[k]
|
||||
for k in avg:
|
||||
if str(avg[k].dtype).startswith("torch.int"):
|
||||
# For int type, not averaged, but only accumulated.
|
||||
# e.g. BatchNorm.num_batches_tracked
|
||||
# (If there are any cases that requires averaging
|
||||
# or the other reducing method, e.g. max/min, for integer type,
|
||||
# please report.)
|
||||
pass
|
||||
else:
|
||||
avg[k] = avg[k] / n
|
||||
|
||||
# 2.b. Save the ave model and create a symlink
|
||||
if oss_bucket is None:
|
||||
torch.save(avg, op)
|
||||
else:
|
||||
buffer = BytesIO()
|
||||
torch.save(avg, buffer)
|
||||
oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
|
||||
buffer.getvalue())
|
||||
|
||||
# 3. *.*.ave.pb is a symlink to the max ave model
|
||||
if oss_bucket is None:
|
||||
op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
|
||||
sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
|
||||
if sym_op.is_symlink() or sym_op.exists():
|
||||
sym_op.unlink()
|
||||
sym_op.symlink_to(op.name)
|
||||
160
funasr_local/main_funcs/calculate_all_attentions.py
Normal file
160
funasr_local/main_funcs/calculate_all_attentions.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.rnn.attentions import AttAdd
|
||||
from funasr_local.modules.rnn.attentions import AttCov
|
||||
from funasr_local.modules.rnn.attentions import AttCovLoc
|
||||
from funasr_local.modules.rnn.attentions import AttDot
|
||||
from funasr_local.modules.rnn.attentions import AttForward
|
||||
from funasr_local.modules.rnn.attentions import AttForwardTA
|
||||
from funasr_local.modules.rnn.attentions import AttLoc
|
||||
from funasr_local.modules.rnn.attentions import AttLoc2D
|
||||
from funasr_local.modules.rnn.attentions import AttLocRec
|
||||
from funasr_local.modules.rnn.attentions import AttMultiHeadAdd
|
||||
from funasr_local.modules.rnn.attentions import AttMultiHeadDot
|
||||
from funasr_local.modules.rnn.attentions import AttMultiHeadLoc
|
||||
from funasr_local.modules.rnn.attentions import AttMultiHeadMultiResLoc
|
||||
from funasr_local.modules.rnn.attentions import NoAtt
|
||||
from funasr_local.modules.attention import MultiHeadedAttention
|
||||
|
||||
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_all_attentions(
|
||||
model: AbsESPnetModel, batch: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, List[torch.Tensor]]:
|
||||
"""Derive the outputs from the all attention layers
|
||||
|
||||
Args:
|
||||
model:
|
||||
batch: same as forward
|
||||
Returns:
|
||||
return_dict: A dict of a list of tensor.
|
||||
key_names x batch x (D1, D2, ...)
|
||||
|
||||
"""
|
||||
bs = len(next(iter(batch.values())))
|
||||
assert all(len(v) == bs for v in batch.values()), {
|
||||
k: v.shape for k, v in batch.items()
|
||||
}
|
||||
|
||||
# 1. Register forward_hook fn to save the output from specific layers
|
||||
outputs = {}
|
||||
handles = {}
|
||||
for name, modu in model.named_modules():
|
||||
|
||||
def hook(module, input, output, name=name):
|
||||
if isinstance(module, MultiHeadedAttention):
|
||||
# NOTE(kamo): MultiHeadedAttention doesn't return attention weight
|
||||
# attn: (B, Head, Tout, Tin)
|
||||
outputs[name] = module.attn.detach().cpu()
|
||||
elif isinstance(module, AttLoc2D):
|
||||
c, w = output
|
||||
# w: previous concate attentions
|
||||
# w: (B, nprev, Tin)
|
||||
att_w = w[:, -1].detach().cpu()
|
||||
outputs.setdefault(name, []).append(att_w)
|
||||
elif isinstance(module, (AttCov, AttCovLoc)):
|
||||
c, w = output
|
||||
assert isinstance(w, list), type(w)
|
||||
# w: list of previous attentions
|
||||
# w: nprev x (B, Tin)
|
||||
att_w = w[-1].detach().cpu()
|
||||
outputs.setdefault(name, []).append(att_w)
|
||||
elif isinstance(module, AttLocRec):
|
||||
# w: (B, Tin)
|
||||
c, (w, (att_h, att_c)) = output
|
||||
att_w = w.detach().cpu()
|
||||
outputs.setdefault(name, []).append(att_w)
|
||||
elif isinstance(
|
||||
module,
|
||||
(
|
||||
AttMultiHeadDot,
|
||||
AttMultiHeadAdd,
|
||||
AttMultiHeadLoc,
|
||||
AttMultiHeadMultiResLoc,
|
||||
),
|
||||
):
|
||||
c, w = output
|
||||
# w: nhead x (B, Tin)
|
||||
assert isinstance(w, list), type(w)
|
||||
att_w = [_w.detach().cpu() for _w in w]
|
||||
outputs.setdefault(name, []).append(att_w)
|
||||
elif isinstance(
|
||||
module,
|
||||
(
|
||||
AttAdd,
|
||||
AttDot,
|
||||
AttForward,
|
||||
AttForwardTA,
|
||||
AttLoc,
|
||||
NoAtt,
|
||||
),
|
||||
):
|
||||
c, w = output
|
||||
att_w = w.detach().cpu()
|
||||
outputs.setdefault(name, []).append(att_w)
|
||||
|
||||
handle = modu.register_forward_hook(hook)
|
||||
handles[name] = handle
|
||||
|
||||
# 2. Just forward one by one sample.
|
||||
# Batch-mode can't be used to keep requirements small for each models.
|
||||
keys = []
|
||||
for k in batch:
|
||||
if not k.endswith("_lengths"):
|
||||
keys.append(k)
|
||||
|
||||
return_dict = defaultdict(list)
|
||||
for ibatch in range(bs):
|
||||
# *: (B, L, ...) -> (1, L2, ...)
|
||||
_sample = {
|
||||
k: batch[k][ibatch, None, : batch[k + "_lengths"][ibatch]]
|
||||
if k + "_lengths" in batch
|
||||
else batch[k][ibatch, None]
|
||||
for k in keys
|
||||
}
|
||||
|
||||
# *_lengths: (B,) -> (1,)
|
||||
_sample.update(
|
||||
{
|
||||
k + "_lengths": batch[k + "_lengths"][ibatch, None]
|
||||
for k in keys
|
||||
if k + "_lengths" in batch
|
||||
}
|
||||
)
|
||||
model(**_sample)
|
||||
|
||||
# Derive the attention results
|
||||
for name, output in outputs.items():
|
||||
if isinstance(output, list):
|
||||
if isinstance(output[0], list):
|
||||
# output: nhead x (Tout, Tin)
|
||||
output = torch.stack(
|
||||
[
|
||||
# Tout x (1, Tin) -> (Tout, Tin)
|
||||
torch.cat([o[idx] for o in output], dim=0)
|
||||
for idx in range(len(output[0]))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
# Tout x (1, Tin) -> (Tout, Tin)
|
||||
output = torch.cat(output, dim=0)
|
||||
else:
|
||||
# output: (1, NHead, Tout, Tin) -> (NHead, Tout, Tin)
|
||||
output = output.squeeze(0)
|
||||
# output: (Tout, Tin) or (NHead, Tout, Tin)
|
||||
return_dict[name].append(output)
|
||||
outputs.clear()
|
||||
|
||||
# 3. Remove all hooks
|
||||
for _, handle in handles.items():
|
||||
handle.remove()
|
||||
|
||||
return dict(return_dict)
|
||||
126
funasr_local/main_funcs/collect_stats.py
Normal file
126
funasr_local/main_funcs/collect_stats.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.parallel import data_parallel
|
||||
from torch.utils.data import DataLoader
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.fileio.datadir_writer import DatadirWriter
|
||||
from funasr_local.fileio.npy_scp import NpyScpWriter
|
||||
from funasr_local.torch_utils.device_funcs import to_device
|
||||
from funasr_local.torch_utils.forward_adaptor import ForwardAdaptor
|
||||
from funasr_local.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def collect_stats(
|
||||
model: AbsESPnetModel,
|
||||
train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
|
||||
valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
|
||||
output_dir: Path,
|
||||
ngpu: Optional[int],
|
||||
log_interval: Optional[int],
|
||||
write_collected_feats: bool,
|
||||
) -> None:
|
||||
"""Perform on collect_stats mode.
|
||||
|
||||
Running for deriving the shape information from data
|
||||
and gathering statistics.
|
||||
This method is used before executing train().
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
npy_scp_writers = {}
|
||||
for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
|
||||
if log_interval is None:
|
||||
try:
|
||||
log_interval = max(len(itr) // 20, 10)
|
||||
except TypeError:
|
||||
log_interval = 100
|
||||
|
||||
sum_dict = defaultdict(lambda: 0)
|
||||
sq_dict = defaultdict(lambda: 0)
|
||||
count_dict = defaultdict(lambda: 0)
|
||||
|
||||
with DatadirWriter(output_dir / mode) as datadir_writer:
|
||||
for iiter, (keys, batch) in enumerate(itr, 1):
|
||||
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
||||
|
||||
# 1. Write shape file
|
||||
for name in batch:
|
||||
if name.endswith("_lengths"):
|
||||
continue
|
||||
for i, (key, data) in enumerate(zip(keys, batch[name])):
|
||||
if f"{name}_lengths" in batch:
|
||||
lg = int(batch[f"{name}_lengths"][i])
|
||||
data = data[:lg]
|
||||
datadir_writer[f"{name}_shape"][key] = ",".join(
|
||||
map(str, data.shape)
|
||||
)
|
||||
|
||||
# 2. Extract feats
|
||||
if ngpu <= 1:
|
||||
data = model.collect_feats(**batch)
|
||||
else:
|
||||
# Note that data_parallel can parallelize only "forward()"
|
||||
data = data_parallel(
|
||||
ForwardAdaptor(model, "collect_feats"),
|
||||
(),
|
||||
range(ngpu),
|
||||
module_kwargs=batch,
|
||||
)
|
||||
|
||||
# 3. Calculate sum and square sum
|
||||
for key, v in data.items():
|
||||
for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())):
|
||||
# Truncate zero-padding region
|
||||
if f"{key}_lengths" in data:
|
||||
length = data[f"{key}_lengths"][i]
|
||||
# seq: (Length, Dim, ...)
|
||||
seq = seq[:length]
|
||||
else:
|
||||
# seq: (Dim, ...) -> (1, Dim, ...)
|
||||
seq = seq[None]
|
||||
# Accumulate value, its square, and count
|
||||
sum_dict[key] += seq.sum(0)
|
||||
sq_dict[key] += (seq**2).sum(0)
|
||||
count_dict[key] += len(seq)
|
||||
|
||||
# 4. [Option] Write derived features as npy format file.
|
||||
if write_collected_feats:
|
||||
# Instantiate NpyScpWriter for the first iteration
|
||||
if (key, mode) not in npy_scp_writers:
|
||||
p = output_dir / mode / "collect_feats"
|
||||
npy_scp_writers[(key, mode)] = NpyScpWriter(
|
||||
p / f"data_{key}", p / f"{key}.scp"
|
||||
)
|
||||
# Save array as npy file
|
||||
npy_scp_writers[(key, mode)][uttid] = seq
|
||||
|
||||
if iiter % log_interval == 0:
|
||||
logging.info(f"Niter: {iiter}")
|
||||
|
||||
for key in sum_dict:
|
||||
np.savez(
|
||||
output_dir / mode / f"{key}_stats.npz",
|
||||
count=count_dict[key],
|
||||
sum=sum_dict[key],
|
||||
sum_square=sq_dict[key],
|
||||
)
|
||||
|
||||
# batch_keys and stats_keys are used by aggregate_stats_dirs.py
|
||||
with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
"\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n"
|
||||
)
|
||||
with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f:
|
||||
f.write("\n".join(sum_dict) + "\n")
|
||||
302
funasr_local/main_funcs/pack_funcs.py
Normal file
302
funasr_local/main_funcs/pack_funcs.py
Normal file
@@ -0,0 +1,302 @@
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from io import TextIOWrapper
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import tarfile
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
import zipfile
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class Archiver:
|
||||
def __init__(self, file, mode="r"):
|
||||
if Path(file).suffix == ".tar":
|
||||
self.type = "tar"
|
||||
elif Path(file).suffix == ".tgz" or Path(file).suffixes == [".tar", ".gz"]:
|
||||
self.type = "tar"
|
||||
if mode == "w":
|
||||
mode = "w:gz"
|
||||
elif Path(file).suffix == ".tbz2" or Path(file).suffixes == [".tar", ".bz2"]:
|
||||
self.type = "tar"
|
||||
if mode == "w":
|
||||
mode = "w:bz2"
|
||||
elif Path(file).suffix == ".txz" or Path(file).suffixes == [".tar", ".xz"]:
|
||||
self.type = "tar"
|
||||
if mode == "w":
|
||||
mode = "w:xz"
|
||||
elif Path(file).suffix == ".zip":
|
||||
self.type = "zip"
|
||||
else:
|
||||
raise ValueError(f"Cannot detect archive format: type={file}")
|
||||
|
||||
if self.type == "tar":
|
||||
self.fopen = tarfile.open(file, mode=mode)
|
||||
elif self.type == "zip":
|
||||
|
||||
self.fopen = zipfile.ZipFile(file, mode=mode)
|
||||
else:
|
||||
raise ValueError(f"Not supported: type={type}")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.fopen.close()
|
||||
|
||||
def close(self):
|
||||
self.fopen.close()
|
||||
|
||||
def __iter__(self):
|
||||
if self.type == "tar":
|
||||
return iter(self.fopen)
|
||||
elif self.type == "zip":
|
||||
return iter(self.fopen.infolist())
|
||||
else:
|
||||
raise ValueError(f"Not supported: type={self.type}")
|
||||
|
||||
def add(self, filename, arcname=None, recursive: bool = True):
|
||||
if arcname is not None:
|
||||
print(f"adding: {arcname}")
|
||||
else:
|
||||
print(f"adding: {filename}")
|
||||
|
||||
if recursive and Path(filename).is_dir():
|
||||
for f in Path(filename).glob("**/*"):
|
||||
if f.is_dir():
|
||||
continue
|
||||
|
||||
if arcname is not None:
|
||||
_arcname = Path(arcname) / f
|
||||
else:
|
||||
_arcname = None
|
||||
|
||||
self.add(f, _arcname)
|
||||
return
|
||||
|
||||
if self.type == "tar":
|
||||
return self.fopen.add(filename, arcname)
|
||||
elif self.type == "zip":
|
||||
return self.fopen.write(filename, arcname)
|
||||
else:
|
||||
raise ValueError(f"Not supported: type={self.type}")
|
||||
|
||||
def addfile(self, info, fileobj):
|
||||
print(f"adding: {self.get_name_from_info(info)}")
|
||||
|
||||
if self.type == "tar":
|
||||
return self.fopen.addfile(info, fileobj)
|
||||
elif self.type == "zip":
|
||||
return self.fopen.writestr(info, fileobj.read())
|
||||
else:
|
||||
raise ValueError(f"Not supported: type={self.type}")
|
||||
|
||||
def generate_info(self, name, size) -> Union[tarfile.TarInfo, zipfile.ZipInfo]:
|
||||
"""Generate TarInfo using system information"""
|
||||
if self.type == "tar":
|
||||
tarinfo = tarfile.TarInfo(str(name))
|
||||
if os.name == "posix":
|
||||
tarinfo.gid = os.getgid()
|
||||
tarinfo.uid = os.getuid()
|
||||
tarinfo.mtime = datetime.now().timestamp()
|
||||
tarinfo.size = size
|
||||
# Keep mode as default
|
||||
return tarinfo
|
||||
elif self.type == "zip":
|
||||
zipinfo = zipfile.ZipInfo(str(name), datetime.now().timetuple()[:6])
|
||||
zipinfo.file_size = size
|
||||
return zipinfo
|
||||
else:
|
||||
raise ValueError(f"Not supported: type={self.type}")
|
||||
|
||||
def get_name_from_info(self, info):
|
||||
if self.type == "tar":
|
||||
assert isinstance(info, tarfile.TarInfo), type(info)
|
||||
return info.name
|
||||
elif self.type == "zip":
|
||||
assert isinstance(info, zipfile.ZipInfo), type(info)
|
||||
return info.filename
|
||||
else:
|
||||
raise ValueError(f"Not supported: type={self.type}")
|
||||
|
||||
def extract(self, info, path=None):
|
||||
if self.type == "tar":
|
||||
return self.fopen.extract(info, path)
|
||||
elif self.type == "zip":
|
||||
return self.fopen.extract(info, path)
|
||||
else:
|
||||
raise ValueError(f"Not supported: type={self.type}")
|
||||
|
||||
def extractfile(self, info, mode="r"):
|
||||
if self.type == "tar":
|
||||
f = self.fopen.extractfile(info)
|
||||
if mode == "r":
|
||||
return TextIOWrapper(f)
|
||||
else:
|
||||
return f
|
||||
elif self.type == "zip":
|
||||
if mode == "rb":
|
||||
mode = "r"
|
||||
return self.fopen.open(info, mode)
|
||||
else:
|
||||
raise ValueError(f"Not supported: type={self.type}")
|
||||
|
||||
|
||||
def find_path_and_change_it_recursive(value, src: str, tgt: str):
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
k: find_path_and_change_it_recursive(v, src, tgt) for k, v in value.items()
|
||||
}
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return [find_path_and_change_it_recursive(v, src, tgt) for v in value]
|
||||
elif isinstance(value, str) and Path(value) == Path(src):
|
||||
return tgt
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def get_dict_from_cache(meta: Union[Path, str]) -> Optional[Dict[str, str]]:
|
||||
meta = Path(meta)
|
||||
outpath = meta.parent.parent
|
||||
if not meta.exists():
|
||||
return None
|
||||
|
||||
with meta.open("r", encoding="utf-8") as f:
|
||||
d = yaml.safe_load(f)
|
||||
assert isinstance(d, dict), type(d)
|
||||
yaml_files = d["yaml_files"]
|
||||
files = d["files"]
|
||||
assert isinstance(yaml_files, dict), type(yaml_files)
|
||||
assert isinstance(files, dict), type(files)
|
||||
|
||||
retval = {}
|
||||
for key, value in list(yaml_files.items()) + list(files.items()):
|
||||
if not (outpath / value).exists():
|
||||
return None
|
||||
retval[key] = str(outpath / value)
|
||||
return retval
|
||||
|
||||
|
||||
def unpack(
|
||||
input_archive: Union[Path, str],
|
||||
outpath: Union[Path, str],
|
||||
use_cache: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
"""Scan all files in the archive file and return as a dict of files.
|
||||
|
||||
Examples:
|
||||
tarfile:
|
||||
model.pb
|
||||
some1.file
|
||||
some2.file
|
||||
|
||||
>>> unpack("tarfile", "out")
|
||||
{'asr_model_file': 'out/model.pb'}
|
||||
"""
|
||||
input_archive = Path(input_archive)
|
||||
outpath = Path(outpath)
|
||||
|
||||
with Archiver(input_archive) as archive:
|
||||
for info in archive:
|
||||
if Path(archive.get_name_from_info(info)).name == "meta.yaml":
|
||||
if (
|
||||
use_cache
|
||||
and (outpath / Path(archive.get_name_from_info(info))).exists()
|
||||
):
|
||||
retval = get_dict_from_cache(
|
||||
outpath / Path(archive.get_name_from_info(info))
|
||||
)
|
||||
if retval is not None:
|
||||
return retval
|
||||
d = yaml.safe_load(archive.extractfile(info))
|
||||
assert isinstance(d, dict), type(d)
|
||||
yaml_files = d["yaml_files"]
|
||||
files = d["files"]
|
||||
assert isinstance(yaml_files, dict), type(yaml_files)
|
||||
assert isinstance(files, dict), type(files)
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Format error: not found meta.yaml")
|
||||
|
||||
for info in archive:
|
||||
fname = archive.get_name_from_info(info)
|
||||
outname = outpath / fname
|
||||
outname.parent.mkdir(parents=True, exist_ok=True)
|
||||
if fname in set(yaml_files.values()):
|
||||
d = yaml.safe_load(archive.extractfile(info))
|
||||
# Rewrite yaml
|
||||
for info2 in archive:
|
||||
name = archive.get_name_from_info(info2)
|
||||
d = find_path_and_change_it_recursive(d, name, str(outpath / name))
|
||||
with outname.open("w", encoding="utf-8") as f:
|
||||
yaml.safe_dump(d, f)
|
||||
else:
|
||||
archive.extract(info, path=outpath)
|
||||
|
||||
retval = {}
|
||||
for key, value in list(yaml_files.items()) + list(files.items()):
|
||||
retval[key] = str(outpath / value)
|
||||
return retval
|
||||
|
||||
|
||||
def _to_relative_or_resolve(f):
|
||||
# Resolve to avoid symbolic link
|
||||
p = Path(f).resolve()
|
||||
try:
|
||||
# Change to relative if it can
|
||||
p = p.relative_to(Path(".").resolve())
|
||||
except ValueError:
|
||||
pass
|
||||
return str(p)
|
||||
|
||||
|
||||
def pack(
|
||||
files: Dict[str, Union[str, Path]],
|
||||
yaml_files: Dict[str, Union[str, Path]],
|
||||
outpath: Union[str, Path],
|
||||
option: Iterable[Union[str, Path]] = (),
|
||||
):
|
||||
for v in list(files.values()) + list(yaml_files.values()) + list(option):
|
||||
if not Path(v).exists():
|
||||
raise FileNotFoundError(f"No such file or directory: {v}")
|
||||
|
||||
files = {k: _to_relative_or_resolve(v) for k, v in files.items()}
|
||||
yaml_files = {k: _to_relative_or_resolve(v) for k, v in yaml_files.items()}
|
||||
option = [_to_relative_or_resolve(v) for v in option]
|
||||
|
||||
meta_objs = dict(
|
||||
files=files,
|
||||
yaml_files=yaml_files,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
python=sys.version,
|
||||
)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
meta_objs.update(torch=str(torch.__version__))
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import espnet
|
||||
|
||||
meta_objs.update(espnet=espnet.__version__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
Path(outpath).parent.mkdir(parents=True, exist_ok=True)
|
||||
with Archiver(outpath, mode="w") as archive:
|
||||
# Write packed/meta.yaml
|
||||
fileobj = BytesIO(yaml.safe_dump(meta_objs).encode())
|
||||
info = archive.generate_info("meta.yaml", fileobj.getbuffer().nbytes)
|
||||
archive.addfile(info, fileobj=fileobj)
|
||||
|
||||
for f in list(yaml_files.values()) + list(files.values()) + list(option):
|
||||
archive.add(f)
|
||||
|
||||
print(f"Generate: {outpath}")
|
||||
Reference in New Issue
Block a user