Architectural Overview#

The benchmark is constructed in a modular fashion, with each module configured via a .yaml config using Hydra.

The predominant ingredients are:

Datasets#

Protein and ProteinBatch Data structures#

We make extensive use of Graphein for data processing and featurisation in the framework. To familiarise yourself with the data structures used in the framework, please see the tutorials provided by Graphein. In essence, these objects inherit from torch_geometric.data.Data and torch_geometric.data.Batch respectively, and are used to represent a single protein or a batch of proteins.

ProteinDataModule Base classes#

The framework provides base classes for datasets and datamodules, which can be extended to create new datasets.

The datamodule is the only object that needs to be configured to add a new dataset. The proteinworkshop.datasets.base.ProteinDataModule class is a subclass of pytorch_lightning.LightningDataModule and is used to represent a datamodule for a dataset of proteins. This class is used to create dataloaders for training, validation and testing.

To do so, the datamodule for the new dataset should inherit from proteinworkshop.datasets.base.ProteinDataModule and implement the following methods:

The methods proteinworkshop.datasets.base.ProteinDataModule.train_dataset(), proteinworkshop.datasets.base.ProteinDataModule.val_dataset() and proteinworkshop.datasets.base.ProteinDataModule.test_dataset() should return a proteinworkshop.datasets.base.ProteinDataset object, which is a subclass of torch.utils.data.Dataset and is used to represent a dataset of proteins.

The methods proteinworkshop.datasets.base.ProteinDataModule.train_dataloader(), proteinworkshop.datasets.base.ProteinDataModule.val_dataloader() and proteinworkshop.datasets.base.ProteinDataModule.test_dataloader() should return a graphein.protein.tensor.dataloader.ProteinDataLoader object, which is used to represent a dataloader for a dataset of proteins.

The methods proteinworkshop.datasets.base.ProteinDataModule.download() proteinworkshop.datasets.base.ProteinDataModule.parse_dataset(), handles all of the dataset-specific logic for downloading, and parsing labels, ids/filenames and chains.

Models#

proteinworkshop.models.base.BaseModel and proteinworkshop.models.base.BenchMarkModel Base classes#

These objects orchestrate model training and validation logic. The proteinworkshop.models.base.BaseModel class is a subclass of pytorch_lightning.LightningModule. The proteinworkshop.models.base.BenchMarkModel class is a subclass of proteinworkshop.models.base.BaseModel and is used as the primary orchestrator in the framework.

To use a different structural encoder, the user should overwrite proteinworkshop.models.base.BenchMarkModel.encoder with a new encoder class. The encoder class should be a subclass of torch.nn.Module and should implement the following methods:

The forward method should be of the form:

from proteinworkshop.types import EncoderOutput

def forward(self, x: [Batch, ProteinBatch]) -> EncoderOutput:
    node_emb = x.x
    graph_emb = self.readout(node_emb, x.batch)
    return EncoderOutput({"node_embedding": node_emb, "graph_embedding": graph_embedding})

Consuming a Batch object and returning a dictionary with keys node_embedding and graph_embedding.

Note

Both keys in the output dictionary are not required to be present, depending on whether the task is node-level or graph-level.