Source code for proteinworkshop.train

"""
Main module to load and train the model. This should be the program entry
point.
"""
import copy
import sys
from typing import List, Optional

import graphein
import hydra
import lightning as L
import lovely_tensors as lt
import torch
import torch.nn as nn
import torch_geometric
from graphein.protein.tensor.dataloader import ProteinDataLoader
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger
from loguru import logger as log
from omegaconf import DictConfig

from proteinworkshop import (
    constants,
    register_custom_omegaconf_resolvers,
    utils,
)
from proteinworkshop.configs import config
from proteinworkshop.models.base import BenchMarkModel

graphein.verbose(False)
lt.monkey_patch()


def _num_training_steps(
    train_dataset: ProteinDataLoader, trainer: L.Trainer
) -> int:
    """
    Returns total training steps inferred from datamodule and devices.

    :param train_dataset: Training dataloader
    :type train_dataset: ProteinDataLoader
    :param trainer: Lightning trainer
    :type trainer: L.Trainer
    :return: Total number of training steps
    :rtype: int
    """
    if trainer.max_steps != -1:
        return trainer.max_steps

    dataset_size = (
        trainer.limit_train_batches
        if trainer.limit_train_batches not in {0, 1}
        else len(train_dataset) * train_dataset.batch_size
    )

    log.info(f"Dataset size: {dataset_size}")

    num_devices = max(1, trainer.num_devices)
    effective_batch_size = (
        train_dataset.batch_size
        * trainer.accumulate_grad_batches
        * num_devices
    )
    return (dataset_size // effective_batch_size) * trainer.max_epochs


[docs] def train_model( cfg: DictConfig, encoder: Optional[nn.Module] = None ): # sourcery skip: extract-method """ Trains a model from a config. If ``encoder`` is provided, it is used instead of the one specified in the config. 1. The datamodule is instantiated from ``cfg.dataset.datamodule``. 2. The callbacks are instantiated from ``cfg.callbacks``. 3. The logger is instantiated from ``cfg.logger``. 4. The trainer is instantiated from ``cfg.trainer``. 5. (Optional) If the config contains a scheduler, the number of training steps is inferred from the datamodule and devices and set in the scheduler. 6. The model is instantiated from ``cfg.model``. 7. The datamodule is setup and a dummy forward pass is run to initialise lazy layers for accurate parameter counts. 8. Hyperparameters are logged to wandb if a logger is present. 9. The model is compiled if ``cfg.compile`` is True. 10. The model is trained if ``cfg.task_name`` is ``"train"``. 11. The model is tested if ``cfg.test`` is ``True``. :param cfg: DictConfig containing the config for the experiment :type cfg: DictConfig :param encoder: Optional encoder to use instead of the one specified in the config :type encoder: Optional[nn.Module] """ # set seed for random number generators in pytorch, numpy and python.random L.seed_everything(cfg.seed) log.info( f"Instantiating datamodule: <{cfg.dataset.datamodule._target_}..." ) datamodule: L.LightningDataModule = hydra.utils.instantiate( cfg.dataset.datamodule ) log.info("Instantiating callbacks...") callbacks: List[Callback] = utils.callbacks.instantiate_callbacks( cfg.get("callbacks") ) log.info("Instantiating loggers...") logger: List[Logger] = utils.loggers.instantiate_loggers(cfg.get("logger")) log.info("Instantiating trainer...") trainer: L.Trainer = hydra.utils.instantiate( cfg.trainer, callbacks=callbacks, logger=logger ) if cfg.get("scheduler"): if ( cfg.scheduler.scheduler._target_ == "flash.core.optimizers.LinearWarmupCosineAnnealingLR" and cfg.scheduler.interval == "step" ): datamodule.setup() # type: ignore num_steps = _num_training_steps( datamodule.train_dataloader(), trainer ) log.info( f"Setting number of training steps in scheduler to: {num_steps}" ) cfg.scheduler.scheduler.warmup_epochs = ( num_steps / trainer.max_epochs ) cfg.scheduler.scheduler.max_epochs = num_steps log.info(cfg.scheduler) log.info("Instantiating model...") model: L.LightningModule = BenchMarkModel(cfg) if encoder is not None: log.info(f"Setting user-defined encoder {encoder}...") model.encoder = encoder log.info("Initializing lazy layers...") with torch.no_grad(): datamodule.setup() # type: ignore batch = next(iter(datamodule.val_dataloader())) log.info(f"Unfeaturized batch: {batch}") batch = model.featurise(batch) log.info(f"Featurized batch: {batch}") log.info(f"Example labels: {model.get_labels(batch)}") # Check batch has required attributes for attr in model.encoder.required_batch_attributes: # type: ignore if not hasattr(batch, attr): raise AttributeError( f"Batch {batch} does not have required attribute: {attr} ({model.encoder.required_batch_attributes})" ) out = model(batch) log.info(f"Model output: {out}") del batch, out object_dict = { "cfg": cfg, "datamodule": datamodule, "model": model, "callbacks": callbacks, "logger": logger, "trainer": trainer, } if logger: log.info("Logging hyperparameters!") utils.logging_utils.log_hyperparameters(object_dict) if cfg.get("compile"): log.info("Compiling model!") model = torch_geometric.compile(model, dynamic=True) if cfg.get("task_name") == "train": log.info("Starting training!") trainer.fit( model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path") ) if cfg.get("test"): log.info("Starting testing!") # Run test on all splits if using fold_classification dataset if ( cfg.dataset.datamodule._target_ == "proteinworkshop.datasets.fold_classification.FoldClassificationDataModule" ): splits = ["fold", "family", "superfamily"] wandb_logger = copy.deepcopy(trainer.logger) for split in splits: dataloader = datamodule.get_test_loader(split) trainer.logger = False results = trainer.test( model=model, dataloaders=dataloader, ckpt_path="best" )[0] results = {f"{k}/{split}": v for k, v in results.items()} log.info(f"{split}: {results}") wandb_logger.log_metrics(results) else: trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
# Load hydra config from yaml files and command line arguments. @hydra.main( version_base="1.3", config_path=str(constants.HYDRA_CONFIG_PATH), config_name="train", ) def _main(cfg: DictConfig) -> None: """Load and validate the hydra config.""" utils.extras(cfg) cfg = config.validate_config(cfg) train_model(cfg) def _script_main(args: List[str]) -> None: """ Provides an entry point for the script dispatcher. Sets the sys.argv to the provided args and calls the main train function. """ sys.argv = args register_custom_omegaconf_resolvers() _main() if __name__ == "__main__": # pylint: disable=no-value-for-parameter register_custom_omegaconf_resolvers() _main() # type: ignore