Update draft training codes

This commit is contained in:
czk32611
2024-04-28 11:34:49 +08:00
parent 6e32247cb1
commit d73daf1808
9 changed files with 1547 additions and 0 deletions

252
train_codes/DataLoader.py Normal file
View File

@@ -0,0 +1,252 @@
import os, random, cv2, argparse
import torch
from torch.utils import data as data_utils
from os.path import dirname, join, basename, isfile
import numpy as np
from glob import glob
from utils.utils import prepare_mask_and_masked_image
import torchvision.utils as vutils
import torchvision.transforms as transforms
import shutil
from tqdm import tqdm
import ast
import json
import re
import heapq
syncnet_T = 1
RESIZED_IMG = 256
connections = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7),(7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13),(13,14),(14,15),(15,16), # 下颌线
(17, 18), (18, 19), (19, 20), (20, 21), #左眉毛
(22, 23), (23, 24), (24, 25), (25, 26), #右眉毛
(27, 28),(28,29),(29,30),# 鼻梁
(31,32),(32,33),(33,34),(34,35), #鼻子
(36,37),(37,38),(38, 39), (39, 40), (40, 41), (41, 36), # 左眼
(42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42), # 右眼
(48, 49),(49, 50), (50, 51),(51, 52),(52, 53), (53, 54), # 上嘴唇 外延
(54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 48), # 下嘴唇 外延
(60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60) #嘴唇内圈
]
def get_image_list(data_root, split):
filelist = []
imgNumList = []
with open('filelists/{}.txt'.format(split)) as f:
for line in f:
line = line.strip()
if ' ' in line:
filename = line.split()[0]
imgNum = int(line.split()[1])
filelist.append(os.path.join(data_root, filename))
imgNumList.append(imgNum)
return filelist, imgNumList
class Dataset(object):
def __init__(self,
data_root,
split,
use_audio_length_left=1,
use_audio_length_right=1,
whisper_model_type = "tiny"
):
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
self.audio_feature = [use_audio_length_left,use_audio_length_right]
self.all_img_names = []
self.split = split
self.img_names_path = '...'
self.whisper_model_type = whisper_model_type
self.use_audio_length_left = use_audio_length_left
self.use_audio_length_right = use_audio_length_right
if self.whisper_model_type =="tiny":
self.whisper_path = '...'
self.whisper_feature_W = 5
self.whisper_feature_H = 384
elif self.whisper_model_type =="largeV2":
self.whisper_path = '...'
self.whisper_feature_W = 33
self.whisper_feature_H = 1280
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
if not os.path.exists(json_path_names):
img_names = glob(join(vidname, '*.png'))
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
with open(json_path_names, "w") as f:
json.dump(img_names,f)
print(f"save to {json_path_names}")
else:
with open(json_path_names, "r") as f:
img_names = json.load(f)
self.all_img_names.append(img_names)
def get_frame_id(self, frame):
return int(basename(frame).split('.')[0])
def get_window(self, start_frame):
start_id = self.get_frame_id(start_frame)
vidname = dirname(start_frame)
window_fnames = []
for frame_id in range(start_id, start_id + syncnet_T):
frame = join(vidname, '{}.png'.format(frame_id))
if not isfile(frame):
return None
window_fnames.append(frame)
return window_fnames
def read_window(self, window_fnames):
if window_fnames is None: return None
window = []
for fname in window_fnames:
img = cv2.imread(fname)
if img is None:
return None
try:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (RESIZED_IMG, RESIZED_IMG))
except Exception as e:
print("read_window has error fname not exist:",fname)
return None
window.append(img)
return window
def crop_audio_window(self, spec, start_frame):
if type(start_frame) == int:
start_frame_num = start_frame
else:
start_frame_num = self.get_frame_id(start_frame)
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
end_idx = start_idx + syncnet_mel_step_size
return spec[start_idx : end_idx, :]
def prepare_window(self, window):
# 1 x H x W x 3
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
return x
def __len__(self):
return len(self.all_videos)
def __getitem__(self, idx):
while 1:
idx = random.randint(0, len(self.all_videos) - 1)
#随机选择某个video里
vidname = self.all_videos[idx].split('/')[-1]
video_imgs = self.all_img_names[idx]
if len(video_imgs) == 0:
# print("video_imgs = 0:",vidname)
continue
img_name = random.choice(video_imgs)
img_idx = int(basename(img_name).split(".")[0])
random_element = random.randint(0,len(video_imgs)-1)
while abs(random_element - img_idx) <= 5:
random_element = random.randint(0,len(video_imgs)-1)
img_dir = os.path.dirname(img_name)
ref_image = os.path.join(img_dir, f"{str(random_element)}.png")
target_window_fnames = self.get_window(img_name)
ref_window_fnames = self.get_window(ref_image)
if target_window_fnames is None or ref_window_fnames is None:
print("No such img",img_name, ref_image)
continue
try:
#构建目标img数据
target_window = self.read_window(target_window_fnames)
if target_window is None :
print("No such target window,",target_window_fnames)
continue
#构建参考img数据
ref_window = self.read_window(ref_window_fnames)
if ref_window is None:
print("No such target ref window,",ref_window)
continue
except Exception as e:
print(f"发生未知错误:{e}")
continue
#构建target输入
target_window = self.prepare_window(target_window)
image = gt = target_window.copy().squeeze()
target_window[:, :, target_window.shape[2]//2:] = 0. # upper half face, mask掉下半部分 V1输入
ref_image = self.prepare_window(ref_window).squeeze()
mask = torch.zeros((ref_image.shape[1], ref_image.shape[2]))
mask[:ref_image.shape[2]//2,:] = 1
image = torch.FloatTensor(image)
mask, masked_image = prepare_mask_and_masked_image(image,mask)
#音频特征
window_index = self.get_frame_id(img_name)
sub_folder_name = vidname.split('/')[-1]
## 根据window_index加载相邻的音频
audio_feature_all = []
is_index_out_of_range = False
if os.path.isdir(os.path.join(self.whisper_path, sub_folder_name)):
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
# 判定是否越界
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
if not os.path.exists(audio_feat_path):
is_index_out_of_range = True
break
try:
audio_feature_all.append(np.load(audio_feat_path))
except Exception as e:
print(f"发生未知错误:{e}")
print(f"npy load error {audio_feat_path}")
if is_index_out_of_range:
continue
audio_feature = np.concatenate(audio_feature_all, axis=0)
else:
continue
audio_feature = audio_feature.reshape(1, -1, self.whisper_feature_H) #1 -1 384
if audio_feature.shape != (1,self.whisper_feature_concateW, self.whisper_feature_H): #1 50 384
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
continue
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
return ref_image, image, masked_image, mask, audio_feature
if __name__ == "__main__":
data_root = '...'
val_data = Dataset(data_root,
'val',
use_audio_length_left = 2,
use_audio_length_right = 2,
whisper_model_type = "tiny"
)
val_data_loader = data_utils.DataLoader(
val_data, batch_size=4, shuffle=True,
num_workers=1)
print("val_dataset:",val_data_loader.__len__())
for i, data in enumerate(val_data_loader):
ref_image, image, masked_image, mask, audio_feature = data
print("ref_image: ", ref_image.shape)

35
train_codes/README.md Normal file
View File

@@ -0,0 +1,35 @@
# Draft training codes
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
## Data preprocessing
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
Finally, the data should be organized as follows:
```
./data/
├── images
│ └──RD_Radio10_000
│ └── 0.png
│ └── 1.png
│ └── xxx.png
│ └──RD_Radio11_000
│ └── 0.png
│ └── 1.png
│ └── xxx.png
├── audios
│ └──RD_Radio10_000
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
│ └──RD_Radio11_000
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
```
## Training
Simply run after preparing the preprocessed data
```
sh train.sh
```

View File

@@ -0,0 +1,322 @@
RD_Radio10_000 3501
RD_Radio11_000 752
RD_Radio11_001 1502
RD_Radio12_000 876
RD_Radio13_000 877
RD_Radio14_000 1251
RD_Radio16_000 2377
RD_Radio17_000 2176
RD_Radio18_000 1827
RD_Radio19_000 1876
RD_Radio1_000 1876
RD_Radio20_000 276
RD_Radio21_000 1201
RD_Radio22_000 752
RD_Radio23_000 1602
RD_Radio25_000 2126
RD_Radio26_000 1052
RD_Radio27_000 1376
RD_Radio28_000 2252
RD_Radio29_000 2252
RD_Radio2_000 2076
RD_Radio30_000 2877
RD_Radio31_000 1377
RD_Radio32_000 1502
RD_Radio33_000 2252
RD_Radio34_000 702
RD_Radio34_001 252
RD_Radio34_002 501
RD_Radio34_003 326
RD_Radio34_004 502
RD_Radio34_005 502
RD_Radio34_006 502
RD_Radio34_007 252
RD_Radio34_008 876
RD_Radio34_009 752
RD_Radio35_000 2127
RD_Radio36_000 2377
RD_Radio37_000 3127
RD_Radio38_000 3752
RD_Radio39_000 1502
RD_Radio3_000 2377
RD_Radio40_000 1252
RD_Radio41_000 2127
RD_Radio42_000 1752
RD_Radio43_000 1502
RD_Radio44_000 1127
RD_Radio45_000 1628
RD_Radio46_000 1877
RD_Radio47_000 1502
RD_Radio48_000 2127
RD_Radio49_000 1502
RD_Radio4_000 1002
RD_Radio50_000 2252
RD_Radio51_000 3253
RD_Radio52_000 2752
RD_Radio53_000 2627
RD_Radio54_000 577
RD_Radio56_000 1952
RD_Radio57_000 3077
RD_Radio59_000 1952
RD_Radio5_000 1377
RD_Radio7_000 2252
RD_Radio8_000 2126
RD_Radio9_000 1502
WDA_AdamSchiff_000 6877
WDA_AdamSmith_000 7327
WDA_AlexandriaOcasioCortez_000 2252
WDA_AmyKlobuchar0_000 6501
WDA_AmyKlobuchar1_000 3253
WDA_AmyKlobuchar1_001 877
WDA_AmyKlobuchar1_002 1502
WDA_AmyKlobuchar1_003 1377
WDA_AndyKim_000 6952
WDA_AndyLevin_000 4876
WDA_AnnieKuster_000 4876
WDA_BarackObama_000 3575
WDA_BarackObama_001 5625
WDA_BarbaraLee0_000 6502
WDA_BarbaraLee1_000 4702
WDA_BenCardin0_000 8127
WDA_BenCardin1_000 7102
WDA_BenRayLujn_000 7127
WDA_BennieThompson1_000 6375
WDA_BennieThompson_000 5375
WDA_BernieSanders_000 8451
WDA_BettyMcCollum_000 7002
WDA_BillPascrell_000 7252
WDA_BillRichardson_000 3127
WDA_BobCasey0_000 10627
WDA_BobCasey1_000 252
WDA_BobMenendez_000 3002
WDA_BobbyScott_000 2002
WDA_BradSchneider_000 6627
WDA_BrendaLawrence_000 5627
WDA_BrianSchatz0_000 502
WDA_BrianSchatz1_000 2627
WDA_BrianSchatz2_000 1952
WDA_ByronDorgan1_000 4277
WDA_CarolynMaloney1_000 8377
WDA_CatherineCortezMasto_000 2827
WDA_CedricRichmond_000 7250
WDA_ChrisCoons1_000 4452
WDA_ChrisCoons_000 6827
WDA_ChrisMurphy0_000 6877
WDA_ChrisMurphy1_000 6377
WDA_ChrisVanHollen0_000 8202
WDA_ChrisVanHollen1_000 7327
WDA_ChuckSchumer0_000 5577
WDA_ChuckSchumer1_000 4827
WDA_ColinAllred_000 4751
WDA_DanKildee1_000 6252
WDA_DanKildee_000 2502
WDA_DavidCicilline_000 5875
WDA_DebHaaland_000 5876
WDA_DebbieDingell0_000 7752
WDA_DebbieDingell1_000 7251
WDA_DebbieStabenow0_000 4577
WDA_DebbieStabenow1_000 5201
WDA_DebbieWassermanSchultz_000 7625
WDA_DianaDeGette0_000 6002
WDA_DianaDeGette1_000 2202
WDA_DianneFeinstein_000 7453
WDA_DickDurbin_000 6326
WDA_DonaldMcEachin_000 7627
WDA_DonnaShalala1_000 7500
WDA_DougJones_000 6077
WDA_EdMarkey0_000 4752
WDA_EdMarkey1_000 6501
WDA_ElijahCummings_000 6377
WDA_EliotEngel_000 6876
WDA_EmanuelCleaver_000 7001
WDA_EricSwalwell_000 6377
WDA_FrankPallone0_000 5625
WDA_FrankPallone1_000 6502
WDA_GerryConnolly_000 6752
WDA_HakeemJeffries_000 6002
WDA_HaleyStevens_000 5001
WDA_HenryWaxman_000 1125
WDA_HillaryClinton_000 2500
WDA_JackReed0_000 2377
WDA_JackReed1_000 4877
WDA_JackieSpeier_000 4625
WDA_JackyRosen1_000 10077
WDA_JackyRosen_000 5502
WDA_JamesClyburn1_000 6876
WDA_JamesClyburn_000 6875
WDA_JanSchakowsky0_000 4128
WDA_JanSchakowsky1_000 6251
WDA_JeanneShaheen0_000 6702
WDA_JeanneShaheen1_000 6577
WDA_JeffMerkley1_000 6952
WDA_JerryNadler_000 8377
WDA_JimHimes_000 7000
WDA_JimmyGomez_000 4627
WDA_JoaquinCastro_000 5126
WDA_JoeCrowley0_000 4877
WDA_JoeCrowley1_000 1127
WDA_JoeCrowley1_001 627
WDA_JoeCrowley1_002 502
WDA_JoeCrowley1_003 752
WDA_JoeDonnelly_000 1377
WDA_JoeKennedy_000 1327
WDA_JoeManchin_000 3377
WDA_JoeNeguse_000 1628
WDA_JoeNeguse_001 1126
WDA_JoeNeguse_002 1251
WDA_JohnLewis0_000 6252
WDA_JohnLewis1_000 7252
WDA_JohnSarbanes0_000 6377
WDA_JohnSarbanes1_000 1877
WDA_JohnYarmuth1_000 5377
WDA_JonTester0_000 3252
WDA_JonTester1_000 5327
WDA_KarenBass_000 7126
WDA_KatherineClark_000 1552
WDA_KathyCastor1_000 6001
WDA_KathyCastor_000 2252
WDA_KimSchrier_000 6877
WDA_KirstenGillibrand_000 9627
WDA_LaurenUnderwood_000 9877
WDA_LisaBluntRochester_000 4500
WDA_LloydDoggett0_000 7377
WDA_LloydDoggett1_000 2252
WDA_LucilleRoybal-Allard_000 2877
WDA_LucyMcBath_000 4000
WDA_MarciaFudge_000 7377
WDA_MarkWarner1_000 377
WDA_MarkWarner1_001 1327
WDA_MarkWarner2_000 377
WDA_MarkWarner_000 3127
WDA_MartinHeinrich_000 5202
WDA_MattCartwright_000 5377
WDA_MazieHirono0_000 3752
WDA_MichelleLujanGrisham_000 6875
WDA_MichelleObama_000 2000
WDA_MikeDoyle_000 8750
WDA_MikeThompson0_000 4625
WDA_MikeThompson1_000 1827
WDA_NancyPelosi0_000 10251
WDA_NancyPelosi1_000 1377
WDA_NancyPelosi3_000 1127
WDA_NitaLowey_000 5876
WDA_NydiaVelzquez_000 5500
WDA_PatrickLeahy0_000 6951
WDA_PatrickLeahy1_000 9953
WDA_PattyMurray0_000 3627
WDA_PattyMurray1_000 5702
WDA_PeterDeFazio_000 6628
WDA_RaulRuiz_000 5752
WDA_RichardBlumenthal_000 7827
WDA_RichardNeal0_000 6252
WDA_RichardNeal1_000 6877
WDA_RobinKelly_000 4377
WDA_RonWyden0_000 5152
WDA_RonWyden1_000 8327
WDA_ScottPeters0_000 7002
WDA_ScottPeters1_000 3952
WDA_SeanCasten_000 6126
WDA_SeanPatrickMaloney_000 6502
WDA_SheldonWhitehouse0_000 6327
WDA_SheldonWhitehouse1_000 5702
WDA_SherrodBrown0_000 7452
WDA_SherrodBrown1_000 7327
WDA_StenyHoyer_000 3502
WDA_StephanieMurphy_000 4000
WDA_SuzanDelBene_000 6875
WDA_TammyBaldwin0_000 5327
WDA_TammyBaldwin1_000 2503
WDA_TammyDuckworth_000 5702
WDA_TedLieu_000 6877
WDA_TerriSewell0_000 2127
WDA_TerriSewell1_000 10952
WDA_TerriSewell_000 6752
WDA_TimWalz_000 6628
WDA_TinaSmith_000 5576
WDA_TomCarper_000 4701
WDA_TomPerez_000 5500
WDA_TomUdall_000 4077
WDA_VeronicaEscobar0_000 4377
WDA_VeronicaEscobar1_000 2453
WDA_WhipJimClyburn_000 6876
WDA_XavierBecerra_000 877
WDA_XavierBecerra_001 627
WDA_XavierBecerra_002 1377
WDA_ZoeLofgren_000 4625
WRA_AdamKinzinger0_000 4251
WRA_AdamKinzinger1_000 2751
WRA_AdamKinzinger2_000 1002
WRA_AdamKinzinger2_001 1001
WRA_AdamKinzinger2_002 502
WRA_AllenWest_000 6876
WRA_AnnMarieBuerkle_000 452
WRA_AnnWagner_000 3127
WRA_AustinScott0_000 1177
WRA_BillCassidy0_000 427
WRA_BobCorker_000 1127
WRA_BobGoodlatte0_000 1001
WRA_BobGoodlatte0_001 752
WRA_BobGoodlatte0_002 876
WRA_BobGoodlatte0_003 1002
WRA_BobbySchilling_001 1000
WRA_BobbySchilling_002 1001
WRA_CandiceMiller0_000 3001
WRA_CarlyFiorina0_000 876
WRA_CarlyFiorina_000 1902
WRA_CathyMcMorrisRodgers0_000 6077
WRA_CathyMcMorrisRodgers1_000 7375
WRA_CathyMcMorrisRodgers1_001 7500
WRA_CathyMcMorrisRodgers1_002 7500
WRA_CathyMcMorrisRodgers2_000 2751
WRA_ChuckGrassley_000 4502
WRA_CoryGardner0_000 8577
WRA_CoryGardner1_000 2002
WRA_CoryGardner_000 1252
WRA_DanNewhouse_000 3627
WRA_DanSullivan_000 6877
WRA_DaveCamp_000 752
WRA_DaveCamp_001 876
WRA_DaveCamp_002 627
WRA_DavidVitter_000 2877
WRA_DeanHeller_000 377
WRA_DebFischer0_000 7577
WRA_DebFischer1_000 1577
WRA_DebFischer2_000 5202
WRA_DianeBlack0_000 451
WRA_DianeBlack0_001 301
WRA_DianeBlack1_000 250
WRA_DuncanHunter_000 727
WRA_EricCantor_000 2952
WRA_ErikPaulsen_000 876
WRA_ErikPaulsen_001 377
WRA_ErikPaulsen_002 627
WRA_ErikPaulsen_003 752
WRA_FredUpton_000 1701
WRA_GeoffDavis_000 250
WRA_GeorgeLeMieux_000 752
WRA_GeorgeLeMieux_001 1001
WRA_GregWalden1_000 752
WRA_GregWalden1_001 1127
WRA_GregWalden1_002 1327
WRA_GregWalden_000 377
WRA_JaimeHerreraBeutler0_000 952
WRA_JebHensarling0_001 1176
WRA_JebHensarling1_000 2327
WRA_JebHensarling1_001 727
WRA_JebHensarling2_000 2302
WRA_JebHensarling2_001 877
WRA_JebHensarling2_002 2127
WRA_JebHensarling2_003 1877
WRA_JeffFlake_000 7202
WRA_JeffFlake_001 7502
WRA_JeffFlake_002 7377
WRA_JimInhofe_000 3127
WRA_JimRisch_000 1377
WRA_JoeHeck1_000 250
WRA_JoePitts_000 1077
WRA_JohnBarrasso0_000 5077
WRA_JohnBarrasso1_000 3452
WRA_JohnBoehner0_000 1626
WRA_JohnBoehner1_000 2951
WRA_JohnBoozman_000 1877
WRA_JohnHoeven_000 2577

View File

@@ -0,0 +1,30 @@
WRA_JohnKasich0_000 1302
WRA_JohnKasich1_000 1301
WRA_JohnKasich1_001 1127
WRA_JohnKasich3_000 1052
WRA_JohnThune_000 1452
WRA_JohnnyIsakson_000 5951
WRA_JohnnyIsakson_001 5000
WRA_JonKyl_000 626
WRA_JoniErnst0_000 2077
WRA_JoniErnst1_000 1452
WRA_JuddGregg_000 1252
WRA_JuddGregg_001 952
WRA_JuddGregg_002 753
WRA_KayBaileyHutchison_000 6325
WRA_KellyAyotte_000 8077
WRA_KevinBrady2_000 1276
WRA_KevinBrady3_000 752
WRA_KevinBrady_000 2503
WRA_KevinMcCarthy0_000 1002
WRA_KevinMcCarthy0_001 1127
WRA_KristiNoem0_000 727
WRA_KristiNoem1_000 452
WRA_KristiNoem2_000 5952
WRA_KristiNoem2_001 7277
WRA_LamarAlexander0_000 1527
WRA_LamarAlexander_000 1552
WRA_LisaMurkowski0_000 1752
WRA_LynnJenkins_000 877
WRA_LynnJenkins_001 1002
WRA_MarcoRubio_000 526

36
train_codes/musetalk.json Normal file
View File

@@ -0,0 +1,36 @@
{
"_class_name": "UNet2DConditionModel",
"_diffusers_version": "0.6.0.dev0",
"act_fn": "silu",
"attention_head_dim": 8,
"block_out_channels": [
320,
640,
1280,
1280
],
"center_input_sample": false,
"cross_attention_dim": 384,
"down_block_types": [
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D"
],
"downsample_padding": 1,
"flip_sin_to_cos": true,
"freq_shift": 0,
"in_channels": 8,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"out_channels": 4,
"sample_size": 64,
"up_block_types": [
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D"
]
}

642
train_codes/train.py Executable file
View File

@@ -0,0 +1,642 @@
import argparse
import itertools
import math
import os
import random
from pathlib import Path
import json
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from PIL import Image, ImageDraw
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from DataLoader import Dataset
from utils.utils import preprocess_img_tensor
from torch.utils import data as data_utils
from model_utils import validation,PositionalEncoding
import time
import pandas as pd
from PIL import Image
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--unet_config_file",
type=str,
default=None,
required=True,
help="the configuration of unet file.",
)
parser.add_argument(
"--reconstruction",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--data_root",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument("--testing_speed", action="store_true", help="Whether to caculate the running time")
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=1000,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
" using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--validation_steps",
type=int,
default=1000,
help=(
"Conduct validation every X updates."
),
)
parser.add_argument(
"--val_out_dir",
type=str,
default = '',
help=(
"Conduct validation every X updates."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--use_audio_length_left",
type=int,
default=1,
help="number of audio length (left).",
)
parser.add_argument(
"--use_audio_length_right",
type=int,
default=1,
help="number of audio length (right)",
)
parser.add_argument(
"--whisper_model_type",
type=str,
default="landmark_nearest",
choices=["tiny","largeV2"],
help="Determine whisper feature type",
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
def main():
args = parse_args()
print(args)
args.output_dir = f"output/{args.output_dir}"
args.val_out_dir = f"val/{args.val_out_dir}"
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.val_out_dir, exist_ok=True)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
logging_dir = Path(args.output_dir, args.logging_dir)
project_config = ProjectConfiguration(
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
project_config=project_config,
)
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
if args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
if args.seed is not None:
# set_seed(args.seed)
set_seed(seed + accelerator.process_index)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load models and create wrapper for stable diffusion
with open(args.unet_config_file, 'r') as f:
unet_config = json.load(f)
#text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
# Todo:
print("Loading AutoencoderKL")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
vae_fp32 = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
print("Loading UNet2DConditionModel")
unet = UNet2DConditionModel(**unet_config)
if args.whisper_model_type == "tiny":
pe = PositionalEncoding(d_model=384)
elif args.whisper_model_type == "largeV2":
pe = PositionalEncoding(d_model=1280)
else:
print(f"not support whisper_model_type {args.whisper_model_type}")
print("Loading models done...")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters())
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
print("loading train_dataset ...")
train_dataset = Dataset(args.data_root,
'train',
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("train_dataset:",train_dataset.__len__())
train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True,
num_workers=8)
print("loading val_dataset ...")
val_dataset = Dataset(args.data_root,
'val',
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("val_dataset:",val_dataset.__len__())
val_data_loader = data_utils.DataLoader(
val_dataset, batch_size=1, shuffle=False,
num_workers=8)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
unet, optimizer, train_data_loader, val_data_loader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_data_loader, val_data_loader,lr_scheduler
)
vae.requires_grad_(False)
vae_fp32.requires_grad_(False)
weight_dtype = torch.float32
vae_fp32.to(accelerator.device, dtype=weight_dtype)
vae_fp32.encoder = None
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae.to(accelerator.device, dtype=weight_dtype)
vae.decoder = None
pe.to(accelerator.device, dtype=weight_dtype)
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
# caluate the elapsed time
elapsed_time = []
start = time.time()
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
# for step, batch in enumerate(train_dataloader):
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
dataloader_time = time.time() - start
start = time.time()
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
"""
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
print("ref_image: ",ref_image.shape)
print("masks: ", masks.shape)
print("masked_image: ", masked_image.shape)
print("audio feature: ", audio_feature.shape)
print("image: ", image.shape)
"""
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
img_process_time = time.time() - start
start = time.time()
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
# Convert ref images to latent space
ref_latents = vae.encode(
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
).latent_dist.sample()
ref_latents = ref_latents * vae.config.scaling_factor
vae_time = time.time() - start
start = time.time()
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
bsz = latents.shape[0]
# fix timestep for each image
timesteps = torch.tensor([0], device=latents.device)
# concatenate the latents with the mask and the masked latents
"""
print("=============vae latents=====".format(epoch,step))
print("ref_latents: ",ref_latents.shape)
print("mask: ", mask.shape)
print("masked_latents: ", masked_latents.shape)
"""
if unet_config['in_channels'] == 9:
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
else:
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
audio_feature = audio_feature.to(dtype=weight_dtype)
# Predict the noise residual
image_pred = unet(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
if args.reconstruction: # decode the image from the predicted latents
image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
image_pred_img = vae_fp32.decode(image_pred_img).sample
# Mask the top half of the image and calculate the loss only for the lower half of the image.
image_pred_img = image_pred_img[:, :, image_pred_img.shape[2]//2:, :]
image = image[:, :, image.shape[2]//2:, :]
loss_lip = F.l1_loss(image_pred_img.float(), image.float(), reduction="mean") # the loss of the decoded images
loss_latents = F.l1_loss(image_pred.float(), latents.float(), reduction="mean") # the loss of the latents
loss = 2.0*loss_lip + loss_latents # add some weight to balance the loss
else:
loss = F.mse_loss(image_pred.float(), latents.float(), reduction="mean")
#
unet_elapsed_time = time.time() - start
start = time.time()
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters())
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
backward_elapsed_time = time.time() - start
start = time.time()
if args.testing_speed is True and accelerator.is_main_process:
elapsed_time.append(
[dataloader_time, unet_elapsed_time, vae_time, backward_elapsed_time,img_process_time]
)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if global_step % args.validation_steps == 0:
if accelerator.is_main_process:
logger.info(
f"Running validation... epoch={epoch}, global_step={global_step}"
)
print("===========start validation==========")
validation(
vae=accelerator.unwrap_model(vae),
vae_fp32=accelerator.unwrap_model(vae_fp32),
unet=accelerator.unwrap_model(unet),
unet_config=unet_config,
weight_dtype=weight_dtype,
epoch=epoch,
global_step=global_step,
val_data_loader=val_data_loader,
output_dir = args.val_out_dir,
whisper_model_type = args.whisper_model_type
)
logger.info(f"Saved samples to images/val")
start = time.time()
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0],
"unet": unet_elapsed_time,
"backward": backward_elapsed_time,
"data": dataloader_time,
"img_process":img_process_time,
"vae":vae_time
}
progress_bar.set_postfix(**logs)
# accelerator.log(logs, step=global_step)
accelerator.log(
{
"loss/step_loss": logs["loss"],
"parameter/lr": logs["lr"],
"time/unet_forward_time": unet_elapsed_time,
"time/unet_backward_time": backward_elapsed_time,
"time/data_time": dataloader_time,
"time/img_process_time":img_process_time,
"time/vae_time": vae_time
},
step=global_step,
)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
main()

27
train_codes/train.sh Normal file
View File

@@ -0,0 +1,27 @@
export VAE_MODEL="./sd-vae-ft-mse/"
export DATASET="..."
export UNET_CONFIG="./musetalk.json"
accelerate launch --multi_gpu train.py \
--mixed_precision="fp16" \
--unet_config_file=$UNET_CONFIG \
--pretrained_model_name_or_path=$VAE_MODEL \
--data_root=$DATASET \
--train_batch_size=8 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=200000 \
--learning_rate=5e-05 \
--max_grad_norm=1 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--output_dir="..." \
--val_out_dir='...' \
--testing_speed \
--checkpointing_steps=1000 \
--validation_steps=1000 \
--reconstruction \
--resume_from_checkpoint="latest" \
--use_audio_length_left=2 \
--use_audio_length_right=2 \
--whisper_model_type="tiny" \

View File

@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import time
import math
from utils import decode_latents, preprocess_img_tensor
import os
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from torch import Tensor, nn
import logging
import json
RESIZED_IMG = 256
class PositionalEncoding(nn.Module):
"""
Transformer 中的位置编码positional encoding
"""
def __init__(self, d_model=384, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
b, seq_len, d_model = x.size()
pe = self.pe[:, :seq_len, :]
#print(b, seq_len, d_model)
x = x + pe.to(x.device)
return x
def validation(vae: torch.nn.Module,
vae_fp32: torch.nn.Module,
unet:torch.nn.Module,
unet_config,
weight_dtype: torch.dtype,
epoch: int,
global_step: int,
val_data_loader,
output_dir,
whisper_model_type,
UNet2DConditionModel=UNet2DConditionModel
):
# Get the validation pipeline
unet_copy = UNet2DConditionModel(**unet_config)
unet_copy.load_state_dict(unet.state_dict())
unet_copy.to(vae.device).to(dtype=weight_dtype)
unet_copy.eval()
if whisper_model_type == "tiny":
pe = PositionalEncoding(d_model=384)
elif whisper_model_type == "largeV2":
pe = PositionalEncoding(d_model=1280)
elif whisper_model_type == "tiny-conv":
pe = PositionalEncoding(d_model=384)
print(f" whisper_model_type: {whisper_model_type} Validation does not need PE")
else:
print(f"not support whisper_model_type {whisper_model_type}")
pe.to(vae.device, dtype=weight_dtype)
start = time.time()
with torch.no_grad():
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader):
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
# Convert ref images to latent space
ref_latents = vae.encode(
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
).latent_dist.sample()
ref_latents = ref_latents * vae.config.scaling_factor
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
bsz = latents.shape[0]
timesteps = torch.tensor([0], device=latents.device)
if unet_config['in_channels'] == 9:
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
else:
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG))
image.paste(decode_latents(vae_fp32,masked_latents), (0, 0))
image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0))
image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0))
image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0))
val_img_dir = f"images/{output_dir}/{global_step}"
if not os.path.exists(val_img_dir):
os.makedirs(val_img_dir)
image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step))
print("valtion in step:{0}, time:{1}".format(step,time.time()-start))
print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start))

