mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
temporary commit to save changes
This commit is contained in:
@@ -140,6 +140,8 @@ def parse_args():
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
|
||||
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
@@ -350,7 +352,7 @@ def main():
|
||||
|
||||
print("loading train_dataset ...")
|
||||
train_dataset = Dataset(args.data_root,
|
||||
'train',
|
||||
args.train_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
@@ -360,7 +362,7 @@ def main():
|
||||
num_workers=8)
|
||||
print("loading val_dataset ...")
|
||||
val_dataset = Dataset(args.data_root,
|
||||
'val',
|
||||
args.val_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
|
||||
Reference in New Issue
Block a user