From d7b9a37359086d659510f5c2d5f0140a5e6585cf Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Mon, 18 Sep 2023 11:37:58 +0000 Subject: [PATCH] adding more validation to multispeaker CLI --- matcha/app.py | 4 ++-- matcha/cli.py | 45 ++++++++++++++++++++++++++++++--------------- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/matcha/app.py b/matcha/app.py index f34f0d1..5554b3b 100644 --- a/matcha/app.py +++ b/matcha/app.py @@ -8,7 +8,7 @@ import torch from matcha.cli import ( MATCHA_URLS, - VOCODER_URL, + VOCODER_URLS, assert_model_downloaded, get_device, load_matcha, @@ -31,7 +31,7 @@ MATCHA_TTS_LOC = LOCATION / f"{args.model}.ckpt" VOCODER_LOC = LOCATION / f"{args.vocoder}" LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png" assert_model_downloaded(MATCHA_TTS_LOC, MATCHA_URLS[args.model]) -assert_model_downloaded(VOCODER_LOC, VOCODER_URL[args.vocoder]) +assert_model_downloaded(VOCODER_LOC, VOCODER_URLS[args.vocoder]) device = get_device(args) model = load_matcha(args.model, MATCHA_TTS_LOC, device) diff --git a/matcha/cli.py b/matcha/cli.py index 16bfd14..06459ff 100644 --- a/matcha/cli.py +++ b/matcha/cli.py @@ -1,6 +1,7 @@ import argparse import datetime as dt import os +import warnings from pathlib import Path import matplotlib.pyplot as plt @@ -23,7 +24,10 @@ MATCHA_URLS = { MULTISPEAKER_MODEL = {"matcha_vctk"} SINGLESPEAKER_MODEL = {"matcha_ljspeech"} -VOCODER_URL = {"hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link"} +VOCODER_URLS = { + "hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link", + "hifigan_univ_v1": "https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link", +} def plot_spectrogram_to_numpy(spectrogram, filename): @@ -62,10 +66,14 @@ def get_texts(args): def assert_required_models_available(args): save_dir = get_user_data_dir() - model_path = save_dir / f"{args.model}.ckpt" + if not hasattr(args, "checkpoint_path") and args.checkpoint_path is None: + model_path = args.checkpoint_path + else: + model_path = save_dir / f"{args.model}.ckpt" + assert_model_downloaded(model_path, MATCHA_URLS[args.model]) + vocoder_path = save_dir / f"{args.vocoder}" - assert_model_downloaded(model_path, MATCHA_URLS[args.model]) - assert_model_downloaded(vocoder_path, VOCODER_URL[args.vocoder]) + assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder]) return {"matcha": model_path, "vocoder": vocoder_path} @@ -81,7 +89,7 @@ def load_hifigan(checkpoint_path, device): def load_vocoder(vocoder_name, checkpoint_path, device): print(f"[!] Loading {vocoder_name}!") vocoder = None - if vocoder_name == "hifigan_T2_v1": + if vocoder_name in ("hifigan_T2_v1", "hifigan_univ_v1"): vocoder = load_hifigan(checkpoint_path, device) else: raise NotImplementedError( @@ -126,16 +134,23 @@ def validate_args(args): assert args.temperature >= 0, "Sampling temperature cannot be negative" assert args.speaking_rate > 0, "Speaking rate must be greater than 0" assert args.steps > 0, "Number of ODE steps must be greater than 0" - if args.model in SINGLESPEAKER_MODEL: - assert args.spk is None, f"Speaker ID is not supported for {args.model}" - if args.spk is not None: - assert args.spk >= 0 and args.spk < 109, "Speaker ID must be between 0 and 108" - assert args.model in MULTISPEAKER_MODEL, "Speaker ID is only supported for multispeaker model" + if args.checkpoint_path is None: + if args.model in SINGLESPEAKER_MODEL: + assert args.spk is None, f"Speaker ID is not supported for {args.model}" - if args.model in MULTISPEAKER_MODEL: - if args.spk is None: - print("[!] Speaker ID not provided! Using speaker ID 0") - args.spk = 0 + if args.spk is not None: + assert args.spk >= 0 and args.spk < 109, "Speaker ID must be between 0 and 108" + assert args.model in MULTISPEAKER_MODEL, "Speaker ID is only supported for multispeaker model" + + if args.model in MULTISPEAKER_MODEL: + if args.spk is None: + print("[!] Speaker ID not provided! Using speaker ID 0") + args.spk = 0 + args.vocoder = "hifigan_univ_v1" + else: + if args.vocoder != "hifigan_univ_v1": + warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech." + warnings.warn(warn_, UserWarning) if args.batched: assert args.batch_size > 0, "Batch size must be greater than 0" @@ -168,7 +183,7 @@ def cli(): type=str, default="hifigan_T2_v1", help="Vocoder to use", - choices=VOCODER_URL.keys(), + choices=VOCODER_URLS.keys(), ) parser.add_argument("--text", type=str, default=None, help="Text to synthesize") parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")