adding more validation to multispeaker CLI

This commit is contained in:
Shivam Mehta
2023-09-18 11:37:58 +00:00
parent ec43ef0732
commit d7b9a37359
2 changed files with 32 additions and 17 deletions

View File

@@ -8,7 +8,7 @@ import torch
from matcha.cli import ( from matcha.cli import (
MATCHA_URLS, MATCHA_URLS,
VOCODER_URL, VOCODER_URLS,
assert_model_downloaded, assert_model_downloaded,
get_device, get_device,
load_matcha, load_matcha,
@@ -31,7 +31,7 @@ MATCHA_TTS_LOC = LOCATION / f"{args.model}.ckpt"
VOCODER_LOC = LOCATION / f"{args.vocoder}" VOCODER_LOC = LOCATION / f"{args.vocoder}"
LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png" LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
assert_model_downloaded(MATCHA_TTS_LOC, MATCHA_URLS[args.model]) assert_model_downloaded(MATCHA_TTS_LOC, MATCHA_URLS[args.model])
assert_model_downloaded(VOCODER_LOC, VOCODER_URL[args.vocoder]) assert_model_downloaded(VOCODER_LOC, VOCODER_URLS[args.vocoder])
device = get_device(args) device = get_device(args)
model = load_matcha(args.model, MATCHA_TTS_LOC, device) model = load_matcha(args.model, MATCHA_TTS_LOC, device)

View File

@@ -1,6 +1,7 @@
import argparse import argparse
import datetime as dt import datetime as dt
import os import os
import warnings
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -23,7 +24,10 @@ MATCHA_URLS = {
MULTISPEAKER_MODEL = {"matcha_vctk"} MULTISPEAKER_MODEL = {"matcha_vctk"}
SINGLESPEAKER_MODEL = {"matcha_ljspeech"} SINGLESPEAKER_MODEL = {"matcha_ljspeech"}
VOCODER_URL = {"hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link"} VOCODER_URLS = {
"hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link",
"hifigan_univ_v1": "https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link",
}
def plot_spectrogram_to_numpy(spectrogram, filename): def plot_spectrogram_to_numpy(spectrogram, filename):
@@ -62,10 +66,14 @@ def get_texts(args):
def assert_required_models_available(args): def assert_required_models_available(args):
save_dir = get_user_data_dir() save_dir = get_user_data_dir()
model_path = save_dir / f"{args.model}.ckpt" if not hasattr(args, "checkpoint_path") and args.checkpoint_path is None:
model_path = args.checkpoint_path
else:
model_path = save_dir / f"{args.model}.ckpt"
assert_model_downloaded(model_path, MATCHA_URLS[args.model])
vocoder_path = save_dir / f"{args.vocoder}" vocoder_path = save_dir / f"{args.vocoder}"
assert_model_downloaded(model_path, MATCHA_URLS[args.model]) assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder])
assert_model_downloaded(vocoder_path, VOCODER_URL[args.vocoder])
return {"matcha": model_path, "vocoder": vocoder_path} return {"matcha": model_path, "vocoder": vocoder_path}
@@ -81,7 +89,7 @@ def load_hifigan(checkpoint_path, device):
def load_vocoder(vocoder_name, checkpoint_path, device): def load_vocoder(vocoder_name, checkpoint_path, device):
print(f"[!] Loading {vocoder_name}!") print(f"[!] Loading {vocoder_name}!")
vocoder = None vocoder = None
if vocoder_name == "hifigan_T2_v1": if vocoder_name in ("hifigan_T2_v1", "hifigan_univ_v1"):
vocoder = load_hifigan(checkpoint_path, device) vocoder = load_hifigan(checkpoint_path, device)
else: else:
raise NotImplementedError( raise NotImplementedError(
@@ -126,16 +134,23 @@ def validate_args(args):
assert args.temperature >= 0, "Sampling temperature cannot be negative" assert args.temperature >= 0, "Sampling temperature cannot be negative"
assert args.speaking_rate > 0, "Speaking rate must be greater than 0" assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
assert args.steps > 0, "Number of ODE steps must be greater than 0" assert args.steps > 0, "Number of ODE steps must be greater than 0"
if args.model in SINGLESPEAKER_MODEL: if args.checkpoint_path is None:
assert args.spk is None, f"Speaker ID is not supported for {args.model}" if args.model in SINGLESPEAKER_MODEL:
if args.spk is not None: assert args.spk is None, f"Speaker ID is not supported for {args.model}"
assert args.spk >= 0 and args.spk < 109, "Speaker ID must be between 0 and 108"
assert args.model in MULTISPEAKER_MODEL, "Speaker ID is only supported for multispeaker model"
if args.model in MULTISPEAKER_MODEL: if args.spk is not None:
if args.spk is None: assert args.spk >= 0 and args.spk < 109, "Speaker ID must be between 0 and 108"
print("[!] Speaker ID not provided! Using speaker ID 0") assert args.model in MULTISPEAKER_MODEL, "Speaker ID is only supported for multispeaker model"
args.spk = 0
if args.model in MULTISPEAKER_MODEL:
if args.spk is None:
print("[!] Speaker ID not provided! Using speaker ID 0")
args.spk = 0
args.vocoder = "hifigan_univ_v1"
else:
if args.vocoder != "hifigan_univ_v1":
warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech."
warnings.warn(warn_, UserWarning)
if args.batched: if args.batched:
assert args.batch_size > 0, "Batch size must be greater than 0" assert args.batch_size > 0, "Batch size must be greater than 0"
@@ -168,7 +183,7 @@ def cli():
type=str, type=str,
default="hifigan_T2_v1", default="hifigan_T2_v1",
help="Vocoder to use", help="Vocoder to use",
choices=VOCODER_URL.keys(), choices=VOCODER_URLS.keys(),
) )
parser.add_argument("--text", type=str, default=None, help="Text to synthesize") parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")