update hifigan

This commit is contained in:
lyuxiang.lx
2024-10-16 12:15:49 +08:00
parent 789ee9e5e7
commit 73784974ce
3 changed files with 64 additions and 5 deletions

View File

@@ -18,6 +18,7 @@ import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import os
import torch
import torch.distributed as dist
import deepspeed
@@ -112,7 +113,10 @@ def main():
# load checkpoint
model = configs[args.model]
if args.checkpoint is not None:
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
if os.path.exists(args.checkpoint):
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
else:
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
# Dispatch model from cpu to gpu
model = wrap_cuda_model(args, model)
@@ -125,7 +129,7 @@ def main():
save_model(model, 'init', info_dict)
# Get executor
executor = Executor()
executor = Executor(gan=gan)
# Start training loop
for epoch in range(info_dict['max_epoch']):