Source code for proteinworkshop.tasks.ppi_site_prediction
from typing import Dict, Union
import numpy as np
import scipy.spatial as spatial
import torch
from graphein.protein.tensor.data import Protein
from torch_geometric import transforms as T
from torch_geometric.data import Data
from proteinworkshop.features.representation import get_full_atom_coords
[docs]
class BindingSiteTransform(T.BaseTransform):
def __init__(self, radius: float = 3.5, ca_only: bool = True) -> None:
"""Extracts Protein-Protein interaction sites from a protein structure.
.. note::
The chains to be kept as inputs must be specified as
``data.graph_y``. This is typically set in the dataloader.
:param radius: Maximum distance between chains to be considered as
interacting, defaults to 3.5 angstrom
:type radius: float, optional
:param ca_only: Whether to use only the alpha carbon atoms for
determining interactions
:type ca_only: bool, optional
"""
self.radius = radius
self.fill_value = 1e-5
self.ca_only = ca_only
charstr: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
self.chain_map: Dict[str, int] = {
charstr[i]: i for i in range(len(charstr))
}
def __call__(self, data: Union[Protein, Data]):
# Map the chain labels to integers
target_chains = []
chain_strs = [res.split(":")[0] for res in data.residue_id]
chain_strs = list(np.unique(chain_strs))
for chain in data.graph_y:
target_chains.append(chain_strs.index(chain))
target_chains = torch.tensor(target_chains)
target_indices = torch.where(torch.isin(data.chains, target_chains))[0]
# Create a mask for the target chains
mask = torch.zeros(data.coords.shape[0], dtype=torch.bool)
mask[target_indices] = True
# Extract the target chains and the other chains
target_struct = data.coords[mask]
other_chains = data.coords[~mask]
N_TARGET_RESIDUES = target_struct.shape[0]
# Unwrap the coordinates
other_chains = other_chains.reshape(-1, 3)
# Remove any rows with 1e-5
other_chains = other_chains[
~torch.all(other_chains == self.fill_value, dim=1)
]
# Create a KDTree
# If Ca only, we only see if the interacting chains are within the
# threshold distance of Ca atoms on the input chains
if self.ca_only:
kd_tree = spatial.KDTree(target_struct[:, 1, :])
else:
# If we are not using CA only, we need to flatten the coordinates
# And keep track of the atom->residue mapping
coords, res_idx, _ = get_full_atom_coords(target_struct)
kd_tree = spatial.KDTree(coords)
indices = kd_tree.query_ball_point(other_chains, self.radius)
indices = [item for sublist in indices for item in sublist]
indices = torch.tensor(indices, dtype=torch.long)
# If not CA only, we need to map the atom indices back to residues
if not self.ca_only:
indices = torch.unique(res_idx[indices])
label = torch.zeros(N_TARGET_RESIDUES)
label[indices] = 1
data.node_y = label.long()
# Delete the graph label containing the chains to avoid the potential
# to incorrectly use them as label
del data.graph_y
# Subset the data to only the target chains
data.coords = target_struct
data.residue_type = data.residue_type[mask]
data.residues = np.array(data.residues)[mask]
data.residue_id = np.array(data.residue_id)[mask]
data.chains = data.chains[mask]
if data.x is not None:
data.x = data.x[mask]
if hasattr(data, "seq_pos"):
data.seq_pos = data.seq_pos[mask]
if hasattr(data, "amino_acid_one_hot"):
data.amino_acid_one_hot = data.amino_acid_one_hot[mask]
return data
if __name__ == "__main__":
from graphein.protein.tensor.data import get_random_protein
a = get_random_protein()
a.graph_label = "A"
t = BindingSiteTransform(radius=4, ca_only=False)
out = t(a)
print(out)
print(out.node_y.sum())