Source code for proteinworkshop.features.node_features

"""Node feature computation functions."""
from typing import List, Union

import torch
import torch.nn.functional as F
from beartype import beartype as typechecker
from graphein.protein.tensor.angles import (
    alpha,
    dihedrals,
    kappa,
    sidechain_torsion,
)
from graphein.protein.tensor.data import Protein, ProteinBatch
from graphein.protein.tensor.types import AtomTensor, CoordTensor
from jaxtyping import jaxtyped
from omegaconf import ListConfig
from torch_geometric.data import Batch, Data
from torch_geometric.nn.pool import knn_graph
from torch_geometric.utils import softmax

from proteinworkshop.models.utils import flatten_list
from proteinworkshop.types import OrientationTensor, ScalarNodeFeature

from .sequence_features import amino_acid_one_hot
from .utils import _normalize


[docs] @jaxtyped(typechecker=typechecker) def compute_scalar_node_features( x: Union[Batch, Data, Protein, ProteinBatch], node_features: Union[ListConfig, List[ScalarNodeFeature]], ) -> torch.Tensor: """ Factory function for node features. .. seealso:: :py:class:`proteinworkshop.types.ScalarNodeFeature` for a list of node features that can be computed. This function operates on a :py:class:`torch_geometric.data.Data` or :py:class:`torch_geometric.data.Batch` object and computes the requested node features. :param x: :py:class:`~torch_geometric.data.Data` or :py:class:`~torch_geometric.data.Batch` protein object. :type x: Union[Data, Batch] :param node_features: List of node features to compute. :type node_features: Union[List[str], ListConfig] :return: Tensor of node features of shape (``N x F``), where ``N`` is the number of nodes and ``F`` is the number of features. :rtype: torch.Tensor """ feats = [] for feature in node_features: if feature == "amino_acid_one_hot": feats.append(amino_acid_one_hot(x, num_classes=23)) elif feature == "alpha": feats.append(alpha(x.coords, x.batch, rad=True, embed=True)) elif feature == "kappa": feats.append(kappa(x.coords, x.batch, rad=True, embed=True)) elif feature == "dihedrals": feats.append(dihedrals(x.coords, x.batch, rad=True, embed=True)) elif feature == "sidechain_torsions": feats.append( sidechain_torsion( x.coords, res_types=flatten_list(x.residues), rad=True, embed=True, ) ) elif feature.startswith("surface"): # k = int(feature.split("_")[1]) # sigmas = [float(s) for s in feature.split("_")[2:]] # feats.append(compute_surface_feat(x.coords, k=k, sigma=sigmas)) raise NotImplementedError( "Surface features not implemented yet. Does not handle batched data." ) elif feature == "sequence_positional_encoding": continue else: raise ValueError(f"Node feature {feature} not recognised.") feats = [feat.unsqueeze(1) if feat.ndim == 1 else feat for feat in feats] # Return concatenated features or original features if no features were computed return torch.cat(feats, dim=1) if feats else x.x
[docs] @jaxtyped(typechecker=typechecker) def compute_vector_node_features( x: Union[Batch, Data, Protein, ProteinBatch], vector_features: Union[ListConfig, List[str]], ) -> Union[Batch, Data, Protein, ProteinBatch]: """Factory function for vector features. Currently implemented vector features are: - ``orientation``: Orientation of each node in the protein backbone - ``virtual_cb_vector``: Virtual CB vector for each node in the protein backbone """ vector_node_features = [] for feature in vector_features: if feature == "orientation": vector_node_features.append(orientations(x.coords, x._slice_dict["coords"])) elif feature == "virtual_cb_vector": raise NotImplementedError("Virtual CB vector not implemented yet.") else: raise ValueError(f"Vector feature {feature} not recognised.") x.x_vector_attr = torch.cat(vector_node_features, dim=0) return x
[docs] @jaxtyped(typechecker=typechecker) def compute_surface_feat( coords: Union[CoordTensor, AtomTensor], k: int, sigma: List[float] ): """Coords: (N, 3) k: number of neighbors to consider in KNN graph """ if coords.ndim == 3: coords = coords[:, 1, :] feat = [] for s in sigma: SIGMA = torch.tensor([s], device=coords.device).unsqueeze(1) N_DIM = coords.shape[-1] N_SIGMA = SIGMA.shape[0] edge_index = knn_graph(coords, k=k) node_in = coords[edge_index[0]] node_out = coords[edge_index[1]] dists = torch.pairwise_distance(node_in, node_out).unsqueeze(0) dists = -(dists**2) / SIGMA weights = softmax(dists, index=edge_index[1], dim=-1) weights = weights.reshape(-1, N_SIGMA, k) diff_vecs = node_in - node_out diff_vecs = diff_vecs.reshape(-1, k, N_DIM) mean_vec = torch.bmm(weights, diff_vecs) denom = torch.bmm( weights, torch.norm(diff_vecs, dim=-1).unsqueeze(-1) ).squeeze(-1) feat.append(torch.norm(mean_vec, dim=-1) / denom) return torch.cat(feat, dim=1)
@jaxtyped(typechecker=typechecker) def orientations( X: Union[CoordTensor, AtomTensor], coords_slice_index: torch.Tensor, ca_idx: int = 1 ) -> OrientationTensor: if X.ndim == 3: X = X[:, ca_idx, :] # NOTE: the first item in the coordinates slice index is always 0, # and the last item is always the node count of the batch batch_num_nodes = X.shape[0] slice_index = coords_slice_index[1:] - 1 last_node_index = slice_index[:-1] first_node_index = slice_index[:-1] + 1 slice_mask = torch.zeros(batch_num_nodes - 1, dtype=torch.bool) last_node_forward_slice_mask = slice_mask.clone() first_node_backward_slice_mask = slice_mask.clone() # NOTE: all of the last (first) nodes in a subgraph have their # forward (backward) vectors set to a padding value (i.e., 0.0) # to mimic feature construction behavior with single input graphs forward_slice = X[1:] - X[:-1] backward_slice = X[:-1] - X[1:] last_node_forward_slice_mask[last_node_index] = True first_node_backward_slice_mask[first_node_index - 1] = True # NOTE: for the backward slices, our indexing defaults to node index `1` forward_slice[last_node_forward_slice_mask] = 0.0 # NOTE: this handles all but the last node in the last subgraph backward_slice[first_node_backward_slice_mask] = 0.0 # NOTE: this handles all but the first node in the first subgraph # NOTE: padding first and last nodes with zero vectors does not impact feature normalization forward = _normalize(forward_slice) backward = _normalize(backward_slice) forward = F.pad(forward, [0, 0, 0, 1]) backward = F.pad(backward, [0, 0, 1, 0]) orientations = torch.cat((forward.unsqueeze(-2), backward.unsqueeze(-2)), dim=-2) # optionally debug/verify the orientations # last_node_indices = torch.cat((last_node_index, torch.tensor([batch_num_nodes - 1])), dim=0) # first_node_indices = torch.cat((torch.tensor([0]), first_node_index), dim=0) # intermediate_node_indices_mask = torch.ones(batch_num_nodes, device=X.device, dtype=torch.bool) # intermediate_node_indices_mask[last_node_indices] = False # intermediate_node_indices_mask[first_node_indices] = False # assert not orientations[last_node_indices][:, 0].any() and orientations[last_node_indices][:, 1].any() # assert orientations[first_node_indices][:, 0].any() and not orientations[first_node_indices][:, 1].any() # assert orientations[intermediate_node_indices_mask][:, 0].any() and orientations[intermediate_node_indices_mask][:, 1].any() return orientations