Update draft training codes

This commit is contained in:
czk32611
2024-04-28 11:34:49 +08:00
parent 6e32247cb1
commit d73daf1808
9 changed files with 1547 additions and 0 deletions

252
train_codes/DataLoader.py Normal file
View File

@@ -0,0 +1,252 @@
import os, random, cv2, argparse
import torch
from torch.utils import data as data_utils
from os.path import dirname, join, basename, isfile
import numpy as np
from glob import glob
from utils.utils import prepare_mask_and_masked_image
import torchvision.utils as vutils
import torchvision.transforms as transforms
import shutil
from tqdm import tqdm
import ast
import json
import re
import heapq
syncnet_T = 1
RESIZED_IMG = 256
connections = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7),(7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13),(13,14),(14,15),(15,16), # 下颌线
(17, 18), (18, 19), (19, 20), (20, 21), #左眉毛
(22, 23), (23, 24), (24, 25), (25, 26), #右眉毛
(27, 28),(28,29),(29,30),# 鼻梁
(31,32),(32,33),(33,34),(34,35), #鼻子
(36,37),(37,38),(38, 39), (39, 40), (40, 41), (41, 36), # 左眼
(42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42), # 右眼
(48, 49),(49, 50), (50, 51),(51, 52),(52, 53), (53, 54), # 上嘴唇 外延
(54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 48), # 下嘴唇 外延
(60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60) #嘴唇内圈
]
def get_image_list(data_root, split):
filelist = []
imgNumList = []
with open('filelists/{}.txt'.format(split)) as f:
for line in f:
line = line.strip()
if ' ' in line:
filename = line.split()[0]
imgNum = int(line.split()[1])
filelist.append(os.path.join(data_root, filename))
imgNumList.append(imgNum)
return filelist, imgNumList
class Dataset(object):
def __init__(self,
data_root,
split,
use_audio_length_left=1,
use_audio_length_right=1,
whisper_model_type = "tiny"
):
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
self.audio_feature = [use_audio_length_left,use_audio_length_right]
self.all_img_names = []
self.split = split
self.img_names_path = '...'
self.whisper_model_type = whisper_model_type
self.use_audio_length_left = use_audio_length_left
self.use_audio_length_right = use_audio_length_right
if self.whisper_model_type =="tiny":
self.whisper_path = '...'
self.whisper_feature_W = 5
self.whisper_feature_H = 384
elif self.whisper_model_type =="largeV2":
self.whisper_path = '...'
self.whisper_feature_W = 33
self.whisper_feature_H = 1280
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
if not os.path.exists(json_path_names):
img_names = glob(join(vidname, '*.png'))
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
with open(json_path_names, "w") as f:
json.dump(img_names,f)
print(f"save to {json_path_names}")
else:
with open(json_path_names, "r") as f:
img_names = json.load(f)
self.all_img_names.append(img_names)
def get_frame_id(self, frame):
return int(basename(frame).split('.')[0])
def get_window(self, start_frame):
start_id = self.get_frame_id(start_frame)
vidname = dirname(start_frame)
window_fnames = []
for frame_id in range(start_id, start_id + syncnet_T):
frame = join(vidname, '{}.png'.format(frame_id))
if not isfile(frame):
return None
window_fnames.append(frame)
return window_fnames
def read_window(self, window_fnames):
if window_fnames is None: return None
window = []
for fname in window_fnames:
img = cv2.imread(fname)
if img is None:
return None
try:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (RESIZED_IMG, RESIZED_IMG))
except Exception as e:
print("read_window has error fname not exist:",fname)
return None
window.append(img)
return window
def crop_audio_window(self, spec, start_frame):
if type(start_frame) == int:
start_frame_num = start_frame
else:
start_frame_num = self.get_frame_id(start_frame)
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
end_idx = start_idx + syncnet_mel_step_size
return spec[start_idx : end_idx, :]
def prepare_window(self, window):
# 1 x H x W x 3
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
return x
def __len__(self):
return len(self.all_videos)
def __getitem__(self, idx):
while 1:
idx = random.randint(0, len(self.all_videos) - 1)
#随机选择某个video里
vidname = self.all_videos[idx].split('/')[-1]
video_imgs = self.all_img_names[idx]
if len(video_imgs) == 0:
# print("video_imgs = 0:",vidname)
continue
img_name = random.choice(video_imgs)
img_idx = int(basename(img_name).split(".")[0])
random_element = random.randint(0,len(video_imgs)-1)
while abs(random_element - img_idx) <= 5:
random_element = random.randint(0,len(video_imgs)-1)
img_dir = os.path.dirname(img_name)
ref_image = os.path.join(img_dir, f"{str(random_element)}.png")
target_window_fnames = self.get_window(img_name)
ref_window_fnames = self.get_window(ref_image)
if target_window_fnames is None or ref_window_fnames is None:
print("No such img",img_name, ref_image)
continue
try:
#构建目标img数据
target_window = self.read_window(target_window_fnames)
if target_window is None :
print("No such target window,",target_window_fnames)
continue
#构建参考img数据
ref_window = self.read_window(ref_window_fnames)
if ref_window is None:
print("No such target ref window,",ref_window)
continue
except Exception as e:
print(f"发生未知错误:{e}")
continue
#构建target输入
target_window = self.prepare_window(target_window)
image = gt = target_window.copy().squeeze()
target_window[:, :, target_window.shape[2]//2:] = 0. # upper half face, mask掉下半部分 V1输入
ref_image = self.prepare_window(ref_window).squeeze()
mask = torch.zeros((ref_image.shape[1], ref_image.shape[2]))
mask[:ref_image.shape[2]//2,:] = 1
image = torch.FloatTensor(image)
mask, masked_image = prepare_mask_and_masked_image(image,mask)
#音频特征
window_index = self.get_frame_id(img_name)
sub_folder_name = vidname.split('/')[-1]
## 根据window_index加载相邻的音频
audio_feature_all = []
is_index_out_of_range = False
if os.path.isdir(os.path.join(self.whisper_path, sub_folder_name)):
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
# 判定是否越界
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
if not os.path.exists(audio_feat_path):
is_index_out_of_range = True
break
try:
audio_feature_all.append(np.load(audio_feat_path))
except Exception as e:
print(f"发生未知错误:{e}")
print(f"npy load error {audio_feat_path}")
if is_index_out_of_range:
continue
audio_feature = np.concatenate(audio_feature_all, axis=0)
else:
continue
audio_feature = audio_feature.reshape(1, -1, self.whisper_feature_H) #1 -1 384
if audio_feature.shape != (1,self.whisper_feature_concateW, self.whisper_feature_H): #1 50 384
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
continue
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
return ref_image, image, masked_image, mask, audio_feature
if __name__ == "__main__":
data_root = '...'
val_data = Dataset(data_root,
'val',
use_audio_length_left = 2,
use_audio_length_right = 2,
whisper_model_type = "tiny"
)
val_data_loader = data_utils.DataLoader(
val_data, batch_size=4, shuffle=True,
num_workers=1)
print("val_dataset:",val_data_loader.__len__())
for i, data in enumerate(val_data_loader):
ref_image, image, masked_image, mask, audio_feature = data
print("ref_image: ", ref_image.shape)