modified dataloader.py and inference.py for training and inference

This commit is contained in:
Shounak Banerjee
2024-06-03 11:09:12 +00:00
parent 7254ca6306
commit b4a592d7f3
6 changed files with 106 additions and 58 deletions

View File

@@ -15,14 +15,27 @@ from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model from musetalk.utils.utils import load_all_model
import shutil import shutil
from accelerate import Accelerator
# load model weights # load model weights
audio_processor, vae, unet, pe = load_all_model() audio_processor, vae, unet, pe = load_all_model()
accelerator = Accelerator(
mixed_precision="fp16",
)
unet = accelerator.prepare(
unet,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device) timesteps = torch.tensor([0], device=device)
@torch.no_grad() @torch.no_grad()
def main(args): def main(args):
global pe global pe
if not (args.unet_checkpoint == None):
print("unet ckpt loaded")
accelerator.load_state(args.unet_checkpoint)
if args.use_float16 is True: if args.use_float16 is True:
pe = pe.half() pe = pe.half()
vae.vae = vae.vae.half() vae.vae = vae.vae.half()
@@ -63,8 +76,6 @@ def main(args):
fps = args.fps fps = args.fps
else: else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
#print(input_img_list)
############################################## extract audio feature ############################################## ############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path) whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
@@ -80,23 +91,26 @@ def main(args):
with open(crop_coord_save_path, 'wb') as f: with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f) pickle.dump(coord_list, f)
i = 0 i = 0
input_latent_list = [] input_latent_list = []
crop_i=0
for bbox, frame in zip(coord_list, frame_list): for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder: if bbox == coord_placeholder:
continue continue
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2] crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
latents = vae.get_latents_for_unet(crop_frame) latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents) input_latent_list.append(latents)
crop_i+=1
# to smooth the first and the last frame # to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1] frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1] coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ############################################## ############################################## inference batch by batch ##############################################
print("start inference")
video_num = len(whisper_chunks) video_num = len(whisper_chunks)
batch_size = args.batch_size batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
@@ -107,7 +121,6 @@ def main(args):
dtype=unet.model.dtype) # torch, B, 5*N,384 dtype=unet.model.dtype) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch) audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype) latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents) recon = vae.decode_latents(pred_latents)
for res_frame in recon: for res_frame in recon:
@@ -122,22 +135,29 @@ def main(args):
try: try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except: except:
# print(bbox)
continue continue
combine_frame = get_image(ori_frame,res_frame,bbox) combine_frame = get_image(ori_frame,res_frame,bbox)
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
print(cmd_img2video)
os.system(cmd_img2video) os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
print(cmd_combine_audio)
os.system(cmd_combine_audio) os.system(cmd_combine_audio)
os.remove("temp.mp4") os.remove("temp.mp4")
shutil.rmtree(result_img_save_path)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
os.system(cmd_img2video)
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
# print(cmd_combine_audio)
# os.system(cmd_combine_audio)
# shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}") print(f"result is save to {output_vid_name}")
if __name__ == "__main__": if __name__ == "__main__":
@@ -156,6 +176,7 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Whether use float16 to speed up inference", help="Whether use float16 to speed up inference",
) )
parser.add_argument("--unet_checkpoint", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -57,13 +57,13 @@ class Dataset(object):
self.audio_feature = [use_audio_length_left,use_audio_length_right] self.audio_feature = [use_audio_length_left,use_audio_length_right]
self.all_img_names = [] self.all_img_names = []
self.split = split self.split = split
self.img_names_path = '...' self.img_names_path = '../data'
self.whisper_model_type = whisper_model_type self.whisper_model_type = whisper_model_type
self.use_audio_length_left = use_audio_length_left self.use_audio_length_left = use_audio_length_left
self.use_audio_length_right = use_audio_length_right self.use_audio_length_right = use_audio_length_right
if self.whisper_model_type =="tiny": if self.whisper_model_type =="tiny":
self.whisper_path = '...' self.whisper_path = '../data/audios'
self.whisper_feature_W = 5 self.whisper_feature_W = 5
self.whisper_feature_H = 384 self.whisper_feature_H = 384
elif self.whisper_model_type =="largeV2": elif self.whisper_model_type =="largeV2":
@@ -72,6 +72,10 @@ class Dataset(object):
self.whisper_feature_H = 1280 self.whisper_feature_H = 1280
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50 self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50
if(self.split=="train"):
self.all_videos=["../data/images/train"]
if(self.split=="val"):
self.all_videos=["../data/images/test"]
for vidname in tqdm(self.all_videos, desc="Preparing dataset"): for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json" json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
if not os.path.exists(json_path_names): if not os.path.exists(json_path_names):
@@ -79,7 +83,6 @@ class Dataset(object):
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0])) img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
with open(json_path_names, "w") as f: with open(json_path_names, "w") as f:
json.dump(img_names,f) json.dump(img_names,f)
print(f"save to {json_path_names}")
else: else:
with open(json_path_names, "r") as f: with open(json_path_names, "r") as f:
img_names = json.load(f) img_names = json.load(f)
@@ -147,7 +150,6 @@ class Dataset(object):
vidname = self.all_videos[idx].split('/')[-1] vidname = self.all_videos[idx].split('/')[-1]
video_imgs = self.all_img_names[idx] video_imgs = self.all_img_names[idx]
if len(video_imgs) == 0: if len(video_imgs) == 0:
# print("video_imgs = 0:",vidname)
continue continue
img_name = random.choice(video_imgs) img_name = random.choice(video_imgs)
img_idx = int(basename(img_name).split(".")[0]) img_idx = int(basename(img_name).split(".")[0])
@@ -205,7 +207,6 @@ class Dataset(object):
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1): for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
# 判定是否越界 # 判定是否越界
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy") audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
if not os.path.exists(audio_feat_path): if not os.path.exists(audio_feat_path):
is_index_out_of_range = True is_index_out_of_range = True
break break
@@ -226,8 +227,6 @@ class Dataset(object):
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}") print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
continue continue
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature)) audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
return ref_image, image, masked_image, mask, audio_feature return ref_image, image, masked_image, mask, audio_feature
@@ -243,10 +242,8 @@ if __name__ == "__main__":
val_data_loader = data_utils.DataLoader( val_data_loader = data_utils.DataLoader(
val_data, batch_size=4, shuffle=True, val_data, batch_size=4, shuffle=True,
num_workers=1) num_workers=1)
print("val_dataset:",val_data_loader.__len__())
for i, data in enumerate(val_data_loader): for i, data in enumerate(val_data_loader):
ref_image, image, masked_image, mask, audio_feature = data ref_image, image, masked_image, mask, audio_feature = data
print("ref_image: ", ref_image.shape)

