mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 02:09:21 +08:00
Adding docstrings
This commit is contained in:
@@ -376,6 +376,24 @@ class TextEncoder(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, x_lengths, spks=None):
|
||||
"""Run forward pass to the transformer based encoder and duration predictor
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): text input
|
||||
shape: (batch_size, max_text_length)
|
||||
x_lengths (torch.Tensor): text input lengths
|
||||
shape: (batch_size,)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
mu (torch.Tensor): average output of the encoder
|
||||
shape: (batch_size, n_feats, max_text_length)
|
||||
logw (torch.Tensor): log duration predicted by the duration predictor
|
||||
shape: (batch_size, 1, max_text_length)
|
||||
x_mask (torch.Tensor): mask for the text input
|
||||
shape: (batch_size, 1, max_text_length)
|
||||
"""
|
||||
x = self.emb(x) * math.sqrt(self.n_channels)
|
||||
x = torch.transpose(x, 1, -1)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
Reference in New Issue
Block a user