Adding possibility of getting durations out

This commit is contained in:
Shivam Mehta
2024-02-24 15:10:19 +00:00
parent def0855608
commit 8e87111a98
6 changed files with 516 additions and 25 deletions

View File

@@ -109,7 +109,7 @@ class TextMelDataModule(LightningDataModule):
"""Clean up after fit or test."""
pass # pylint: disable=unnecessary-pass
def state_dict(self): # pylint: disable=no-self-use
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
@@ -167,7 +167,7 @@ class TextMelDataset(torch.utils.data.Dataset):
text = self.get_text(text, add_blank=self.add_blank)
mel = self.get_mel(filepath)
return {"x": text, "y": mel, "spk": spk}
return {"x": text, "y": mel, "spk": spk, "filepath": filepath}
def get_mel(self, filepath):
audio, sr = ta.load(filepath)
@@ -207,15 +207,16 @@ class TextMelBatchCollate:
def __call__(self, batch):
B = len(batch)
y_max_length = max([item["y"].shape[-1] for item in batch])
y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
y_max_length = fix_len_compatibility(y_max_length)
x_max_length = max([item["x"].shape[-1] for item in batch])
x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
n_feats = batch[0]["y"].shape[-2]
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long)
y_lengths, x_lengths = [], []
spks = []
filepaths = []
for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"]
y_lengths.append(y_.shape[-1])
@@ -223,9 +224,10 @@ class TextMelBatchCollate:
y[i, :, : y_.shape[-1]] = y_
x[i, : x_.shape[-1]] = x_
spks.append(item["spk"])
filepaths.append(item["filepath"])
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks}
return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks, "filepaths": filepaths}