Using Wget for new ckpt downloadsA

This commit is contained in:
Shivam Mehta
2024-01-12 11:09:25 +00:00
parent 95ec24b599
commit 39cbd85236
4 changed files with 17 additions and 15 deletions

View File

@@ -292,13 +292,6 @@ max-line-length=120
# Maximum number of lines in a module. # Maximum number of lines in a module.
max-module-lines=1000 max-module-lines=1000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body # Allow the body of a class to be on the same line as the declaration if body
# contains single statement. # contains single statement.
single-line-class-stmt=no single-line-class-stmt=no
@@ -528,5 +521,5 @@ min-public-methods=2
# Exceptions that will emit a warning when being caught. Defaults to # Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception". # "BaseException, Exception".
overgeneral-exceptions=BaseException, overgeneral-exceptions=builtins.BaseException,
Exception builtins.Exception

View File

@@ -29,8 +29,15 @@ args = Namespace(
CURRENTLY_LOADED_MODEL = args.model CURRENTLY_LOADED_MODEL = args.model
MATCHA_TTS_LOC = lambda x: LOCATION / f"{x}.ckpt" # noqa: E731
VOCODER_LOC = lambda x: LOCATION / f"{x}" # noqa: E731 def MATCHA_TTS_LOC(x):
return LOCATION / f"{x}.ckpt"
def VOCODER_LOC(x):
return LOCATION / f"{x}"
LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png" LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
RADIO_OPTIONS = { RADIO_OPTIONS = {
"Multi Speaker (VCTK)": { "Multi Speaker (VCTK)": {
@@ -44,9 +51,9 @@ RADIO_OPTIONS = {
} }
# Ensure all the required models are downloaded # Ensure all the required models are downloaded
assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"]) assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"], use_wget=True)
assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"]) assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"]) assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"], use_wget=True)
assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"]) assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
device = get_device(args) device = get_device(args)

View File

@@ -74,7 +74,7 @@ def assert_required_models_available(args):
model_path = args.checkpoint_path model_path = args.checkpoint_path
else: else:
model_path = save_dir / f"{args.model}.ckpt" model_path = save_dir / f"{args.model}.ckpt"
assert_model_downloaded(model_path, MATCHA_URLS[args.model]) assert_model_downloaded(model_path, MATCHA_URLS[args.model], use_wget=True)
vocoder_path = save_dir / f"{args.vocoder}" vocoder_path = save_dir / f"{args.vocoder}"
assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder]) assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder])

View File

@@ -115,7 +115,7 @@ def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
return None return None
if metric_name not in metric_dict: if metric_name not in metric_dict:
raise Exception( raise ValueError(
f"Metric value not found! <metric_name={metric_name}>\n" f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n" "Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!" "Make sure `optimized_metric` name in `hparams_search` config is correct!"
@@ -208,8 +208,10 @@ def get_user_data_dir(appname="matcha_tts"):
def assert_model_downloaded(checkpoint_path, url, use_wget=False): def assert_model_downloaded(checkpoint_path, url, use_wget=False):
if Path(checkpoint_path).exists(): if Path(checkpoint_path).exists():
log.debug(f"[+] Model already present at {checkpoint_path}!") log.debug(f"[+] Model already present at {checkpoint_path}!")
print(f"[+] Model already present at {checkpoint_path}!")
return return
log.info(f"[-] Model not found at {checkpoint_path}! Will download it") log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
print(f"[-] Model not found at {checkpoint_path}! Will download it")
checkpoint_path = str(checkpoint_path) checkpoint_path = str(checkpoint_path)
if not use_wget: if not use_wget:
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)