mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 17:59:19 +08:00
Using Wget for new ckpt downloadsA
This commit is contained in:
@@ -29,8 +29,15 @@ args = Namespace(
|
||||
|
||||
CURRENTLY_LOADED_MODEL = args.model
|
||||
|
||||
MATCHA_TTS_LOC = lambda x: LOCATION / f"{x}.ckpt" # noqa: E731
|
||||
VOCODER_LOC = lambda x: LOCATION / f"{x}" # noqa: E731
|
||||
|
||||
def MATCHA_TTS_LOC(x):
|
||||
return LOCATION / f"{x}.ckpt"
|
||||
|
||||
|
||||
def VOCODER_LOC(x):
|
||||
return LOCATION / f"{x}"
|
||||
|
||||
|
||||
LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
|
||||
RADIO_OPTIONS = {
|
||||
"Multi Speaker (VCTK)": {
|
||||
@@ -44,9 +51,9 @@ RADIO_OPTIONS = {
|
||||
}
|
||||
|
||||
# Ensure all the required models are downloaded
|
||||
assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"])
|
||||
assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"], use_wget=True)
|
||||
assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
|
||||
assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"])
|
||||
assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"], use_wget=True)
|
||||
assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
|
||||
|
||||
device = get_device(args)
|
||||
|
||||
@@ -74,7 +74,7 @@ def assert_required_models_available(args):
|
||||
model_path = args.checkpoint_path
|
||||
else:
|
||||
model_path = save_dir / f"{args.model}.ckpt"
|
||||
assert_model_downloaded(model_path, MATCHA_URLS[args.model])
|
||||
assert_model_downloaded(model_path, MATCHA_URLS[args.model], use_wget=True)
|
||||
|
||||
vocoder_path = save_dir / f"{args.vocoder}"
|
||||
assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder])
|
||||
|
||||
@@ -115,7 +115,7 @@ def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
|
||||
return None
|
||||
|
||||
if metric_name not in metric_dict:
|
||||
raise Exception(
|
||||
raise ValueError(
|
||||
f"Metric value not found! <metric_name={metric_name}>\n"
|
||||
"Make sure metric name logged in LightningModule is correct!\n"
|
||||
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
||||
@@ -208,8 +208,10 @@ def get_user_data_dir(appname="matcha_tts"):
|
||||
def assert_model_downloaded(checkpoint_path, url, use_wget=False):
|
||||
if Path(checkpoint_path).exists():
|
||||
log.debug(f"[+] Model already present at {checkpoint_path}!")
|
||||
print(f"[+] Model already present at {checkpoint_path}!")
|
||||
return
|
||||
log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
|
||||
print(f"[-] Model not found at {checkpoint_path}! Will download it")
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if not use_wget:
|
||||
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
|
||||
|
||||
Reference in New Issue
Block a user