online feature

This commit is contained in:
lyuxiang.lx
2026-01-28 15:19:07 +00:00
parent 1822c5c908
commit 66b80dbccb
22 changed files with 133 additions and 116 deletions

View File

@@ -49,6 +49,7 @@ def get_args():
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
@@ -96,6 +97,7 @@ def get_args():
@record
def main():
args = get_args()
os.environ['onnx_path'] = args.onnx_path
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
# gan train has some special initialization logic
@@ -104,12 +106,10 @@ def main():
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True:
override_dict.pop('hift')
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
except Exception:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if args.qwen_pretrain_path is not None:
override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if gan is True:
configs['train_conf'] = configs['train_conf_gan']
configs['train_conf'].update(vars(args))