mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update hifigan
This commit is contained in:
@@ -18,6 +18,7 @@ import datetime
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import deepspeed
|
||||
@@ -112,7 +113,10 @@ def main():
|
||||
# load checkpoint
|
||||
model = configs[args.model]
|
||||
if args.checkpoint is not None:
|
||||
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
|
||||
if os.path.exists(args.checkpoint):
|
||||
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
|
||||
else:
|
||||
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
||||
|
||||
# Dispatch model from cpu to gpu
|
||||
model = wrap_cuda_model(args, model)
|
||||
@@ -125,7 +129,7 @@ def main():
|
||||
save_model(model, 'init', info_dict)
|
||||
|
||||
# Get executor
|
||||
executor = Executor()
|
||||
executor = Executor(gan=gan)
|
||||
|
||||
# Start training loop
|
||||
for epoch in range(info_dict['max_epoch']):
|
||||
|
||||
@@ -393,8 +393,6 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
|
||||
"speech_token_len": speech_token_len,
|
||||
"speech_feat": speech_feat,
|
||||
"speech_feat_len": speech_feat_len,
|
||||
"pitch_feat": pitch_feat,
|
||||
"pitch_feat_len": pitch_feat_len,
|
||||
"text": text,
|
||||
"text_token": text_token,
|
||||
"text_token_len": text_token_len,
|
||||
|
||||
Reference in New Issue
Block a user