diff --git a/.github/dependabot.yml b/.github/dependabot.yml
index 5a861fd..b19ccab 100644
--- a/.github/dependabot.yml
+++ b/.github/dependabot.yml
@@ -7,6 +7,7 @@ version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
+ target-branch: "dev"
schedule:
interval: "daily"
ignore:
diff --git a/Makefile b/Makefile
index bf66306..4b523dd 100644
--- a/Makefile
+++ b/Makefile
@@ -17,7 +17,7 @@ create-package: ## Create wheel and tar gz
rm -rf dist/
python setup.py bdist_wheel --plat-name=manylinux1_x86_64
python setup.py sdist
- python -m twine upload dist/* --verbose
+ python -m twine upload dist/* --verbose --skip-existing
format: ## Run pre-commit hooks
pre-commit run -a
diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml
index 1886902..e3932c8 100644
--- a/configs/debug/default.yaml
+++ b/configs/debug/default.yaml
@@ -7,8 +7,8 @@
task_name: "debug"
# disable callbacks and loggers during debugging
-callbacks: null
-logger: null
+# callbacks: null
+# logger: null
extras:
ignore_warnings: False
diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml
index 2bd7da8..266295f 100644
--- a/configs/debug/profiler.yaml
+++ b/configs/debug/profiler.yaml
@@ -7,6 +7,9 @@ defaults:
trainer:
max_epochs: 1
- profiler: "simple"
- # profiler: "advanced"
+ # profiler: "simple"
+ profiler: "advanced"
# profiler: "pytorch"
+ accelerator: gpu
+
+ limit_train_batches: 0.02
diff --git a/matcha/VERSION b/matcha/VERSION
index 50ec600..8acdd82 100644
--- a/matcha/VERSION
+++ b/matcha/VERSION
@@ -1 +1 @@
-0.0.1.dev4
+0.0.1
diff --git a/matcha/app.py b/matcha/app.py
index 5554b3b..16e8077 100644
--- a/matcha/app.py
+++ b/matcha/app.py
@@ -22,20 +22,73 @@ LOCATION = Path(get_user_data_dir())
args = Namespace(
cpu=False,
- model="matcha_ljspeech",
- vocoder="hifigan_T2_v1",
- spk=None,
+ model="matcha_vctk",
+ vocoder="hifigan_univ_v1",
+ spk=0,
)
-MATCHA_TTS_LOC = LOCATION / f"{args.model}.ckpt"
-VOCODER_LOC = LOCATION / f"{args.vocoder}"
+CURRENTLY_LOADED_MODEL = args.model
+
+MATCHA_TTS_LOC = lambda x: LOCATION / f"{x}.ckpt" # noqa: E731
+VOCODER_LOC = lambda x: LOCATION / f"{x}" # noqa: E731
LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
-assert_model_downloaded(MATCHA_TTS_LOC, MATCHA_URLS[args.model])
-assert_model_downloaded(VOCODER_LOC, VOCODER_URLS[args.vocoder])
+RADIO_OPTIONS = {
+ "Multi Speaker (VCTK)": {
+ "model": "matcha_vctk",
+ "vocoder": "hifigan_univ_v1",
+ },
+ "Single Speaker (LJ Speech)": {
+ "model": "matcha_ljspeech",
+ "vocoder": "hifigan_T2_v1",
+ },
+}
+
+# Ensure all the required models are downloaded
+assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"])
+assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
+assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"])
+assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
+
device = get_device(args)
-model = load_matcha(args.model, MATCHA_TTS_LOC, device)
-vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC, device)
+# Load default model
+model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device)
+vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device)
+
+
+def load_model(model_name, vocoder_name):
+ model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device)
+ vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device)
+ return model, vocoder, denoiser
+
+
+def load_model_ui(model_type, textbox):
+ model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"]
+
+ global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
+ if CURRENTLY_LOADED_MODEL != model_name:
+ model, vocoder, denoiser = load_model(model_name, vocoder_name)
+ CURRENTLY_LOADED_MODEL = model_name
+
+ if model_name == "matcha_ljspeech":
+ spk_slider = gr.update(visible=False, value=-1)
+ single_speaker_examples = gr.update(visible=True)
+ multi_speaker_examples = gr.update(visible=False)
+ length_scale = gr.update(value=0.95)
+ else:
+ spk_slider = gr.update(visible=True, value=0)
+ single_speaker_examples = gr.update(visible=False)
+ multi_speaker_examples = gr.update(visible=True)
+ length_scale = gr.update(value=0.85)
+
+ return (
+ textbox,
+ gr.update(interactive=True),
+ spk_slider,
+ single_speaker_examples,
+ multi_speaker_examples,
+ length_scale,
+ )
@torch.inference_mode()
@@ -45,13 +98,14 @@ def process_text_gradio(text):
@torch.inference_mode()
-def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale):
+def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk):
+ spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
output = model.synthesise(
text,
text_length,
n_timesteps=n_timesteps,
temperature=temperature,
- spks=args.spk,
+ spks=spk,
length_scale=length_scale,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
@@ -61,9 +115,27 @@ def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale):
return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
-def run_full_synthesis(text, n_timesteps, mel_temp, length_scale):
+def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk):
+ global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
+ if CURRENTLY_LOADED_MODEL != "matcha_vctk":
+ global model, vocoder, denoiser # pylint: disable=global-statement
+ model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1")
+ CURRENTLY_LOADED_MODEL = "matcha_vctk"
+
phones, text, text_lengths = process_text_gradio(text)
- audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale)
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
+ return phones, audio, mel_spectrogram
+
+
+def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1):
+ global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
+ if CURRENTLY_LOADED_MODEL != "matcha_ljspeech":
+ global model, vocoder, denoiser # pylint: disable=global-statement
+ model, vocoder, denoiser = load_model("matcha_ljspeech", "hifigan_T2_v1")
+ CURRENTLY_LOADED_MODEL = "matcha_ljspeech"
+
+ phones, text, text_lengths = process_text_gradio(text)
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
return phones, audio, mel_spectrogram
@@ -92,13 +164,24 @@ def main():
with gr.Box():
with gr.Row():
gr.Markdown(description, scale=3)
- gr.Image(LOGO_URL, label="Matcha-TTS logo", height=150, width=150, scale=1, show_label=False)
+ with gr.Column():
+ gr.Image(LOGO_URL, label="Matcha-TTS logo", height=50, width=50, scale=1, show_label=False)
+ html = '
'
+ gr.HTML(html)
with gr.Box():
+ radio_options = list(RADIO_OPTIONS.keys())
+ model_type = gr.Radio(
+ radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False
+ )
+
with gr.Row():
gr.Markdown("# Text Input")
with gr.Row():
- text = gr.Textbox(value="", lines=2, label="Text to synthesise")
+ text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3)
+ spk_slider = gr.Slider(
+ minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1
+ )
with gr.Row():
gr.Markdown("### Hyper parameters")
@@ -142,58 +225,110 @@ def main():
# with gr.Row():
audio = gr.Audio(interactive=False, label="Audio")
- with gr.Row():
+ with gr.Row(visible=False) as example_row_lj_speech:
examples = gr.Examples( # pylint: disable=unused-variable
examples=[
[
"We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
50,
0.677,
- 1.0,
+ 0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
2,
0.677,
- 1.0,
+ 0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
4,
0.677,
- 1.0,
+ 0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
10,
0.677,
- 1.0,
+ 0.95,
],
[
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
50,
0.677,
- 1.0,
+ 0.95,
],
[
"The narrative of these events is based largely on the recollections of the participants.",
10,
0.677,
- 1.0,
+ 0.95,
],
[
"The jury did not believe him, and the verdict was for the defendants.",
10,
0.677,
- 1.0,
+ 0.95,
],
],
- fn=run_full_synthesis,
+ fn=ljspeech_example_cacher,
inputs=[text, n_timesteps, mel_temp, length_scale],
outputs=[phonetised_text, audio, mel_spectrogram],
cache_examples=True,
)
+ with gr.Row() as example_row_multispeaker:
+ multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable
+ examples=[
+ [
+ "Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!",
+ 10,
+ 0.677,
+ 0.85,
+ 0,
+ ],
+ [
+ "Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!",
+ 10,
+ 0.677,
+ 0.85,
+ 16,
+ ],
+ [
+ "Hello everyone! I am speaker 44 and I am here to tell you that Matcha-TTS is amazing!",
+ 50,
+ 0.677,
+ 0.85,
+ 44,
+ ],
+ [
+ "Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!",
+ 50,
+ 0.677,
+ 0.85,
+ 45,
+ ],
+ [
+ "Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!",
+ 4,
+ 0.677,
+ 0.85,
+ 58,
+ ],
+ ],
+ fn=multispeaker_example_cacher,
+ inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider],
+ outputs=[phonetised_text, audio, mel_spectrogram],
+ cache_examples=True,
+ label="Multi Speaker Examples",
+ )
+
+ model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then(
+ load_model_ui,
+ inputs=[model_type, text],
+ outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale],
+ )
+
synth_btn.click(
fn=process_text_gradio,
inputs=[
@@ -204,11 +339,11 @@ def main():
queue=True,
).then(
fn=synthesise_mel,
- inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale],
+ inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider],
outputs=[audio, mel_spectrogram],
)
- demo.queue(concurrency_count=5).launch(share=True)
+ demo.queue().launch(share=True)
if __name__ == "__main__":
diff --git a/matcha/cli.py b/matcha/cli.py
index 06459ff..9e3f7fb 100644
--- a/matcha/cli.py
+++ b/matcha/cli.py
@@ -18,17 +18,21 @@ from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
MATCHA_URLS = {
- "matcha_ljspeech": "https://drive.google.com/file/d/1BBzmMU7k3a_WetDfaFblMoN18GqQeHCg/view?usp=drive_link"
-} # , "matcha_vctk": ""} # Coming soon
-
-MULTISPEAKER_MODEL = {"matcha_vctk"}
-SINGLESPEAKER_MODEL = {"matcha_ljspeech"}
+ "matcha_ljspeech": "https://drive.google.com/file/d/1BBzmMU7k3a_WetDfaFblMoN18GqQeHCg/view?usp=drive_link",
+ "matcha_vctk": "https://drive.google.com/file/d/1enuxmfslZciWGAl63WGh2ekVo00FYuQ9/view?usp=drive_link",
+}
VOCODER_URLS = {
"hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link",
"hifigan_univ_v1": "https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link",
}
+MULTISPEAKER_MODEL = {
+ "matcha_vctk": {"vocoder": "hifigan_univ_v1", "speaking_rate": 0.85, "spk": 0, "spk_range": (0, 107)}
+}
+
+SINGLESPEAKER_MODEL = {"matcha_ljspeech": {"vocoder": "hifigan_T2_v1", "speaking_rate": 0.95, "spk": None}}
+
def plot_spectrogram_to_numpy(spectrogram, filename):
fig, ax = plt.subplots(figsize=(12, 3))
@@ -132,28 +136,70 @@ def validate_args(args):
args.text or args.file
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
assert args.temperature >= 0, "Sampling temperature cannot be negative"
- assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
assert args.steps > 0, "Number of ODE steps must be greater than 0"
- if args.checkpoint_path is None:
- if args.model in SINGLESPEAKER_MODEL:
- assert args.spk is None, f"Speaker ID is not supported for {args.model}"
- if args.spk is not None:
- assert args.spk >= 0 and args.spk < 109, "Speaker ID must be between 0 and 108"
- assert args.model in MULTISPEAKER_MODEL, "Speaker ID is only supported for multispeaker model"
+ if args.checkpoint_path is None:
+ # When using pretrained models
+ if args.model in SINGLESPEAKER_MODEL.keys():
+ args = validate_args_for_single_speaker_model(args)
if args.model in MULTISPEAKER_MODEL:
- if args.spk is None:
- print("[!] Speaker ID not provided! Using speaker ID 0")
- args.spk = 0
- args.vocoder = "hifigan_univ_v1"
+ args = validate_args_for_multispeaker_model(args)
else:
+ # When using a custom model
if args.vocoder != "hifigan_univ_v1":
warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech."
warnings.warn(warn_, UserWarning)
+ if args.speaking_rate is None:
+ args.speaking_rate = 1.0
if args.batched:
assert args.batch_size > 0, "Batch size must be greater than 0"
+ assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
+
+ return args
+
+
+def validate_args_for_multispeaker_model(args):
+ if args.vocoder is not None:
+ if args.vocoder != MULTISPEAKER_MODEL[args.model]["vocoder"]:
+ warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {MULTISPEAKER_MODEL[args.model]['vocoder']}"
+ warnings.warn(warn_, UserWarning)
+ else:
+ args.vocoder = MULTISPEAKER_MODEL[args.model]["vocoder"]
+
+ if args.speaking_rate is None:
+ args.speaking_rate = MULTISPEAKER_MODEL[args.model]["speaking_rate"]
+
+ spk_range = MULTISPEAKER_MODEL[args.model]["spk_range"]
+ if args.spk is not None:
+ assert (
+ args.spk >= spk_range[0] and args.spk <= spk_range[-1]
+ ), f"Speaker ID must be between {spk_range} for this model."
+ else:
+ available_spk_id = MULTISPEAKER_MODEL[args.model]["spk"]
+ warn_ = f"[!] Speaker ID not provided! Using speaker ID {available_spk_id}"
+ warnings.warn(warn_, UserWarning)
+ args.spk = available_spk_id
+
+ return args
+
+
+def validate_args_for_single_speaker_model(args):
+ if args.vocoder is not None:
+ if args.vocoder != SINGLESPEAKER_MODEL[args.model]["vocoder"]:
+ warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {SINGLESPEAKER_MODEL[args.model]['vocoder']}"
+ warnings.warn(warn_, UserWarning)
+ else:
+ args.vocoder = SINGLESPEAKER_MODEL[args.model]["vocoder"]
+
+ if args.speaking_rate is None:
+ args.speaking_rate = SINGLESPEAKER_MODEL[args.model]["speaking_rate"]
+
+ if args.spk != SINGLESPEAKER_MODEL[args.model]["spk"]:
+ warn_ = f"[-] Ignoring speaker id {args.spk} for {args.model}"
+ warnings.warn(warn_, UserWarning)
+ args.spk = SINGLESPEAKER_MODEL[args.model]["spk"]
return args
@@ -181,8 +227,8 @@ def cli():
parser.add_argument(
"--vocoder",
type=str,
- default="hifigan_T2_v1",
- help="Vocoder to use",
+ default=None,
+ help="Vocoder to use (default: will use the one suggested with the pretrained model))",
choices=VOCODER_URLS.keys(),
)
parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
@@ -197,7 +243,7 @@ def cli():
parser.add_argument(
"--speaking_rate",
type=float,
- default=1.0,
+ default=None,
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
)
parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)")
@@ -214,8 +260,10 @@ def cli():
default=os.getcwd(),
help="Output folder to save results (default: current dir)",
)
- parser.add_argument("--batched", action="store_true")
- parser.add_argument("--batch_size", type=int, default=32)
+ parser.add_argument("--batched", action="store_true", help="Batched inference (default: False)")
+ parser.add_argument(
+ "--batch_size", type=int, default=32, help="Batch size only useful when --batched (default: 32)"
+ )
args = parser.parse_args()
@@ -348,6 +396,8 @@ def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
def print_config(args):
print("[!] Configurations: ")
+ print(f"\t- Model: {args.model}")
+ print(f"\t- Vocoder: {args.vocoder}")
print(f"\t- Temperature: {args.temperature}")
print(f"\t- Speaking rate: {args.speaking_rate}")
print(f"\t- Number of ODE steps: {args.steps}")
diff --git a/requirements.txt b/requirements.txt
index ac1abf8..c058372 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -35,7 +35,7 @@ torchaudio
matplotlib
pandas
conformer==0.3.2
-diffusers==0.21.1
+diffusers==0.21.2
notebook
ipywidgets
gradio