From 39cbd85236e46236a61cccb444afd8063b4cf41b Mon Sep 17 00:00:00 2001 From: Shivam Mehta Date: Fri, 12 Jan 2024 11:09:25 +0000 Subject: [PATCH] Using Wget for new ckpt downloadsA --- .pylintrc | 11 ++--------- matcha/app.py | 15 +++++++++++---- matcha/cli.py | 2 +- matcha/utils/utils.py | 4 +++- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/.pylintrc b/.pylintrc index 7ab186a..9628641 100644 --- a/.pylintrc +++ b/.pylintrc @@ -292,13 +292,6 @@ max-line-length=120 # Maximum number of lines in a module. max-module-lines=1000 -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma, - dict-separator - # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no @@ -528,5 +521,5 @@ min-public-methods=2 # Exceptions that will emit a warning when being caught. Defaults to # "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception +overgeneral-exceptions=builtins.BaseException, + builtins.Exception diff --git a/matcha/app.py b/matcha/app.py index 16e8077..0091e7a 100644 --- a/matcha/app.py +++ b/matcha/app.py @@ -29,8 +29,15 @@ args = Namespace( CURRENTLY_LOADED_MODEL = args.model -MATCHA_TTS_LOC = lambda x: LOCATION / f"{x}.ckpt" # noqa: E731 -VOCODER_LOC = lambda x: LOCATION / f"{x}" # noqa: E731 + +def MATCHA_TTS_LOC(x): + return LOCATION / f"{x}.ckpt" + + +def VOCODER_LOC(x): + return LOCATION / f"{x}" + + LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png" RADIO_OPTIONS = { "Multi Speaker (VCTK)": { @@ -44,9 +51,9 @@ RADIO_OPTIONS = { } # Ensure all the required models are downloaded -assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"]) +assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"], use_wget=True) assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"]) -assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"]) +assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"], use_wget=True) assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"]) device = get_device(args) diff --git a/matcha/cli.py b/matcha/cli.py index f3c29a7..29d4f48 100644 --- a/matcha/cli.py +++ b/matcha/cli.py @@ -74,7 +74,7 @@ def assert_required_models_available(args): model_path = args.checkpoint_path else: model_path = save_dir / f"{args.model}.ckpt" - assert_model_downloaded(model_path, MATCHA_URLS[args.model]) + assert_model_downloaded(model_path, MATCHA_URLS[args.model], use_wget=True) vocoder_path = save_dir / f"{args.vocoder}" assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder]) diff --git a/matcha/utils/utils.py b/matcha/utils/utils.py index 5f8162d..adb9290 100644 --- a/matcha/utils/utils.py +++ b/matcha/utils/utils.py @@ -115,7 +115,7 @@ def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: return None if metric_name not in metric_dict: - raise Exception( + raise ValueError( f"Metric value not found! \n" "Make sure metric name logged in LightningModule is correct!\n" "Make sure `optimized_metric` name in `hparams_search` config is correct!" @@ -208,8 +208,10 @@ def get_user_data_dir(appname="matcha_tts"): def assert_model_downloaded(checkpoint_path, url, use_wget=False): if Path(checkpoint_path).exists(): log.debug(f"[+] Model already present at {checkpoint_path}!") + print(f"[+] Model already present at {checkpoint_path}!") return log.info(f"[-] Model not found at {checkpoint_path}! Will download it") + print(f"[-] Model not found at {checkpoint_path}! Will download it") checkpoint_path = str(checkpoint_path) if not use_wget: gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)