Merge pull request #3 from shivammehta25/dev

Adding multispeaker 🍵 Matcha-TTS
This commit is contained in:
Shivam Mehta
2023-09-21 15:23:15 +02:00
committed by GitHub
8 changed files with 243 additions and 54 deletions

View File

@@ -7,6 +7,7 @@ version: 2
updates: updates:
- package-ecosystem: "pip" # See documentation for possible values - package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests directory: "/" # Location of package manifests
target-branch: "dev"
schedule: schedule:
interval: "daily" interval: "daily"
ignore: ignore:

View File

@@ -17,7 +17,7 @@ create-package: ## Create wheel and tar gz
rm -rf dist/ rm -rf dist/
python setup.py bdist_wheel --plat-name=manylinux1_x86_64 python setup.py bdist_wheel --plat-name=manylinux1_x86_64
python setup.py sdist python setup.py sdist
python -m twine upload dist/* --verbose python -m twine upload dist/* --verbose --skip-existing
format: ## Run pre-commit hooks format: ## Run pre-commit hooks
pre-commit run -a pre-commit run -a

View File

@@ -7,8 +7,8 @@
task_name: "debug" task_name: "debug"
# disable callbacks and loggers during debugging # disable callbacks and loggers during debugging
callbacks: null # callbacks: null
logger: null # logger: null
extras: extras:
ignore_warnings: False ignore_warnings: False

View File

@@ -7,6 +7,9 @@ defaults:
trainer: trainer:
max_epochs: 1 max_epochs: 1
profiler: "simple" # profiler: "simple"
# profiler: "advanced" profiler: "advanced"
# profiler: "pytorch" # profiler: "pytorch"
accelerator: gpu
limit_train_batches: 0.02

View File

@@ -1 +1 @@
0.0.1.dev4 0.0.1

View File

@@ -22,20 +22,73 @@ LOCATION = Path(get_user_data_dir())
args = Namespace( args = Namespace(
cpu=False, cpu=False,
model="matcha_ljspeech", model="matcha_vctk",
vocoder="hifigan_T2_v1", vocoder="hifigan_univ_v1",
spk=None, spk=0,
) )
MATCHA_TTS_LOC = LOCATION / f"{args.model}.ckpt" CURRENTLY_LOADED_MODEL = args.model
VOCODER_LOC = LOCATION / f"{args.vocoder}"
MATCHA_TTS_LOC = lambda x: LOCATION / f"{x}.ckpt" # noqa: E731
VOCODER_LOC = lambda x: LOCATION / f"{x}" # noqa: E731
LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png" LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
assert_model_downloaded(MATCHA_TTS_LOC, MATCHA_URLS[args.model]) RADIO_OPTIONS = {
assert_model_downloaded(VOCODER_LOC, VOCODER_URLS[args.vocoder]) "Multi Speaker (VCTK)": {
"model": "matcha_vctk",
"vocoder": "hifigan_univ_v1",
},
"Single Speaker (LJ Speech)": {
"model": "matcha_ljspeech",
"vocoder": "hifigan_T2_v1",
},
}
# Ensure all the required models are downloaded
assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"])
assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"])
assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
device = get_device(args) device = get_device(args)
model = load_matcha(args.model, MATCHA_TTS_LOC, device) # Load default model
vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC, device) model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device)
vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device)
def load_model(model_name, vocoder_name):
model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device)
vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device)
return model, vocoder, denoiser
def load_model_ui(model_type, textbox):
model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"]
global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
if CURRENTLY_LOADED_MODEL != model_name:
model, vocoder, denoiser = load_model(model_name, vocoder_name)
CURRENTLY_LOADED_MODEL = model_name
if model_name == "matcha_ljspeech":
spk_slider = gr.update(visible=False, value=-1)
single_speaker_examples = gr.update(visible=True)
multi_speaker_examples = gr.update(visible=False)
length_scale = gr.update(value=0.95)
else:
spk_slider = gr.update(visible=True, value=0)
single_speaker_examples = gr.update(visible=False)
multi_speaker_examples = gr.update(visible=True)
length_scale = gr.update(value=0.85)
return (
textbox,
gr.update(interactive=True),
spk_slider,
single_speaker_examples,
multi_speaker_examples,
length_scale,
)
@torch.inference_mode() @torch.inference_mode()
@@ -45,13 +98,14 @@ def process_text_gradio(text):
@torch.inference_mode() @torch.inference_mode()
def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale): def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk):
spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
output = model.synthesise( output = model.synthesise(
text, text,
text_length, text_length,
n_timesteps=n_timesteps, n_timesteps=n_timesteps,
temperature=temperature, temperature=temperature,
spks=args.spk, spks=spk,
length_scale=length_scale, length_scale=length_scale,
) )
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
@@ -61,9 +115,27 @@ def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale):
return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy()) return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
def run_full_synthesis(text, n_timesteps, mel_temp, length_scale): def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk):
global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
if CURRENTLY_LOADED_MODEL != "matcha_vctk":
global model, vocoder, denoiser # pylint: disable=global-statement
model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1")
CURRENTLY_LOADED_MODEL = "matcha_vctk"
phones, text, text_lengths = process_text_gradio(text) phones, text, text_lengths = process_text_gradio(text)
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale) audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
return phones, audio, mel_spectrogram
def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1):
global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
if CURRENTLY_LOADED_MODEL != "matcha_ljspeech":
global model, vocoder, denoiser # pylint: disable=global-statement
model, vocoder, denoiser = load_model("matcha_ljspeech", "hifigan_T2_v1")
CURRENTLY_LOADED_MODEL = "matcha_ljspeech"
phones, text, text_lengths = process_text_gradio(text)
audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
return phones, audio, mel_spectrogram return phones, audio, mel_spectrogram
@@ -92,13 +164,24 @@ def main():
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
gr.Markdown(description, scale=3) gr.Markdown(description, scale=3)
gr.Image(LOGO_URL, label="Matcha-TTS logo", height=150, width=150, scale=1, show_label=False) with gr.Column():
gr.Image(LOGO_URL, label="Matcha-TTS logo", height=50, width=50, scale=1, show_label=False)
html = '<br><iframe width="560" height="315" src="https://www.youtube.com/embed/xmvJkz3bqw0?si=jN7ILyDsbPwJCGoa" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>'
gr.HTML(html)
with gr.Box(): with gr.Box():
radio_options = list(RADIO_OPTIONS.keys())
model_type = gr.Radio(
radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False
)
with gr.Row(): with gr.Row():
gr.Markdown("# Text Input") gr.Markdown("# Text Input")
with gr.Row(): with gr.Row():
text = gr.Textbox(value="", lines=2, label="Text to synthesise") text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3)
spk_slider = gr.Slider(
minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1
)
with gr.Row(): with gr.Row():
gr.Markdown("### Hyper parameters") gr.Markdown("### Hyper parameters")
@@ -142,58 +225,110 @@ def main():
# with gr.Row(): # with gr.Row():
audio = gr.Audio(interactive=False, label="Audio") audio = gr.Audio(interactive=False, label="Audio")
with gr.Row(): with gr.Row(visible=False) as example_row_lj_speech:
examples = gr.Examples( # pylint: disable=unused-variable examples = gr.Examples( # pylint: disable=unused-variable
examples=[ examples=[
[ [
"We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.", "We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
50, 50,
0.677, 0.677,
1.0, 0.95,
], ],
[ [
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
2, 2,
0.677, 0.677,
1.0, 0.95,
], ],
[ [
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
4, 4,
0.677, 0.677,
1.0, 0.95,
], ],
[ [
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
10, 10,
0.677, 0.677,
1.0, 0.95,
], ],
[ [
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
50, 50,
0.677, 0.677,
1.0, 0.95,
], ],
[ [
"The narrative of these events is based largely on the recollections of the participants.", "The narrative of these events is based largely on the recollections of the participants.",
10, 10,
0.677, 0.677,
1.0, 0.95,
], ],
[ [
"The jury did not believe him, and the verdict was for the defendants.", "The jury did not believe him, and the verdict was for the defendants.",
10, 10,
0.677, 0.677,
1.0, 0.95,
], ],
], ],
fn=run_full_synthesis, fn=ljspeech_example_cacher,
inputs=[text, n_timesteps, mel_temp, length_scale], inputs=[text, n_timesteps, mel_temp, length_scale],
outputs=[phonetised_text, audio, mel_spectrogram], outputs=[phonetised_text, audio, mel_spectrogram],
cache_examples=True, cache_examples=True,
) )
with gr.Row() as example_row_multispeaker:
multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable
examples=[
[
"Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!",
10,
0.677,
0.85,
0,
],
[
"Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!",
10,
0.677,
0.85,
16,
],
[
"Hello everyone! I am speaker 44 and I am here to tell you that Matcha-TTS is amazing!",
50,
0.677,
0.85,
44,
],
[
"Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!",
50,
0.677,
0.85,
45,
],
[
"Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!",
4,
0.677,
0.85,
58,
],
],
fn=multispeaker_example_cacher,
inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider],
outputs=[phonetised_text, audio, mel_spectrogram],
cache_examples=True,
label="Multi Speaker Examples",
)
model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then(
load_model_ui,
inputs=[model_type, text],
outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale],
)
synth_btn.click( synth_btn.click(
fn=process_text_gradio, fn=process_text_gradio,
inputs=[ inputs=[
@@ -204,11 +339,11 @@ def main():
queue=True, queue=True,
).then( ).then(
fn=synthesise_mel, fn=synthesise_mel,
inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale], inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider],
outputs=[audio, mel_spectrogram], outputs=[audio, mel_spectrogram],
) )
demo.queue(concurrency_count=5).launch(share=True) demo.queue().launch(share=True)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -18,17 +18,21 @@ from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
MATCHA_URLS = { MATCHA_URLS = {
"matcha_ljspeech": "https://drive.google.com/file/d/1BBzmMU7k3a_WetDfaFblMoN18GqQeHCg/view?usp=drive_link" "matcha_ljspeech": "https://drive.google.com/file/d/1BBzmMU7k3a_WetDfaFblMoN18GqQeHCg/view?usp=drive_link",
} # , "matcha_vctk": ""} # Coming soon "matcha_vctk": "https://drive.google.com/file/d/1enuxmfslZciWGAl63WGh2ekVo00FYuQ9/view?usp=drive_link",
}
MULTISPEAKER_MODEL = {"matcha_vctk"}
SINGLESPEAKER_MODEL = {"matcha_ljspeech"}
VOCODER_URLS = { VOCODER_URLS = {
"hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link", "hifigan_T2_v1": "https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link",
"hifigan_univ_v1": "https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link", "hifigan_univ_v1": "https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link",
} }
MULTISPEAKER_MODEL = {
"matcha_vctk": {"vocoder": "hifigan_univ_v1", "speaking_rate": 0.85, "spk": 0, "spk_range": (0, 107)}
}
SINGLESPEAKER_MODEL = {"matcha_ljspeech": {"vocoder": "hifigan_T2_v1", "speaking_rate": 0.95, "spk": None}}
def plot_spectrogram_to_numpy(spectrogram, filename): def plot_spectrogram_to_numpy(spectrogram, filename):
fig, ax = plt.subplots(figsize=(12, 3)) fig, ax = plt.subplots(figsize=(12, 3))
@@ -132,28 +136,70 @@ def validate_args(args):
args.text or args.file args.text or args.file
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
assert args.temperature >= 0, "Sampling temperature cannot be negative" assert args.temperature >= 0, "Sampling temperature cannot be negative"
assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
assert args.steps > 0, "Number of ODE steps must be greater than 0" assert args.steps > 0, "Number of ODE steps must be greater than 0"
if args.checkpoint_path is None:
if args.model in SINGLESPEAKER_MODEL:
assert args.spk is None, f"Speaker ID is not supported for {args.model}"
if args.spk is not None: if args.checkpoint_path is None:
assert args.spk >= 0 and args.spk < 109, "Speaker ID must be between 0 and 108" # When using pretrained models
assert args.model in MULTISPEAKER_MODEL, "Speaker ID is only supported for multispeaker model" if args.model in SINGLESPEAKER_MODEL.keys():
args = validate_args_for_single_speaker_model(args)
if args.model in MULTISPEAKER_MODEL: if args.model in MULTISPEAKER_MODEL:
if args.spk is None: args = validate_args_for_multispeaker_model(args)
print("[!] Speaker ID not provided! Using speaker ID 0")
args.spk = 0
args.vocoder = "hifigan_univ_v1"
else: else:
# When using a custom model
if args.vocoder != "hifigan_univ_v1": if args.vocoder != "hifigan_univ_v1":
warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech." warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech."
warnings.warn(warn_, UserWarning) warnings.warn(warn_, UserWarning)
if args.speaking_rate is None:
args.speaking_rate = 1.0
if args.batched: if args.batched:
assert args.batch_size > 0, "Batch size must be greater than 0" assert args.batch_size > 0, "Batch size must be greater than 0"
assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
return args
def validate_args_for_multispeaker_model(args):
if args.vocoder is not None:
if args.vocoder != MULTISPEAKER_MODEL[args.model]["vocoder"]:
warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {MULTISPEAKER_MODEL[args.model]['vocoder']}"
warnings.warn(warn_, UserWarning)
else:
args.vocoder = MULTISPEAKER_MODEL[args.model]["vocoder"]
if args.speaking_rate is None:
args.speaking_rate = MULTISPEAKER_MODEL[args.model]["speaking_rate"]
spk_range = MULTISPEAKER_MODEL[args.model]["spk_range"]
if args.spk is not None:
assert (
args.spk >= spk_range[0] and args.spk <= spk_range[-1]
), f"Speaker ID must be between {spk_range} for this model."
else:
available_spk_id = MULTISPEAKER_MODEL[args.model]["spk"]
warn_ = f"[!] Speaker ID not provided! Using speaker ID {available_spk_id}"
warnings.warn(warn_, UserWarning)
args.spk = available_spk_id
return args
def validate_args_for_single_speaker_model(args):
if args.vocoder is not None:
if args.vocoder != SINGLESPEAKER_MODEL[args.model]["vocoder"]:
warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {SINGLESPEAKER_MODEL[args.model]['vocoder']}"
warnings.warn(warn_, UserWarning)
else:
args.vocoder = SINGLESPEAKER_MODEL[args.model]["vocoder"]
if args.speaking_rate is None:
args.speaking_rate = SINGLESPEAKER_MODEL[args.model]["speaking_rate"]
if args.spk != SINGLESPEAKER_MODEL[args.model]["spk"]:
warn_ = f"[-] Ignoring speaker id {args.spk} for {args.model}"
warnings.warn(warn_, UserWarning)
args.spk = SINGLESPEAKER_MODEL[args.model]["spk"]
return args return args
@@ -181,8 +227,8 @@ def cli():
parser.add_argument( parser.add_argument(
"--vocoder", "--vocoder",
type=str, type=str,
default="hifigan_T2_v1", default=None,
help="Vocoder to use", help="Vocoder to use (default: will use the one suggested with the pretrained model))",
choices=VOCODER_URLS.keys(), choices=VOCODER_URLS.keys(),
) )
parser.add_argument("--text", type=str, default=None, help="Text to synthesize") parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
@@ -197,7 +243,7 @@ def cli():
parser.add_argument( parser.add_argument(
"--speaking_rate", "--speaking_rate",
type=float, type=float,
default=1.0, default=None,
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
) )
parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)") parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)")
@@ -214,8 +260,10 @@ def cli():
default=os.getcwd(), default=os.getcwd(),
help="Output folder to save results (default: current dir)", help="Output folder to save results (default: current dir)",
) )
parser.add_argument("--batched", action="store_true") parser.add_argument("--batched", action="store_true", help="Batched inference (default: False)")
parser.add_argument("--batch_size", type=int, default=32) parser.add_argument(
"--batch_size", type=int, default=32, help="Batch size only useful when --batched (default: 32)"
)
args = parser.parse_args() args = parser.parse_args()
@@ -348,6 +396,8 @@ def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
def print_config(args): def print_config(args):
print("[!] Configurations: ") print("[!] Configurations: ")
print(f"\t- Model: {args.model}")
print(f"\t- Vocoder: {args.vocoder}")
print(f"\t- Temperature: {args.temperature}") print(f"\t- Temperature: {args.temperature}")
print(f"\t- Speaking rate: {args.speaking_rate}") print(f"\t- Speaking rate: {args.speaking_rate}")
print(f"\t- Number of ODE steps: {args.steps}") print(f"\t- Number of ODE steps: {args.steps}")

View File

@@ -35,7 +35,7 @@ torchaudio
matplotlib matplotlib
pandas pandas
conformer==0.3.2 conformer==0.3.2
diffusers==0.21.1 diffusers==0.21.2
notebook notebook
ipywidgets ipywidgets
gradio gradio