import math import torch import torch.nn as nn import torch.nn.functional as F from einops import pack from matcha.models.components.decoder import SinusoidalPosEmb, TimestepEmbedding from matcha.models.components.text_encoder import LayerNorm # Define available networks class DurationPredictorNetwork(nn.Module): def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): super().__init__() self.in_channels = in_channels self.filter_channels = filter_channels self.p_dropout = p_dropout self.drop = torch.nn.Dropout(p_dropout) self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_1 = LayerNorm(filter_channels) self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_2 = LayerNorm(filter_channels) self.proj = torch.nn.Conv1d(filter_channels, 1, 1) def forward(self, x, x_mask): x = self.conv_1(x * x_mask) x = torch.relu(x) x = self.norm_1(x) x = self.drop(x) x = self.conv_2(x * x_mask) x = torch.relu(x) x = self.norm_2(x) x = self.drop(x) x = self.proj(x * x_mask) return x * x_mask class DurationPredictorNetworkWithTimeStep(nn.Module): """Similar architecture but with a time embedding support""" def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): super().__init__() self.in_channels = in_channels self.filter_channels = filter_channels self.p_dropout = p_dropout self.time_embeddings = SinusoidalPosEmb(filter_channels) self.time_mlp = TimestepEmbedding( in_channels=filter_channels, time_embed_dim=filter_channels, act_fn="silu", ) self.drop = torch.nn.Dropout(p_dropout) self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_1 = LayerNorm(filter_channels) self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_2 = LayerNorm(filter_channels) self.proj = torch.nn.Conv1d(filter_channels, 1, 1) def forward(self, x, x_mask, enc_outputs, t): t = self.time_embeddings(t) t = self.time_mlp(t).unsqueeze(-1) x = pack([x, enc_outputs], "b * t")[0] x = self.conv_1(x * x_mask) x = torch.relu(x) x = x + t x = self.norm_1(x) x = self.drop(x) x = self.conv_2(x * x_mask) x = torch.relu(x) x = x + t x = self.norm_2(x) x = self.drop(x) x = self.proj(x * x_mask) return x * x_mask # Define available methods to compute loss # Simple MSE deterministic class DeterministicDurationPredictor(nn.Module): def __init__(self, params): super().__init__() self.estimator = DurationPredictorNetwork( params.n_channels + (params.spk_emb_dim if params.n_spks > 1 else 0), params.filter_channels, params.kernel_size, params.p_dropout, ) @torch.inference_mode() def forward(self, x, x_mask): return self.estimator(x, x_mask) def compute_loss(self, durations, enc_outputs, x_mask): return F.mse_loss(self.estimator(enc_outputs, x_mask), durations, reduction="sum") / torch.sum(x_mask) # Flow Matching duration predictor class FlowMatchingDurationPrediction(nn.Module): def __init__(self, params) -> None: super().__init__() self.estimator = DurationPredictorNetworkWithTimeStep( 1 + params.n_channels + ( params.spk_emb_dim if params.n_spks > 1 else 0 ), # 1 for the durations and n_channels for encoder outputs params.filter_channels, params.kernel_size, params.p_dropout, ) self.sigma_min = params.sigma_min self.n_steps = params.n_steps @torch.inference_mode() def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1): """Forward diffusion Args: mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) n_timesteps (int): number of diffusion steps temperature (float, optional): temperature for scaling noise. Defaults to 1.0. spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes Returns: sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ if n_timesteps is None: n_timesteps = self.n_steps b, _, t = enc_outputs.shape z = torch.randn((b, 1, t), device=enc_outputs.device, dtype=enc_outputs.dtype) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=enc_outputs.device) return self.solve_euler(z, t_span=t_span, enc_outputs=enc_outputs, mask=mask) def solve_euler(self, x, t_span, enc_outputs, mask): """ Fixed euler solver for ODEs. Args: x (torch.Tensor): random noise t_span (torch.Tensor): n_timesteps interpolated shape: (n_timesteps + 1,) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] # I am storing this because I can later plot it by putting a debugger here and saving it to a file # Or in future might add like a return_all_steps flag sol = [] for step in range(1, len(t_span)): dphi_dt = self.estimator(x, mask, enc_outputs, t) x = x + dt * dphi_dt t = t + dt sol.append(x) if step < len(t_span) - 1: dt = t_span[step + 1] - t return sol[-1] def compute_loss(self, x1, enc_outputs, mask): """Computes diffusion loss Args: x1 (torch.Tensor): Target shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): target mask shape: (batch_size, 1, mel_timesteps) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) spks (torch.Tensor, optional): speaker embedding. Defaults to None. shape: (batch_size, spk_emb_dim) Returns: loss: conditional flow matching loss y: conditional flow shape: (batch_size, n_feats, mel_timesteps) """ enc_outputs = enc_outputs.detach() # don't update encoder from the duration predictor b, _, t = enc_outputs.shape # random timestep t = torch.rand([b, 1, 1], device=enc_outputs.device, dtype=enc_outputs.dtype) # sample noise p(x_0) z = torch.randn_like(x1) y = (1 - (1 - self.sigma_min) * t) * z + t * x1 u = x1 - (1 - self.sigma_min) * z loss = F.mse_loss(self.estimator(y, mask, enc_outputs, t.squeeze()), u, reduction="sum") / ( torch.sum(mask) * u.shape[1] ) return loss # VITS discrete normalising flow based duration predictor class Log(nn.Module): def forward(self, x, x_mask, reverse=False, **kwargs): if not reverse: y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask logdet = torch.sum(-y, [1, 2]) return y, logdet else: x = torch.exp(x) * x_mask return x class ElementwiseAffine(nn.Module): def __init__(self, channels): super().__init__() self.channels = channels self.m = nn.Parameter(torch.zeros(channels, 1)) self.logs = nn.Parameter(torch.zeros(channels, 1)) def forward(self, x, x_mask, reverse=False, **kwargs): if not reverse: y = self.m + torch.exp(self.logs) * x y = y * x_mask logdet = torch.sum(self.logs * x_mask, [1, 2]) return y, logdet else: x = (x - self.m) * torch.exp(-self.logs) * x_mask return x class DDSConv(nn.Module): """ Dialted and Depth-Separable Convolution """ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): super().__init__() self.channels = channels self.kernel_size = kernel_size self.n_layers = n_layers self.p_dropout = p_dropout self.drop = nn.Dropout(p_dropout) self.convs_sep = nn.ModuleList() self.convs_1x1 = nn.ModuleList() self.norms_1 = nn.ModuleList() self.norms_2 = nn.ModuleList() for i in range(n_layers): dilation = kernel_size**i padding = (kernel_size * dilation - dilation) // 2 self.convs_sep.append( nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) ) self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) self.norms_1.append(LayerNorm(channels)) self.norms_2.append(LayerNorm(channels)) def forward(self, x, x_mask, g=None): if g is not None: x = x + g for i in range(self.n_layers): y = self.convs_sep[i](x * x_mask) y = self.norms_1[i](y) y = F.gelu(y) y = self.convs_1x1[i](y) y = self.norms_2[i](y) y = F.gelu(y) y = self.drop(y) x = x + y return x * x_mask class ConvFlow(nn.Module): def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): super().__init__() self.in_channels = in_channels self.filter_channels = filter_channels self.kernel_size = kernel_size self.n_layers = n_layers self.num_bins = num_bins self.tail_bound = tail_bound self.half_channels = in_channels // 2 self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) self.proj.weight.data.zero_() self.proj.bias.data.zero_() def forward(self, x, x_mask, g=None, reverse=False): x0, x1 = torch.split(x, [self.half_channels] * 2, 1) h = self.pre(x0) h = self.convs(h, x_mask, g=g) h = self.proj(h) * x_mask b, c, t = x0.shape h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) unnormalized_derivatives = h[..., 2 * self.num_bins :] x1, logabsdet = piecewise_rational_quadratic_transform( x1, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, inverse=reverse, tails="linear", tail_bound=self.tail_bound, ) x = torch.cat([x0, x1], 1) * x_mask logdet = torch.sum(logabsdet * x_mask, [1, 2]) if not reverse: return x, logdet else: return x class StochasticDurationPredictor(nn.Module): def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): super().__init__() filter_channels = in_channels # it needs to be removed from future version. self.in_channels = in_channels self.filter_channels = filter_channels self.kernel_size = kernel_size self.p_dropout = p_dropout self.n_flows = n_flows self.gin_channels = gin_channels self.log_flow = Log() self.flows = nn.ModuleList() self.flows.append(ElementwiseAffine(2)) for i in range(n_flows): self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) self.flows.append(modules.Flip()) self.post_pre = nn.Conv1d(1, filter_channels, 1) self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) self.post_flows = nn.ModuleList() self.post_flows.append(modules.ElementwiseAffine(2)) for i in range(4): self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) self.post_flows.append(modules.Flip()) self.pre = nn.Conv1d(in_channels, filter_channels, 1) self.proj = nn.Conv1d(filter_channels, filter_channels, 1) self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, filter_channels, 1) def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): x = torch.detach(x) x = self.pre(x) if g is not None: g = torch.detach(g) x = x + self.cond(g) x = self.convs(x, x_mask) x = self.proj(x) * x_mask if not reverse: flows = self.flows assert w is not None logdet_tot_q = 0 h_w = self.post_pre(w) h_w = self.post_convs(h_w, x_mask) h_w = self.post_proj(h_w) * x_mask e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask z_q = e_q for flow in self.post_flows: z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) logdet_tot_q += logdet_q z_u, z1 = torch.split(z_q, [1, 1], 1) u = torch.sigmoid(z_u) * x_mask z0 = (w - u) * x_mask logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q logdet_tot = 0 z0, logdet = self.log_flow(z0, x_mask) logdet_tot += logdet z = torch.cat([z0, z1], 1) for flow in flows: z, logdet = flow(z, x_mask, g=x, reverse=reverse) logdet_tot = logdet_tot + logdet nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot return nll + logq # [b] else: flows = list(reversed(self.flows)) flows = flows[:-2] + [flows[-1]] # remove a useless vflow z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale for flow in flows: z = flow(z, x_mask, g=x, reverse=reverse) z0, z1 = torch.split(z, [1, 1], 1) logw = z0 return logw # Meta class to wrap all duration predictors class DP(nn.Module): def __init__(self, params): super().__init__() self.name = params.name if params.name == "deterministic": self.dp = DeterministicDurationPredictor( params, ) elif params.name == "flow_matching": self.dp = FlowMatchingDurationPrediction( params, ) else: raise ValueError(f"Invalid duration predictor configuration: {params.name}") @torch.inference_mode() def forward(self, enc_outputs, mask): return self.dp(enc_outputs, mask) def compute_loss(self, durations, enc_outputs, mask): return self.dp.compute_loss(durations, enc_outputs, mask)