"""
Main module to load and train the model. This should be the program entry
point.
"""
import copy
import sys
from typing import List, Optional
import graphein
import hydra
import lightning as L
import lovely_tensors as lt
import torch
import torch.nn as nn
import torch_geometric
from graphein.protein.tensor.dataloader import ProteinDataLoader
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger
from loguru import logger as log
from omegaconf import DictConfig
from proteinworkshop import (
constants,
register_custom_omegaconf_resolvers,
utils,
)
from proteinworkshop.configs import config
from proteinworkshop.models.base import BenchMarkModel
graphein.verbose(False)
lt.monkey_patch()
def _num_training_steps(
train_dataset: ProteinDataLoader, trainer: L.Trainer
) -> int:
"""
Returns total training steps inferred from datamodule and devices.
:param train_dataset: Training dataloader
:type train_dataset: ProteinDataLoader
:param trainer: Lightning trainer
:type trainer: L.Trainer
:return: Total number of training steps
:rtype: int
"""
if trainer.max_steps != -1:
return trainer.max_steps
dataset_size = (
trainer.limit_train_batches
if trainer.limit_train_batches not in {0, 1}
else len(train_dataset) * train_dataset.batch_size
)
log.info(f"Dataset size: {dataset_size}")
num_devices = max(1, trainer.num_devices)
effective_batch_size = (
train_dataset.batch_size
* trainer.accumulate_grad_batches
* num_devices
)
return (dataset_size // effective_batch_size) * trainer.max_epochs
[docs]
def train_model(
cfg: DictConfig, encoder: Optional[nn.Module] = None
): # sourcery skip: extract-method
"""
Trains a model from a config.
If ``encoder`` is provided, it is used instead of the one specified in the
config.
1. The datamodule is instantiated from ``cfg.dataset.datamodule``.
2. The callbacks are instantiated from ``cfg.callbacks``.
3. The logger is instantiated from ``cfg.logger``.
4. The trainer is instantiated from ``cfg.trainer``.
5. (Optional) If the config contains a scheduler, the number of training steps is
inferred from the datamodule and devices and set in the scheduler.
6. The model is instantiated from ``cfg.model``.
7. The datamodule is setup and a dummy forward pass is run to initialise
lazy layers for accurate parameter counts.
8. Hyperparameters are logged to wandb if a logger is present.
9. The model is compiled if ``cfg.compile`` is True.
10. The model is trained if ``cfg.task_name`` is ``"train"``.
11. The model is tested if ``cfg.test`` is ``True``.
:param cfg: DictConfig containing the config for the experiment
:type cfg: DictConfig
:param encoder: Optional encoder to use instead of the one specified in
the config
:type encoder: Optional[nn.Module]
"""
# set seed for random number generators in pytorch, numpy and python.random
L.seed_everything(cfg.seed)
log.info(
f"Instantiating datamodule: <{cfg.dataset.datamodule._target_}..."
)
datamodule: L.LightningDataModule = hydra.utils.instantiate(
cfg.dataset.datamodule
)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.callbacks.instantiate_callbacks(
cfg.get("callbacks")
)
log.info("Instantiating loggers...")
logger: List[Logger] = utils.loggers.instantiate_loggers(cfg.get("logger"))
log.info("Instantiating trainer...")
trainer: L.Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger
)
if cfg.get("scheduler"):
if (
cfg.scheduler.scheduler._target_
== "flash.core.optimizers.LinearWarmupCosineAnnealingLR"
and cfg.scheduler.interval == "step"
):
datamodule.setup() # type: ignore
num_steps = _num_training_steps(
datamodule.train_dataloader(), trainer
)
log.info(
f"Setting number of training steps in scheduler to: {num_steps}"
)
cfg.scheduler.scheduler.warmup_epochs = (
num_steps / trainer.max_epochs
)
cfg.scheduler.scheduler.max_epochs = num_steps
log.info(cfg.scheduler)
log.info("Instantiating model...")
model: L.LightningModule = BenchMarkModel(cfg)
if encoder is not None:
log.info(f"Setting user-defined encoder {encoder}...")
model.encoder = encoder
log.info("Initializing lazy layers...")
with torch.no_grad():
datamodule.setup() # type: ignore
batch = next(iter(datamodule.val_dataloader()))
log.info(f"Unfeaturized batch: {batch}")
batch = model.featurise(batch)
log.info(f"Featurized batch: {batch}")
log.info(f"Example labels: {model.get_labels(batch)}")
# Check batch has required attributes
for attr in model.encoder.required_batch_attributes: # type: ignore
if not hasattr(batch, attr):
raise AttributeError(
f"Batch {batch} does not have required attribute: {attr} ({model.encoder.required_batch_attributes})"
)
out = model(batch)
log.info(f"Model output: {out}")
del batch, out
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.logging_utils.log_hyperparameters(object_dict)
if cfg.get("compile"):
log.info("Compiling model!")
model = torch_geometric.compile(model, dynamic=True)
if cfg.get("task_name") == "train":
log.info("Starting training!")
trainer.fit(
model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")
)
if cfg.get("test"):
log.info("Starting testing!")
# Run test on all splits if using fold_classification dataset
if (
cfg.dataset.datamodule._target_
== "proteinworkshop.datasets.fold_classification.FoldClassificationDataModule"
):
splits = ["fold", "family", "superfamily"]
wandb_logger = copy.deepcopy(trainer.logger)
for split in splits:
dataloader = datamodule.get_test_loader(split)
trainer.logger = False
results = trainer.test(
model=model, dataloaders=dataloader, ckpt_path="best"
)[0]
results = {f"{k}/{split}": v for k, v in results.items()}
log.info(f"{split}: {results}")
wandb_logger.log_metrics(results)
else:
trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
# Load hydra config from yaml files and command line arguments.
@hydra.main(
version_base="1.3",
config_path=str(constants.HYDRA_CONFIG_PATH),
config_name="train",
)
def _main(cfg: DictConfig) -> None:
"""Load and validate the hydra config."""
utils.extras(cfg)
cfg = config.validate_config(cfg)
train_model(cfg)
def _script_main(args: List[str]) -> None:
"""
Provides an entry point for the script dispatcher.
Sets the sys.argv to the provided args and calls the main train function.
"""
sys.argv = args
register_custom_omegaconf_resolvers()
_main()
if __name__ == "__main__":
# pylint: disable=no-value-for-parameter
register_custom_omegaconf_resolvers()
_main() # type: ignore