mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 09:29:24 +08:00
feat: Initial commit
This commit is contained in:
53
utils/cache.py
Normal file
53
utils/cache.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import os
|
||||
import SharedArray
|
||||
|
||||
try:
|
||||
from multiprocessing.shared_memory import ShareableList
|
||||
except ImportError:
|
||||
import warnings
|
||||
|
||||
warnings.warn("Please update python version >= 3.8 to enable shared_memory")
|
||||
import numpy as np
|
||||
|
||||
|
||||
def shared_array(name, var=None):
|
||||
if var is not None:
|
||||
# check exist
|
||||
if os.path.exists(f"/dev/shm/{name}"):
|
||||
return SharedArray.attach(f"shm://{name}")
|
||||
# create shared_array
|
||||
data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype)
|
||||
data[...] = var[...]
|
||||
data.flags.writeable = False
|
||||
else:
|
||||
data = SharedArray.attach(f"shm://{name}").copy()
|
||||
return data
|
||||
|
||||
|
||||
def shared_dict(name, var=None):
|
||||
name = str(name)
|
||||
assert "." not in name # '.' is used as sep flag
|
||||
data = {}
|
||||
if var is not None:
|
||||
assert isinstance(var, dict)
|
||||
keys = var.keys()
|
||||
# current version only cache np.array
|
||||
keys_valid = []
|
||||
for key in keys:
|
||||
if isinstance(var[key], np.ndarray):
|
||||
keys_valid.append(key)
|
||||
keys = keys_valid
|
||||
|
||||
ShareableList(sequence=keys, name=name + ".keys")
|
||||
for key in keys:
|
||||
if isinstance(var[key], np.ndarray):
|
||||
data[key] = shared_array(name=f"{name}.{key}", var=var[key])
|
||||
else:
|
||||
keys = list(ShareableList(name=name + ".keys"))
|
||||
for key in keys:
|
||||
data[key] = shared_array(name=f"{name}.{key}")
|
||||
return data
|
||||
Reference in New Issue
Block a user