Source code for proteinworkshop.tasks.edge_distance_prediction

from typing import Set, Union

import torch
import torch_geometric.transforms as T
from graphein.protein.tensor.data import ProteinBatch
from torch_geometric.data import Batch


[docs] class EdgeDistancePredictionTransform(T.BaseTransform): """ Self-supervision task to predict the pairwise distance between two nodes. We first sample ``num_samples`` edges randomly from the input batch. We then construct a mask to remove the sampled edges from the batch. We store the masked node indices and their pairwise distance as ``batch.node_mask`` and ``batch.edge_distance_labels``, respectively. Finally, it masks the edges (and their attributes) using the constructed mask and returns the modified batch. """ def __init__(self, num_samples: int): """Initialise the transform. :param num_samples: Number of edges to mask :type num_samples: int """ self.num_samples = num_samples @property def required_batch_attributes(self) -> Set[str]: """ Returns the set of attributes that this transform requires to be present on the batch object for correct operation. :return: Set of required attributes :rtype: Set[str] """ return {"num_edges", "edge_index", "pos"} def __call__( self, batch: Union[ProteinBatch, Batch] ) -> Union[Batch, ProteinBatch]: # Sample edges indices = torch.randint( 0, batch.num_edges, (self.num_samples,), device=batch.edge_index.device, ).long() # Construct mask mask = torch.ones_like( batch.edge_index[0], device=batch.edge_index.device ).bool() mask[indices] = 0 # Store masked node indices & labels nodes = batch.edge_index[:, indices] batch.node_mask = nodes batch.edge_distance_labels = torch.pairwise_distance( batch.pos[nodes[0]], batch.pos[nodes[1]] ) # Mask edges and attributes batch.edge_index = batch.edge_index[:, mask] if hasattr(batch, "edge_type"): batch.edge_type = batch.edge_type[:, mask] if hasattr(batch, "edge_attr"): batch.edge_attr = batch.edge_attr[mask] # TODO - non scalar edge attributes return batch