View File

@@ -1,32 +1,35 @@
# Draft training codes # Data preprocessing
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized. Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
The train yaml file should contain the training video paths and corresponding audio paths
The test yaml file should contain the validation video paths and corresponding audio paths
## Setup Run:
```
python -m scripts.data --inference_config path_to_train.yaml --folder_name train
python -m scripts.data --inference_config path_to_test.yaml --folder_name test
```
This creates folders which contain the image frames and npy files.
We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training.
## Data preprocessing ## Data organization
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
Finally, the data should be organized as follows:
``` ```
./data/ ./data/
├── images ├── images
│ └──RD_Radio10_000 │ └──train
│ └── 0.png │ └── 0.png
│ └── 1.png │ └── 1.png
│ └── xxx.png │ └── xxx.png
│ └──RD_Radio11_000 │ └──test
│ └── 0.png │ └── 0.png
│ └── 1.png │ └── 1.png
│ └── xxx.png │ └── xxx.png
├── audios ├── audios
│ └──RD_Radio10_000 │ └──train
│ └── 0.npy │ └── 0.npy
│ └── 1.npy │ └── 1.npy
│ └── xxx.npy │ └── xxx.npy
│ └──RD_Radio11_000 │ └──test
│ └── 0.npy │ └── 0.npy
│ └── 1.npy │ └── 1.npy
│ └── xxx.npy │ └── xxx.npy
@@ -37,7 +40,12 @@ Simply run after preparing the preprocessed data
``` ```
sh train.sh sh train.sh
``` ```
## Inference with trained checkpoit
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
```
python -m scripts.inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
```
## TODO ## TODO
- [ ] release data preprocessing codes - [x] release data preprocessing codes
- [ ] release some novel designs in training (after technical report) - [ ] release some novel designs in training (after technical report)

