Architectural Overview#
The benchmark is constructed in a modular fashion, with each module configured via a .yaml
config using Hydra.
The predominant ingredients are:
Datasets: The benchmark supports a variety of datasets, described in Dataset and documented in protein_workshop.datasets
Models: The benchmark supports a variety of models, described in Models and documented in protein_workshop.models
Tasks: The benchmark supports a variety of tasks, described in Tasks and documented in protein_workshop.tasks
Features: The benchmark supports a variety of features, described in Features and documented in protein_workshop.features
Datasets#
Protein
and ProteinBatch
Data structures#
We make extensive use of Graphein for data processing and featurisation in the framework.
To familiarise yourself with the data structures used in the framework, please see the
tutorials provided by Graphein. In essence, these objects inherit from torch_geometric.data.Data
and torch_geometric.data.Batch
respectively, and are used to represent a single protein or a batch of proteins.
ProteinDataModule
Base classes#
The framework provides base classes for datasets and datamodules, which can be extended to create new datasets.
The datamodule is the only object that needs to be configured to add a new dataset. The proteinworkshop.datasets.base.ProteinDataModule
class is a subclass of pytorch_lightning.LightningDataModule
and is used to represent a datamodule for a dataset of proteins. This class is used to create dataloaders for training, validation and testing.
To do so, the datamodule for the new dataset should inherit from proteinworkshop.datasets.base.ProteinDataModule
and implement the following methods:
proteinworkshop.datasets.base.ProteinDataModule.parse_dataset()
(optionally)
proteinworkshop.datasets.base.ProteinDataModule.parse_labels()
(optionally)
proteinworkshop.datasets.base.ProteinDataModule.exclude_pdbs()
proteinworkshop.datasets.base.ProteinDataModule.train_dataset()
proteinworkshop.datasets.base.ProteinDataModule.val_dataset()
proteinworkshop.datasets.base.ProteinDataModule.test_dataset()
proteinworkshop.datasets.base.ProteinDataModule.train_dataloader()
proteinworkshop.datasets.base.ProteinDataModule.val_dataloader()
proteinworkshop.datasets.base.ProteinDataModule.test_dataloader()
The methods proteinworkshop.datasets.base.ProteinDataModule.train_dataset()
, proteinworkshop.datasets.base.ProteinDataModule.val_dataset()
and proteinworkshop.datasets.base.ProteinDataModule.test_dataset()
should return a proteinworkshop.datasets.base.ProteinDataset
object, which is a subclass of torch.utils.data.Dataset
and is used to represent a dataset of proteins.
The methods proteinworkshop.datasets.base.ProteinDataModule.train_dataloader()
, proteinworkshop.datasets.base.ProteinDataModule.val_dataloader()
and proteinworkshop.datasets.base.ProteinDataModule.test_dataloader()
should return a graphein.protein.tensor.dataloader.ProteinDataLoader
object, which is used to represent a dataloader for a dataset of proteins.
The methods proteinworkshop.datasets.base.ProteinDataModule.download()
proteinworkshop.datasets.base.ProteinDataModule.parse_dataset()
, handles all of the dataset-specific logic for downloading, and parsing labels, ids/filenames and chains.
Models#
proteinworkshop.models.base.BaseModel
and proteinworkshop.models.base.BenchMarkModel
Base classes#
These objects orchestrate model training and validation logic. The proteinworkshop.models.base.BaseModel
class is a subclass of pytorch_lightning.LightningModule
.
The proteinworkshop.models.base.BenchMarkModel
class is a subclass of proteinworkshop.models.base.BaseModel
and is used as the primary orchestrator in the framework.
To use a different structural encoder, the user should overwrite proteinworkshop.models.base.BenchMarkModel.encoder
with a new encoder class. The encoder class should be a subclass of torch.nn.Module
and should implement the following methods:
The forward method should be of the form:
from proteinworkshop.types import EncoderOutput
def forward(self, x: [Batch, ProteinBatch]) -> EncoderOutput:
node_emb = x.x
graph_emb = self.readout(node_emb, x.batch)
return EncoderOutput({"node_embedding": node_emb, "graph_embedding": graph_embedding})
Consuming a Batch object and returning a dictionary with keys node_embedding
and graph_embedding
.
Note
Both keys in the output dictionary are not required to be present, depending on whether the task is node-level or graph-level.