Source code for proteinworkshop.tasks.binding_site_prediction

from typing import List, Set, Union

import torch
from graphein.protein.tensor.data import Protein
from scipy import spatial
from torch_geometric import transforms
from torch_geometric.data import Data

from proteinworkshop.features.representation import get_full_atom_coords


[docs] class BindingSiteTransform(transforms.BaseTransform): """ Extracts binding site labels for a given set of HETATMs. This transform builds a KDTree from the protein coordinates. Atoms belonging to HETATMs (specified by the ``hetatms`` arg at initialization) are then queried against the KDTree to obtain indices of residues within ``threshold`` distance of the HETATM. These indices are used to assign node labels to the protein graph. If ``multilabel`` is set to ``True``, then each binding HETATM will be assigned a separate label (i.e. whether residue :math:`i` is proximal to HETATM :math:`j` is given by: :math:`\hat{y}_{ij} \in \mathbb{R}^{|V| \times |H|}`). Otherwise, the labels will be assigned as a single label (i.e. is residue :math:`i` proximal to any HETATM :math:`\hat{y} \in \mathbb{R}^{|V|}`). proximal to any HETATM). If ``ca_only`` is set to ``True``, then only the alpha carbon atoms will be used to determine proximity. If ``ca_only`` is set to ``False``, then all atoms will be used to determine proximity. I.e. if any atom in a residue is within ``threshold`` distance of a HETATM, then the residue will be labeled accordingly. .. warning:: This transform requires that the ``data.coords`` and ``data.hetatms`` fields to be set on the input Data/Batch. See: :py:meth:`required_attributes` """ def __init__( self, hetatms: List[str], threshold: float, ca_only: bool = False, multilabel: bool = True, ) -> None: """Initializes the BindingSiteTransform. :param hetatms: List of HETATM names to use for labeling. :type hetatms: List[str] :param threshold: Threshold distance for determining proximity in angstroms. :type threshold: float :param ca_only: Whether to use only the alpha carbon atoms for assigning proximity labels, defaults to ``False``. :type ca_only: bool, optional :param multilabel: Whether to assign multilabel labels, defaults to ``True`` :type multilabel: bool, optional """ self.hetatms = hetatms self.threshold = threshold self.ca_only = ca_only self.multilabel = multilabel self.num_classes = len(hetatms) if multilabel else 1 @property def required_attributes(self) -> Set[str]: """Returns the required batch attributes that this transform requires. I.e. ``data.coords`` and ``data.hetatms`` must be set. :return: Set of required attributes :rtype: Set[str] """ return {"coords", "hetatms"} def __call__(self, data: Union[Data, Protein]) -> Union[Data, Protein]: # Create a KDTree # If Ca only, we only see if the hetatms are within the # threshold distance of Ca atoms on the input structure if self.ca_only: kd_tree = spatial.KDTree(data.coords[:, 1, :]) else: # If we are not using CA only, we need to flatten the coordinates # And keep track of the atom->residue mapping coords, res_idx, _ = get_full_atom_coords(data.coords) kd_tree = spatial.KDTree(coords) if self.multilabel: label = torch.zeros((data.coords.shape[0], self.num_classes)) else: label = torch.zeros(data.coords.shape[0]) for hetatm_idx, hetatm in enumerate(self.hetatms): try: indices = kd_tree.query_ball_point( data.hetatms[0][hetatm].numpy(), r=self.threshold, p=2.0 ) indices = [item for sublist in indices for item in sublist] indices = torch.tensor(indices, dtype=torch.long) # If not CA only, we need to map the atom indices back to residues if not self.ca_only: indices = torch.unique(res_idx[indices]) if self.multilabel: label[indices, hetatm_idx] = 1 else: label[indices] = 1 setattr(data, "node_y", label) except KeyError: continue return data def __repr__(self) -> str: return f"{self.__class__}(hetatms: {self.hetatms}, threshold: {self.threshold})"
if __name__ == "__main__": import graphein graphein.verbose(False) from graphein.protein.tensor.io import protein_to_pyg pdb_code = "3eiy" p = protein_to_pyg(pdb_code=pdb_code, store_het=True) print(p) print(p.hetatms) transform = BindingSiteTransform( hetatms=["HOH", "POP", "SO4", "PEG"], threshold=7.0, ca_only=True ) print(transform) p = transform(p) print(p) print(p.node_y)