protein_workshop.utils#
- class proteinworkshop.utils.EMA(decay: float, apply_ema_every_n_steps: int = 1, start_step: int = 0, save_ema_weights_in_callback_state: bool = False, evaluate_ema_weights_instead: bool = False)[source]#
Implements Exponential Moving Averaging (EMA). When training a model, this callback will maintain moving averages of the trained parameters. When evaluating, we use the moving averages copy of the trained parameters. When saving, we save an additional set of parameters with the prefix ema. :param decay: The exponential decay used when calculating the moving average. Has to be between 0-1. :param apply_ema_every_n_steps: Apply EMA every n global steps. :param start_step: Start applying EMA from
start_step
global step onwards. :param save_ema_weights_in_callback_state: Enable saving EMA weights in callback state. :param evaluate_ema_weights_instead: Validate the EMA weights instead of the original weights.Note this means that when saving the model, the validation metrics are calculated with the EMA weights.
Adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
- load_state_dict(state_dict: dict[str, Any]) None [source]#
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.- Parameters:
state_dict – the callback state returned by
state_dict
.
- on_load_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None [source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainer
instance.pl_module – the current
LightningModule
instance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.
- on_test_start(trainer: Trainer, pl_module: LightningModule) None [source]#
Called when the test begins.
- on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int) None [source]#
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
- on_train_start(trainer: Trainer, pl_module: LightningModule) None [source]#
Called when the train begins.
- on_validation_end(trainer: Trainer, pl_module: LightningModule) None [source]#
Called when the validation loop ends.
- class proteinworkshop.utils.EMAModelCheckpoint(**kwargs)[source]#
Light wrapper around Lightning’s ModelCheckpoint to, upon request, save an EMA copy of the model as well.
Adapted from: https://github.com/NVIDIA/NeMo/blob/be0804f61e82dd0f63da7f9fe8a4d8388e330b18/nemo/utils/exp_manager.py#L744