Using Wget for new ckpt downloadsA

This commit is contained in:
Shivam Mehta
2024-01-12 11:09:25 +00:00
parent 95ec24b599
commit 39cbd85236
4 changed files with 17 additions and 15 deletions

View File

@@ -115,7 +115,7 @@ def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
return None
if metric_name not in metric_dict:
raise Exception(
raise ValueError(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
@@ -208,8 +208,10 @@ def get_user_data_dir(appname="matcha_tts"):
def assert_model_downloaded(checkpoint_path, url, use_wget=False):
if Path(checkpoint_path).exists():
log.debug(f"[+] Model already present at {checkpoint_path}!")
print(f"[+] Model already present at {checkpoint_path}!")
return
log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
print(f"[-] Model not found at {checkpoint_path}! Will download it")
checkpoint_path = str(checkpoint_path)
if not use_wget:
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)