diff --git a/matcha/cli.py b/matcha/cli.py index 635c586..7daf130 100644 --- a/matcha/cli.py +++ b/matcha/cli.py @@ -326,12 +326,13 @@ def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk): for i, batch in enumerate(dataloader): i = i + 1 start_t = dt.datetime.now() + b = batch["x"].shape[0] output = model.synthesise( batch["x"].to(device), batch["x_lengths"].to(device), n_timesteps=args.steps, temperature=args.temperature, - spks=spk, + spks=spk.expand(b) if spk is not None else spk, length_scale=args.speaking_rate, )