protein_workshop.train#

Main module to load and train the model. This should be the program entry point.

proteinworkshop.train.train_model(cfg: DictConfig, encoder: PatchedModule | None = None)[source]#

Trains a model from a config.

If encoder is provided, it is used instead of the one specified in the config.

  1. The datamodule is instantiated from cfg.dataset.datamodule.

  2. The callbacks are instantiated from cfg.callbacks.

  3. The logger is instantiated from cfg.logger.

  4. The trainer is instantiated from cfg.trainer.

  5. (Optional) If the config contains a scheduler, the number of training steps is

    inferred from the datamodule and devices and set in the scheduler.

  6. The model is instantiated from cfg.model.

7. The datamodule is setup and a dummy forward pass is run to initialise lazy layers for accurate parameter counts. 8. Hyperparameters are logged to wandb if a logger is present. 9. The model is compiled if cfg.compile is True. 10. The model is trained if cfg.task_name is "train". 11. The model is tested if cfg.test is True.

Parameters:
  • cfg (DictConfig) – DictConfig containing the config for the experiment

  • encoder (Optional[nn.Module]) – Optional encoder to use instead of the one specified in the config