diff --git a/train_codes/DataLoader.py b/train_codes/DataLoader.py new file mode 100644 index 0000000..8dacbde --- /dev/null +++ b/train_codes/DataLoader.py @@ -0,0 +1,252 @@ +import os, random, cv2, argparse +import torch +from torch.utils import data as data_utils +from os.path import dirname, join, basename, isfile +import numpy as np +from glob import glob +from utils.utils import prepare_mask_and_masked_image +import torchvision.utils as vutils +import torchvision.transforms as transforms +import shutil +from tqdm import tqdm +import ast +import json +import re +import heapq + +syncnet_T = 1 +RESIZED_IMG = 256 + +connections = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7),(7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13),(13,14),(14,15),(15,16), # 下颌线 + (17, 18), (18, 19), (19, 20), (20, 21), #左眉毛 + (22, 23), (23, 24), (24, 25), (25, 26), #右眉毛 + (27, 28),(28,29),(29,30),# 鼻梁 + (31,32),(32,33),(33,34),(34,35), #鼻子 + (36,37),(37,38),(38, 39), (39, 40), (40, 41), (41, 36), # 左眼 + (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42), # 右眼 + (48, 49),(49, 50), (50, 51),(51, 52),(52, 53), (53, 54), # 上嘴唇 外延 + (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 48), # 下嘴唇 外延 + (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60) #嘴唇内圈 + ] + + +def get_image_list(data_root, split): + filelist = [] + imgNumList = [] + with open('filelists/{}.txt'.format(split)) as f: + for line in f: + line = line.strip() + if ' ' in line: + filename = line.split()[0] + imgNum = int(line.split()[1]) + filelist.append(os.path.join(data_root, filename)) + imgNumList.append(imgNum) + return filelist, imgNumList + + + +class Dataset(object): + def __init__(self, + data_root, + split, + use_audio_length_left=1, + use_audio_length_right=1, + whisper_model_type = "tiny" + ): + self.all_videos, self.all_imgNum = get_image_list(data_root, split) + self.audio_feature = [use_audio_length_left,use_audio_length_right] + self.all_img_names = [] + self.split = split + self.img_names_path = '...' + self.whisper_model_type = whisper_model_type + self.use_audio_length_left = use_audio_length_left + self.use_audio_length_right = use_audio_length_right + + if self.whisper_model_type =="tiny": + self.whisper_path = '...' + self.whisper_feature_W = 5 + self.whisper_feature_H = 384 + elif self.whisper_model_type =="largeV2": + self.whisper_path = '...' + self.whisper_feature_W = 33 + self.whisper_feature_H = 1280 + self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50 + + for vidname in tqdm(self.all_videos, desc="Preparing dataset"): + json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json" + if not os.path.exists(json_path_names): + img_names = glob(join(vidname, '*.png')) + img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0])) + with open(json_path_names, "w") as f: + json.dump(img_names,f) + print(f"save to {json_path_names}") + else: + with open(json_path_names, "r") as f: + img_names = json.load(f) + self.all_img_names.append(img_names) + + def get_frame_id(self, frame): + return int(basename(frame).split('.')[0]) + + def get_window(self, start_frame): + start_id = self.get_frame_id(start_frame) + vidname = dirname(start_frame) + + window_fnames = [] + for frame_id in range(start_id, start_id + syncnet_T): + frame = join(vidname, '{}.png'.format(frame_id)) + if not isfile(frame): + return None + window_fnames.append(frame) + return window_fnames + + def read_window(self, window_fnames): + if window_fnames is None: return None + window = [] + for fname in window_fnames: + img = cv2.imread(fname) + if img is None: + return None + try: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (RESIZED_IMG, RESIZED_IMG)) + except Exception as e: + print("read_window has error fname not exist:",fname) + return None + + window.append(img) + + return window + + + def crop_audio_window(self, spec, start_frame): + if type(start_frame) == int: + start_frame_num = start_frame + else: + start_frame_num = self.get_frame_id(start_frame) + start_idx = int(80. * (start_frame_num / float(hparams.fps))) + + end_idx = start_idx + syncnet_mel_step_size + + return spec[start_idx : end_idx, :] + + def prepare_window(self, window): + # 1 x H x W x 3 + x = np.asarray(window) / 255. + x = np.transpose(x, (3, 0, 1, 2)) + + return x + + def __len__(self): + return len(self.all_videos) + + def __getitem__(self, idx): + while 1: + idx = random.randint(0, len(self.all_videos) - 1) + #随机选择某个video里 + vidname = self.all_videos[idx].split('/')[-1] + video_imgs = self.all_img_names[idx] + if len(video_imgs) == 0: +# print("video_imgs = 0:",vidname) + continue + img_name = random.choice(video_imgs) + img_idx = int(basename(img_name).split(".")[0]) + random_element = random.randint(0,len(video_imgs)-1) + while abs(random_element - img_idx) <= 5: + random_element = random.randint(0,len(video_imgs)-1) + img_dir = os.path.dirname(img_name) + ref_image = os.path.join(img_dir, f"{str(random_element)}.png") + target_window_fnames = self.get_window(img_name) + ref_window_fnames = self.get_window(ref_image) + + if target_window_fnames is None or ref_window_fnames is None: + print("No such img",img_name, ref_image) + continue + + try: + #构建目标img数据 + target_window = self.read_window(target_window_fnames) + if target_window is None : + print("No such target window,",target_window_fnames) + continue + #构建参考img数据 + ref_window = self.read_window(ref_window_fnames) + + if ref_window is None: + print("No such target ref window,",ref_window) + continue + except Exception as e: + print(f"发生未知错误:{e}") + continue + + #构建target输入 + target_window = self.prepare_window(target_window) + image = gt = target_window.copy().squeeze() + target_window[:, :, target_window.shape[2]//2:] = 0. # upper half face, mask掉下半部分 V1:输入 + ref_image = self.prepare_window(ref_window).squeeze() + + + + mask = torch.zeros((ref_image.shape[1], ref_image.shape[2])) + mask[:ref_image.shape[2]//2,:] = 1 + image = torch.FloatTensor(image) + mask, masked_image = prepare_mask_and_masked_image(image,mask) + + + + #音频特征 + window_index = self.get_frame_id(img_name) + sub_folder_name = vidname.split('/')[-1] + + ## 根据window_index加载相邻的音频 + audio_feature_all = [] + is_index_out_of_range = False + if os.path.isdir(os.path.join(self.whisper_path, sub_folder_name)): + for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1): + # 判定是否越界 + audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy") + + if not os.path.exists(audio_feat_path): + is_index_out_of_range = True + break + + try: + audio_feature_all.append(np.load(audio_feat_path)) + except Exception as e: + print(f"发生未知错误:{e}") + print(f"npy load error {audio_feat_path}") + if is_index_out_of_range: + continue + audio_feature = np.concatenate(audio_feature_all, axis=0) + else: + continue + + audio_feature = audio_feature.reshape(1, -1, self.whisper_feature_H) #1, -1, 384 + if audio_feature.shape != (1,self.whisper_feature_concateW, self.whisper_feature_H): #1 50 384 + print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}") + continue + audio_feature = torch.squeeze(torch.FloatTensor(audio_feature)) + + + return ref_image, image, masked_image, mask, audio_feature + + + +if __name__ == "__main__": + data_root = '...' + val_data = Dataset(data_root, + 'val', + use_audio_length_left = 2, + use_audio_length_right = 2, + whisper_model_type = "tiny" + ) + val_data_loader = data_utils.DataLoader( + val_data, batch_size=4, shuffle=True, + num_workers=1) + print("val_dataset:",val_data_loader.__len__()) + + for i, data in enumerate(val_data_loader): + ref_image, image, masked_image, mask, audio_feature = data + print("ref_image: ", ref_image.shape) + + \ No newline at end of file diff --git a/train_codes/README.md b/train_codes/README.md new file mode 100644 index 0000000..fdd56ee --- /dev/null +++ b/train_codes/README.md @@ -0,0 +1,35 @@ +# Draft training codes + +We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized. + +## Data preprocessing + You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69). + +Finally, the data should be organized as follows: +``` +./data/ +├── images +│ └──RD_Radio10_000 +│ └── 0.png +│ └── 1.png +│ └── xxx.png +│ └──RD_Radio11_000 +│ └── 0.png +│ └── 1.png +│ └── xxx.png +├── audios +│ └──RD_Radio10_000 +│ └── 0.npy +│ └── 1.npy +│ └── xxx.npy +│ └──RD_Radio11_000 +│ └── 0.npy +│ └── 1.npy +│ └── xxx.npy +``` + +## Training +Simply run after preparing the preprocessed data +``` +sh train.sh +``` \ No newline at end of file diff --git a/train_codes/filelists/train.txt b/train_codes/filelists/train.txt new file mode 100644 index 0000000..3d1fb52 --- /dev/null +++ b/train_codes/filelists/train.txt @@ -0,0 +1,322 @@ +RD_Radio10_000 3501 +RD_Radio11_000 752 +RD_Radio11_001 1502 +RD_Radio12_000 876 +RD_Radio13_000 877 +RD_Radio14_000 1251 +RD_Radio16_000 2377 +RD_Radio17_000 2176 +RD_Radio18_000 1827 +RD_Radio19_000 1876 +RD_Radio1_000 1876 +RD_Radio20_000 276 +RD_Radio21_000 1201 +RD_Radio22_000 752 +RD_Radio23_000 1602 +RD_Radio25_000 2126 +RD_Radio26_000 1052 +RD_Radio27_000 1376 +RD_Radio28_000 2252 +RD_Radio29_000 2252 +RD_Radio2_000 2076 +RD_Radio30_000 2877 +RD_Radio31_000 1377 +RD_Radio32_000 1502 +RD_Radio33_000 2252 +RD_Radio34_000 702 +RD_Radio34_001 252 +RD_Radio34_002 501 +RD_Radio34_003 326 +RD_Radio34_004 502 +RD_Radio34_005 502 +RD_Radio34_006 502 +RD_Radio34_007 252 +RD_Radio34_008 876 +RD_Radio34_009 752 +RD_Radio35_000 2127 +RD_Radio36_000 2377 +RD_Radio37_000 3127 +RD_Radio38_000 3752 +RD_Radio39_000 1502 +RD_Radio3_000 2377 +RD_Radio40_000 1252 +RD_Radio41_000 2127 +RD_Radio42_000 1752 +RD_Radio43_000 1502 +RD_Radio44_000 1127 +RD_Radio45_000 1628 +RD_Radio46_000 1877 +RD_Radio47_000 1502 +RD_Radio48_000 2127 +RD_Radio49_000 1502 +RD_Radio4_000 1002 +RD_Radio50_000 2252 +RD_Radio51_000 3253 +RD_Radio52_000 2752 +RD_Radio53_000 2627 +RD_Radio54_000 577 +RD_Radio56_000 1952 +RD_Radio57_000 3077 +RD_Radio59_000 1952 +RD_Radio5_000 1377 +RD_Radio7_000 2252 +RD_Radio8_000 2126 +RD_Radio9_000 1502 +WDA_AdamSchiff_000 6877 +WDA_AdamSmith_000 7327 +WDA_AlexandriaOcasioCortez_000 2252 +WDA_AmyKlobuchar0_000 6501 +WDA_AmyKlobuchar1_000 3253 +WDA_AmyKlobuchar1_001 877 +WDA_AmyKlobuchar1_002 1502 +WDA_AmyKlobuchar1_003 1377 +WDA_AndyKim_000 6952 +WDA_AndyLevin_000 4876 +WDA_AnnieKuster_000 4876 +WDA_BarackObama_000 3575 +WDA_BarackObama_001 5625 +WDA_BarbaraLee0_000 6502 +WDA_BarbaraLee1_000 4702 +WDA_BenCardin0_000 8127 +WDA_BenCardin1_000 7102 +WDA_BenRayLujn_000 7127 +WDA_BennieThompson1_000 6375 +WDA_BennieThompson_000 5375 +WDA_BernieSanders_000 8451 +WDA_BettyMcCollum_000 7002 +WDA_BillPascrell_000 7252 +WDA_BillRichardson_000 3127 +WDA_BobCasey0_000 10627 +WDA_BobCasey1_000 252 +WDA_BobMenendez_000 3002 +WDA_BobbyScott_000 2002 +WDA_BradSchneider_000 6627 +WDA_BrendaLawrence_000 5627 +WDA_BrianSchatz0_000 502 +WDA_BrianSchatz1_000 2627 +WDA_BrianSchatz2_000 1952 +WDA_ByronDorgan1_000 4277 +WDA_CarolynMaloney1_000 8377 +WDA_CatherineCortezMasto_000 2827 +WDA_CedricRichmond_000 7250 +WDA_ChrisCoons1_000 4452 +WDA_ChrisCoons_000 6827 +WDA_ChrisMurphy0_000 6877 +WDA_ChrisMurphy1_000 6377 +WDA_ChrisVanHollen0_000 8202 +WDA_ChrisVanHollen1_000 7327 +WDA_ChuckSchumer0_000 5577 +WDA_ChuckSchumer1_000 4827 +WDA_ColinAllred_000 4751 +WDA_DanKildee1_000 6252 +WDA_DanKildee_000 2502 +WDA_DavidCicilline_000 5875 +WDA_DebHaaland_000 5876 +WDA_DebbieDingell0_000 7752 +WDA_DebbieDingell1_000 7251 +WDA_DebbieStabenow0_000 4577 +WDA_DebbieStabenow1_000 5201 +WDA_DebbieWassermanSchultz_000 7625 +WDA_DianaDeGette0_000 6002 +WDA_DianaDeGette1_000 2202 +WDA_DianneFeinstein_000 7453 +WDA_DickDurbin_000 6326 +WDA_DonaldMcEachin_000 7627 +WDA_DonnaShalala1_000 7500 +WDA_DougJones_000 6077 +WDA_EdMarkey0_000 4752 +WDA_EdMarkey1_000 6501 +WDA_ElijahCummings_000 6377 +WDA_EliotEngel_000 6876 +WDA_EmanuelCleaver_000 7001 +WDA_EricSwalwell_000 6377 +WDA_FrankPallone0_000 5625 +WDA_FrankPallone1_000 6502 +WDA_GerryConnolly_000 6752 +WDA_HakeemJeffries_000 6002 +WDA_HaleyStevens_000 5001 +WDA_HenryWaxman_000 1125 +WDA_HillaryClinton_000 2500 +WDA_JackReed0_000 2377 +WDA_JackReed1_000 4877 +WDA_JackieSpeier_000 4625 +WDA_JackyRosen1_000 10077 +WDA_JackyRosen_000 5502 +WDA_JamesClyburn1_000 6876 +WDA_JamesClyburn_000 6875 +WDA_JanSchakowsky0_000 4128 +WDA_JanSchakowsky1_000 6251 +WDA_JeanneShaheen0_000 6702 +WDA_JeanneShaheen1_000 6577 +WDA_JeffMerkley1_000 6952 +WDA_JerryNadler_000 8377 +WDA_JimHimes_000 7000 +WDA_JimmyGomez_000 4627 +WDA_JoaquinCastro_000 5126 +WDA_JoeCrowley0_000 4877 +WDA_JoeCrowley1_000 1127 +WDA_JoeCrowley1_001 627 +WDA_JoeCrowley1_002 502 +WDA_JoeCrowley1_003 752 +WDA_JoeDonnelly_000 1377 +WDA_JoeKennedy_000 1327 +WDA_JoeManchin_000 3377 +WDA_JoeNeguse_000 1628 +WDA_JoeNeguse_001 1126 +WDA_JoeNeguse_002 1251 +WDA_JohnLewis0_000 6252 +WDA_JohnLewis1_000 7252 +WDA_JohnSarbanes0_000 6377 +WDA_JohnSarbanes1_000 1877 +WDA_JohnYarmuth1_000 5377 +WDA_JonTester0_000 3252 +WDA_JonTester1_000 5327 +WDA_KarenBass_000 7126 +WDA_KatherineClark_000 1552 +WDA_KathyCastor1_000 6001 +WDA_KathyCastor_000 2252 +WDA_KimSchrier_000 6877 +WDA_KirstenGillibrand_000 9627 +WDA_LaurenUnderwood_000 9877 +WDA_LisaBluntRochester_000 4500 +WDA_LloydDoggett0_000 7377 +WDA_LloydDoggett1_000 2252 +WDA_LucilleRoybal-Allard_000 2877 +WDA_LucyMcBath_000 4000 +WDA_MarciaFudge_000 7377 +WDA_MarkWarner1_000 377 +WDA_MarkWarner1_001 1327 +WDA_MarkWarner2_000 377 +WDA_MarkWarner_000 3127 +WDA_MartinHeinrich_000 5202 +WDA_MattCartwright_000 5377 +WDA_MazieHirono0_000 3752 +WDA_MichelleLujanGrisham_000 6875 +WDA_MichelleObama_000 2000 +WDA_MikeDoyle_000 8750 +WDA_MikeThompson0_000 4625 +WDA_MikeThompson1_000 1827 +WDA_NancyPelosi0_000 10251 +WDA_NancyPelosi1_000 1377 +WDA_NancyPelosi3_000 1127 +WDA_NitaLowey_000 5876 +WDA_NydiaVelzquez_000 5500 +WDA_PatrickLeahy0_000 6951 +WDA_PatrickLeahy1_000 9953 +WDA_PattyMurray0_000 3627 +WDA_PattyMurray1_000 5702 +WDA_PeterDeFazio_000 6628 +WDA_RaulRuiz_000 5752 +WDA_RichardBlumenthal_000 7827 +WDA_RichardNeal0_000 6252 +WDA_RichardNeal1_000 6877 +WDA_RobinKelly_000 4377 +WDA_RonWyden0_000 5152 +WDA_RonWyden1_000 8327 +WDA_ScottPeters0_000 7002 +WDA_ScottPeters1_000 3952 +WDA_SeanCasten_000 6126 +WDA_SeanPatrickMaloney_000 6502 +WDA_SheldonWhitehouse0_000 6327 +WDA_SheldonWhitehouse1_000 5702 +WDA_SherrodBrown0_000 7452 +WDA_SherrodBrown1_000 7327 +WDA_StenyHoyer_000 3502 +WDA_StephanieMurphy_000 4000 +WDA_SuzanDelBene_000 6875 +WDA_TammyBaldwin0_000 5327 +WDA_TammyBaldwin1_000 2503 +WDA_TammyDuckworth_000 5702 +WDA_TedLieu_000 6877 +WDA_TerriSewell0_000 2127 +WDA_TerriSewell1_000 10952 +WDA_TerriSewell_000 6752 +WDA_TimWalz_000 6628 +WDA_TinaSmith_000 5576 +WDA_TomCarper_000 4701 +WDA_TomPerez_000 5500 +WDA_TomUdall_000 4077 +WDA_VeronicaEscobar0_000 4377 +WDA_VeronicaEscobar1_000 2453 +WDA_WhipJimClyburn_000 6876 +WDA_XavierBecerra_000 877 +WDA_XavierBecerra_001 627 +WDA_XavierBecerra_002 1377 +WDA_ZoeLofgren_000 4625 +WRA_AdamKinzinger0_000 4251 +WRA_AdamKinzinger1_000 2751 +WRA_AdamKinzinger2_000 1002 +WRA_AdamKinzinger2_001 1001 +WRA_AdamKinzinger2_002 502 +WRA_AllenWest_000 6876 +WRA_AnnMarieBuerkle_000 452 +WRA_AnnWagner_000 3127 +WRA_AustinScott0_000 1177 +WRA_BillCassidy0_000 427 +WRA_BobCorker_000 1127 +WRA_BobGoodlatte0_000 1001 +WRA_BobGoodlatte0_001 752 +WRA_BobGoodlatte0_002 876 +WRA_BobGoodlatte0_003 1002 +WRA_BobbySchilling_001 1000 +WRA_BobbySchilling_002 1001 +WRA_CandiceMiller0_000 3001 +WRA_CarlyFiorina0_000 876 +WRA_CarlyFiorina_000 1902 +WRA_CathyMcMorrisRodgers0_000 6077 +WRA_CathyMcMorrisRodgers1_000 7375 +WRA_CathyMcMorrisRodgers1_001 7500 +WRA_CathyMcMorrisRodgers1_002 7500 +WRA_CathyMcMorrisRodgers2_000 2751 +WRA_ChuckGrassley_000 4502 +WRA_CoryGardner0_000 8577 +WRA_CoryGardner1_000 2002 +WRA_CoryGardner_000 1252 +WRA_DanNewhouse_000 3627 +WRA_DanSullivan_000 6877 +WRA_DaveCamp_000 752 +WRA_DaveCamp_001 876 +WRA_DaveCamp_002 627 +WRA_DavidVitter_000 2877 +WRA_DeanHeller_000 377 +WRA_DebFischer0_000 7577 +WRA_DebFischer1_000 1577 +WRA_DebFischer2_000 5202 +WRA_DianeBlack0_000 451 +WRA_DianeBlack0_001 301 +WRA_DianeBlack1_000 250 +WRA_DuncanHunter_000 727 +WRA_EricCantor_000 2952 +WRA_ErikPaulsen_000 876 +WRA_ErikPaulsen_001 377 +WRA_ErikPaulsen_002 627 +WRA_ErikPaulsen_003 752 +WRA_FredUpton_000 1701 +WRA_GeoffDavis_000 250 +WRA_GeorgeLeMieux_000 752 +WRA_GeorgeLeMieux_001 1001 +WRA_GregWalden1_000 752 +WRA_GregWalden1_001 1127 +WRA_GregWalden1_002 1327 +WRA_GregWalden_000 377 +WRA_JaimeHerreraBeutler0_000 952 +WRA_JebHensarling0_001 1176 +WRA_JebHensarling1_000 2327 +WRA_JebHensarling1_001 727 +WRA_JebHensarling2_000 2302 +WRA_JebHensarling2_001 877 +WRA_JebHensarling2_002 2127 +WRA_JebHensarling2_003 1877 +WRA_JeffFlake_000 7202 +WRA_JeffFlake_001 7502 +WRA_JeffFlake_002 7377 +WRA_JimInhofe_000 3127 +WRA_JimRisch_000 1377 +WRA_JoeHeck1_000 250 +WRA_JoePitts_000 1077 +WRA_JohnBarrasso0_000 5077 +WRA_JohnBarrasso1_000 3452 +WRA_JohnBoehner0_000 1626 +WRA_JohnBoehner1_000 2951 +WRA_JohnBoozman_000 1877 +WRA_JohnHoeven_000 2577 \ No newline at end of file diff --git a/train_codes/filelists/val.txt b/train_codes/filelists/val.txt new file mode 100644 index 0000000..a7f5e0b --- /dev/null +++ b/train_codes/filelists/val.txt @@ -0,0 +1,30 @@ +WRA_JohnKasich0_000 1302 +WRA_JohnKasich1_000 1301 +WRA_JohnKasich1_001 1127 +WRA_JohnKasich3_000 1052 +WRA_JohnThune_000 1452 +WRA_JohnnyIsakson_000 5951 +WRA_JohnnyIsakson_001 5000 +WRA_JonKyl_000 626 +WRA_JoniErnst0_000 2077 +WRA_JoniErnst1_000 1452 +WRA_JuddGregg_000 1252 +WRA_JuddGregg_001 952 +WRA_JuddGregg_002 753 +WRA_KayBaileyHutchison_000 6325 +WRA_KellyAyotte_000 8077 +WRA_KevinBrady2_000 1276 +WRA_KevinBrady3_000 752 +WRA_KevinBrady_000 2503 +WRA_KevinMcCarthy0_000 1002 +WRA_KevinMcCarthy0_001 1127 +WRA_KristiNoem0_000 727 +WRA_KristiNoem1_000 452 +WRA_KristiNoem2_000 5952 +WRA_KristiNoem2_001 7277 +WRA_LamarAlexander0_000 1527 +WRA_LamarAlexander_000 1552 +WRA_LisaMurkowski0_000 1752 +WRA_LynnJenkins_000 877 +WRA_LynnJenkins_001 1002 +WRA_MarcoRubio_000 526 \ No newline at end of file diff --git a/train_codes/musetalk.json b/train_codes/musetalk.json new file mode 100644 index 0000000..b822db8 --- /dev/null +++ b/train_codes/musetalk.json @@ -0,0 +1,36 @@ +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.6.0.dev0", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 384, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 8, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 64, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ] +} diff --git a/train_codes/train.py b/train_codes/train.py new file mode 100755 index 0000000..cb50c1c --- /dev/null +++ b/train_codes/train.py @@ -0,0 +1,642 @@ +import argparse +import itertools +import math +import os +import random +from pathlib import Path +import json +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from PIL import Image, ImageDraw +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm + +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version + +from DataLoader import Dataset +from utils.utils import preprocess_img_tensor +from torch.utils import data as data_utils +from model_utils import validation,PositionalEncoding +import time +import pandas as pd +from PIL import Image + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.13.0.dev0") + +logger = get_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--unet_config_file", + type=str, + default=None, + required=True, + help="the configuration of unet file.", + ) + parser.add_argument( + "--reconstruction", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + + parser.add_argument( + "--data_root", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + + + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--testing_speed", action="store_true", help="Whether to caculate the running time") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=1000, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint and are suitable for resuming training" + " using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--validation_steps", + type=int, + default=1000, + help=( + "Conduct validation every X updates." + ), + ) + parser.add_argument( + "--val_out_dir", + type=str, + default = '', + help=( + "Conduct validation every X updates." + ), + ) + + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--use_audio_length_left", + type=int, + default=1, + help="number of audio length (left).", + ) + parser.add_argument( + "--use_audio_length_right", + type=int, + default=1, + help="number of audio length (right)", + ) + parser.add_argument( + "--whisper_model_type", + type=str, + default="landmark_nearest", + choices=["tiny","largeV2"], + help="Determine whisper feature type", + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + + + +def main(): + args = parse_args() + print(args) + args.output_dir = f"output/{args.output_dir}" + args.val_out_dir = f"val/{args.val_out_dir}" + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.val_out_dir, exist_ok=True) + + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') + + + logging_dir = Path(args.output_dir, args.logging_dir) + + project_config = ProjectConfiguration( + total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir + ) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + project_config=project_config, + ) + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + if args.seed is not None: +# set_seed(args.seed) + set_seed(seed + accelerator.process_index) + + + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + + # Load models and create wrapper for stable diffusion + with open(args.unet_config_file, 'r') as f: + unet_config = json.load(f) + + #text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + + # Todo: + print("Loading AutoencoderKL") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') + vae_fp32 = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + print("Loading UNet2DConditionModel") + unet = UNet2DConditionModel(**unet_config) + + if args.whisper_model_type == "tiny": + pe = PositionalEncoding(d_model=384) + elif args.whisper_model_type == "largeV2": + pe = PositionalEncoding(d_model=1280) + else: + print(f"not support whisper_model_type {args.whisper_model_type}") + + print("Loading models done...") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = ( + itertools.chain(unet.parameters()) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + print("loading train_dataset ...") + train_dataset = Dataset(args.data_root, + 'train', + use_audio_length_left=args.use_audio_length_left, + use_audio_length_right=args.use_audio_length_right, + whisper_model_type=args.whisper_model_type + ) + print("train_dataset:",train_dataset.__len__()) + train_data_loader = data_utils.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, + num_workers=8) + print("loading val_dataset ...") + val_dataset = Dataset(args.data_root, + 'val', + use_audio_length_left=args.use_audio_length_left, + use_audio_length_right=args.use_audio_length_right, + whisper_model_type=args.whisper_model_type + ) + print("val_dataset:",val_dataset.__len__()) + val_data_loader = data_utils.DataLoader( + val_dataset, batch_size=1, shuffle=False, + num_workers=8) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + + + unet, optimizer, train_data_loader, val_data_loader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_data_loader, val_data_loader,lr_scheduler + ) + + vae.requires_grad_(False) + vae_fp32.requires_grad_(False) + + weight_dtype = torch.float32 + vae_fp32.to(accelerator.device, dtype=weight_dtype) + vae_fp32.encoder = None + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + vae.to(accelerator.device, dtype=weight_dtype) + vae.decoder = None + pe.to(accelerator.device, dtype=weight_dtype) + + num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_data_loader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + # caluate the elapsed time + elapsed_time = [] + start = time.time() + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() +# for step, batch in enumerate(train_dataloader): + for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + dataloader_time = time.time() - start + start = time.time() + + masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device) + """ + print("=============epoch:{0}=step:{1}=====".format(epoch,step)) + print("ref_image: ",ref_image.shape) + print("masks: ", masks.shape) + print("masked_image: ", masked_image.shape) + print("audio feature: ", audio_feature.shape) + print("image: ", image.shape) + """ + ref_image = preprocess_img_tensor(ref_image).to(vae.device) + image = preprocess_img_tensor(image).to(vae.device) + masked_image = preprocess_img_tensor(masked_image).to(vae.device) + + img_process_time = time.time() - start + start = time.time() + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image + latents = latents * vae.config.scaling_factor + + # Convert masked images to latent space + masked_latents = vae.encode( + masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image + ).latent_dist.sample() + masked_latents = masked_latents * vae.config.scaling_factor + + # Convert ref images to latent space + ref_latents = vae.encode( + ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image + ).latent_dist.sample() + ref_latents = ref_latents * vae.config.scaling_factor + + vae_time = time.time() - start + start = time.time() + + mask = torch.stack( + [ + torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8)) + for mask in masks + ] + ) + mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1]) + + + bsz = latents.shape[0] + # fix timestep for each image + timesteps = torch.tensor([0], device=latents.device) + # concatenate the latents with the mask and the masked latents + """ + print("=============vae latents=====".format(epoch,step)) + print("ref_latents: ",ref_latents.shape) + print("mask: ", mask.shape) + print("masked_latents: ", masked_latents.shape) + """ + + if unet_config['in_channels'] == 9: + latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1) + else: + latent_model_input = torch.cat([masked_latents, ref_latents], dim=1) + + audio_feature = audio_feature.to(dtype=weight_dtype) + # Predict the noise residual + image_pred = unet(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample + + if args.reconstruction: # decode the image from the predicted latents + image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred + image_pred_img = vae_fp32.decode(image_pred_img).sample + + # Mask the top half of the image and calculate the loss only for the lower half of the image. + image_pred_img = image_pred_img[:, :, image_pred_img.shape[2]//2:, :] + image = image[:, :, image.shape[2]//2:, :] + loss_lip = F.l1_loss(image_pred_img.float(), image.float(), reduction="mean") # the loss of the decoded images + loss_latents = F.l1_loss(image_pred.float(), latents.float(), reduction="mean") # the loss of the latents + + loss = 2.0*loss_lip + loss_latents # add some weight to balance the loss + + else: + loss = F.mse_loss(image_pred.float(), latents.float(), reduction="mean") +# + + unet_elapsed_time = time.time() - start + start = time.time() + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters()) + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + backward_elapsed_time = time.time() - start + start = time.time() + + if args.testing_speed is True and accelerator.is_main_process: + elapsed_time.append( + [dataloader_time, unet_elapsed_time, vae_time, backward_elapsed_time,img_process_time] + ) + + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + + if global_step % args.validation_steps == 0: + if accelerator.is_main_process: + logger.info( + f"Running validation... epoch={epoch}, global_step={global_step}" + ) + print("===========start validation==========") + validation( + vae=accelerator.unwrap_model(vae), + vae_fp32=accelerator.unwrap_model(vae_fp32), + unet=accelerator.unwrap_model(unet), + unet_config=unet_config, + weight_dtype=weight_dtype, + epoch=epoch, + global_step=global_step, + val_data_loader=val_data_loader, + output_dir = args.val_out_dir, + whisper_model_type = args.whisper_model_type + ) + logger.info(f"Saved samples to images/val") + start = time.time() + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], + "unet": unet_elapsed_time, + "backward": backward_elapsed_time, + "data": dataloader_time, + "img_process":img_process_time, + "vae":vae_time + } + progress_bar.set_postfix(**logs) +# accelerator.log(logs, step=global_step) + + accelerator.log( + { + "loss/step_loss": logs["loss"], + "parameter/lr": logs["lr"], + "time/unet_forward_time": unet_elapsed_time, + "time/unet_backward_time": backward_elapsed_time, + "time/data_time": dataloader_time, + "time/img_process_time":img_process_time, + "time/vae_time": vae_time + }, + step=global_step, + ) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + accelerator.end_training() + + + +if __name__ == "__main__": + main() diff --git a/train_codes/train.sh b/train_codes/train.sh new file mode 100644 index 0000000..908a676 --- /dev/null +++ b/train_codes/train.sh @@ -0,0 +1,27 @@ +export VAE_MODEL="./sd-vae-ft-mse/" +export DATASET="..." +export UNET_CONFIG="./musetalk.json" + +accelerate launch --multi_gpu train.py \ +--mixed_precision="fp16" \ +--unet_config_file=$UNET_CONFIG \ +--pretrained_model_name_or_path=$VAE_MODEL \ +--data_root=$DATASET \ +--train_batch_size=8 \ +--gradient_accumulation_steps=4 \ +--gradient_checkpointing \ +--max_train_steps=200000 \ +--learning_rate=5e-05 \ +--max_grad_norm=1 \ +--lr_scheduler="cosine" \ +--lr_warmup_steps=0 \ +--output_dir="..." \ +--val_out_dir='...' \ +--testing_speed \ +--checkpointing_steps=1000 \ +--validation_steps=1000 \ +--reconstruction \ +--resume_from_checkpoint="latest" \ +--use_audio_length_left=2 \ +--use_audio_length_right=2 \ +--whisper_model_type="tiny" \ diff --git a/train_codes/utils/model_utils.py b/train_codes/utils/model_utils.py new file mode 100644 index 0000000..e0fd2e2 --- /dev/null +++ b/train_codes/utils/model_utils.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn + +import torch +import torch.nn as nn +import time +import math +from utils import decode_latents, preprocess_img_tensor +import os +from PIL import Image +from typing import Any, Dict, List, Optional, Tuple, Union +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, +) +from torch import Tensor, nn +import logging +import json + +RESIZED_IMG = 256 + +class PositionalEncoding(nn.Module): + """ + Transformer 中的位置编码(positional encoding) + """ + def __init__(self, d_model=384, max_len=5000): + super(PositionalEncoding, self).__init__() + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + b, seq_len, d_model = x.size() + pe = self.pe[:, :seq_len, :] + #print(b, seq_len, d_model) + x = x + pe.to(x.device) + return x + +def validation(vae: torch.nn.Module, + vae_fp32: torch.nn.Module, + unet:torch.nn.Module, + unet_config, + weight_dtype: torch.dtype, + epoch: int, + global_step: int, + val_data_loader, + output_dir, + whisper_model_type, + UNet2DConditionModel=UNet2DConditionModel + ): + + # Get the validation pipeline + unet_copy = UNet2DConditionModel(**unet_config) + + unet_copy.load_state_dict(unet.state_dict()) + unet_copy.to(vae.device).to(dtype=weight_dtype) + unet_copy.eval() + + if whisper_model_type == "tiny": + pe = PositionalEncoding(d_model=384) + elif whisper_model_type == "largeV2": + pe = PositionalEncoding(d_model=1280) + elif whisper_model_type == "tiny-conv": + pe = PositionalEncoding(d_model=384) + print(f" whisper_model_type: {whisper_model_type} Validation does not need PE") + else: + print(f"not support whisper_model_type {whisper_model_type}") + pe.to(vae.device, dtype=weight_dtype) + + start = time.time() + with torch.no_grad(): + for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader): + + + masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device) + ref_image = preprocess_img_tensor(ref_image).to(vae.device) + image = preprocess_img_tensor(image).to(vae.device) + masked_image = preprocess_img_tensor(masked_image).to(vae.device) + + # Convert images to latent space + latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image + latents = latents * vae.config.scaling_factor + # Convert masked images to latent space + masked_latents = vae.encode( + masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image + ).latent_dist.sample() + masked_latents = masked_latents * vae.config.scaling_factor + # Convert ref images to latent space + ref_latents = vae.encode( + ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image + ).latent_dist.sample() + ref_latents = ref_latents * vae.config.scaling_factor + + mask = torch.stack( + [ + torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8)) + for mask in masks + ] + ) + mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1]) + bsz = latents.shape[0] + timesteps = torch.tensor([0], device=latents.device) + + if unet_config['in_channels'] == 9: + latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1) + else: + latent_model_input = torch.cat([masked_latents, ref_latents], dim=1) + + image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample + + image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG)) + image.paste(decode_latents(vae_fp32,masked_latents), (0, 0)) + image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0)) + image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0)) + image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0)) + + val_img_dir = f"images/{output_dir}/{global_step}" + if not os.path.exists(val_img_dir): + os.makedirs(val_img_dir) + image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step)) + + print("valtion in step:{0}, time:{1}".format(step,time.time()-start)) + + print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start)) + diff --git a/train_codes/utils/utils.py b/train_codes/utils/utils.py new file mode 100644 index 0000000..06d192c --- /dev/null +++ b/train_codes/utils/utils.py @@ -0,0 +1,74 @@ +import matplotlib.pyplot as plt +import PIL +from PIL import Image +import numpy as np + +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from einops import rearrange + +import torch +import torchvision.transforms as transforms + +from diffusers import AutoencoderKL +import matplotlib.pyplot as plt +import PIL +import os +import cv2 +from glob import glob + + +def preprocess_img_tensor(image_tensor): + # 假设输入是一个形状为 (N, C, H, W) 的 PyTorch 张量 + N, C, H, W = image_tensor.shape + # 计算新的宽度和高度,使其为 32 的整数倍 + new_w = W - W % 32 + new_h = H - H % 32 + # 使用 torchvision.transforms 库中的方法进行缩放和重采样 + transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + # 对每个图像应用变换,并将结果存储在一个新的张量中 + preprocessed_images = torch.empty((N, C, new_h, new_w), dtype=torch.float32) + for i in range(N): + # 使用 F.interpolate 替换 transforms.Resize + resized_image = F.interpolate(image_tensor[i].unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False) + preprocessed_images[i] = transform(resized_image.squeeze(0)) + + return preprocessed_images + + +def prepare_mask_and_masked_image(image_tensor, mask_tensor): + # 假设输入 image_tensor 的形状为 [C, H, W],输入 mask_tensor 的形状为 [H, W] +# # 对图像张量进行归一化 + image_tensor_ori = (image_tensor.to(dtype=torch.float32) / 127.5) - 1.0 +# # 对遮罩张量进行归一化和二值化 +# mask_tensor = (mask_tensor.to(dtype=torch.float32) / 255.0).unsqueeze(0) + mask_tensor[mask_tensor < 0.5] = 0 + mask_tensor[mask_tensor >= 0.5] = 1 + # 创建遮罩后的图像 + masked_image_tensor = image_tensor * (mask_tensor > 0.5) + + return mask_tensor, masked_image_tensor + + +def encode_latents(vae, image): +# init_image = preprocess_image(image) + init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist + init_latents = 0.18215 * init_latent_dist.sample() + return init_latents + +def decode_latents(vae, latents, ref_images=None): + latents = (1/ 0.18215) * latents + image = vae.decode(latents.to(vae.dtype)).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() + image = (image * 255).round().astype("uint8") + if ref_images is not None: + ref_images = ref_images.detach().cpu().permute(0, 2, 3, 1).float().numpy() + ref_images = (ref_images * 255).round().astype("uint8") + h = image.shape[1] + image[:, :h//2] = ref_images[:, :h//2] + image = [Image.fromarray(im) for im in image] + + return image[0] +