temporary commit to save changes

This commit is contained in:
Shounak Banerjee
2024-06-13 14:14:52 +00:00
parent d74c4c098b
commit af82f3b00f
5 changed files with 125 additions and 15 deletions

View File

@@ -48,15 +48,15 @@ def get_image_list(data_root, split):
class Dataset(object):
def __init__(self,
data_root,
split,
json_path,
use_audio_length_left=1,
use_audio_length_right=1,
whisper_model_type = "tiny"
):
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
self.audio_feature = [use_audio_length_left,use_audio_length_right]
self.all_img_names = []
self.split = split
# self.split = split
self.img_names_path = '../data'
self.whisper_model_type = whisper_model_type
self.use_audio_length_left = use_audio_length_left
@@ -72,10 +72,13 @@ class Dataset(object):
self.whisper_feature_H = 1280
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50
if(self.split=="train"):
self.all_videos=["../data/images/train"]
if(self.split=="val"):
self.all_videos=["../data/images/test"]
# if(self.split=="train"):
# self.all_videos=["../data/images/train"]
# if(self.split=="val"):
# self.all_videos=["../data/images/test"]
with open(json_path, 'r') as file:
self.all_videos = json.load(file)
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
if not os.path.exists(json_path_names):

View File

@@ -140,6 +140,8 @@ def parse_args():
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
parser.add_argument(
"--hub_model_id",
type=str,
@@ -350,7 +352,7 @@ def main():
print("loading train_dataset ...")
train_dataset = Dataset(args.data_root,
'train',
args.train_json,
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
@@ -360,7 +362,7 @@ def main():
num_workers=8)
print("loading val_dataset ...")
val_dataset = Dataset(args.data_root,
'val',
args.val_json,
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type

View File

@@ -18,10 +18,12 @@ accelerate launch train.py \
--output_dir="output" \
--val_out_dir='val' \
--testing_speed \
--checkpointing_steps=1000 \
--validation_steps=1000 \
--checkpointing_steps=2000 \
--validation_steps=2000 \
--reconstruction \
--resume_from_checkpoint="latest" \
--use_audio_length_left=2 \
--use_audio_length_right=2 \
--whisper_model_type="tiny" \
--train_json="/root/MuseTalk/train.json" \
--val_json="/root/MuseTalk/val.json" \