Source code for proteinworkshop.datasets.go

import os
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Callable, Dict, Iterable, Literal, Optional

import omegaconf
import pandas as pd
import torch
import wget
from graphein.protein.tensor.data import Protein
from graphein.protein.tensor.dataloader import ProteinDataLoader
from loguru import logger as log
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

from proteinworkshop.datasets.base import ProteinDataModule, ProteinDataset

LABEL_LINE: Dict[str, int] = {
    "MF": 1,
    "BP": 5,
    "CC": 9,
}


[docs] class GeneOntologyDataset(ProteinDataModule): """ Statistics (test_cutoff=0.95): - #Train: 27,496 - #Valid: 3,053 - #Test: 2,991 """ def __init__( self, path: str, batch_size: int, split: str = "BP", obsolete="drop", pdb_dir: Optional[str] = None, format: Literal["mmtf", "pdb"] = "mmtf", in_memory: bool = False, dataset_fraction: float = 1.0, shuffle_labels: bool = False, pin_memory: bool = True, num_workers: int = 16, transforms: Optional[Iterable[Callable]] = None, overwrite: bool = False, ) -> None: super().__init__() self.pdb_dir = pdb_dir self.data_dir = Path(path) if not os.path.exists(self.data_dir): os.makedirs(self.data_dir) self.dataset_fraction = dataset_fraction self.split = split self.obsolete = obsolete self.format = format self.in_memory = in_memory self.overwrite = overwrite self.batch_size = batch_size self.pin_memory = pin_memory self.num_workers = num_workers self.prepare_data_per_node = True self.shuffle_labels = shuffle_labels if transforms is not None: self.transform = self.compose_transforms( omegaconf.OmegaConf.to_container(transforms, resolve=True) ) else: self.transform = None self.train_fname = self.data_dir / "nrPDB-GO_train.txt" self.val_fname = self.data_dir / "nrPDB-GO_valid.txt" self.test_fname = self.data_dir / "nrPDB-GO_test.txt" self.label_fname = self.data_dir / "nrPDB-GO_annot.tsv" self.url = "https://zenodo.org/record/6622158/files/GeneOntology.zip" log.info( f"Setting up Gene Ontology dataset. Fraction {self.dataset_fraction}" )
[docs] @lru_cache def parse_labels(self) -> Dict[str, torch.Tensor]: """ Parse the GO labels from the nrPDB-GO_annot.tsv file. """ log.info( f"Loading GO labels for task {self.split} from file {self.label_fname}." ) try: label_line = LABEL_LINE[self.split] except KeyError as e: raise ValueError(f"Task {self.split} not recognised.") from e # Load list of all labels with open(self.label_fname, "r") as f: all_labels = f.readlines()[label_line].strip("\n").split("\t") log.info(f"Found {len(all_labels)} labels for task {self.split}.") # Load labels for each PDB example df = pd.read_csv(self.label_fname, sep="\t", skiprows=12) df.columns = ["PDB", "MF", "BP", "CC"] df.set_index("PDB", inplace=True) # Remove rows with no labels for this task labels = df[self.split].dropna().to_dict() log.info(f"Found {len(labels)} examples for task {self.split}.") # Split GO terms string into list of individual terms labels = {k: v.split(",") for k, v in labels.items()} # Encode labels into numeric values log.info("Encoding labels...") label_encoder = LabelEncoder().fit(all_labels) labels = { k: torch.tensor(label_encoder.transform(v)) for k, v in tqdm(labels.items()) } log.info(f"Encoded {len(labels)} labels for task {self.split}.") return labels
def _get_dataset( self, split: Literal["training", "validation", "testing"] ) -> ProteinDataset: df = self.parse_dataset(split) log.info("Initialising Graphein dataset...") return ProteinDataset( root=str(self.data_dir), pdb_dir=str(self.pdb_dir), pdb_codes=list(df.pdb), chains=list(df.chain), graph_labels=list(list(df.label)), overwrite=self.overwrite, transform=self.labeller if self.transform is None else self.compose_transforms([self.labeller] + [self.transform]), format=self.format, in_memory=self.in_memory, )
[docs] def train_dataset(self) -> ProteinDataset: return self._get_dataset("training")
[docs] def val_dataset(self) -> ProteinDataset: return self._get_dataset("validation")
[docs] def test_dataset(self) -> ProteinDataset: return self._get_dataset("testing")
[docs] def train_dataloader(self) -> ProteinDataLoader: return ProteinDataLoader( self.train_dataset(), batch_size=self.batch_size, shuffle=True, pin_memory=self.pin_memory, num_workers=self.num_workers, )
[docs] def val_dataloader(self) -> ProteinDataLoader: return ProteinDataLoader( self.val_dataset(), batch_size=self.batch_size, shuffle=False, pin_memory=self.pin_memory, num_workers=self.num_workers, )
[docs] def test_dataloader(self) -> ProteinDataLoader: return ProteinDataLoader( self.test_dataset(), batch_size=self.batch_size, shuffle=False, pin_memory=self.pin_memory, num_workers=self.num_workers, )
[docs] def download(self): if not all( os.path.exists(f) for f in [ self.train_fname, self.val_fname, self.test_fname, self.label_fname, ] ): log.info("Downloading dataset...") wget.download(self.url, out=str(self.data_dir)) with zipfile.ZipFile(self.data_dir / "GeneOntology.zip") as f: f.extractall(self.data_dir.parent) else: log.info(f"Found dataset at {self.data_dir}")
[docs] def exclude_pdbs(self): pass
[docs] def parse_dataset( self, split: Literal["training", "validation", "testing"] ) -> pd.DataFrame: # sourcery skip: remove-unnecessary-else, swap-if-else-branches, switch """ Parses the raw dataset files to Pandas DataFrames. Maps classes to numerical values. """ # Load ID: label mapping class_map = self.parse_labels() # Read in IDs of structures in split if split == "training": data = pd.read_csv(self.train_fname, sep="\t", header=None) data = data.sample(frac=self.dataset_fraction) elif split == "validation": data = pd.read_csv(self.val_fname, sep="\t", header=None) elif split == "testing": data = pd.read_csv(self.test_fname, sep="\t", header=None) else: raise ValueError(f"Unknown split: {split}") log.info(f"Found {len(data)} original examples in {split}") log.info("Removing unlabelled proteins for this task...") data = data.loc[data[0].isin(class_map.keys())] log.info(f"Found {len(data)} labelled examples in {split}") # Map labels to IDs in dataframe log.info("Mapping labels to IDs...") data["label"] = data[0].map(class_map) data.columns = ["pdb", "label"] to_drop = ["5EXC-I"] data = data.loc[~data["pdb"].isin(to_drop)] data["chain"] = data["pdb"].str[5:] data["pdb"] = data["pdb"].str[:4].str.lower() if self.obsolete == "drop": log.info("Dropping obsolete PDBs") data = data.loc[ ~data["pdb"].str.lower().isin(self.obsolete_pdbs.keys()) ] log.info( f"Found {len(data)} examples in {split} after dropping obsolete PDBs" ) else: raise NotImplementedError( "Obsolete PDB replacement not implemented" ) # logger.info(f"Identified {len(data['label'].unique())} classes in this split: {split}") if self.shuffle_labels: log.info("Shuffling labels. Expecting random performance.") data["label"] = data["label"].sample(frac=1).values # logger.info(f"Found {len(data)} examples in {split} after removing nonstandard proteins") self.labeller = GOLabeller(data) return data.sample(frac=1) # Shuffle dataset for batches
[docs] class GOLabeller: """ This labeller applies the graph labels to each example as a transform. This is required as chains can be used across tasks (e.g. CC, BP or MF) with different labels. """ def __init__(self, label_df: pd.DataFrame): self.labels = label_df def __call__(self, data: Protein) -> Protein: pdb, chain = data.id.split("_") label = self.labels.loc[ (self.labels.pdb == pdb) & (self.labels.chain == chain) ].label.item() data.graph_y = label return data
if __name__ == "__main__": import pathlib import hydra from proteinworkshop import constants log.info("Imported libs") cfg = omegaconf.OmegaConf.load( constants.SRC_PATH / "config" / "dataset" / "go-bp.yaml" ) # cfg = omegaconf.OmegaConf.load(constants.SRC_PATH / "config" / "dataset" / "go-mf.yaml") # cfg = omegaconf.OmegaConf.load(constants.SRC_PATH / "config" / "dataset" / "go-bp.yaml") cfg.datamodule.path = pathlib.Path(constants.DATA_PATH) / "GeneOntology" cfg.datamodule.pdb_dir = pathlib.Path(constants.DATA_PATH) / "pdb" cfg.datamodule.num_workers = 1 cfg.datamodule.transforms = [] log.info("Loaded config") ds = hydra.utils.instantiate(cfg) print(ds) # labels = ds["datamodule"].parse_labels() ds.datamodule.setup() dl = ds["datamodule"].train_dataloader() for batch in dl: print(batch) dl = ds["datamodule"].val_dataloader() for batch in dl: print(batch) dl = ds["datamodule"].test_dataloader() for batch in dl: print(batch)