diff --git a/Makefile b/Makefile index bf66306..4b523dd 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ create-package: ## Create wheel and tar gz rm -rf dist/ python setup.py bdist_wheel --plat-name=manylinux1_x86_64 python setup.py sdist - python -m twine upload dist/* --verbose + python -m twine upload dist/* --verbose --skip-existing format: ## Run pre-commit hooks pre-commit run -a diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml index 1886902..e3932c8 100644 --- a/configs/debug/default.yaml +++ b/configs/debug/default.yaml @@ -7,8 +7,8 @@ task_name: "debug" # disable callbacks and loggers during debugging -callbacks: null -logger: null +# callbacks: null +# logger: null extras: ignore_warnings: False diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml index 2bd7da8..266295f 100644 --- a/configs/debug/profiler.yaml +++ b/configs/debug/profiler.yaml @@ -7,6 +7,9 @@ defaults: trainer: max_epochs: 1 - profiler: "simple" - # profiler: "advanced" + # profiler: "simple" + profiler: "advanced" # profiler: "pytorch" + accelerator: gpu + + limit_train_batches: 0.02 diff --git a/matcha/app.py b/matcha/app.py index 5554b3b..9ca108d 100644 --- a/matcha/app.py +++ b/matcha/app.py @@ -22,20 +22,64 @@ LOCATION = Path(get_user_data_dir()) args = Namespace( cpu=False, - model="matcha_ljspeech", - vocoder="hifigan_T2_v1", - spk=None, + model="matcha_vctk", + vocoder="hifigan_univ_v1", + spk=0, ) -MATCHA_TTS_LOC = LOCATION / f"{args.model}.ckpt" -VOCODER_LOC = LOCATION / f"{args.vocoder}" +CURRENTLY_LOADED_MODEL = args.model + +MATCHA_TTS_LOC = lambda x: LOCATION / f"{x}.ckpt" # noqa: E731 +VOCODER_LOC = lambda x: LOCATION / f"{x}" # noqa: E731 LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png" -assert_model_downloaded(MATCHA_TTS_LOC, MATCHA_URLS[args.model]) -assert_model_downloaded(VOCODER_LOC, VOCODER_URLS[args.vocoder]) +RADIO_OPTIONS = { + "Multi Speaker (VCTK)": { + "model": "matcha_vctk", + "vocoder": "hifigan_univ_v1", + }, + "Single Speaker (LJ Speech)": { + "model": "matcha_ljspeech", + "vocoder": "hifigan_T2_v1", + }, +} + +# Ensure all the required models are downloaded +assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"]) +assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"]) +assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"]) +assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"]) + device = get_device(args) -model = load_matcha(args.model, MATCHA_TTS_LOC, device) -vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC, device) +# Load default model +model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device) +vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device) + + +def load_model(model_name, vocoder_name): + model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device) + vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device) + return model, vocoder, denoiser + + +def load_model_ui(model_type, textbox): + model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"] + + global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement + if CURRENTLY_LOADED_MODEL != model_name: + model, vocoder, denoiser = load_model(model_name, vocoder_name) + CURRENTLY_LOADED_MODEL = model_name + + if model_name == "matcha_ljspeech": + spk_slider = gr.update(visible=False, value=-1) + single_speaker_examples = gr.update(visible=True) + multi_speaker_examples = gr.update(visible=False) + else: + spk_slider = gr.update(visible=True, value=0) + single_speaker_examples = gr.update(visible=False) + multi_speaker_examples = gr.update(visible=True) + + return textbox, gr.update(interactive=True), spk_slider, single_speaker_examples, multi_speaker_examples @torch.inference_mode() @@ -45,13 +89,14 @@ def process_text_gradio(text): @torch.inference_mode() -def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale): +def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk): + spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None output = model.synthesise( text, text_length, n_timesteps=n_timesteps, temperature=temperature, - spks=args.spk, + spks=spk, length_scale=length_scale, ) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) @@ -61,9 +106,27 @@ def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale): return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy()) -def run_full_synthesis(text, n_timesteps, mel_temp, length_scale): +def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk): + global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement + if CURRENTLY_LOADED_MODEL != "matcha_vctk": + global model, vocoder, denoiser # pylint: disable=global-statement + model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1") + CURRENTLY_LOADED_MODEL = "matcha_vctk" + phones, text, text_lengths = process_text_gradio(text) - audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale) + audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk) + return phones, audio, mel_spectrogram + + +def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1): + global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement + if CURRENTLY_LOADED_MODEL != "matcha_ljspeech": + global model, vocoder, denoiser # pylint: disable=global-statement + model, vocoder, denoiser = load_model("matcha_ljspeech", "hifigan_T2_v1") + CURRENTLY_LOADED_MODEL = "matcha_ljspeech" + + phones, text, text_lengths = process_text_gradio(text) + audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk) return phones, audio, mel_spectrogram @@ -95,10 +158,18 @@ def main(): gr.Image(LOGO_URL, label="Matcha-TTS logo", height=150, width=150, scale=1, show_label=False) with gr.Box(): + radio_options = list(RADIO_OPTIONS.keys()) + model_type = gr.Radio( + radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False + ) + with gr.Row(): gr.Markdown("# Text Input") with gr.Row(): - text = gr.Textbox(value="", lines=2, label="Text to synthesise") + text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3) + spk_slider = gr.Slider( + minimum=0, maximum=108, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1 + ) with gr.Row(): gr.Markdown("### Hyper parameters") @@ -142,7 +213,7 @@ def main(): # with gr.Row(): audio = gr.Audio(interactive=False, label="Audio") - with gr.Row(): + with gr.Row(visible=False) as example_row_lj_speech: examples = gr.Examples( # pylint: disable=unused-variable examples=[ [ @@ -188,12 +259,64 @@ def main(): 1.0, ], ], - fn=run_full_synthesis, + fn=ljspeech_example_cacher, inputs=[text, n_timesteps, mel_temp, length_scale], outputs=[phonetised_text, audio, mel_spectrogram], cache_examples=True, ) + with gr.Row() as example_row_multispeaker: + multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable + examples=[ + [ + "Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!", + 10, + 0.677, + 1.0, + 0, + ], + [ + "Hello everyone! I am speaker 13 and I am here to tell you that Matcha-TTS is amazing!", + 50, + 0.677, + 1.0, + 13, + ], + [ + "Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!", + 10, + 0.677, + 1.0, + 16, + ], + [ + "Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!", + 50, + 0.677, + 1.0, + 45, + ], + [ + "Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!", + 4, + 0.677, + 1.0, + 58, + ], + ], + fn=multispeaker_example_cacher, + inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider], + outputs=[phonetised_text, audio, mel_spectrogram], + cache_examples=True, + label="Multi Speaker Examples", + ) + + model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then( + load_model_ui, + inputs=[model_type, text], + outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker], + ) + synth_btn.click( fn=process_text_gradio, inputs=[ @@ -204,11 +327,11 @@ def main(): queue=True, ).then( fn=synthesise_mel, - inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale], + inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider], outputs=[audio, mel_spectrogram], ) - demo.queue(concurrency_count=5).launch(share=True) + demo.queue().launch(debug=True) if __name__ == "__main__":