mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-05 18:29:19 +08:00
Initial commit
This commit is contained in:
21
matcha/hifigan/LICENSE
Normal file
21
matcha/hifigan/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Jungil Kong
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
101
matcha/hifigan/README.md
Normal file
101
matcha/hifigan/README.md
Normal file
@@ -0,0 +1,101 @@
|
||||
# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
|
||||
|
||||
### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
|
||||
|
||||
In our [paper](https://arxiv.org/abs/2010.05646),
|
||||
we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
|
||||
We provide our implementation and pretrained models as open source in this repository.
|
||||
|
||||
**Abstract :**
|
||||
Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
|
||||
Although such methods improve the sampling efficiency and memory usage,
|
||||
their sample quality has not yet reached that of autoregressive and flow-based generative models.
|
||||
In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
|
||||
As speech audio consists of sinusoidal signals with various periods,
|
||||
we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
|
||||
A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
|
||||
demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
|
||||
real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
|
||||
speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
|
||||
faster than real-time on CPU with comparable quality to an autoregressive counterpart.
|
||||
|
||||
Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
|
||||
|
||||
## Pre-requisites
|
||||
|
||||
1. Python >= 3.6
|
||||
2. Clone this repository.
|
||||
3. Install python requirements. Please refer [requirements.txt](requirements.txt)
|
||||
4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
|
||||
And move all wav files to `LJSpeech-1.1/wavs`
|
||||
|
||||
## Training
|
||||
|
||||
```
|
||||
python train.py --config config_v1.json
|
||||
```
|
||||
|
||||
To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
|
||||
Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
|
||||
You can change the path by adding `--checkpoint_path` option.
|
||||
|
||||
Validation loss during training with V1 generator.<br>
|
||||

|
||||
|
||||
## Pretrained Model
|
||||
|
||||
You can also use pretrained models we provide.<br/>
|
||||
[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
|
||||
Details of each folder are as in follows:
|
||||
|
||||
| Folder Name | Generator | Dataset | Fine-Tuned |
|
||||
| ------------ | --------- | --------- | ------------------------------------------------------ |
|
||||
| LJ_V1 | V1 | LJSpeech | No |
|
||||
| LJ_V2 | V2 | LJSpeech | No |
|
||||
| LJ_V3 | V3 | LJSpeech | No |
|
||||
| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
|
||||
| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
|
||||
| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
|
||||
| VCTK_V1 | V1 | VCTK | No |
|
||||
| VCTK_V2 | V2 | VCTK | No |
|
||||
| VCTK_V3 | V3 | VCTK | No |
|
||||
| UNIVERSAL_V1 | V1 | Universal | No |
|
||||
|
||||
We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
|
||||
|
||||
## Fine-Tuning
|
||||
|
||||
1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
|
||||
The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
|
||||
Example:
|
||||
` Audio File : LJ001-0001.wav
|
||||
Mel-Spectrogram File : LJ001-0001.npy`
|
||||
2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
|
||||
3. Run the following command.
|
||||
```
|
||||
python train.py --fine_tuning True --config config_v1.json
|
||||
```
|
||||
For other command line options, please refer to the training section.
|
||||
|
||||
## Inference from wav file
|
||||
|
||||
1. Make `test_files` directory and copy wav files into the directory.
|
||||
2. Run the following command.
|
||||
` python inference.py --checkpoint_file [generator checkpoint file path]`
|
||||
Generated wav files are saved in `generated_files` by default.<br>
|
||||
You can change the path by adding `--output_dir` option.
|
||||
|
||||
## Inference for end-to-end speech synthesis
|
||||
|
||||
1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
|
||||
You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
|
||||
[Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
|
||||
2. Run the following command.
|
||||
` python inference_e2e.py --checkpoint_file [generator checkpoint file path]`
|
||||
Generated wav files are saved in `generated_files_from_mel` by default.<br>
|
||||
You can change the path by adding `--output_dir` option.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
|
||||
and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
|
||||
0
matcha/hifigan/__init__.py
Normal file
0
matcha/hifigan/__init__.py
Normal file
28
matcha/hifigan/config.py
Normal file
28
matcha/hifigan/config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
v1 = {
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0004,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
"upsample_rates": [8, 8, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 4, 4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_initial_channel": 256,
|
||||
"segment_size": 8192,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
"sampling_rate": 22050,
|
||||
"fmin": 0,
|
||||
"fmax": 8000,
|
||||
"fmax_loss": None,
|
||||
"num_workers": 4,
|
||||
"dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
|
||||
}
|
||||
64
matcha/hifigan/denoiser.py
Normal file
64
matcha/hifigan/denoiser.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
|
||||
|
||||
"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
|
||||
import torch
|
||||
|
||||
|
||||
class Denoiser(torch.nn.Module):
|
||||
"""Removes model bias from audio produced with waveglow"""
|
||||
|
||||
def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
|
||||
super().__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = int(filter_length / n_overlap)
|
||||
self.win_length = win_length
|
||||
|
||||
dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
|
||||
self.device = device
|
||||
if mode == "zeros":
|
||||
mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
|
||||
elif mode == "normal":
|
||||
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
|
||||
else:
|
||||
raise Exception(f"Mode {mode} if not supported")
|
||||
|
||||
def stft_fn(audio, n_fft, hop_length, win_length, window):
|
||||
spec = torch.stft(
|
||||
audio,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.view_as_real(spec)
|
||||
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
|
||||
|
||||
self.stft = lambda x: stft_fn(
|
||||
audio=x,
|
||||
n_fft=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=torch.hann_window(self.win_length, device=device),
|
||||
)
|
||||
self.istft = lambda x, y: torch.istft(
|
||||
torch.complex(x * torch.cos(y), x * torch.sin(y)),
|
||||
n_fft=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=torch.hann_window(self.win_length, device=device),
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
bias_audio = vocoder(mel_input).float().squeeze(0)
|
||||
bias_spec, _ = self.stft(bias_audio)
|
||||
|
||||
self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, audio, strength=0.0005):
|
||||
audio_spec, audio_angles = self.stft(audio)
|
||||
audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
|
||||
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
|
||||
audio_denoised = self.istft(audio_spec_denoised, audio_angles)
|
||||
return audio_denoised
|
||||
17
matcha/hifigan/env.py
Normal file
17
matcha/hifigan/env.py
Normal file
@@ -0,0 +1,17 @@
|
||||
""" from https://github.com/jik876/hifi-gan """
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
||||
217
matcha/hifigan/meldataset.py
Normal file
217
matcha/hifigan/meldataset.py
Normal file
@@ -0,0 +1,217 @@
|
||||
""" from https://github.com/jik876/hifi-gan """
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from librosa.util import normalize
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def load_wav(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return data, sampling_rate
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[str(y.device)],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
||||
|
||||
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def get_dataset_filelist(a):
|
||||
with open(a.input_training_file, encoding="utf-8") as fi:
|
||||
training_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
|
||||
with open(a.input_validation_file, encoding="utf-8") as fi:
|
||||
validation_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
return training_files, validation_files
|
||||
|
||||
|
||||
class MelDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
training_files,
|
||||
segment_size,
|
||||
n_fft,
|
||||
num_mels,
|
||||
hop_size,
|
||||
win_size,
|
||||
sampling_rate,
|
||||
fmin,
|
||||
fmax,
|
||||
split=True,
|
||||
shuffle=True,
|
||||
n_cache_reuse=1,
|
||||
device=None,
|
||||
fmax_loss=None,
|
||||
fine_tuning=False,
|
||||
base_mels_path=None,
|
||||
):
|
||||
self.audio_files = training_files
|
||||
random.seed(1234)
|
||||
if shuffle:
|
||||
random.shuffle(self.audio_files)
|
||||
self.segment_size = segment_size
|
||||
self.sampling_rate = sampling_rate
|
||||
self.split = split
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.fmax_loss = fmax_loss
|
||||
self.cached_wav = None
|
||||
self.n_cache_reuse = n_cache_reuse
|
||||
self._cache_ref_count = 0
|
||||
self.device = device
|
||||
self.fine_tuning = fine_tuning
|
||||
self.base_mels_path = base_mels_path
|
||||
|
||||
def __getitem__(self, index):
|
||||
filename = self.audio_files[index]
|
||||
if self._cache_ref_count == 0:
|
||||
audio, sampling_rate = load_wav(filename)
|
||||
audio = audio / MAX_WAV_VALUE
|
||||
if not self.fine_tuning:
|
||||
audio = normalize(audio) * 0.95
|
||||
self.cached_wav = audio
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
|
||||
self._cache_ref_count = self.n_cache_reuse
|
||||
else:
|
||||
audio = self.cached_wav
|
||||
self._cache_ref_count -= 1
|
||||
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
if not self.fine_tuning:
|
||||
if self.split:
|
||||
if audio.size(1) >= self.segment_size:
|
||||
max_audio_start = audio.size(1) - self.segment_size
|
||||
audio_start = random.randint(0, max_audio_start)
|
||||
audio = audio[:, audio_start : audio_start + self.segment_size]
|
||||
else:
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||
|
||||
mel = mel_spectrogram(
|
||||
audio,
|
||||
self.n_fft,
|
||||
self.num_mels,
|
||||
self.sampling_rate,
|
||||
self.hop_size,
|
||||
self.win_size,
|
||||
self.fmin,
|
||||
self.fmax,
|
||||
center=False,
|
||||
)
|
||||
else:
|
||||
mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
|
||||
mel = torch.from_numpy(mel)
|
||||
|
||||
if len(mel.shape) < 3:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
if self.split:
|
||||
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
||||
|
||||
if audio.size(1) >= self.segment_size:
|
||||
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||
audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
|
||||
else:
|
||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||
|
||||
mel_loss = mel_spectrogram(
|
||||
audio,
|
||||
self.n_fft,
|
||||
self.num_mels,
|
||||
self.sampling_rate,
|
||||
self.hop_size,
|
||||
self.win_size,
|
||||
self.fmin,
|
||||
self.fmax_loss,
|
||||
center=False,
|
||||
)
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_files)
|
||||
368
matcha/hifigan/models.py
Normal file
368
matcha/hifigan/models.py
Normal file
@@ -0,0 +1,368 @@
|
||||
""" from https://github.com/jik876/hifi-gan """
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
from .xutils import get_padding, init_weights
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
self.h = h
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super().__init__()
|
||||
self.h = h
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, h):
|
||||
super().__init__()
|
||||
self.h = h
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
]
|
||||
)
|
||||
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
y = self.meanpools[i - 1](y)
|
||||
y_hat = self.meanpools[i - 1](y_hat)
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
60
matcha/hifigan/xutils.py
Normal file
60
matcha/hifigan/xutils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
""" from https://github.com/jik876/hifi-gan """
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
import torch
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def apply_weight_norm(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
weight_norm(m)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print(f"Loading '{filepath}'")
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def save_checkpoint(filepath, obj):
|
||||
print(f"Saving checkpoint to {filepath}")
|
||||
torch.save(obj, filepath)
|
||||
print("Complete.")
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + "????????")
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return None
|
||||
return sorted(cp_list)[-1]
|
||||
Reference in New Issue
Block a user