mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 09:49:21 +08:00
Initial commit
This commit is contained in:
1
configs/__init__.py
Normal file
1
configs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# this file is needed here to include configs when building project as a package
|
||||
5
configs/callbacks/default.yaml
Normal file
5
configs/callbacks/default.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
defaults:
|
||||
- model_checkpoint.yaml
|
||||
- model_summary.yaml
|
||||
- rich_progress_bar.yaml
|
||||
- _self_
|
||||
17
configs/callbacks/model_checkpoint.yaml
Normal file
17
configs/callbacks/model_checkpoint.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
||||
|
||||
model_checkpoint:
|
||||
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
||||
dirpath: ${paths.output_dir}/checkpoints # directory to save the model file
|
||||
filename: checkpoint_{epoch:03d} # checkpoint filename
|
||||
monitor: epoch # name of the logged metric which determines when model is improving
|
||||
verbose: False # verbosity mode
|
||||
save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
||||
save_top_k: 10 # save k best models (determined by above metric)
|
||||
mode: "max" # "max" means higher metric value is better, can be also "min"
|
||||
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
||||
save_weights_only: False # if True, then only the model’s weights will be saved
|
||||
every_n_train_steps: null # number of training steps between checkpoints
|
||||
train_time_interval: null # checkpoints are monitored at the specified time interval
|
||||
every_n_epochs: 100 # number of epochs between checkpoints
|
||||
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
||||
5
configs/callbacks/model_summary.yaml
Normal file
5
configs/callbacks/model_summary.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
||||
|
||||
model_summary:
|
||||
_target_: lightning.pytorch.callbacks.RichModelSummary
|
||||
max_depth: 3 # the maximum depth of layer nesting that the summary will include
|
||||
0
configs/callbacks/none.yaml
Normal file
0
configs/callbacks/none.yaml
Normal file
4
configs/callbacks/rich_progress_bar.yaml
Normal file
4
configs/callbacks/rich_progress_bar.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
|
||||
|
||||
rich_progress_bar:
|
||||
_target_: lightning.pytorch.callbacks.RichProgressBar
|
||||
21
configs/data/ljspeech.yaml
Normal file
21
configs/data/ljspeech.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
|
||||
name: ljspeech
|
||||
train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt
|
||||
valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt
|
||||
batch_size: 32
|
||||
num_workers: 20
|
||||
pin_memory: True
|
||||
cleaners: [english_cleaners2]
|
||||
add_blank: True
|
||||
n_spks: 1
|
||||
n_fft: 1024
|
||||
n_feats: 80
|
||||
sample_rate: 22050
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
f_min: 0
|
||||
f_max: 8000
|
||||
data_statistics: # Computed for ljspeech dataset
|
||||
mel_mean: -5.536622
|
||||
mel_std: 2.116101
|
||||
seed: ${seed}
|
||||
14
configs/data/vctk.yaml
Normal file
14
configs/data/vctk.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
defaults:
|
||||
- ljspeech
|
||||
- _self_
|
||||
|
||||
_target_: matcha.data.text_mel_datamodule.TextMelDataModule
|
||||
name: vctk
|
||||
train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt
|
||||
valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt
|
||||
batch_size: 32
|
||||
add_blank: True
|
||||
n_spks: 109
|
||||
data_statistics: # Computed for vctk dataset
|
||||
mel_mean: -6.630575
|
||||
mel_std: 2.482914
|
||||
35
configs/debug/default.yaml
Normal file
35
configs/debug/default.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
# @package _global_
|
||||
|
||||
# default debugging setup, runs 1 full epoch
|
||||
# other debugging configs can inherit from this one
|
||||
|
||||
# overwrite task name so debugging logs are stored in separate folder
|
||||
task_name: "debug"
|
||||
|
||||
# disable callbacks and loggers during debugging
|
||||
callbacks: null
|
||||
logger: null
|
||||
|
||||
extras:
|
||||
ignore_warnings: False
|
||||
enforce_tags: False
|
||||
|
||||
# sets level of all command line loggers to 'DEBUG'
|
||||
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
||||
hydra:
|
||||
job_logging:
|
||||
root:
|
||||
level: DEBUG
|
||||
|
||||
# use this to also set hydra loggers to 'DEBUG'
|
||||
# verbose: True
|
||||
|
||||
trainer:
|
||||
max_epochs: 1
|
||||
accelerator: cpu # debuggers don't like gpus
|
||||
devices: 1 # debuggers don't like multiprocessing
|
||||
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
|
||||
|
||||
data:
|
||||
num_workers: 0 # debuggers don't like multiprocessing
|
||||
pin_memory: False # disable gpu memory pin
|
||||
9
configs/debug/fdr.yaml
Normal file
9
configs/debug/fdr.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
# @package _global_
|
||||
|
||||
# runs 1 train, 1 validation and 1 test step
|
||||
|
||||
defaults:
|
||||
- default
|
||||
|
||||
trainer:
|
||||
fast_dev_run: true
|
||||
12
configs/debug/limit.yaml
Normal file
12
configs/debug/limit.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# @package _global_
|
||||
|
||||
# uses only 1% of the training data and 5% of validation/test data
|
||||
|
||||
defaults:
|
||||
- default
|
||||
|
||||
trainer:
|
||||
max_epochs: 3
|
||||
limit_train_batches: 0.01
|
||||
limit_val_batches: 0.05
|
||||
limit_test_batches: 0.05
|
||||
13
configs/debug/overfit.yaml
Normal file
13
configs/debug/overfit.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
# @package _global_
|
||||
|
||||
# overfits to 3 batches
|
||||
|
||||
defaults:
|
||||
- default
|
||||
|
||||
trainer:
|
||||
max_epochs: 20
|
||||
overfit_batches: 3
|
||||
|
||||
# model ckpt and early stopping need to be disabled during overfitting
|
||||
callbacks: null
|
||||
12
configs/debug/profiler.yaml
Normal file
12
configs/debug/profiler.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# @package _global_
|
||||
|
||||
# runs with execution time profiling
|
||||
|
||||
defaults:
|
||||
- default
|
||||
|
||||
trainer:
|
||||
max_epochs: 1
|
||||
profiler: "simple"
|
||||
# profiler: "advanced"
|
||||
# profiler: "pytorch"
|
||||
18
configs/eval.yaml
Normal file
18
configs/eval.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- _self_
|
||||
- data: mnist # choose datamodule with `test_dataloader()` for evaluation
|
||||
- model: mnist
|
||||
- logger: null
|
||||
- trainer: default
|
||||
- paths: default
|
||||
- extras: default
|
||||
- hydra: default
|
||||
|
||||
task_name: "eval"
|
||||
|
||||
tags: ["dev"]
|
||||
|
||||
# passing checkpoint path is necessary for evaluation
|
||||
ckpt_path: ???
|
||||
14
configs/experiment/ljspeech.yaml
Normal file
14
configs/experiment/ljspeech.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
# @package _global_
|
||||
|
||||
# to execute this experiment run:
|
||||
# python train.py experiment=multispeaker
|
||||
|
||||
defaults:
|
||||
- override /data: ljspeech.yaml
|
||||
|
||||
# all parameters below will be merged with parameters from default configurations set above
|
||||
# this allows you to overwrite only specified parameters
|
||||
|
||||
tags: ["ljspeech"]
|
||||
|
||||
run_name: ljspeech
|
||||
18
configs/experiment/ljspeech_min_memory.yaml
Normal file
18
configs/experiment/ljspeech_min_memory.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
# @package _global_
|
||||
|
||||
# to execute this experiment run:
|
||||
# python train.py experiment=multispeaker
|
||||
|
||||
defaults:
|
||||
- override /data: ljspeech.yaml
|
||||
|
||||
# all parameters below will be merged with parameters from default configurations set above
|
||||
# this allows you to overwrite only specified parameters
|
||||
|
||||
tags: ["ljspeech"]
|
||||
|
||||
run_name: ljspeech_min
|
||||
|
||||
|
||||
model:
|
||||
out_size: 172
|
||||
14
configs/experiment/multispeaker.yaml
Normal file
14
configs/experiment/multispeaker.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
# @package _global_
|
||||
|
||||
# to execute this experiment run:
|
||||
# python train.py experiment=multispeaker
|
||||
|
||||
defaults:
|
||||
- override /data: vctk.yaml
|
||||
|
||||
# all parameters below will be merged with parameters from default configurations set above
|
||||
# this allows you to overwrite only specified parameters
|
||||
|
||||
tags: ["multispeaker"]
|
||||
|
||||
run_name: multispeaker
|
||||
8
configs/extras/default.yaml
Normal file
8
configs/extras/default.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
# disable python warnings if they annoy you
|
||||
ignore_warnings: False
|
||||
|
||||
# ask user for tags if none are provided in the config
|
||||
enforce_tags: True
|
||||
|
||||
# pretty print config tree at the start of the run using Rich library
|
||||
print_config: True
|
||||
52
configs/hparams_search/mnist_optuna.yaml
Normal file
52
configs/hparams_search/mnist_optuna.yaml
Normal file
@@ -0,0 +1,52 @@
|
||||
# @package _global_
|
||||
|
||||
# example hyperparameter optimization of some experiment with Optuna:
|
||||
# python train.py -m hparams_search=mnist_optuna experiment=example
|
||||
|
||||
defaults:
|
||||
- override /hydra/sweeper: optuna
|
||||
|
||||
# choose metric which will be optimized by Optuna
|
||||
# make sure this is the correct name of some metric logged in lightning module!
|
||||
optimized_metric: "val/acc_best"
|
||||
|
||||
# here we define Optuna hyperparameter search
|
||||
# it optimizes for value returned from function with @hydra.main decorator
|
||||
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
||||
hydra:
|
||||
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
|
||||
|
||||
sweeper:
|
||||
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
||||
|
||||
# storage URL to persist optimization results
|
||||
# for example, you can use SQLite if you set 'sqlite:///example.db'
|
||||
storage: null
|
||||
|
||||
# name of the study to persist optimization results
|
||||
study_name: null
|
||||
|
||||
# number of parallel workers
|
||||
n_jobs: 1
|
||||
|
||||
# 'minimize' or 'maximize' the objective
|
||||
direction: maximize
|
||||
|
||||
# total number of runs that will be executed
|
||||
n_trials: 20
|
||||
|
||||
# choose Optuna hyperparameter sampler
|
||||
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
|
||||
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
||||
sampler:
|
||||
_target_: optuna.samplers.TPESampler
|
||||
seed: 1234
|
||||
n_startup_trials: 10 # number of random sampling runs before optimization starts
|
||||
|
||||
# define hyperparameter search space
|
||||
params:
|
||||
model.optimizer.lr: interval(0.0001, 0.1)
|
||||
data.batch_size: choice(32, 64, 128, 256)
|
||||
model.net.lin1_size: choice(64, 128, 256)
|
||||
model.net.lin2_size: choice(64, 128, 256)
|
||||
model.net.lin3_size: choice(32, 64, 128, 256)
|
||||
19
configs/hydra/default.yaml
Normal file
19
configs/hydra/default.yaml
Normal file
@@ -0,0 +1,19 @@
|
||||
# https://hydra.cc/docs/configure_hydra/intro/
|
||||
|
||||
# enable color logging
|
||||
defaults:
|
||||
- override hydra_logging: colorlog
|
||||
- override job_logging: colorlog
|
||||
|
||||
# output directory, generated dynamically on each run
|
||||
run:
|
||||
dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
||||
sweep:
|
||||
dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
||||
subdir: ${hydra.job.num}
|
||||
|
||||
job_logging:
|
||||
handlers:
|
||||
file:
|
||||
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
||||
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
||||
0
configs/local/.gitkeep
Normal file
0
configs/local/.gitkeep
Normal file
28
configs/logger/aim.yaml
Normal file
28
configs/logger/aim.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
# https://aimstack.io/
|
||||
|
||||
# example usage in lightning module:
|
||||
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
|
||||
|
||||
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
|
||||
# `aim up`
|
||||
|
||||
aim:
|
||||
_target_: aim.pytorch_lightning.AimLogger
|
||||
repo: ${paths.root_dir} # .aim folder will be created here
|
||||
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
|
||||
|
||||
# aim allows to group runs under experiment name
|
||||
experiment: null # any string, set to "default" if not specified
|
||||
|
||||
train_metric_prefix: "train/"
|
||||
val_metric_prefix: "val/"
|
||||
test_metric_prefix: "test/"
|
||||
|
||||
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
|
||||
system_tracking_interval: 10 # set to null to disable system metrics tracking
|
||||
|
||||
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
|
||||
log_system_params: true
|
||||
|
||||
# enable/disable tracking console logs (default value is true)
|
||||
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
|
||||
12
configs/logger/comet.yaml
Normal file
12
configs/logger/comet.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# https://www.comet.ml
|
||||
|
||||
comet:
|
||||
_target_: lightning.pytorch.loggers.comet.CometLogger
|
||||
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
||||
save_dir: "${paths.output_dir}"
|
||||
project_name: "lightning-hydra-template"
|
||||
rest_api_key: null
|
||||
# experiment_name: ""
|
||||
experiment_key: null # set to resume experiment
|
||||
offline: False
|
||||
prefix: ""
|
||||
7
configs/logger/csv.yaml
Normal file
7
configs/logger/csv.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
# csv logger built in lightning
|
||||
|
||||
csv:
|
||||
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
|
||||
save_dir: "${paths.output_dir}"
|
||||
name: "csv/"
|
||||
prefix: ""
|
||||
9
configs/logger/many_loggers.yaml
Normal file
9
configs/logger/many_loggers.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
# train with many loggers at once
|
||||
|
||||
defaults:
|
||||
# - comet
|
||||
- csv
|
||||
# - mlflow
|
||||
# - neptune
|
||||
- tensorboard
|
||||
- wandb
|
||||
12
configs/logger/mlflow.yaml
Normal file
12
configs/logger/mlflow.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# https://mlflow.org
|
||||
|
||||
mlflow:
|
||||
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
|
||||
# experiment_name: ""
|
||||
# run_name: ""
|
||||
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
|
||||
tags: null
|
||||
# save_dir: "./mlruns"
|
||||
prefix: ""
|
||||
artifact_location: null
|
||||
# run_id: ""
|
||||
9
configs/logger/neptune.yaml
Normal file
9
configs/logger/neptune.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
# https://neptune.ai
|
||||
|
||||
neptune:
|
||||
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
|
||||
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
|
||||
project: username/lightning-hydra-template
|
||||
# name: ""
|
||||
log_model_checkpoints: True
|
||||
prefix: ""
|
||||
10
configs/logger/tensorboard.yaml
Normal file
10
configs/logger/tensorboard.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
# https://www.tensorflow.org/tensorboard/
|
||||
|
||||
tensorboard:
|
||||
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
||||
save_dir: "${paths.output_dir}/tensorboard/"
|
||||
name: null
|
||||
log_graph: False
|
||||
default_hp_metric: True
|
||||
prefix: ""
|
||||
# version: ""
|
||||
16
configs/logger/wandb.yaml
Normal file
16
configs/logger/wandb.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
# https://wandb.ai
|
||||
|
||||
wandb:
|
||||
_target_: lightning.pytorch.loggers.wandb.WandbLogger
|
||||
# name: "" # name of the run (normally generated by wandb)
|
||||
save_dir: "${paths.output_dir}"
|
||||
offline: False
|
||||
id: null # pass correct id to resume experiment!
|
||||
anonymous: null # enable anonymous logging
|
||||
project: "lightning-hydra-template"
|
||||
log_model: False # upload lightning ckpts
|
||||
prefix: "" # a string to put at the beginning of metric keys
|
||||
# entity: "" # set to name of your wandb team
|
||||
group: ""
|
||||
tags: []
|
||||
job_type: ""
|
||||
3
configs/model/cfm/default.yaml
Normal file
3
configs/model/cfm/default.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
name: CFM
|
||||
solver: euler
|
||||
sigma_min: 1e-4
|
||||
7
configs/model/decoder/default.yaml
Normal file
7
configs/model/decoder/default.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
channels: [256, 256]
|
||||
dropout: 0.05
|
||||
attention_head_dim: 64
|
||||
n_blocks: 1
|
||||
num_mid_blocks: 2
|
||||
num_heads: 2
|
||||
act_fn: snakebeta
|
||||
18
configs/model/encoder/default.yaml
Normal file
18
configs/model/encoder/default.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
encoder_type: RoPE Encoder
|
||||
encoder_params:
|
||||
n_feats: ${model.n_feats}
|
||||
n_channels: 192
|
||||
filter_channels: 768
|
||||
filter_channels_dp: 256
|
||||
n_heads: 2
|
||||
n_layers: 6
|
||||
kernel_size: 3
|
||||
p_dropout: 0.1
|
||||
spk_emb_dim: 64
|
||||
n_spks: 1
|
||||
prenet: true
|
||||
|
||||
duration_predictor_params:
|
||||
filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp}
|
||||
kernel_size: 3
|
||||
p_dropout: ${model.encoder.encoder_params.p_dropout}
|
||||
14
configs/model/matcha.yaml
Normal file
14
configs/model/matcha.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
defaults:
|
||||
- _self_
|
||||
- encoder: default.yaml
|
||||
- decoder: default.yaml
|
||||
- cfm: default.yaml
|
||||
- optimizer: adam.yaml
|
||||
|
||||
_target_: matcha.models.matcha_tts.MatchaTTS
|
||||
n_vocab: 178
|
||||
n_spks: ${data.n_spks}
|
||||
spk_emb_dim: 64
|
||||
n_feats: 80
|
||||
data_statistics: ${data.data_statistics}
|
||||
out_size: null # Must be divisible by 4
|
||||
4
configs/model/optimizer/adam.yaml
Normal file
4
configs/model/optimizer/adam.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
_target_: torch.optim.Adam
|
||||
_partial_: true
|
||||
lr: 1e-4
|
||||
weight_decay: 0.0
|
||||
18
configs/paths/default.yaml
Normal file
18
configs/paths/default.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
# path to root directory
|
||||
# this requires PROJECT_ROOT environment variable to exist
|
||||
# you can replace it with "." if you want the root to be the current working directory
|
||||
root_dir: ${oc.env:PROJECT_ROOT}
|
||||
|
||||
# path to data directory
|
||||
data_dir: ${paths.root_dir}/data/
|
||||
|
||||
# path to logging directory
|
||||
log_dir: ${paths.root_dir}/logs/
|
||||
|
||||
# path to output directory, created dynamically by hydra
|
||||
# path generation pattern is specified in `configs/hydra/default.yaml`
|
||||
# use it to store all files generated during the run, like ckpts and metrics
|
||||
output_dir: ${hydra:runtime.output_dir}
|
||||
|
||||
# path to working directory
|
||||
work_dir: ${hydra:runtime.cwd}
|
||||
51
configs/train.yaml
Normal file
51
configs/train.yaml
Normal file
@@ -0,0 +1,51 @@
|
||||
# @package _global_
|
||||
|
||||
# specify here default configuration
|
||||
# order of defaults determines the order in which configs override each other
|
||||
defaults:
|
||||
- _self_
|
||||
- data: ljspeech
|
||||
- model: matcha
|
||||
- callbacks: default
|
||||
- logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
||||
- trainer: default
|
||||
- paths: default
|
||||
- extras: default
|
||||
- hydra: default
|
||||
|
||||
# experiment configs allow for version control of specific hyperparameters
|
||||
# e.g. best hyperparameters for given model and datamodule
|
||||
- experiment: null
|
||||
|
||||
# config for hyperparameter optimization
|
||||
- hparams_search: null
|
||||
|
||||
# optional local config for machine/user specific settings
|
||||
# it's optional since it doesn't need to exist and is excluded from version control
|
||||
- optional local: default
|
||||
|
||||
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
||||
- debug: null
|
||||
|
||||
# task name, determines output directory path
|
||||
task_name: "train"
|
||||
|
||||
run_name: ???
|
||||
|
||||
# tags to help you identify your experiments
|
||||
# you can overwrite this in experiment configs
|
||||
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
|
||||
tags: ["dev"]
|
||||
|
||||
# set False to skip model training
|
||||
train: True
|
||||
|
||||
# evaluate on test set, using best model weights achieved during training
|
||||
# lightning chooses best weights based on the metric specified in checkpoint callback
|
||||
test: True
|
||||
|
||||
# simply provide checkpoint path to resume training
|
||||
ckpt_path: null
|
||||
|
||||
# seed for random number generators in pytorch, numpy and python.random
|
||||
seed: 1234
|
||||
5
configs/trainer/cpu.yaml
Normal file
5
configs/trainer/cpu.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
accelerator: cpu
|
||||
devices: 1
|
||||
9
configs/trainer/ddp.yaml
Normal file
9
configs/trainer/ddp.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
strategy: ddp
|
||||
|
||||
accelerator: gpu
|
||||
devices: [0,1]
|
||||
num_nodes: 1
|
||||
sync_batchnorm: True
|
||||
7
configs/trainer/ddp_sim.yaml
Normal file
7
configs/trainer/ddp_sim.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
# simulate DDP on CPU, useful for debugging
|
||||
accelerator: cpu
|
||||
devices: 2
|
||||
strategy: ddp_spawn
|
||||
20
configs/trainer/default.yaml
Normal file
20
configs/trainer/default.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
_target_: lightning.pytorch.trainer.Trainer
|
||||
|
||||
default_root_dir: ${paths.output_dir}
|
||||
|
||||
max_epochs: -1
|
||||
|
||||
accelerator: gpu
|
||||
devices: [0]
|
||||
|
||||
# mixed precision for extra speed-up
|
||||
precision: 16-mixed
|
||||
|
||||
# perform a validation loop every N training epochs
|
||||
check_val_every_n_epoch: 1
|
||||
|
||||
# set True to to ensure deterministic results
|
||||
# makes training slower but gives more reproducibility than just setting seeds
|
||||
deterministic: False
|
||||
|
||||
gradient_clip_val: 5.0
|
||||
5
configs/trainer/gpu.yaml
Normal file
5
configs/trainer/gpu.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
accelerator: gpu
|
||||
devices: 1
|
||||
5
configs/trainer/mps.yaml
Normal file
5
configs/trainer/mps.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
accelerator: mps
|
||||
devices: 1
|
||||
Reference in New Issue
Block a user