Adding docstrings

This commit is contained in:
Shivam Mehta
2023-09-17 06:49:12 +00:00
parent 0554a5b87c
commit c079e5254a
3 changed files with 82 additions and 25 deletions

View File

@@ -376,6 +376,24 @@ class TextEncoder(nn.Module):
)
def forward(self, x, x_lengths, spks=None):
"""Run forward pass to the transformer based encoder and duration predictor
Args:
x (torch.Tensor): text input
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): text input lengths
shape: (batch_size,)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size,)
Returns:
mu (torch.Tensor): average output of the encoder
shape: (batch_size, n_feats, max_text_length)
logw (torch.Tensor): log duration predicted by the duration predictor
shape: (batch_size, 1, max_text_length)
x_mask (torch.Tensor): mask for the text input
shape: (batch_size, 1, max_text_length)
"""
x = self.emb(x) * math.sqrt(self.n_channels)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)