mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
40
funasr_local/datasets/large_datasets/utils/clipping.py
Normal file
40
funasr_local/datasets/large_datasets/utils/clipping.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr_local.datasets.collate_fn import crop_to_max_size
|
||||
|
||||
|
||||
def clipping(data):
|
||||
assert isinstance(data, list)
|
||||
assert "key" in data[0]
|
||||
|
||||
keys = [x["key"] for x in data]
|
||||
|
||||
batch = {}
|
||||
data_names = data[0].keys()
|
||||
for data_name in data_names:
|
||||
if data_name == "key":
|
||||
continue
|
||||
else:
|
||||
if data[0][data_name].dtype.kind == "i":
|
||||
tensor_type = torch.int64
|
||||
else:
|
||||
tensor_type = torch.float32
|
||||
|
||||
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
|
||||
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
|
||||
|
||||
length_clip = min(tensor_lengths)
|
||||
tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
|
||||
for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
|
||||
diff = length - length_clip
|
||||
assert diff >= 0
|
||||
if diff == 0:
|
||||
tensor_clip[i] = tensor
|
||||
else:
|
||||
tensor_clip[i] = crop_to_max_size(tensor, length_clip)
|
||||
|
||||
batch[data_name] = tensor_clip
|
||||
batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
|
||||
|
||||
return keys, batch
|
||||
Reference in New Issue
Block a user