Architectural Overview#
The benchmark is constructed in a modular fashion, with each module configured via a .yaml config using Hydra.
The predominant ingredients are:
- Datasets: The benchmark supports a variety of datasets, described in Dataset and documented in protein_workshop.datasets 
- Models: The benchmark supports a variety of models, described in Models and documented in protein_workshop.models 
- Tasks: The benchmark supports a variety of tasks, described in Tasks and documented in protein_workshop.tasks 
- Features: The benchmark supports a variety of features, described in Features and documented in protein_workshop.features 
Datasets#
Protein and ProteinBatch Data structures#
We make extensive use of Graphein for data processing and featurisation in the framework.
To familiarise yourself with the data structures used in the framework, please see the
tutorials provided by Graphein. In essence, these objects inherit from torch_geometric.data.Data
and torch_geometric.data.Batch respectively, and are used to represent a single protein or a batch of proteins.
ProteinDataModule Base classes#
The framework provides base classes for datasets and datamodules, which can be extended to create new datasets.
The datamodule is the only object that needs to be configured to add a new dataset. The proteinworkshop.datasets.base.ProteinDataModule class is a subclass of pytorch_lightning.LightningDataModule and is used to represent a datamodule for a dataset of proteins. This class is used to create dataloaders for training, validation and testing.
To do so, the datamodule for the new dataset should inherit from proteinworkshop.datasets.base.ProteinDataModule and implement the following methods:
- proteinworkshop.datasets.base.ProteinDataModule.parse_dataset()
- (optionally) - proteinworkshop.datasets.base.ProteinDataModule.parse_labels()
- (optionally) - proteinworkshop.datasets.base.ProteinDataModule.exclude_pdbs()
- proteinworkshop.datasets.base.ProteinDataModule.train_dataset()
- proteinworkshop.datasets.base.ProteinDataModule.val_dataset()
- proteinworkshop.datasets.base.ProteinDataModule.test_dataset()
- proteinworkshop.datasets.base.ProteinDataModule.train_dataloader()
- proteinworkshop.datasets.base.ProteinDataModule.val_dataloader()
- proteinworkshop.datasets.base.ProteinDataModule.test_dataloader()
The methods proteinworkshop.datasets.base.ProteinDataModule.train_dataset(), proteinworkshop.datasets.base.ProteinDataModule.val_dataset() and proteinworkshop.datasets.base.ProteinDataModule.test_dataset() should return a proteinworkshop.datasets.base.ProteinDataset object, which is a subclass of torch.utils.data.Dataset and is used to represent a dataset of proteins.
The methods proteinworkshop.datasets.base.ProteinDataModule.train_dataloader(), proteinworkshop.datasets.base.ProteinDataModule.val_dataloader() and proteinworkshop.datasets.base.ProteinDataModule.test_dataloader() should return a graphein.protein.tensor.dataloader.ProteinDataLoader object, which is used to represent a dataloader for a dataset of proteins.
The methods proteinworkshop.datasets.base.ProteinDataModule.download() proteinworkshop.datasets.base.ProteinDataModule.parse_dataset(), handles all of the dataset-specific logic for downloading, and parsing labels, ids/filenames and chains.
Models#
proteinworkshop.models.base.BaseModel and proteinworkshop.models.base.BenchMarkModel Base classes#
These objects orchestrate model training and validation logic. The proteinworkshop.models.base.BaseModel class is a subclass of pytorch_lightning.LightningModule.
The proteinworkshop.models.base.BenchMarkModel class is a subclass of proteinworkshop.models.base.BaseModel and is used as the primary orchestrator in the framework.
To use a different structural encoder, the user should overwrite proteinworkshop.models.base.BenchMarkModel.encoder with a new encoder class. The encoder class should be a subclass of torch.nn.Module and should implement the following methods:
The forward method should be of the form:
from proteinworkshop.types import EncoderOutput
def forward(self, x: [Batch, ProteinBatch]) -> EncoderOutput:
    node_emb = x.x
    graph_emb = self.readout(node_emb, x.batch)
    return EncoderOutput({"node_embedding": node_emb, "graph_embedding": graph_embedding})
Consuming a Batch object and returning a dictionary with keys node_embedding and graph_embedding.
Note
Both keys in the output dictionary are not required to be present, depending on whether the task is node-level or graph-level.