mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
Update draft training codes
This commit is contained in:
252
train_codes/DataLoader.py
Normal file
252
train_codes/DataLoader.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import os, random, cv2, argparse
|
||||
import torch
|
||||
from torch.utils import data as data_utils
|
||||
from os.path import dirname, join, basename, isfile
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from utils.utils import prepare_mask_and_masked_image
|
||||
import torchvision.utils as vutils
|
||||
import torchvision.transforms as transforms
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import heapq
|
||||
|
||||
syncnet_T = 1
|
||||
RESIZED_IMG = 256
|
||||
|
||||
connections = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7),(7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13),(13,14),(14,15),(15,16), # 下颌线
|
||||
(17, 18), (18, 19), (19, 20), (20, 21), #左眉毛
|
||||
(22, 23), (23, 24), (24, 25), (25, 26), #右眉毛
|
||||
(27, 28),(28,29),(29,30),# 鼻梁
|
||||
(31,32),(32,33),(33,34),(34,35), #鼻子
|
||||
(36,37),(37,38),(38, 39), (39, 40), (40, 41), (41, 36), # 左眼
|
||||
(42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42), # 右眼
|
||||
(48, 49),(49, 50), (50, 51),(51, 52),(52, 53), (53, 54), # 上嘴唇 外延
|
||||
(54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 48), # 下嘴唇 外延
|
||||
(60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60) #嘴唇内圈
|
||||
]
|
||||
|
||||
|
||||
def get_image_list(data_root, split):
|
||||
filelist = []
|
||||
imgNumList = []
|
||||
with open('filelists/{}.txt'.format(split)) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if ' ' in line:
|
||||
filename = line.split()[0]
|
||||
imgNum = int(line.split()[1])
|
||||
filelist.append(os.path.join(data_root, filename))
|
||||
imgNumList.append(imgNum)
|
||||
return filelist, imgNumList
|
||||
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
split,
|
||||
use_audio_length_left=1,
|
||||
use_audio_length_right=1,
|
||||
whisper_model_type = "tiny"
|
||||
):
|
||||
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||
self.audio_feature = [use_audio_length_left,use_audio_length_right]
|
||||
self.all_img_names = []
|
||||
self.split = split
|
||||
self.img_names_path = '...'
|
||||
self.whisper_model_type = whisper_model_type
|
||||
self.use_audio_length_left = use_audio_length_left
|
||||
self.use_audio_length_right = use_audio_length_right
|
||||
|
||||
if self.whisper_model_type =="tiny":
|
||||
self.whisper_path = '...'
|
||||
self.whisper_feature_W = 5
|
||||
self.whisper_feature_H = 384
|
||||
elif self.whisper_model_type =="largeV2":
|
||||
self.whisper_path = '...'
|
||||
self.whisper_feature_W = 33
|
||||
self.whisper_feature_H = 1280
|
||||
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
||||
|
||||
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
|
||||
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
|
||||
if not os.path.exists(json_path_names):
|
||||
img_names = glob(join(vidname, '*.png'))
|
||||
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
|
||||
with open(json_path_names, "w") as f:
|
||||
json.dump(img_names,f)
|
||||
print(f"save to {json_path_names}")
|
||||
else:
|
||||
with open(json_path_names, "r") as f:
|
||||
img_names = json.load(f)
|
||||
self.all_img_names.append(img_names)
|
||||
|
||||
def get_frame_id(self, frame):
|
||||
return int(basename(frame).split('.')[0])
|
||||
|
||||
def get_window(self, start_frame):
|
||||
start_id = self.get_frame_id(start_frame)
|
||||
vidname = dirname(start_frame)
|
||||
|
||||
window_fnames = []
|
||||
for frame_id in range(start_id, start_id + syncnet_T):
|
||||
frame = join(vidname, '{}.png'.format(frame_id))
|
||||
if not isfile(frame):
|
||||
return None
|
||||
window_fnames.append(frame)
|
||||
return window_fnames
|
||||
|
||||
def read_window(self, window_fnames):
|
||||
if window_fnames is None: return None
|
||||
window = []
|
||||
for fname in window_fnames:
|
||||
img = cv2.imread(fname)
|
||||
if img is None:
|
||||
return None
|
||||
try:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (RESIZED_IMG, RESIZED_IMG))
|
||||
except Exception as e:
|
||||
print("read_window has error fname not exist:",fname)
|
||||
return None
|
||||
|
||||
window.append(img)
|
||||
|
||||
return window
|
||||
|
||||
|
||||
def crop_audio_window(self, spec, start_frame):
|
||||
if type(start_frame) == int:
|
||||
start_frame_num = start_frame
|
||||
else:
|
||||
start_frame_num = self.get_frame_id(start_frame)
|
||||
start_idx = int(80. * (start_frame_num / float(hparams.fps)))
|
||||
|
||||
end_idx = start_idx + syncnet_mel_step_size
|
||||
|
||||
return spec[start_idx : end_idx, :]
|
||||
|
||||
def prepare_window(self, window):
|
||||
# 1 x H x W x 3
|
||||
x = np.asarray(window) / 255.
|
||||
x = np.transpose(x, (3, 0, 1, 2))
|
||||
|
||||
return x
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_videos)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
while 1:
|
||||
idx = random.randint(0, len(self.all_videos) - 1)
|
||||
#随机选择某个video里
|
||||
vidname = self.all_videos[idx].split('/')[-1]
|
||||
video_imgs = self.all_img_names[idx]
|
||||
if len(video_imgs) == 0:
|
||||
# print("video_imgs = 0:",vidname)
|
||||
continue
|
||||
img_name = random.choice(video_imgs)
|
||||
img_idx = int(basename(img_name).split(".")[0])
|
||||
random_element = random.randint(0,len(video_imgs)-1)
|
||||
while abs(random_element - img_idx) <= 5:
|
||||
random_element = random.randint(0,len(video_imgs)-1)
|
||||
img_dir = os.path.dirname(img_name)
|
||||
ref_image = os.path.join(img_dir, f"{str(random_element)}.png")
|
||||
target_window_fnames = self.get_window(img_name)
|
||||
ref_window_fnames = self.get_window(ref_image)
|
||||
|
||||
if target_window_fnames is None or ref_window_fnames is None:
|
||||
print("No such img",img_name, ref_image)
|
||||
continue
|
||||
|
||||
try:
|
||||
#构建目标img数据
|
||||
target_window = self.read_window(target_window_fnames)
|
||||
if target_window is None :
|
||||
print("No such target window,",target_window_fnames)
|
||||
continue
|
||||
#构建参考img数据
|
||||
ref_window = self.read_window(ref_window_fnames)
|
||||
|
||||
if ref_window is None:
|
||||
print("No such target ref window,",ref_window)
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"发生未知错误:{e}")
|
||||
continue
|
||||
|
||||
#构建target输入
|
||||
target_window = self.prepare_window(target_window)
|
||||
image = gt = target_window.copy().squeeze()
|
||||
target_window[:, :, target_window.shape[2]//2:] = 0. # upper half face, mask掉下半部分 V1:输入
|
||||
ref_image = self.prepare_window(ref_window).squeeze()
|
||||
|
||||
|
||||
|
||||
mask = torch.zeros((ref_image.shape[1], ref_image.shape[2]))
|
||||
mask[:ref_image.shape[2]//2,:] = 1
|
||||
image = torch.FloatTensor(image)
|
||||
mask, masked_image = prepare_mask_and_masked_image(image,mask)
|
||||
|
||||
|
||||
|
||||
#音频特征
|
||||
window_index = self.get_frame_id(img_name)
|
||||
sub_folder_name = vidname.split('/')[-1]
|
||||
|
||||
## 根据window_index加载相邻的音频
|
||||
audio_feature_all = []
|
||||
is_index_out_of_range = False
|
||||
if os.path.isdir(os.path.join(self.whisper_path, sub_folder_name)):
|
||||
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
|
||||
# 判定是否越界
|
||||
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
|
||||
|
||||
if not os.path.exists(audio_feat_path):
|
||||
is_index_out_of_range = True
|
||||
break
|
||||
|
||||
try:
|
||||
audio_feature_all.append(np.load(audio_feat_path))
|
||||
except Exception as e:
|
||||
print(f"发生未知错误:{e}")
|
||||
print(f"npy load error {audio_feat_path}")
|
||||
if is_index_out_of_range:
|
||||
continue
|
||||
audio_feature = np.concatenate(audio_feature_all, axis=0)
|
||||
else:
|
||||
continue
|
||||
|
||||
audio_feature = audio_feature.reshape(1, -1, self.whisper_feature_H) #1, -1, 384
|
||||
if audio_feature.shape != (1,self.whisper_feature_concateW, self.whisper_feature_H): #1 50 384
|
||||
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
|
||||
continue
|
||||
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
|
||||
|
||||
|
||||
return ref_image, image, masked_image, mask, audio_feature
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_root = '...'
|
||||
val_data = Dataset(data_root,
|
||||
'val',
|
||||
use_audio_length_left = 2,
|
||||
use_audio_length_right = 2,
|
||||
whisper_model_type = "tiny"
|
||||
)
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_data, batch_size=4, shuffle=True,
|
||||
num_workers=1)
|
||||
print("val_dataset:",val_data_loader.__len__())
|
||||
|
||||
for i, data in enumerate(val_data_loader):
|
||||
ref_image, image, masked_image, mask, audio_feature = data
|
||||
print("ref_image: ", ref_image.shape)
|
||||
|
||||
|
||||
35
train_codes/README.md
Normal file
35
train_codes/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Draft training codes
|
||||
|
||||
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
|
||||
|
||||
## Data preprocessing
|
||||
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
|
||||
|
||||
Finally, the data should be organized as follows:
|
||||
```
|
||||
./data/
|
||||
├── images
|
||||
│ └──RD_Radio10_000
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
│ └──RD_Radio11_000
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
├── audios
|
||||
│ └──RD_Radio10_000
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
│ └──RD_Radio11_000
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
```
|
||||
|
||||
## Training
|
||||
Simply run after preparing the preprocessed data
|
||||
```
|
||||
sh train.sh
|
||||
```
|
||||
322
train_codes/filelists/train.txt
Normal file
322
train_codes/filelists/train.txt
Normal file
@@ -0,0 +1,322 @@
|
||||
RD_Radio10_000 3501
|
||||
RD_Radio11_000 752
|
||||
RD_Radio11_001 1502
|
||||
RD_Radio12_000 876
|
||||
RD_Radio13_000 877
|
||||
RD_Radio14_000 1251
|
||||
RD_Radio16_000 2377
|
||||
RD_Radio17_000 2176
|
||||
RD_Radio18_000 1827
|
||||
RD_Radio19_000 1876
|
||||
RD_Radio1_000 1876
|
||||
RD_Radio20_000 276
|
||||
RD_Radio21_000 1201
|
||||
RD_Radio22_000 752
|
||||
RD_Radio23_000 1602
|
||||
RD_Radio25_000 2126
|
||||
RD_Radio26_000 1052
|
||||
RD_Radio27_000 1376
|
||||
RD_Radio28_000 2252
|
||||
RD_Radio29_000 2252
|
||||
RD_Radio2_000 2076
|
||||
RD_Radio30_000 2877
|
||||
RD_Radio31_000 1377
|
||||
RD_Radio32_000 1502
|
||||
RD_Radio33_000 2252
|
||||
RD_Radio34_000 702
|
||||
RD_Radio34_001 252
|
||||
RD_Radio34_002 501
|
||||
RD_Radio34_003 326
|
||||
RD_Radio34_004 502
|
||||
RD_Radio34_005 502
|
||||
RD_Radio34_006 502
|
||||
RD_Radio34_007 252
|
||||
RD_Radio34_008 876
|
||||
RD_Radio34_009 752
|
||||
RD_Radio35_000 2127
|
||||
RD_Radio36_000 2377
|
||||
RD_Radio37_000 3127
|
||||
RD_Radio38_000 3752
|
||||
RD_Radio39_000 1502
|
||||
RD_Radio3_000 2377
|
||||
RD_Radio40_000 1252
|
||||
RD_Radio41_000 2127
|
||||
RD_Radio42_000 1752
|
||||
RD_Radio43_000 1502
|
||||
RD_Radio44_000 1127
|
||||
RD_Radio45_000 1628
|
||||
RD_Radio46_000 1877
|
||||
RD_Radio47_000 1502
|
||||
RD_Radio48_000 2127
|
||||
RD_Radio49_000 1502
|
||||
RD_Radio4_000 1002
|
||||
RD_Radio50_000 2252
|
||||
RD_Radio51_000 3253
|
||||
RD_Radio52_000 2752
|
||||
RD_Radio53_000 2627
|
||||
RD_Radio54_000 577
|
||||
RD_Radio56_000 1952
|
||||
RD_Radio57_000 3077
|
||||
RD_Radio59_000 1952
|
||||
RD_Radio5_000 1377
|
||||
RD_Radio7_000 2252
|
||||
RD_Radio8_000 2126
|
||||
RD_Radio9_000 1502
|
||||
WDA_AdamSchiff_000 6877
|
||||
WDA_AdamSmith_000 7327
|
||||
WDA_AlexandriaOcasioCortez_000 2252
|
||||
WDA_AmyKlobuchar0_000 6501
|
||||
WDA_AmyKlobuchar1_000 3253
|
||||
WDA_AmyKlobuchar1_001 877
|
||||
WDA_AmyKlobuchar1_002 1502
|
||||
WDA_AmyKlobuchar1_003 1377
|
||||
WDA_AndyKim_000 6952
|
||||
WDA_AndyLevin_000 4876
|
||||
WDA_AnnieKuster_000 4876
|
||||
WDA_BarackObama_000 3575
|
||||
WDA_BarackObama_001 5625
|
||||
WDA_BarbaraLee0_000 6502
|
||||
WDA_BarbaraLee1_000 4702
|
||||
WDA_BenCardin0_000 8127
|
||||
WDA_BenCardin1_000 7102
|
||||
WDA_BenRayLujn_000 7127
|
||||
WDA_BennieThompson1_000 6375
|
||||
WDA_BennieThompson_000 5375
|
||||
WDA_BernieSanders_000 8451
|
||||
WDA_BettyMcCollum_000 7002
|
||||
WDA_BillPascrell_000 7252
|
||||
WDA_BillRichardson_000 3127
|
||||
WDA_BobCasey0_000 10627
|
||||
WDA_BobCasey1_000 252
|
||||
WDA_BobMenendez_000 3002
|
||||
WDA_BobbyScott_000 2002
|
||||
WDA_BradSchneider_000 6627
|
||||
WDA_BrendaLawrence_000 5627
|
||||
WDA_BrianSchatz0_000 502
|
||||
WDA_BrianSchatz1_000 2627
|
||||
WDA_BrianSchatz2_000 1952
|
||||
WDA_ByronDorgan1_000 4277
|
||||
WDA_CarolynMaloney1_000 8377
|
||||
WDA_CatherineCortezMasto_000 2827
|
||||
WDA_CedricRichmond_000 7250
|
||||
WDA_ChrisCoons1_000 4452
|
||||
WDA_ChrisCoons_000 6827
|
||||
WDA_ChrisMurphy0_000 6877
|
||||
WDA_ChrisMurphy1_000 6377
|
||||
WDA_ChrisVanHollen0_000 8202
|
||||
WDA_ChrisVanHollen1_000 7327
|
||||
WDA_ChuckSchumer0_000 5577
|
||||
WDA_ChuckSchumer1_000 4827
|
||||
WDA_ColinAllred_000 4751
|
||||
WDA_DanKildee1_000 6252
|
||||
WDA_DanKildee_000 2502
|
||||
WDA_DavidCicilline_000 5875
|
||||
WDA_DebHaaland_000 5876
|
||||
WDA_DebbieDingell0_000 7752
|
||||
WDA_DebbieDingell1_000 7251
|
||||
WDA_DebbieStabenow0_000 4577
|
||||
WDA_DebbieStabenow1_000 5201
|
||||
WDA_DebbieWassermanSchultz_000 7625
|
||||
WDA_DianaDeGette0_000 6002
|
||||
WDA_DianaDeGette1_000 2202
|
||||
WDA_DianneFeinstein_000 7453
|
||||
WDA_DickDurbin_000 6326
|
||||
WDA_DonaldMcEachin_000 7627
|
||||
WDA_DonnaShalala1_000 7500
|
||||
WDA_DougJones_000 6077
|
||||
WDA_EdMarkey0_000 4752
|
||||
WDA_EdMarkey1_000 6501
|
||||
WDA_ElijahCummings_000 6377
|
||||
WDA_EliotEngel_000 6876
|
||||
WDA_EmanuelCleaver_000 7001
|
||||
WDA_EricSwalwell_000 6377
|
||||
WDA_FrankPallone0_000 5625
|
||||
WDA_FrankPallone1_000 6502
|
||||
WDA_GerryConnolly_000 6752
|
||||
WDA_HakeemJeffries_000 6002
|
||||
WDA_HaleyStevens_000 5001
|
||||
WDA_HenryWaxman_000 1125
|
||||
WDA_HillaryClinton_000 2500
|
||||
WDA_JackReed0_000 2377
|
||||
WDA_JackReed1_000 4877
|
||||
WDA_JackieSpeier_000 4625
|
||||
WDA_JackyRosen1_000 10077
|
||||
WDA_JackyRosen_000 5502
|
||||
WDA_JamesClyburn1_000 6876
|
||||
WDA_JamesClyburn_000 6875
|
||||
WDA_JanSchakowsky0_000 4128
|
||||
WDA_JanSchakowsky1_000 6251
|
||||
WDA_JeanneShaheen0_000 6702
|
||||
WDA_JeanneShaheen1_000 6577
|
||||
WDA_JeffMerkley1_000 6952
|
||||
WDA_JerryNadler_000 8377
|
||||
WDA_JimHimes_000 7000
|
||||
WDA_JimmyGomez_000 4627
|
||||
WDA_JoaquinCastro_000 5126
|
||||
WDA_JoeCrowley0_000 4877
|
||||
WDA_JoeCrowley1_000 1127
|
||||
WDA_JoeCrowley1_001 627
|
||||
WDA_JoeCrowley1_002 502
|
||||
WDA_JoeCrowley1_003 752
|
||||
WDA_JoeDonnelly_000 1377
|
||||
WDA_JoeKennedy_000 1327
|
||||
WDA_JoeManchin_000 3377
|
||||
WDA_JoeNeguse_000 1628
|
||||
WDA_JoeNeguse_001 1126
|
||||
WDA_JoeNeguse_002 1251
|
||||
WDA_JohnLewis0_000 6252
|
||||
WDA_JohnLewis1_000 7252
|
||||
WDA_JohnSarbanes0_000 6377
|
||||
WDA_JohnSarbanes1_000 1877
|
||||
WDA_JohnYarmuth1_000 5377
|
||||
WDA_JonTester0_000 3252
|
||||
WDA_JonTester1_000 5327
|
||||
WDA_KarenBass_000 7126
|
||||
WDA_KatherineClark_000 1552
|
||||
WDA_KathyCastor1_000 6001
|
||||
WDA_KathyCastor_000 2252
|
||||
WDA_KimSchrier_000 6877
|
||||
WDA_KirstenGillibrand_000 9627
|
||||
WDA_LaurenUnderwood_000 9877
|
||||
WDA_LisaBluntRochester_000 4500
|
||||
WDA_LloydDoggett0_000 7377
|
||||
WDA_LloydDoggett1_000 2252
|
||||
WDA_LucilleRoybal-Allard_000 2877
|
||||
WDA_LucyMcBath_000 4000
|
||||
WDA_MarciaFudge_000 7377
|
||||
WDA_MarkWarner1_000 377
|
||||
WDA_MarkWarner1_001 1327
|
||||
WDA_MarkWarner2_000 377
|
||||
WDA_MarkWarner_000 3127
|
||||
WDA_MartinHeinrich_000 5202
|
||||
WDA_MattCartwright_000 5377
|
||||
WDA_MazieHirono0_000 3752
|
||||
WDA_MichelleLujanGrisham_000 6875
|
||||
WDA_MichelleObama_000 2000
|
||||
WDA_MikeDoyle_000 8750
|
||||
WDA_MikeThompson0_000 4625
|
||||
WDA_MikeThompson1_000 1827
|
||||
WDA_NancyPelosi0_000 10251
|
||||
WDA_NancyPelosi1_000 1377
|
||||
WDA_NancyPelosi3_000 1127
|
||||
WDA_NitaLowey_000 5876
|
||||
WDA_NydiaVelzquez_000 5500
|
||||
WDA_PatrickLeahy0_000 6951
|
||||
WDA_PatrickLeahy1_000 9953
|
||||
WDA_PattyMurray0_000 3627
|
||||
WDA_PattyMurray1_000 5702
|
||||
WDA_PeterDeFazio_000 6628
|
||||
WDA_RaulRuiz_000 5752
|
||||
WDA_RichardBlumenthal_000 7827
|
||||
WDA_RichardNeal0_000 6252
|
||||
WDA_RichardNeal1_000 6877
|
||||
WDA_RobinKelly_000 4377
|
||||
WDA_RonWyden0_000 5152
|
||||
WDA_RonWyden1_000 8327
|
||||
WDA_ScottPeters0_000 7002
|
||||
WDA_ScottPeters1_000 3952
|
||||
WDA_SeanCasten_000 6126
|
||||
WDA_SeanPatrickMaloney_000 6502
|
||||
WDA_SheldonWhitehouse0_000 6327
|
||||
WDA_SheldonWhitehouse1_000 5702
|
||||
WDA_SherrodBrown0_000 7452
|
||||
WDA_SherrodBrown1_000 7327
|
||||
WDA_StenyHoyer_000 3502
|
||||
WDA_StephanieMurphy_000 4000
|
||||
WDA_SuzanDelBene_000 6875
|
||||
WDA_TammyBaldwin0_000 5327
|
||||
WDA_TammyBaldwin1_000 2503
|
||||
WDA_TammyDuckworth_000 5702
|
||||
WDA_TedLieu_000 6877
|
||||
WDA_TerriSewell0_000 2127
|
||||
WDA_TerriSewell1_000 10952
|
||||
WDA_TerriSewell_000 6752
|
||||
WDA_TimWalz_000 6628
|
||||
WDA_TinaSmith_000 5576
|
||||
WDA_TomCarper_000 4701
|
||||
WDA_TomPerez_000 5500
|
||||
WDA_TomUdall_000 4077
|
||||
WDA_VeronicaEscobar0_000 4377
|
||||
WDA_VeronicaEscobar1_000 2453
|
||||
WDA_WhipJimClyburn_000 6876
|
||||
WDA_XavierBecerra_000 877
|
||||
WDA_XavierBecerra_001 627
|
||||
WDA_XavierBecerra_002 1377
|
||||
WDA_ZoeLofgren_000 4625
|
||||
WRA_AdamKinzinger0_000 4251
|
||||
WRA_AdamKinzinger1_000 2751
|
||||
WRA_AdamKinzinger2_000 1002
|
||||
WRA_AdamKinzinger2_001 1001
|
||||
WRA_AdamKinzinger2_002 502
|
||||
WRA_AllenWest_000 6876
|
||||
WRA_AnnMarieBuerkle_000 452
|
||||
WRA_AnnWagner_000 3127
|
||||
WRA_AustinScott0_000 1177
|
||||
WRA_BillCassidy0_000 427
|
||||
WRA_BobCorker_000 1127
|
||||
WRA_BobGoodlatte0_000 1001
|
||||
WRA_BobGoodlatte0_001 752
|
||||
WRA_BobGoodlatte0_002 876
|
||||
WRA_BobGoodlatte0_003 1002
|
||||
WRA_BobbySchilling_001 1000
|
||||
WRA_BobbySchilling_002 1001
|
||||
WRA_CandiceMiller0_000 3001
|
||||
WRA_CarlyFiorina0_000 876
|
||||
WRA_CarlyFiorina_000 1902
|
||||
WRA_CathyMcMorrisRodgers0_000 6077
|
||||
WRA_CathyMcMorrisRodgers1_000 7375
|
||||
WRA_CathyMcMorrisRodgers1_001 7500
|
||||
WRA_CathyMcMorrisRodgers1_002 7500
|
||||
WRA_CathyMcMorrisRodgers2_000 2751
|
||||
WRA_ChuckGrassley_000 4502
|
||||
WRA_CoryGardner0_000 8577
|
||||
WRA_CoryGardner1_000 2002
|
||||
WRA_CoryGardner_000 1252
|
||||
WRA_DanNewhouse_000 3627
|
||||
WRA_DanSullivan_000 6877
|
||||
WRA_DaveCamp_000 752
|
||||
WRA_DaveCamp_001 876
|
||||
WRA_DaveCamp_002 627
|
||||
WRA_DavidVitter_000 2877
|
||||
WRA_DeanHeller_000 377
|
||||
WRA_DebFischer0_000 7577
|
||||
WRA_DebFischer1_000 1577
|
||||
WRA_DebFischer2_000 5202
|
||||
WRA_DianeBlack0_000 451
|
||||
WRA_DianeBlack0_001 301
|
||||
WRA_DianeBlack1_000 250
|
||||
WRA_DuncanHunter_000 727
|
||||
WRA_EricCantor_000 2952
|
||||
WRA_ErikPaulsen_000 876
|
||||
WRA_ErikPaulsen_001 377
|
||||
WRA_ErikPaulsen_002 627
|
||||
WRA_ErikPaulsen_003 752
|
||||
WRA_FredUpton_000 1701
|
||||
WRA_GeoffDavis_000 250
|
||||
WRA_GeorgeLeMieux_000 752
|
||||
WRA_GeorgeLeMieux_001 1001
|
||||
WRA_GregWalden1_000 752
|
||||
WRA_GregWalden1_001 1127
|
||||
WRA_GregWalden1_002 1327
|
||||
WRA_GregWalden_000 377
|
||||
WRA_JaimeHerreraBeutler0_000 952
|
||||
WRA_JebHensarling0_001 1176
|
||||
WRA_JebHensarling1_000 2327
|
||||
WRA_JebHensarling1_001 727
|
||||
WRA_JebHensarling2_000 2302
|
||||
WRA_JebHensarling2_001 877
|
||||
WRA_JebHensarling2_002 2127
|
||||
WRA_JebHensarling2_003 1877
|
||||
WRA_JeffFlake_000 7202
|
||||
WRA_JeffFlake_001 7502
|
||||
WRA_JeffFlake_002 7377
|
||||
WRA_JimInhofe_000 3127
|
||||
WRA_JimRisch_000 1377
|
||||
WRA_JoeHeck1_000 250
|
||||
WRA_JoePitts_000 1077
|
||||
WRA_JohnBarrasso0_000 5077
|
||||
WRA_JohnBarrasso1_000 3452
|
||||
WRA_JohnBoehner0_000 1626
|
||||
WRA_JohnBoehner1_000 2951
|
||||
WRA_JohnBoozman_000 1877
|
||||
WRA_JohnHoeven_000 2577
|
||||
30
train_codes/filelists/val.txt
Normal file
30
train_codes/filelists/val.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
WRA_JohnKasich0_000 1302
|
||||
WRA_JohnKasich1_000 1301
|
||||
WRA_JohnKasich1_001 1127
|
||||
WRA_JohnKasich3_000 1052
|
||||
WRA_JohnThune_000 1452
|
||||
WRA_JohnnyIsakson_000 5951
|
||||
WRA_JohnnyIsakson_001 5000
|
||||
WRA_JonKyl_000 626
|
||||
WRA_JoniErnst0_000 2077
|
||||
WRA_JoniErnst1_000 1452
|
||||
WRA_JuddGregg_000 1252
|
||||
WRA_JuddGregg_001 952
|
||||
WRA_JuddGregg_002 753
|
||||
WRA_KayBaileyHutchison_000 6325
|
||||
WRA_KellyAyotte_000 8077
|
||||
WRA_KevinBrady2_000 1276
|
||||
WRA_KevinBrady3_000 752
|
||||
WRA_KevinBrady_000 2503
|
||||
WRA_KevinMcCarthy0_000 1002
|
||||
WRA_KevinMcCarthy0_001 1127
|
||||
WRA_KristiNoem0_000 727
|
||||
WRA_KristiNoem1_000 452
|
||||
WRA_KristiNoem2_000 5952
|
||||
WRA_KristiNoem2_001 7277
|
||||
WRA_LamarAlexander0_000 1527
|
||||
WRA_LamarAlexander_000 1552
|
||||
WRA_LisaMurkowski0_000 1752
|
||||
WRA_LynnJenkins_000 877
|
||||
WRA_LynnJenkins_001 1002
|
||||
WRA_MarcoRubio_000 526
|
||||
36
train_codes/musetalk.json
Normal file
36
train_codes/musetalk.json
Normal file
@@ -0,0 +1,36 @@
|
||||
{
|
||||
"_class_name": "UNet2DConditionModel",
|
||||
"_diffusers_version": "0.6.0.dev0",
|
||||
"act_fn": "silu",
|
||||
"attention_head_dim": 8,
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"center_input_sample": false,
|
||||
"cross_attention_dim": 384,
|
||||
"down_block_types": [
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 8,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 4,
|
||||
"sample_size": 64,
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D"
|
||||
]
|
||||
}
|
||||
642
train_codes/train.py
Executable file
642
train_codes/train.py
Executable file
@@ -0,0 +1,642 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from PIL import Image, ImageDraw
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
|
||||
from DataLoader import Dataset
|
||||
from utils.utils import preprocess_img_tensor
|
||||
from torch.utils import data as data_utils
|
||||
from model_utils import validation,PositionalEncoding
|
||||
import time
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--unet_config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="the configuration of unet file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reconstruction",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to add prior preservation loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument("--testing_speed", action="store_true", help="Whether to caculate the running time")
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
|
||||
" using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=(
|
||||
"Conduct validation every X updates."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_out_dir",
|
||||
type=str,
|
||||
default = '',
|
||||
help=(
|
||||
"Conduct validation every X updates."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more docs"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_audio_length_left",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of audio length (left).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_audio_length_right",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of audio length (right)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper_model_type",
|
||||
type=str,
|
||||
default="landmark_nearest",
|
||||
choices=["tiny","largeV2"],
|
||||
help="Determine whisper feature type",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print(args)
|
||||
args.output_dir = f"output/{args.output_dir}"
|
||||
args.val_out_dir = f"val/{args.val_out_dir}"
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
os.makedirs(args.val_out_dir, exist_ok=True)
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
|
||||
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
project_config=project_config,
|
||||
)
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
if args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
||||
raise ValueError(
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
if args.seed is not None:
|
||||
# set_seed(args.seed)
|
||||
set_seed(seed + accelerator.process_index)
|
||||
|
||||
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
with open(args.unet_config_file, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
|
||||
#text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
|
||||
# Todo:
|
||||
print("Loading AutoencoderKL")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
|
||||
vae_fp32 = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
print("Loading UNet2DConditionModel")
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if args.whisper_model_type == "tiny":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
elif args.whisper_model_type == "largeV2":
|
||||
pe = PositionalEncoding(d_model=1280)
|
||||
else:
|
||||
print(f"not support whisper_model_type {args.whisper_model_type}")
|
||||
|
||||
print("Loading models done...")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters())
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
print("loading train_dataset ...")
|
||||
train_dataset = Dataset(args.data_root,
|
||||
'train',
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
print("train_dataset:",train_dataset.__len__())
|
||||
train_data_loader = data_utils.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True,
|
||||
num_workers=8)
|
||||
print("loading val_dataset ...")
|
||||
val_dataset = Dataset(args.data_root,
|
||||
'val',
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
print("val_dataset:",val_dataset.__len__())
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_dataset, batch_size=1, shuffle=False,
|
||||
num_workers=8)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
|
||||
|
||||
unet, optimizer, train_data_loader, val_data_loader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_data_loader, val_data_loader,lr_scheduler
|
||||
)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
vae_fp32.requires_grad_(False)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
vae_fp32.to(accelerator.device, dtype=weight_dtype)
|
||||
vae_fp32.encoder = None
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.decoder = None
|
||||
pe.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_data_loader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth", config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# caluate the elapsed time
|
||||
elapsed_time = []
|
||||
start = time.time()
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
# for step, batch in enumerate(train_dataloader):
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
dataloader_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
"""
|
||||
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||
print("ref_image: ",ref_image.shape)
|
||||
print("masks: ", masks.shape)
|
||||
print("masked_image: ", masked_image.shape)
|
||||
print("audio feature: ", audio_feature.shape)
|
||||
print("image: ", image.shape)
|
||||
"""
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
img_process_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
|
||||
# Convert ref images to latent space
|
||||
ref_latents = vae.encode(
|
||||
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
|
||||
).latent_dist.sample()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
|
||||
vae_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
|
||||
|
||||
|
||||
bsz = latents.shape[0]
|
||||
# fix timestep for each image
|
||||
timesteps = torch.tensor([0], device=latents.device)
|
||||
# concatenate the latents with the mask and the masked latents
|
||||
"""
|
||||
print("=============vae latents=====".format(epoch,step))
|
||||
print("ref_latents: ",ref_latents.shape)
|
||||
print("mask: ", mask.shape)
|
||||
print("masked_latents: ", masked_latents.shape)
|
||||
"""
|
||||
|
||||
if unet_config['in_channels'] == 9:
|
||||
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
|
||||
else:
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
|
||||
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||
# Predict the noise residual
|
||||
image_pred = unet(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
|
||||
|
||||
if args.reconstruction: # decode the image from the predicted latents
|
||||
image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
|
||||
image_pred_img = vae_fp32.decode(image_pred_img).sample
|
||||
|
||||
# Mask the top half of the image and calculate the loss only for the lower half of the image.
|
||||
image_pred_img = image_pred_img[:, :, image_pred_img.shape[2]//2:, :]
|
||||
image = image[:, :, image.shape[2]//2:, :]
|
||||
loss_lip = F.l1_loss(image_pred_img.float(), image.float(), reduction="mean") # the loss of the decoded images
|
||||
loss_latents = F.l1_loss(image_pred.float(), latents.float(), reduction="mean") # the loss of the latents
|
||||
|
||||
loss = 2.0*loss_lip + loss_latents # add some weight to balance the loss
|
||||
|
||||
else:
|
||||
loss = F.mse_loss(image_pred.float(), latents.float(), reduction="mean")
|
||||
#
|
||||
|
||||
unet_elapsed_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters())
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
backward_elapsed_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
if args.testing_speed is True and accelerator.is_main_process:
|
||||
elapsed_time.append(
|
||||
[dataloader_time, unet_elapsed_time, vae_time, backward_elapsed_time,img_process_time]
|
||||
)
|
||||
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
|
||||
if global_step % args.validation_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
logger.info(
|
||||
f"Running validation... epoch={epoch}, global_step={global_step}"
|
||||
)
|
||||
print("===========start validation==========")
|
||||
validation(
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
vae_fp32=accelerator.unwrap_model(vae_fp32),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet_config=unet_config,
|
||||
weight_dtype=weight_dtype,
|
||||
epoch=epoch,
|
||||
global_step=global_step,
|
||||
val_data_loader=val_data_loader,
|
||||
output_dir = args.val_out_dir,
|
||||
whisper_model_type = args.whisper_model_type
|
||||
)
|
||||
logger.info(f"Saved samples to images/val")
|
||||
start = time.time()
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0],
|
||||
"unet": unet_elapsed_time,
|
||||
"backward": backward_elapsed_time,
|
||||
"data": dataloader_time,
|
||||
"img_process":img_process_time,
|
||||
"vae":vae_time
|
||||
}
|
||||
progress_bar.set_postfix(**logs)
|
||||
# accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.log(
|
||||
{
|
||||
"loss/step_loss": logs["loss"],
|
||||
"parameter/lr": logs["lr"],
|
||||
"time/unet_forward_time": unet_elapsed_time,
|
||||
"time/unet_backward_time": backward_elapsed_time,
|
||||
"time/data_time": dataloader_time,
|
||||
"time/img_process_time":img_process_time,
|
||||
"time/vae_time": vae_time
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
27
train_codes/train.sh
Normal file
27
train_codes/train.sh
Normal file
@@ -0,0 +1,27 @@
|
||||
export VAE_MODEL="./sd-vae-ft-mse/"
|
||||
export DATASET="..."
|
||||
export UNET_CONFIG="./musetalk.json"
|
||||
|
||||
accelerate launch --multi_gpu train.py \
|
||||
--mixed_precision="fp16" \
|
||||
--unet_config_file=$UNET_CONFIG \
|
||||
--pretrained_model_name_or_path=$VAE_MODEL \
|
||||
--data_root=$DATASET \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=200000 \
|
||||
--learning_rate=5e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="cosine" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="..." \
|
||||
--val_out_dir='...' \
|
||||
--testing_speed \
|
||||
--checkpointing_steps=1000 \
|
||||
--validation_steps=1000 \
|
||||
--reconstruction \
|
||||
--resume_from_checkpoint="latest" \
|
||||
--use_audio_length_left=2 \
|
||||
--use_audio_length_right=2 \
|
||||
--whisper_model_type="tiny" \
|
||||
129
train_codes/utils/model_utils.py
Normal file
129
train_codes/utils/model_utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import math
|
||||
from utils import decode_latents, preprocess_img_tensor
|
||||
import os
|
||||
from PIL import Image
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
import logging
|
||||
import json
|
||||
|
||||
RESIZED_IMG = 256
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""
|
||||
Transformer 中的位置编码(positional encoding)
|
||||
"""
|
||||
def __init__(self, d_model=384, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
b, seq_len, d_model = x.size()
|
||||
pe = self.pe[:, :seq_len, :]
|
||||
#print(b, seq_len, d_model)
|
||||
x = x + pe.to(x.device)
|
||||
return x
|
||||
|
||||
def validation(vae: torch.nn.Module,
|
||||
vae_fp32: torch.nn.Module,
|
||||
unet:torch.nn.Module,
|
||||
unet_config,
|
||||
weight_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
val_data_loader,
|
||||
output_dir,
|
||||
whisper_model_type,
|
||||
UNet2DConditionModel=UNet2DConditionModel
|
||||
):
|
||||
|
||||
# Get the validation pipeline
|
||||
unet_copy = UNet2DConditionModel(**unet_config)
|
||||
|
||||
unet_copy.load_state_dict(unet.state_dict())
|
||||
unet_copy.to(vae.device).to(dtype=weight_dtype)
|
||||
unet_copy.eval()
|
||||
|
||||
if whisper_model_type == "tiny":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
elif whisper_model_type == "largeV2":
|
||||
pe = PositionalEncoding(d_model=1280)
|
||||
elif whisper_model_type == "tiny-conv":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
print(f" whisper_model_type: {whisper_model_type} Validation does not need PE")
|
||||
else:
|
||||
print(f"not support whisper_model_type {whisper_model_type}")
|
||||
pe.to(vae.device, dtype=weight_dtype)
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader):
|
||||
|
||||
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
# Convert ref images to latent space
|
||||
ref_latents = vae.encode(
|
||||
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
|
||||
).latent_dist.sample()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
|
||||
bsz = latents.shape[0]
|
||||
timesteps = torch.tensor([0], device=latents.device)
|
||||
|
||||
if unet_config['in_channels'] == 9:
|
||||
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
|
||||
else:
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
|
||||
image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
|
||||
|
||||
image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG))
|
||||
image.paste(decode_latents(vae_fp32,masked_latents), (0, 0))
|
||||
image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0))
|
||||
image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0))
|
||||
image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0))
|
||||
|
||||
val_img_dir = f"images/{output_dir}/{global_step}"
|
||||
if not os.path.exists(val_img_dir):
|
||||
os.makedirs(val_img_dir)
|
||||
image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step))
|
||||
|
||||
print("valtion in step:{0}, time:{1}".format(step,time.time()-start))
|
||||
|
||||
print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start))
|
||||
|
||||
74
train_codes/utils/utils.py
Normal file
74
train_codes/utils/utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
from einops import rearrange
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
import os
|
||||
import cv2
|
||||
from glob import glob
|
||||
|
||||
|
||||
def preprocess_img_tensor(image_tensor):
|
||||
# 假设输入是一个形状为 (N, C, H, W) 的 PyTorch 张量
|
||||
N, C, H, W = image_tensor.shape
|
||||
# 计算新的宽度和高度,使其为 32 的整数倍
|
||||
new_w = W - W % 32
|
||||
new_h = H - H % 32
|
||||
# 使用 torchvision.transforms 库中的方法进行缩放和重采样
|
||||
transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
# 对每个图像应用变换,并将结果存储在一个新的张量中
|
||||
preprocessed_images = torch.empty((N, C, new_h, new_w), dtype=torch.float32)
|
||||
for i in range(N):
|
||||
# 使用 F.interpolate 替换 transforms.Resize
|
||||
resized_image = F.interpolate(image_tensor[i].unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False)
|
||||
preprocessed_images[i] = transform(resized_image.squeeze(0))
|
||||
|
||||
return preprocessed_images
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image_tensor, mask_tensor):
|
||||
# 假设输入 image_tensor 的形状为 [C, H, W],输入 mask_tensor 的形状为 [H, W]
|
||||
# # 对图像张量进行归一化
|
||||
image_tensor_ori = (image_tensor.to(dtype=torch.float32) / 127.5) - 1.0
|
||||
# # 对遮罩张量进行归一化和二值化
|
||||
# mask_tensor = (mask_tensor.to(dtype=torch.float32) / 255.0).unsqueeze(0)
|
||||
mask_tensor[mask_tensor < 0.5] = 0
|
||||
mask_tensor[mask_tensor >= 0.5] = 1
|
||||
# 创建遮罩后的图像
|
||||
masked_image_tensor = image_tensor * (mask_tensor > 0.5)
|
||||
|
||||
return mask_tensor, masked_image_tensor
|
||||
|
||||
|
||||
def encode_latents(vae, image):
|
||||
# init_image = preprocess_image(image)
|
||||
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
|
||||
init_latents = 0.18215 * init_latent_dist.sample()
|
||||
return init_latents
|
||||
|
||||
def decode_latents(vae, latents, ref_images=None):
|
||||
latents = (1/ 0.18215) * latents
|
||||
image = vae.decode(latents.to(vae.dtype)).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
image = (image * 255).round().astype("uint8")
|
||||
if ref_images is not None:
|
||||
ref_images = ref_images.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
ref_images = (ref_images * 255).round().astype("uint8")
|
||||
h = image.shape[1]
|
||||
image[:, :h//2] = ref_images[:, :h//2]
|
||||
image = [Image.fromarray(im) for im in image]
|
||||
|
||||
return image[0]
|
||||
|
||||
Reference in New Issue
Block a user