protein_workshop.train#
Main module to load and train the model. This should be the program entry point.
- proteinworkshop.train.train_model(cfg: DictConfig, encoder: PatchedModule | None = None)[source]#
Trains a model from a config.
If
encoder
is provided, it is used instead of the one specified in the config.The datamodule is instantiated from
cfg.dataset.datamodule
.The callbacks are instantiated from
cfg.callbacks
.The logger is instantiated from
cfg.logger
.The trainer is instantiated from
cfg.trainer
.- (Optional) If the config contains a scheduler, the number of training steps is
inferred from the datamodule and devices and set in the scheduler.
The model is instantiated from
cfg.model
.
7. The datamodule is setup and a dummy forward pass is run to initialise lazy layers for accurate parameter counts. 8. Hyperparameters are logged to wandb if a logger is present. 9. The model is compiled if
cfg.compile
is True. 10. The model is trained ifcfg.task_name
is"train"
. 11. The model is tested ifcfg.test
isTrue
.- Parameters:
cfg (DictConfig) – DictConfig containing the config for the experiment
encoder (Optional[nn.Module]) – Optional encoder to use instead of the one specified in the config