mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
@@ -48,22 +48,22 @@ def get_image_list(data_root, split):
|
||||
class Dataset(object):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
split,
|
||||
json_path,
|
||||
use_audio_length_left=1,
|
||||
use_audio_length_right=1,
|
||||
whisper_model_type = "tiny"
|
||||
):
|
||||
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||
self.audio_feature = [use_audio_length_left,use_audio_length_right]
|
||||
self.all_img_names = []
|
||||
self.split = split
|
||||
self.img_names_path = '...'
|
||||
# self.split = split
|
||||
self.img_names_path = '../data'
|
||||
self.whisper_model_type = whisper_model_type
|
||||
self.use_audio_length_left = use_audio_length_left
|
||||
self.use_audio_length_right = use_audio_length_right
|
||||
|
||||
if self.whisper_model_type =="tiny":
|
||||
self.whisper_path = '...'
|
||||
self.whisper_path = '../data/audios'
|
||||
self.whisper_feature_W = 5
|
||||
self.whisper_feature_H = 384
|
||||
elif self.whisper_model_type =="largeV2":
|
||||
@@ -71,6 +71,8 @@ class Dataset(object):
|
||||
self.whisper_feature_W = 33
|
||||
self.whisper_feature_H = 1280
|
||||
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
||||
with open(json_path, 'r') as file:
|
||||
self.all_videos = json.load(file)
|
||||
|
||||
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
|
||||
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
|
||||
@@ -79,7 +81,6 @@ class Dataset(object):
|
||||
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
|
||||
with open(json_path_names, "w") as f:
|
||||
json.dump(img_names,f)
|
||||
print(f"save to {json_path_names}")
|
||||
else:
|
||||
with open(json_path_names, "r") as f:
|
||||
img_names = json.load(f)
|
||||
@@ -135,7 +136,6 @@ class Dataset(object):
|
||||
vidname = self.all_videos[idx].split('/')[-1]
|
||||
video_imgs = self.all_img_names[idx]
|
||||
if len(video_imgs) == 0:
|
||||
# print("video_imgs = 0:",vidname)
|
||||
continue
|
||||
img_name = random.choice(video_imgs)
|
||||
img_idx = int(basename(img_name).split(".")[0])
|
||||
@@ -193,7 +193,6 @@ class Dataset(object):
|
||||
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
|
||||
# 判定是否越界
|
||||
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
|
||||
|
||||
if not os.path.exists(audio_feat_path):
|
||||
is_index_out_of_range = True
|
||||
break
|
||||
@@ -214,8 +213,6 @@ class Dataset(object):
|
||||
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
|
||||
continue
|
||||
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
|
||||
|
||||
|
||||
return ref_image, image, masked_image, mask, audio_feature
|
||||
|
||||
|
||||
@@ -231,10 +228,8 @@ if __name__ == "__main__":
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_data, batch_size=4, shuffle=True,
|
||||
num_workers=1)
|
||||
print("val_dataset:",val_data_loader.__len__())
|
||||
|
||||
for i, data in enumerate(val_data_loader):
|
||||
ref_image, image, masked_image, mask, audio_feature = data
|
||||
print("ref_image: ", ref_image.shape)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user