View File

@@ -27,10 +27,13 @@ from diffusers import (
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
import sys
sys.path.append("./")
from DataLoader import Dataset from DataLoader import Dataset
from utils.utils import preprocess_img_tensor from utils.utils import preprocess_img_tensor
from torch.utils import data as data_utils from torch.utils import data as data_utils
from model_utils import validation,PositionalEncoding from utils.model_utils import validation,PositionalEncoding
import time import time
import pandas as pd import pandas as pd
from PIL import Image from PIL import Image
@@ -235,12 +238,16 @@ def parse_args():
if env_local_rank != -1 and env_local_rank != args.local_rank: if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank args.local_rank = env_local_rank
return args
def print_model_dtypes(model, model_name):
for name, param in model.named_parameters():
if(param.dtype!=torch.float32):
print(f"{name}: {param.dtype}")
def main(): def main():
args = parse_args() args = parse_args()
print(args)
args.output_dir = f"output/{args.output_dir}" args.output_dir = f"output/{args.output_dir}"
args.val_out_dir = f"val/{args.val_out_dir}" args.val_out_dir = f"val/{args.val_out_dir}"
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
@@ -332,7 +339,7 @@ def main():
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
params_to_optimize = ( params_to_optimize = (
itertools.chain(unet.parameters()) itertools.chain(unet.parameters()))
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate, lr=args.learning_rate,
@@ -348,7 +355,6 @@ def main():
use_audio_length_right=args.use_audio_length_right, use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type whisper_model_type=args.whisper_model_type
) )
print("train_dataset:",train_dataset.__len__())
train_data_loader = data_utils.DataLoader( train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, train_dataset, batch_size=args.train_batch_size, shuffle=True,
num_workers=8) num_workers=8)
@@ -359,7 +365,6 @@ def main():
use_audio_length_right=args.use_audio_length_right, use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type whisper_model_type=args.whisper_model_type
) )
print("val_dataset:",val_dataset.__len__())
val_data_loader = data_utils.DataLoader( val_data_loader = data_utils.DataLoader(
val_dataset, batch_size=1, shuffle=False, val_dataset, batch_size=1, shuffle=False,
num_workers=8) num_workers=8)
@@ -388,6 +393,7 @@ def main():
vae_fp32.requires_grad_(False) vae_fp32.requires_grad_(False)
weight_dtype = torch.float32 weight_dtype = torch.float32
# weight_dtype = torch.float16
vae_fp32.to(accelerator.device, dtype=weight_dtype) vae_fp32.to(accelerator.device, dtype=weight_dtype)
vae_fp32.encoder = None vae_fp32.encoder = None
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
@@ -412,6 +418,8 @@ def main():
# Train! # Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print(f" Num batches each epoch = {len(train_data_loader)}")
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_data_loader)}") logger.info(f" Num batches each epoch = {len(train_data_loader)}")
@@ -433,6 +441,9 @@ def main():
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None path = dirs[-1] if len(dirs) > 0 else None
# path="../models/pytorch_model.bin"
#TODO change path
# path=None
if path is None: if path is None:
accelerator.print( accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
@@ -459,9 +470,10 @@ def main():
elapsed_time = [] elapsed_time = []
start = time.time() start = time.time()
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
# for step, batch in enumerate(train_dataloader):
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader): for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
# Skip steps until we reach the resumed step # Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -470,24 +482,23 @@ def main():
continue continue
dataloader_time = time.time() - start dataloader_time = time.time() - start
start = time.time() start = time.time()
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device) masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
""" # """
print("=============epoch:{0}=step:{1}=====".format(epoch,step)) # print("=============epoch:{0}=step:{1}=====".format(epoch,step))
print("ref_image: ",ref_image.shape) # print("ref_image: ",ref_image.shape)
print("masks: ", masks.shape) # print("masks: ", masks.shape)
print("masked_image: ", masked_image.shape) # print("masked_image: ", masked_image.shape)
print("audio feature: ", audio_feature.shape) # print("audio feature: ", audio_feature.shape)
print("image: ", image.shape) # print("image: ", image.shape)
""" # """
ref_image = preprocess_img_tensor(ref_image).to(vae.device) ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device) image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device) masked_image = preprocess_img_tensor(masked_image).to(vae.device)
img_process_time = time.time() - start img_process_time = time.time() - start
start = time.time() start = time.time()
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
vae = vae.half()
# Convert images to latent space # Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
@@ -592,12 +603,23 @@ def main():
f"Running validation... epoch={epoch}, global_step={global_step}" f"Running validation... epoch={epoch}, global_step={global_step}"
) )
print("===========start validation==========") print("===========start validation==========")
# Use the helper function to check the data types for each model
vae_new = vae.float()
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
print(f"weight_dtype: {weight_dtype}")
print(f"epoch type: {type(epoch)}")
print(f"global_step type: {type(global_step)}")
validation( validation(
vae=accelerator.unwrap_model(vae), # vae=accelerator.unwrap_model(vae),
vae=accelerator.unwrap_model(vae_new),
vae_fp32=accelerator.unwrap_model(vae_fp32), vae_fp32=accelerator.unwrap_model(vae_fp32),
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
unet_config=unet_config, unet_config=unet_config,
weight_dtype=weight_dtype, # weight_dtype=weight_dtype,
weight_dtype=torch.float32,
epoch=epoch, epoch=epoch,
global_step=global_step, global_step=global_step,
val_data_loader=val_data_loader, val_data_loader=val_data_loader,

View File

@@ -1,8 +1,8 @@
export VAE_MODEL="./sd-vae-ft-mse/" export VAE_MODEL="../models/sd-vae-ft-mse/"
export DATASET="..." export DATASET="../data"
export UNET_CONFIG="./musetalk.json" export UNET_CONFIG="../models/musetalk/musetalk.json"
accelerate launch --multi_gpu train.py \ accelerate launch train.py \
--mixed_precision="fp16" \ --mixed_precision="fp16" \
--unet_config_file=$UNET_CONFIG \ --unet_config_file=$UNET_CONFIG \
--pretrained_model_name_or_path=$VAE_MODEL \ --pretrained_model_name_or_path=$VAE_MODEL \
@@ -10,13 +10,13 @@ accelerate launch --multi_gpu train.py \
--train_batch_size=8 \ --train_batch_size=8 \
--gradient_accumulation_steps=4 \ --gradient_accumulation_steps=4 \
--gradient_checkpointing \ --gradient_checkpointing \
--max_train_steps=200000 \ --max_train_steps=50000 \
--learning_rate=5e-05 \ --learning_rate=5e-05 \
--max_grad_norm=1 \ --max_grad_norm=1 \
--lr_scheduler="cosine" \ --lr_scheduler="cosine" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--output_dir="..." \ --output_dir="output" \
--val_out_dir='...' \ --val_out_dir='val' \
--testing_speed \ --testing_speed \
--checkpointing_steps=1000 \ --checkpointing_steps=1000 \
--validation_steps=1000 \ --validation_steps=1000 \

View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import time import time
import math import math
from utils import decode_latents, preprocess_img_tensor from utils.utils import decode_latents, preprocess_img_tensor
import os import os
from PIL import Image from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union