protein_workshop.utils#

class proteinworkshop.utils.EMA(decay: float, apply_ema_every_n_steps: int = 1, start_step: int = 0, save_ema_weights_in_callback_state: bool = False, evaluate_ema_weights_instead: bool = False)[source]#

Implements Exponential Moving Averaging (EMA). When training a model, this callback will maintain moving averages of the trained parameters. When evaluating, we use the moving averages copy of the trained parameters. When saving, we save an additional set of parameters with the prefix ema. :param decay: The exponential decay used when calculating the moving average. Has to be between 0-1. :param apply_ema_every_n_steps: Apply EMA every n global steps. :param start_step: Start applying EMA from start_step global step onwards. :param save_ema_weights_in_callback_state: Enable saving EMA weights in callback state. :param evaluate_ema_weights_instead: Validate the EMA weights instead of the original weights.

Note this means that when saving the model, the validation metrics are calculated with the EMA weights.

Adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py

load_state_dict(state_dict: dict[str, Any]) None[source]#

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict – the callback state returned by state_dict.

on_load_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None[source]#

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the full checkpoint dictionary that got loaded by the Trainer.

on_test_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test ends.

on_test_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test begins.

on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int) None[source]#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the train begins.

on_validation_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation loop ends.

on_validation_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation loop begins.

state_dict() dict[str, Any][source]#

Called when saving a checkpoint, implement to generate callback’s state_dict.

Returns:

A dictionary containing callback state.

class proteinworkshop.utils.EMAModelCheckpoint(**kwargs)[source]#

Light wrapper around Lightning’s ModelCheckpoint to, upon request, save an EMA copy of the model as well.

Adapted from: https://github.com/NVIDIA/NeMo/blob/be0804f61e82dd0f63da7f9fe8a4d8388e330b18/nemo/utils/exp_manager.py#L744