mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
71
funasr_local/torch_utils/device_funcs.py
Normal file
71
funasr_local/torch_utils/device_funcs.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import dataclasses
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
|
||||
"""Change the device of object recursively"""
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()
|
||||
}
|
||||
elif dataclasses.is_dataclass(data) and not isinstance(data, type):
|
||||
return type(data)(
|
||||
*[
|
||||
to_device(v, device, dtype, non_blocking, copy)
|
||||
for v in dataclasses.astuple(data)
|
||||
]
|
||||
)
|
||||
# maybe namedtuple. I don't know the correct way to judge namedtuple.
|
||||
elif isinstance(data, tuple) and type(data) is not tuple:
|
||||
return type(data)(
|
||||
*[to_device(o, device, dtype, non_blocking, copy) for o in data]
|
||||
)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data.to(device, dtype, non_blocking, copy)
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def force_gatherable(data, device):
|
||||
"""Change object to gatherable in torch.nn.DataParallel recursively
|
||||
|
||||
The difference from to_device() is changing to torch.Tensor if float or int
|
||||
value is found.
|
||||
|
||||
The restriction to the returned value in DataParallel:
|
||||
The object must be
|
||||
- torch.cuda.Tensor
|
||||
- 1 or more dimension. 0-dimension-tensor sends warning.
|
||||
or a list, tuple, dict.
|
||||
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {k: force_gatherable(v, device) for k, v in data.items()}
|
||||
# DataParallel can't handle NamedTuple well
|
||||
elif isinstance(data, tuple) and type(data) is not tuple:
|
||||
return type(data)(*[force_gatherable(o, device) for o in data])
|
||||
elif isinstance(data, (list, tuple, set)):
|
||||
return type(data)(force_gatherable(v, device) for v in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
return force_gatherable(torch.from_numpy(data), device)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
if data.dim() == 0:
|
||||
# To 1-dim array
|
||||
data = data[None]
|
||||
return data.to(device)
|
||||
elif isinstance(data, float):
|
||||
return torch.tensor([data], dtype=torch.float, device=device)
|
||||
elif isinstance(data, int):
|
||||
return torch.tensor([data], dtype=torch.long, device=device)
|
||||
elif data is None:
|
||||
return None
|
||||
else:
|
||||
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
|
||||
return data
|
||||
Reference in New Issue
Block a user