Callbacks#

Default#

defaults:
  - model_checkpoint.yaml
  - early_stopping.yaml
  - model_summary.yaml
  - rich_progress_bar.yaml
  - learning_rate_monitor.yaml
  - stop_on_nan.yaml
  - _self_

model_checkpoint:
  dirpath: ${env.paths.output_dir}/checkpoints
  filename: "epoch_{epoch:03d}"
  save_last: True
  auto_insert_metric_name: False

model_summary:
  max_depth: -1

Training#

Early Stopping (early_stopping)#

config/callbacks/early_stopping.yaml#
# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.EarlyStopping.html

# Monitor a metric and stop training when it stops improving.
# Look at the above link for more detailed information.
early_stopping:
  _target_: lightning.pytorch.callbacks.EarlyStopping
  monitor: ??? # quantity to be monitored, must be specified !!!
  min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
  patience: 10 # number of checks with no improvement after which training will be stopped
  verbose: True # verbosity mode
  mode: "min" # "max" means higher metric value is better, can be also "min"
  strict: True # whether to crash the training if monitor is not found in the validation metrics
  check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
  stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
  divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
  check_on_train_epoch_end: False # whether to run early stopping at the end of the training epoch
  # log_rank_zero_only: False  # this keyword argument isn't available in stable version

Checkpointing (model_checkpoint)#

config/callbacks/model_checkpoint.yaml#
# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.ModelCheckpoint.html

# Save the model periodically by monitoring a quantity.
# Look at the above link for more detailed information.
model_checkpoint:
  _target_: lightning.pytorch.callbacks.ModelCheckpoint
  dirpath: null # directory to save the model file
  filename: null # checkpoint filename
  monitor: null # name of the logged metric which determines when model is improving
  verbose: True # verbosity mode
  save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
  save_top_k: 1 # save k best models (determined by above metric)
  mode: "min" # "max" means higher metric value is better, can be also "min"
  auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
  save_weights_only: False # if True, then only the model’s weights will be saved
  every_n_train_steps: null # number of training steps between checkpoints
  train_time_interval: null # checkpoints are monitored at the specified time interval
  every_n_epochs: null # number of epochs between checkpoints
  save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation

Stop on NaN (stop_on_nan)#

config/callbacks/stop_on_nan.yaml#
# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.EarlyStopping.html

# Monitor a metric and stop training when it stops improving.
# Look at the above link for more detailed information.
stop_on_nan:
  _target_: lightning.pytorch.callbacks.EarlyStopping
  monitor: train/loss/total # quantity to be monitored, must be specified !!!
  min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
  patience: 10_000_000 # number of checks with no improvement after which training will be stopped
  verbose: True # verbosity mode
  mode: "min" # "max" means higher metric value is better, can be also "min"
  strict: True # whether to crash the training if monitor is not found in the validation metrics
  check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
  stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
  divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
  check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
  # log_rank_zero_only: False  # this keyword argument isn't available in stable version

Exponential Moving Average (ema)#

config/callbacks/ema.yaml#
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py

# Maintains an exponential moving average (EMA) of model weights.
# Look at the above link for more detailed information regarding the original implementation.
ema:
  _target_: proteinworkshop.utils.EMA
  decay: 0.9999 # weight decay factor for the EMA
  apply_ema_every_n_steps: 1 # after how many steps to apply the EMA
  start_step: 0 # when to start the EMA
  save_ema_weights_in_callback_state: true # whether to store the EMA weights in the corresponding `LightningModule`'s `callback` state
  evaluate_ema_weights_instead: true # whether to perform e.g., validation or testing using the EMA weights instead of the original weights

Logging#

Model Summary (model_summary)#

config/callbacks/model_summary.yaml#
# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.RichModelSummary.html

# Generates a summary of all layers in a LightningModule with rich text formatting.
# Look at the above link for more detailed information.
model_summary:
  _target_: lightning.pytorch.callbacks.RichModelSummary
  max_depth: 1 # the maximum depth of layer nesting that the summary will include

Rich Progress Bar (rich_progress_bar)#

config/callbacks/rich_progress_bar.yaml#
# https://pytorch-lightning.readthedocs.io/en/latest/api/lightning.callbacks.RichProgressBar.html

# Create a progress bar with rich text formatting.
# Look at the above link for more detailed information.
rich_progress_bar:
  _target_: lightning.pytorch.callbacks.RichProgressBar

Learning Rate Monitor (learning_rate_monitor)#

This is automatically configured when using a learning rate scheduler. See Schedulers

config/callbacks/learning_rate_monitor.yaml#
learning_rate_monitor:
  _target_: lightning.pytorch.callbacks.LearningRateMonitor