mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 09:49:21 +08:00
Initial commit
This commit is contained in:
0
matcha/models/__init__.py
Normal file
0
matcha/models/__init__.py
Normal file
209
matcha/models/baselightningmodule.py
Normal file
209
matcha/models/baselightningmodule.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
This is a base lightning module that can be used to train a model.
|
||||
The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
|
||||
"""
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch.utilities import grad_norm
|
||||
|
||||
from matcha import utils
|
||||
from matcha.utils.utils import plot_tensor
|
||||
|
||||
log = utils.get_pylogger(__name__)
|
||||
|
||||
|
||||
class BaseLightningClass(LightningModule, ABC):
|
||||
def update_data_statistics(self, data_statistics):
|
||||
if data_statistics is None:
|
||||
data_statistics = {
|
||||
"mel_mean": 0.0,
|
||||
"mel_std": 1.0,
|
||||
}
|
||||
|
||||
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
|
||||
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
optimizer = self.hparams.optimizer(params=self.parameters())
|
||||
if self.hparams.scheduler not in (None, {}):
|
||||
scheduler_args = {}
|
||||
# Manage last epoch for exponential schedulers
|
||||
if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
|
||||
if hasattr(self, "ckpt_loaded_epoch"):
|
||||
current_epoch = self.ckpt_loaded_epoch - 1
|
||||
else:
|
||||
current_epoch = -1
|
||||
|
||||
scheduler_args.update({"optimizer": optimizer})
|
||||
scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
|
||||
scheduler.last_epoch = current_epoch
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"interval": self.hparams.scheduler.lightning_args.interval,
|
||||
"frequency": self.hparams.scheduler.lightning_args.frequency,
|
||||
"name": "learning_rate",
|
||||
},
|
||||
}
|
||||
|
||||
return {"optimizer": optimizer}
|
||||
|
||||
def get_losses(self, batch):
|
||||
x, x_lengths = batch["x"], batch["x_lengths"]
|
||||
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||
spks = batch["spks"]
|
||||
|
||||
dur_loss, prior_loss, diff_loss = self(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
y=y,
|
||||
y_lengths=y_lengths,
|
||||
spks=spks,
|
||||
out_size=self.out_size,
|
||||
)
|
||||
return {
|
||||
"dur_loss": dur_loss,
|
||||
"prior_loss": prior_loss,
|
||||
"diff_loss": diff_loss,
|
||||
}
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
def training_step(self, batch: Any, batch_idx: int):
|
||||
loss_dict = self.get_losses(batch)
|
||||
self.log(
|
||||
"step",
|
||||
float(self.global_step),
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
self.log(
|
||||
"sub_loss/train_dur_loss",
|
||||
loss_dict["dur_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"sub_loss/train_prior_loss",
|
||||
loss_dict["prior_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"sub_loss/train_diff_loss",
|
||||
loss_dict["diff_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
total_loss = sum(loss_dict.values())
|
||||
self.log(
|
||||
"loss/train",
|
||||
total_loss,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
return {"loss": total_loss, "log": loss_dict}
|
||||
|
||||
def validation_step(self, batch: Any, batch_idx: int):
|
||||
loss_dict = self.get_losses(batch)
|
||||
self.log(
|
||||
"sub_loss/val_dur_loss",
|
||||
loss_dict["dur_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"sub_loss/val_prior_loss",
|
||||
loss_dict["prior_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"sub_loss/val_diff_loss",
|
||||
loss_dict["diff_loss"],
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
total_loss = sum(loss_dict.values())
|
||||
self.log(
|
||||
"loss/val",
|
||||
total_loss,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
logger=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
def on_validation_end(self) -> None:
|
||||
if self.trainer.is_global_zero:
|
||||
one_batch = next(iter(self.trainer.val_dataloaders))
|
||||
if self.current_epoch == 0:
|
||||
log.debug("Plotting original samples")
|
||||
for i in range(2):
|
||||
y = one_batch["y"][i].unsqueeze(0).to(self.device)
|
||||
self.logger.experiment.add_image(
|
||||
f"original/{i}",
|
||||
plot_tensor(y.squeeze().cpu()),
|
||||
self.current_epoch,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
log.debug("Synthesising...")
|
||||
for i in range(2):
|
||||
x = one_batch["x"][i].unsqueeze(0).to(self.device)
|
||||
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
|
||||
spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
|
||||
output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
|
||||
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
|
||||
attn = output["attn"]
|
||||
self.logger.experiment.add_image(
|
||||
f"generated_enc/{i}",
|
||||
plot_tensor(y_enc.squeeze().cpu()),
|
||||
self.current_epoch,
|
||||
dataformats="HWC",
|
||||
)
|
||||
self.logger.experiment.add_image(
|
||||
f"generated_dec/{i}",
|
||||
plot_tensor(y_dec.squeeze().cpu()),
|
||||
self.current_epoch,
|
||||
dataformats="HWC",
|
||||
)
|
||||
self.logger.experiment.add_image(
|
||||
f"alignment/{i}",
|
||||
plot_tensor(attn.squeeze().cpu()),
|
||||
self.current_epoch,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
def on_before_optimizer_step(self, optimizer):
|
||||
self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})
|
||||
0
matcha/models/components/__init__.py
Normal file
0
matcha/models/components/__init__.py
Normal file
394
matcha/models/components/decoder.py
Normal file
394
matcha/models/components/decoder.py
Normal file
@@ -0,0 +1,394 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from conformer import ConformerBlock
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.embeddings import TimestepEmbedding
|
||||
from einops import pack, rearrange, repeat
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class Block1D(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super().__init__()
|
||||
self.block = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
||||
torch.nn.GroupNorm(groups, dim_out),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
class ResnetBlock1D(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
||||
|
||||
self.block1 = Block1D(dim, dim_out, groups=groups)
|
||||
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
||||
|
||||
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
||||
|
||||
def forward(self, x, mask, time_emb):
|
||||
h = self.block1(x, mask)
|
||||
h += self.mlp(time_emb).unsqueeze(-1)
|
||||
h = self.block2(h, mask)
|
||||
output = h + self.res_conv(x * mask)
|
||||
return output
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ConformerWrapper(ConformerBlock):
|
||||
def __init__( # pylint: disable=useless-super-delegation
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
conv_expansion_factor=2,
|
||||
conv_kernel_size=31,
|
||||
attn_dropout=0,
|
||||
ff_dropout=0,
|
||||
conv_dropout=0,
|
||||
conv_causal=False,
|
||||
):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
ff_mult=ff_mult,
|
||||
conv_expansion_factor=conv_expansion_factor,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
conv_dropout=conv_dropout,
|
||||
conv_causal=conv_causal,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
timestep=None,
|
||||
):
|
||||
return super().forward(x=hidden_states, mask=attention_mask.bool())
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
down_block_type="transformer",
|
||||
mid_block_type="transformer",
|
||||
up_block_type="transformer",
|
||||
):
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
down_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
|
||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||
|
||||
for i in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
mid_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
||||
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i]
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
|
||||
resnet = ResnetBlock1D(
|
||||
dim=2 * input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
up_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True)
|
||||
if not is_last
|
||||
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
|
||||
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
||||
|
||||
self.final_block = Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
|
||||
self.initialize_weights()
|
||||
# nn.init.normal_(self.final_proj.weight)
|
||||
|
||||
@staticmethod
|
||||
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
|
||||
if block_type == "conformer":
|
||||
block = ConformerWrapper(
|
||||
dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_heads,
|
||||
ff_mult=1,
|
||||
conv_expansion_factor=2,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
conv_dropout=dropout,
|
||||
conv_kernel_size=31,
|
||||
)
|
||||
elif block_type == "transformer":
|
||||
block = BasicTransformerBlock(
|
||||
dim=dim,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type {block_type}")
|
||||
|
||||
return block
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x = resnet(x, mask_down, t)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask_down = rearrange(mask_down, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=mask_down,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
mask_down = rearrange(mask_down, "b t -> b 1 t")
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask_mid = rearrange(mask_mid, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=mask_mid,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
mask_mid = rearrange(mask_mid, "b t -> b 1 t")
|
||||
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask_up = rearrange(mask_up, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=mask_up,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
mask_up = rearrange(mask_up, "b t -> b 1 t")
|
||||
x = upsample(x * mask_up)
|
||||
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
|
||||
return output * mask
|
||||
114
matcha/models/components/flow_matching.py
Normal file
114
matcha/models/components/flow_matching.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from matcha.models.components.decoder import Decoder
|
||||
from matcha.utils.pylogger import get_pylogger
|
||||
|
||||
log = get_pylogger(__name__)
|
||||
|
||||
|
||||
class BASECFM(torch.nn.Module, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
n_feats,
|
||||
cfm_params,
|
||||
n_spks=1,
|
||||
spk_emb_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_feats = n_feats
|
||||
self.n_spks = n_spks
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.solver = cfm_params.solver
|
||||
if hasattr(cfm_params, "sigma_min"):
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
else:
|
||||
self.sigma_min = 1e-4
|
||||
|
||||
self.estimator = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
z (_type_): mu + noise (we don't need this in this formulation), we will sample the noise again
|
||||
mask (_type_): output_mask
|
||||
mu (_type_): output of encoder
|
||||
n_timesteps (_type_): number of diffusion steps
|
||||
stoc (bool, optional): _description_. Defaults to False.
|
||||
spks (_type_, optional): _description_. Defaults to None.
|
||||
|
||||
Returns:
|
||||
sample: _description_
|
||||
"""
|
||||
z = torch.randn_like(mu) * temperature
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (_type_): _description_
|
||||
t (_type_): _description_
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
sol = []
|
||||
|
||||
steps = 1
|
||||
while steps <= len(t_span) - 1:
|
||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
||||
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if steps < len(t_span) - 1:
|
||||
dt = t_span[steps + 1] - t
|
||||
steps += 1
|
||||
|
||||
return sol[-1]
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Args:
|
||||
x1 (_type_): Target
|
||||
mask (_type_): target mask
|
||||
mu (_type_): output of encoder
|
||||
spks (_type_, optional): speaker embedding. Defaults to None.
|
||||
|
||||
Returns:
|
||||
loss: diffusion loss
|
||||
y: conditional flow
|
||||
"""
|
||||
b, _, t = mu.shape
|
||||
|
||||
# random timestep
|
||||
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
||||
# sample noise p(x_0)
|
||||
z = torch.randn_like(x1)
|
||||
|
||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||
u = x1 - (1 - self.sigma_min) * z
|
||||
|
||||
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
||||
torch.sum(mask) * u.shape[1]
|
||||
)
|
||||
return loss, y
|
||||
|
||||
|
||||
class CFM(BASECFM):
|
||||
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
|
||||
super().__init__(
|
||||
n_feats=in_channels,
|
||||
cfm_params=cfm_params,
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
)
|
||||
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|
||||
392
matcha/models/components/text_encoder.py
Normal file
392
matcha/models/components/text_encoder.py
Normal file
@@ -0,0 +1,392 @@
|
||||
""" from https://github.com/jaywalnut310/glow-tts """
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
import matcha.utils as utils
|
||||
from matcha.utils.model import sequence_mask
|
||||
|
||||
log = utils.get_pylogger(__name__)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-4):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
||||
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
n_dims = len(x.shape)
|
||||
mean = torch.mean(x, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
||||
|
||||
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
||||
|
||||
shape = [1, -1] + [1] * (n_dims - 2)
|
||||
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
||||
return x
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
self.norm_layers = torch.nn.ModuleList()
|
||||
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(
|
||||
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = LayerNorm(filter_channels)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = LayerNorm(filter_channels)
|
||||
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.drop(x)
|
||||
x = self.proj(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class RotaryPositionalEmbeddings(nn.Module):
|
||||
"""
|
||||
## RoPE module
|
||||
|
||||
Rotary encoding transforms pairs of features by rotating in the 2D plane.
|
||||
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
|
||||
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
|
||||
by an angle depending on the position of the token.
|
||||
"""
|
||||
|
||||
def __init__(self, d: int, base: int = 10_000):
|
||||
r"""
|
||||
* `d` is the number of features $d$
|
||||
* `base` is the constant used for calculating $\Theta$
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.base = base
|
||||
self.d = int(d)
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
|
||||
def _build_cache(self, x: torch.Tensor):
|
||||
r"""
|
||||
Cache $\cos$ and $\sin$ values
|
||||
"""
|
||||
# Return if cache is already built
|
||||
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
|
||||
return
|
||||
|
||||
# Get sequence length
|
||||
seq_len = x.shape[0]
|
||||
|
||||
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||||
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
|
||||
|
||||
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||||
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
||||
|
||||
# Calculate the product of position index and $\theta_i$
|
||||
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
|
||||
|
||||
# Concatenate so that for row $m$ we have
|
||||
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
|
||||
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
||||
|
||||
# Cache them
|
||||
self.cos_cached = idx_theta2.cos()[:, None, None, :]
|
||||
self.sin_cached = idx_theta2.sin()[:, None, None, :]
|
||||
|
||||
def _neg_half(self, x: torch.Tensor):
|
||||
# $\frac{d}{2}$
|
||||
d_2 = self.d // 2
|
||||
|
||||
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
||||
"""
|
||||
# Cache $\cos$ and $\sin$ values
|
||||
x = rearrange(x, "b h t d -> t b h d")
|
||||
|
||||
self._build_cache(x)
|
||||
|
||||
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
||||
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
|
||||
|
||||
# Calculate
|
||||
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
neg_half_x = self._neg_half(x_rope)
|
||||
|
||||
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
|
||||
|
||||
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
out_channels,
|
||||
n_heads,
|
||||
heads_share=True,
|
||||
p_dropout=0.0,
|
||||
proximal_bias=False,
|
||||
proximal_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert channels % n_heads == 0
|
||||
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.n_heads = n_heads
|
||||
self.heads_share = heads_share
|
||||
self.proximal_bias = proximal_bias
|
||||
self.p_dropout = p_dropout
|
||||
self.attn = None
|
||||
|
||||
self.k_channels = channels // n_heads
|
||||
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
||||
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
||||
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
||||
|
||||
# from https://nn.labml.ai/transformers/rope/index.html
|
||||
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
||||
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
||||
|
||||
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
|
||||
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
if proximal_init:
|
||||
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
||||
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
||||
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
|
||||
key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
|
||||
value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
|
||||
|
||||
query = self.query_rotary_pe(query)
|
||||
key = self.key_rotary_pe(key)
|
||||
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
||||
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
||||
return output, p_attn
|
||||
|
||||
@staticmethod
|
||||
def _attention_bias_proximal(length):
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
||||
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size=1,
|
||||
p_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
self.attn_layers = torch.nn.ModuleList()
|
||||
self.norm_layers_1 = torch.nn.ModuleList()
|
||||
self.ffn_layers = torch.nn.ModuleList()
|
||||
self.norm_layers_2 = torch.nn.ModuleList()
|
||||
for _ in range(self.n_layers):
|
||||
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
FFN(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=p_dropout,
|
||||
)
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
for i in range(self.n_layers):
|
||||
x = x * x_mask
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_type,
|
||||
encoder_params,
|
||||
duration_predictor_params,
|
||||
n_vocab,
|
||||
n_spks=1,
|
||||
spk_emb_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder_type = encoder_type
|
||||
self.n_vocab = n_vocab
|
||||
self.n_feats = encoder_params.n_feats
|
||||
self.n_channels = encoder_params.n_channels
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.n_spks = n_spks
|
||||
|
||||
self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
|
||||
torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
|
||||
|
||||
if encoder_params.prenet:
|
||||
self.prenet = ConvReluNorm(
|
||||
self.n_channels,
|
||||
self.n_channels,
|
||||
self.n_channels,
|
||||
kernel_size=5,
|
||||
n_layers=3,
|
||||
p_dropout=0.5,
|
||||
)
|
||||
else:
|
||||
self.prenet = lambda x, x_mask: x
|
||||
|
||||
self.encoder = Encoder(
|
||||
encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
||||
encoder_params.filter_channels,
|
||||
encoder_params.n_heads,
|
||||
encoder_params.n_layers,
|
||||
encoder_params.kernel_size,
|
||||
encoder_params.p_dropout,
|
||||
)
|
||||
|
||||
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
|
||||
self.proj_w = DurationPredictor(
|
||||
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
||||
duration_predictor_params.filter_channels_dp,
|
||||
duration_predictor_params.kernel_size,
|
||||
duration_predictor_params.p_dropout,
|
||||
)
|
||||
|
||||
def forward(self, x, x_lengths, spks=None):
|
||||
x = self.emb(x) * math.sqrt(self.n_channels)
|
||||
x = torch.transpose(x, 1, -1)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
x = self.prenet(x, x_mask)
|
||||
if self.n_spks > 1:
|
||||
x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
||||
x = self.encoder(x, x_mask)
|
||||
mu = self.proj_m(x) * x_mask
|
||||
|
||||
x_dp = torch.detach(x)
|
||||
logw = self.proj_w(x_dp, x_mask)
|
||||
|
||||
return mu, logw, x_mask
|
||||
211
matcha/models/matcha_tts.py
Normal file
211
matcha/models/matcha_tts.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import datetime as dt
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import matcha.utils.monotonic_align as monotonic_align
|
||||
from matcha import utils
|
||||
from matcha.models.baselightningmodule import BaseLightningClass
|
||||
from matcha.models.components.flow_matching import CFM
|
||||
from matcha.models.components.text_encoder import TextEncoder
|
||||
from matcha.utils.model import (
|
||||
denormalize,
|
||||
duration_loss,
|
||||
fix_len_compatibility,
|
||||
generate_path,
|
||||
sequence_mask,
|
||||
)
|
||||
|
||||
log = utils.get_pylogger(__name__)
|
||||
|
||||
|
||||
class MatchaTTS(BaseLightningClass): # 🍵
|
||||
def __init__(
|
||||
self,
|
||||
n_vocab,
|
||||
n_spks,
|
||||
spk_emb_dim,
|
||||
n_feats,
|
||||
encoder,
|
||||
decoder,
|
||||
cfm,
|
||||
data_statistics,
|
||||
out_size,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.n_spks = n_spks
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.n_feats = n_feats
|
||||
self.out_size = out_size
|
||||
|
||||
if n_spks > 1:
|
||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||
|
||||
self.encoder = TextEncoder(
|
||||
encoder.encoder_type,
|
||||
encoder.encoder_params,
|
||||
encoder.duration_predictor_params,
|
||||
n_vocab,
|
||||
n_spks,
|
||||
spk_emb_dim,
|
||||
)
|
||||
|
||||
self.decoder = CFM(
|
||||
in_channels=2 * encoder.encoder_params.n_feats,
|
||||
out_channel=encoder.encoder_params.n_feats,
|
||||
cfm_params=cfm,
|
||||
decoder_params=decoder,
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
)
|
||||
|
||||
self.update_data_statistics(data_statistics)
|
||||
|
||||
@torch.inference_mode()
|
||||
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
|
||||
"""
|
||||
Generates mel-spectrogram from text. Returns:
|
||||
1. encoder outputs
|
||||
2. decoder outputs
|
||||
3. generated alignment
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
||||
x_lengths (torch.Tensor): lengths of texts in batch.
|
||||
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
|
||||
temperature (float, optional): controls variance of terminal distribution.
|
||||
stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
|
||||
Usually, does not provide synthesis improvements.
|
||||
length_scale (float, optional): controls speech pace.
|
||||
Increase value to slow down generated speech and vice versa.
|
||||
"""
|
||||
# For RTF computation
|
||||
t = dt.datetime.now()
|
||||
|
||||
if self.n_spks > 1:
|
||||
# Get speaker embedding
|
||||
spks = self.spk_emb(spks.long())
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
|
||||
|
||||
w = torch.exp(logw) * x_mask
|
||||
w_ceil = torch.ceil(w) * length_scale
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_max_length = int(y_lengths.max())
|
||||
y_max_length_ = fix_len_compatibility(y_max_length)
|
||||
|
||||
# Using obtained durations `w` construct alignment map `attn`
|
||||
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
||||
|
||||
# Align encoded text and get mu_y
|
||||
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
||||
mu_y = mu_y.transpose(1, 2)
|
||||
encoder_outputs = mu_y[:, :, :y_max_length]
|
||||
|
||||
# Generate sample by performing reverse dynamics
|
||||
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
|
||||
decoder_outputs = decoder_outputs[:, :, :y_max_length]
|
||||
|
||||
t = (dt.datetime.now() - t).total_seconds()
|
||||
rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"attn": attn[:, :, :y_max_length],
|
||||
"mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
|
||||
"mel_lengths": y_lengths,
|
||||
"rtf": rtf,
|
||||
}
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
|
||||
"""
|
||||
Computes 3 losses:
|
||||
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
||||
2. prior loss: loss between mel-spectrogram and encoder outputs.
|
||||
3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
||||
x_lengths (torch.Tensor): lengths of texts in batch.
|
||||
y (torch.Tensor): batch of corresponding mel-spectrograms.
|
||||
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
|
||||
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
|
||||
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
|
||||
"""
|
||||
if self.n_spks > 1:
|
||||
# Get speaker embedding
|
||||
spks = self.spk_emb(spks)
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
|
||||
y_max_length = y.shape[-1]
|
||||
|
||||
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
|
||||
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
||||
with torch.no_grad():
|
||||
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
||||
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
|
||||
y_square = torch.matmul(factor.transpose(1, 2), y**2)
|
||||
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
||||
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
|
||||
log_prior = y_square - y_mu_double + mu_square + const
|
||||
|
||||
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
|
||||
attn = attn.detach()
|
||||
|
||||
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
||||
# refered to as prior loss in the paper
|
||||
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
|
||||
dur_loss = duration_loss(logw, logw_, x_lengths)
|
||||
|
||||
# Cut a small segment of mel-spectrogram in order to increase batch size
|
||||
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
|
||||
# - Do not need this hack for Matcha-TTS, but it works with it as well
|
||||
if not isinstance(out_size, type(None)):
|
||||
max_offset = (y_lengths - out_size).clamp(0)
|
||||
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
|
||||
out_offset = torch.LongTensor(
|
||||
[torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
|
||||
).to(y_lengths)
|
||||
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
|
||||
y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
|
||||
|
||||
y_cut_lengths = []
|
||||
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
|
||||
y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
|
||||
y_cut_lengths.append(y_cut_length)
|
||||
cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
|
||||
y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
|
||||
attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
|
||||
|
||||
y_cut_lengths = torch.LongTensor(y_cut_lengths)
|
||||
y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
|
||||
|
||||
attn = attn_cut
|
||||
y = y_cut
|
||||
y_mask = y_cut_mask
|
||||
|
||||
# Align encoded text with mel-spectrogram and get mu_y segment
|
||||
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
||||
mu_y = mu_y.transpose(1, 2)
|
||||
|
||||
# Compute loss of the decoder
|
||||
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
|
||||
|
||||
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
||||
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
||||
|
||||
return dur_loss, prior_loss, diff_loss
|
||||
Reference in New Issue
Block a user