diff --git a/matcha/data/text_mel_datamodule.py b/matcha/data/text_mel_datamodule.py index e10dfcb..48f8266 100644 --- a/matcha/data/text_mel_datamodule.py +++ b/matcha/data/text_mel_datamodule.py @@ -234,9 +234,9 @@ class TextMelBatchCollate: def __call__(self, batch): B = len(batch) - y_max_length = max([item["y"].shape[-1] for item in batch]) + y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator y_max_length = fix_len_compatibility(y_max_length) - x_max_length = max([item["x"].shape[-1] for item in batch]) + x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator n_feats = batch[0]["y"].shape[-2] y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)