View File

@@ -0,0 +1,74 @@
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange
import torch
import torchvision.transforms as transforms
from diffusers import AutoencoderKL
import matplotlib.pyplot as plt
import PIL
import os
import cv2
from glob import glob
def preprocess_img_tensor(image_tensor):
# 假设输入是一个形状为 (N, C, H, W) 的 PyTorch 张量
N, C, H, W = image_tensor.shape
# 计算新的宽度和高度,使其为 32 的整数倍
new_w = W - W % 32
new_h = H - H % 32
# 使用 torchvision.transforms 库中的方法进行缩放和重采样
transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# 对每个图像应用变换,并将结果存储在一个新的张量中
preprocessed_images = torch.empty((N, C, new_h, new_w), dtype=torch.float32)
for i in range(N):
# 使用 F.interpolate 替换 transforms.Resize
resized_image = F.interpolate(image_tensor[i].unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False)
preprocessed_images[i] = transform(resized_image.squeeze(0))
return preprocessed_images
def prepare_mask_and_masked_image(image_tensor, mask_tensor):
# 假设输入 image_tensor 的形状为 [C, H, W],输入 mask_tensor 的形状为 [H, W]
# # 对图像张量进行归一化
image_tensor_ori = (image_tensor.to(dtype=torch.float32) / 127.5) - 1.0
# # 对遮罩张量进行归一化和二值化
# mask_tensor = (mask_tensor.to(dtype=torch.float32) / 255.0).unsqueeze(0)
mask_tensor[mask_tensor < 0.5] = 0
mask_tensor[mask_tensor >= 0.5] = 1
# 创建遮罩后的图像
masked_image_tensor = image_tensor * (mask_tensor > 0.5)
return mask_tensor, masked_image_tensor
def encode_latents(vae, image):
# init_image = preprocess_image(image)
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
init_latents = 0.18215 * init_latent_dist.sample()
return init_latents
def decode_latents(vae, latents, ref_images=None):
latents = (1/ 0.18215) * latents
image = vae.decode(latents.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
if ref_images is not None:
ref_images = ref_images.detach().cpu().permute(0, 2, 3, 1).float().numpy()
ref_images = (ref_images * 255).round().astype("uint8")
h = image.shape[1]
image[:, :h//2] = ref_images[:, :h//2]
image = [Image.fromarray(im) for im in image]
return image[0]