mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/torch_utils/__init__.py
Normal file
0
funasr_local/torch_utils/__init__.py
Normal file
31
funasr_local/torch_utils/add_gradient_noise.py
Normal file
31
funasr_local/torch_utils/add_gradient_noise.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
|
||||
def add_gradient_noise(
|
||||
model: torch.nn.Module,
|
||||
iteration: int,
|
||||
duration: float = 100,
|
||||
eta: float = 1.0,
|
||||
scale_factor: float = 0.55,
|
||||
):
|
||||
"""Adds noise from a standard normal distribution to the gradients.
|
||||
|
||||
The standard deviation (`sigma`) is controlled
|
||||
by the three hyper-parameters below.
|
||||
`sigma` goes to zero (no noise) with more iterations.
|
||||
|
||||
Args:
|
||||
model: Model.
|
||||
iteration: Number of iterations.
|
||||
duration: {100, 1000}: Number of durations to control
|
||||
the interval of the `sigma` change.
|
||||
eta: {0.01, 0.3, 1.0}: The magnitude of `sigma`.
|
||||
scale_factor: {0.55}: The scale of `sigma`.
|
||||
"""
|
||||
interval = (iteration // duration) + 1
|
||||
sigma = eta / interval**scale_factor
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
_shape = param.grad.size()
|
||||
noise = sigma * torch.randn(_shape).to(param.device)
|
||||
param.grad += noise
|
||||
71
funasr_local/torch_utils/device_funcs.py
Normal file
71
funasr_local/torch_utils/device_funcs.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import dataclasses
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
|
||||
"""Change the device of object recursively"""
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()
|
||||
}
|
||||
elif dataclasses.is_dataclass(data) and not isinstance(data, type):
|
||||
return type(data)(
|
||||
*[
|
||||
to_device(v, device, dtype, non_blocking, copy)
|
||||
for v in dataclasses.astuple(data)
|
||||
]
|
||||
)
|
||||
# maybe namedtuple. I don't know the correct way to judge namedtuple.
|
||||
elif isinstance(data, tuple) and type(data) is not tuple:
|
||||
return type(data)(
|
||||
*[to_device(o, device, dtype, non_blocking, copy) for o in data]
|
||||
)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data.to(device, dtype, non_blocking, copy)
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def force_gatherable(data, device):
|
||||
"""Change object to gatherable in torch.nn.DataParallel recursively
|
||||
|
||||
The difference from to_device() is changing to torch.Tensor if float or int
|
||||
value is found.
|
||||
|
||||
The restriction to the returned value in DataParallel:
|
||||
The object must be
|
||||
- torch.cuda.Tensor
|
||||
- 1 or more dimension. 0-dimension-tensor sends warning.
|
||||
or a list, tuple, dict.
|
||||
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {k: force_gatherable(v, device) for k, v in data.items()}
|
||||
# DataParallel can't handle NamedTuple well
|
||||
elif isinstance(data, tuple) and type(data) is not tuple:
|
||||
return type(data)(*[force_gatherable(o, device) for o in data])
|
||||
elif isinstance(data, (list, tuple, set)):
|
||||
return type(data)(force_gatherable(v, device) for v in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
return force_gatherable(torch.from_numpy(data), device)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
if data.dim() == 0:
|
||||
# To 1-dim array
|
||||
data = data[None]
|
||||
return data.to(device)
|
||||
elif isinstance(data, float):
|
||||
return torch.tensor([data], dtype=torch.float, device=device)
|
||||
elif isinstance(data, int):
|
||||
return torch.tensor([data], dtype=torch.long, device=device)
|
||||
elif data is None:
|
||||
return None
|
||||
else:
|
||||
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
|
||||
return data
|
||||
33
funasr_local/torch_utils/forward_adaptor.py
Normal file
33
funasr_local/torch_utils/forward_adaptor.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
class ForwardAdaptor(torch.nn.Module):
|
||||
"""Wrapped module to parallelize specified method
|
||||
|
||||
torch.nn.DataParallel parallelizes only "forward()"
|
||||
and, maybe, the method having the other name can't be applied
|
||||
except for wrapping the module just like this class.
|
||||
|
||||
Examples:
|
||||
>>> class A(torch.nn.Module):
|
||||
... def foo(self, x):
|
||||
... ...
|
||||
>>> model = A()
|
||||
>>> model = ForwardAdaptor(model, "foo")
|
||||
>>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
|
||||
>>> x = torch.randn(2, 10)
|
||||
>>> model(x)
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.nn.Module, name: str):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.name = name
|
||||
if not hasattr(module, name):
|
||||
raise ValueError(f"{module} doesn't have {name}")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
func = getattr(self.module, self.name)
|
||||
return func(*args, **kwargs)
|
||||
102
funasr_local/torch_utils/initialize.py
Normal file
102
funasr_local/torch_utils/initialize.py
Normal file
@@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Initialize modules for espnet2 neural networks."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
def initialize(model: torch.nn.Module, init: str):
|
||||
"""Initialize weights of a neural network module.
|
||||
|
||||
Parameters are initialized using the given method or distribution.
|
||||
|
||||
Custom initialization routines can be implemented into submodules
|
||||
as function `espnet_initialization_fn` within the custom module.
|
||||
|
||||
Args:
|
||||
model: Target.
|
||||
init: Method of initialization.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
if init == "chainer":
|
||||
# 1. lecun_normal_init_parameters
|
||||
for p in model.parameters():
|
||||
data = p.data
|
||||
if data.dim() == 1:
|
||||
# bias
|
||||
data.zero_()
|
||||
elif data.dim() == 2:
|
||||
# linear weight
|
||||
n = data.size(1)
|
||||
stdv = 1.0 / math.sqrt(n)
|
||||
data.normal_(0, stdv)
|
||||
elif data.dim() in (3, 4):
|
||||
# conv weight
|
||||
n = data.size(1)
|
||||
for k in data.size()[2:]:
|
||||
n *= k
|
||||
stdv = 1.0 / math.sqrt(n)
|
||||
data.normal_(0, stdv)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for mod in model.modules():
|
||||
# 2. embed weight ~ Normal(0, 1)
|
||||
if isinstance(mod, torch.nn.Embedding):
|
||||
mod.weight.data.normal_(0, 1)
|
||||
# 3. forget-bias = 1.0
|
||||
elif isinstance(mod, torch.nn.RNNCellBase):
|
||||
n = mod.bias_ih.size(0)
|
||||
mod.bias_ih.data[n // 4 : n // 2].fill_(1.0)
|
||||
elif isinstance(mod, torch.nn.RNNBase):
|
||||
for name, param in mod.named_parameters():
|
||||
if "bias" in name:
|
||||
n = param.size(0)
|
||||
param.data[n // 4 : n // 2].fill_(1.0)
|
||||
if hasattr(mod, "espnet_initialization_fn"):
|
||||
mod.espnet_initialization_fn()
|
||||
|
||||
else:
|
||||
# weight init
|
||||
for p in model.parameters():
|
||||
if p.dim() > 1:
|
||||
if init == "xavier_uniform":
|
||||
torch.nn.init.xavier_uniform_(p.data)
|
||||
elif init == "xavier_normal":
|
||||
torch.nn.init.xavier_normal_(p.data)
|
||||
elif init == "kaiming_uniform":
|
||||
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
|
||||
elif init == "kaiming_normal":
|
||||
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
|
||||
else:
|
||||
raise ValueError("Unknown initialization: " + init)
|
||||
# bias init
|
||||
for p in model.parameters():
|
||||
if p.dim() == 1:
|
||||
p.data.zero_()
|
||||
|
||||
# reset some modules with default init
|
||||
for m in model.modules():
|
||||
if isinstance(
|
||||
m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)
|
||||
):
|
||||
m.reset_parameters()
|
||||
if hasattr(m, "espnet_initialization_fn"):
|
||||
m.espnet_initialization_fn()
|
||||
|
||||
# TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization
|
||||
if getattr(model, "encoder", None) and getattr(
|
||||
model.encoder, "reload_pretrained_parameters", None
|
||||
):
|
||||
model.encoder.reload_pretrained_parameters()
|
||||
if getattr(model, "frontend", None) and getattr(
|
||||
model.frontend, "reload_pretrained_parameters", None
|
||||
):
|
||||
model.frontend.reload_pretrained_parameters()
|
||||
if getattr(model, "postencoder", None) and getattr(
|
||||
model.postencoder, "reload_pretrained_parameters", None
|
||||
):
|
||||
model.postencoder.reload_pretrained_parameters()
|
||||
125
funasr_local/torch_utils/load_pretrained_model.py
Normal file
125
funasr_local/torch_utils/load_pretrained_model.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Union
|
||||
from io import BytesIO
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
|
||||
|
||||
def filter_state_dict(
|
||||
dst_state: Dict[str, Union[float, torch.Tensor]],
|
||||
src_state: Dict[str, Union[float, torch.Tensor]],
|
||||
):
|
||||
"""Filter name, size mismatch instances between dicts.
|
||||
|
||||
Args:
|
||||
dst_state: reference state dict for filtering
|
||||
src_state: target state dict for filtering
|
||||
|
||||
"""
|
||||
match_state = {}
|
||||
for key, value in src_state.items():
|
||||
if key in dst_state and (dst_state[key].size() == src_state[key].size()):
|
||||
match_state[key] = value
|
||||
else:
|
||||
if key not in dst_state:
|
||||
logging.warning(
|
||||
f"Filter out {key} from pretrained dict"
|
||||
+ " because of name not found in target dict"
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Filter out {key} from pretrained dict"
|
||||
+ " because of size mismatch"
|
||||
+ f"({dst_state[key].size()}-{src_state[key].size()})"
|
||||
)
|
||||
return match_state
|
||||
|
||||
|
||||
def load_pretrained_model(
|
||||
init_param: str,
|
||||
model: torch.nn.Module,
|
||||
ignore_init_mismatch: bool,
|
||||
map_location: str = "cpu",
|
||||
oss_bucket=None,
|
||||
):
|
||||
"""Load a model state and set it to the model.
|
||||
|
||||
Args:
|
||||
init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
|
||||
|
||||
Examples:
|
||||
>>> load_pretrained_model("somewhere/model.pb", model)
|
||||
>>> load_pretrained_model("somewhere/model.pb:decoder:decoder", model)
|
||||
>>> load_pretrained_model("somewhere/model.pb:decoder:decoder:", model)
|
||||
>>> load_pretrained_model(
|
||||
... "somewhere/model.pb:decoder:decoder:decoder.embed", model
|
||||
... )
|
||||
>>> load_pretrained_model("somewhere/decoder.pb::decoder", model)
|
||||
"""
|
||||
sps = init_param.split(":", 4)
|
||||
if len(sps) == 4:
|
||||
path, src_key, dst_key, excludes = sps
|
||||
elif len(sps) == 3:
|
||||
path, src_key, dst_key = sps
|
||||
excludes = None
|
||||
elif len(sps) == 2:
|
||||
path, src_key = sps
|
||||
dst_key, excludes = None, None
|
||||
else:
|
||||
(path,) = sps
|
||||
src_key, dst_key, excludes = None, None, None
|
||||
if src_key == "":
|
||||
src_key = None
|
||||
if dst_key == "":
|
||||
dst_key = None
|
||||
|
||||
if dst_key is None:
|
||||
obj = model
|
||||
else:
|
||||
|
||||
def get_attr(obj: Any, key: str):
|
||||
"""Get an nested attribute.
|
||||
|
||||
>>> class A(torch.nn.Module):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.linear = torch.nn.Linear(10, 10)
|
||||
>>> a = A()
|
||||
>>> assert A.linear.weight is get_attr(A, 'linear.weight')
|
||||
|
||||
"""
|
||||
if key.strip() == "":
|
||||
return obj
|
||||
for k in key.split("."):
|
||||
obj = getattr(obj, k)
|
||||
return obj
|
||||
|
||||
obj = get_attr(model, dst_key)
|
||||
|
||||
if oss_bucket is None:
|
||||
src_state = torch.load(path, map_location=map_location)
|
||||
else:
|
||||
buffer = BytesIO(oss_bucket.get_object(path).read())
|
||||
src_state = torch.load(buffer, map_location=map_location)
|
||||
if excludes is not None:
|
||||
for e in excludes.split(","):
|
||||
src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
|
||||
|
||||
if src_key is not None:
|
||||
src_state = {
|
||||
k[len(src_key) + 1 :]: v
|
||||
for k, v in src_state.items()
|
||||
if k.startswith(src_key)
|
||||
}
|
||||
|
||||
dst_state = obj.state_dict()
|
||||
if ignore_init_mismatch:
|
||||
src_state = filter_state_dict(dst_state, src_state)
|
||||
|
||||
logging.info("Loaded src_state keys: {}".format(src_state.keys()))
|
||||
dst_state.update(src_state)
|
||||
obj.load_state_dict(dst_state)
|
||||
70
funasr_local/torch_utils/model_summary.py
Normal file
70
funasr_local/torch_utils/model_summary.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def get_human_readable_count(number: int) -> str:
|
||||
"""Return human_readable_count
|
||||
|
||||
Originated from:
|
||||
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
|
||||
|
||||
Abbreviates an integer number with K, M, B, T for thousands, millions,
|
||||
billions and trillions, respectively.
|
||||
Examples:
|
||||
>>> get_human_readable_count(123)
|
||||
'123 '
|
||||
>>> get_human_readable_count(1234) # (one thousand)
|
||||
'1 K'
|
||||
>>> get_human_readable_count(2e6) # (two million)
|
||||
'2 M'
|
||||
>>> get_human_readable_count(3e9) # (three billion)
|
||||
'3 B'
|
||||
>>> get_human_readable_count(4e12) # (four trillion)
|
||||
'4 T'
|
||||
>>> get_human_readable_count(5e15) # (more than trillion)
|
||||
'5,000 T'
|
||||
Args:
|
||||
number: a positive integer number
|
||||
Return:
|
||||
A string formatted according to the pattern described above.
|
||||
"""
|
||||
assert number >= 0
|
||||
labels = [" ", "K", "M", "B", "T"]
|
||||
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
|
||||
num_groups = int(np.ceil(num_digits / 3))
|
||||
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
|
||||
shift = -3 * (num_groups - 1)
|
||||
number = number * (10**shift)
|
||||
index = num_groups - 1
|
||||
return f"{number:.2f} {labels[index]}"
|
||||
|
||||
|
||||
def to_bytes(dtype) -> int:
|
||||
# torch.float16 -> 16
|
||||
return int(str(dtype)[-2:]) // 8
|
||||
|
||||
|
||||
def model_summary(model: torch.nn.Module) -> str:
|
||||
message = "Model structure:\n"
|
||||
message += str(model)
|
||||
tot_params = sum(p.numel() for p in model.parameters())
|
||||
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
|
||||
tot_params = get_human_readable_count(tot_params)
|
||||
num_params = get_human_readable_count(num_params)
|
||||
message += "\n\nModel summary:\n"
|
||||
message += f" Class Name: {model.__class__.__name__}\n"
|
||||
message += f" Total Number of model parameters: {tot_params}\n"
|
||||
message += (
|
||||
f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
|
||||
)
|
||||
num_bytes = humanfriendly.format_size(
|
||||
sum(
|
||||
p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad
|
||||
)
|
||||
)
|
||||
message += f" Size: {num_bytes}\n"
|
||||
dtype = next(iter(model.parameters())).dtype
|
||||
message += f" Type: {dtype}"
|
||||
return message
|
||||
16
funasr_local/torch_utils/pytorch_version.py
Normal file
16
funasr_local/torch_utils/pytorch_version.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
|
||||
|
||||
def pytorch_cudnn_version() -> str:
|
||||
message = (
|
||||
f"pytorch.version={torch.__version__}, "
|
||||
f"cuda.available={torch.cuda.is_available()}, "
|
||||
)
|
||||
|
||||
if torch.backends.cudnn.enabled:
|
||||
message += (
|
||||
f"cudnn.version={torch.backends.cudnn.version()}, "
|
||||
f"cudnn.benchmark={torch.backends.cudnn.benchmark}, "
|
||||
f"cudnn.deterministic={torch.backends.cudnn.deterministic}"
|
||||
)
|
||||
return message
|
||||
47
funasr_local/torch_utils/recursive_op.py
Normal file
47
funasr_local/torch_utils/recursive_op.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Torch utility module."""
|
||||
import torch
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
|
||||
def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False):
|
||||
assert weight.dim() == 1, weight.size()
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return type(obj)(recursive_sum(v, weight, distributed) for v in obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()}
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
assert obj.size() == weight.size(), (obj.size(), weight.size())
|
||||
obj = (obj * weight.type(obj.dtype)).sum()
|
||||
if distributed:
|
||||
torch.distributed.all_reduce(obj, op=ReduceOp.SUM)
|
||||
return obj
|
||||
elif obj is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(type(obj))
|
||||
|
||||
|
||||
def recursive_divide(a, b: torch.Tensor):
|
||||
if isinstance(a, (tuple, list)):
|
||||
return type(a)(recursive_divide(v, b) for v in a)
|
||||
elif isinstance(a, dict):
|
||||
return {k: recursive_divide(v, b) for k, v in a.items()}
|
||||
elif isinstance(a, torch.Tensor):
|
||||
assert a.size() == b.size(), (a.size(), b.size())
|
||||
return a / b.type(a.dtype)
|
||||
elif a is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(type(a))
|
||||
|
||||
|
||||
def recursive_average(obj, weight: torch.Tensor, distributed: bool = False):
|
||||
obj = recursive_sum(obj, weight, distributed)
|
||||
weight = weight.sum()
|
||||
if distributed:
|
||||
torch.distributed.all_reduce(weight, op=ReduceOp.SUM)
|
||||
# Normalize weight to be sum-to-1
|
||||
obj = recursive_divide(obj, weight)
|
||||
return obj, weight
|
||||
10
funasr_local/torch_utils/set_all_random_seed.py
Normal file
10
funasr_local/torch_utils/set_all_random_seed.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def set_all_random_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
Reference in New Issue
Block a user