Source code for proteinworkshop.models.graph_encoders.gear_net

from typing import Optional, Set, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from graphein.protein.tensor.data import ProteinBatch
from loguru import logger as log
from torch_geometric.data import Batch
from torch_scatter import scatter_add

from proteinworkshop.models.graph_encoders.layers import gear_net
from proteinworkshop.models.utils import get_aggregation
from proteinworkshop.types import EncoderOutput


[docs] class GearNet(nn.Module): def __init__( self, input_dim: int, num_relation: int, num_layers: int, emb_dim: int, short_cut: bool, concat_hidden: bool, batch_norm: bool, num_angle_bin: Optional[int], activation: str = "relu", pool: str = "sum", ) -> None: """Initializes an instance of the GearNet model. :param input_dim: Dimension of the input node features :type input_dim: int :param num_relation: Number of edge types :type num_relation: int :param num_layers: Number of layers in the model :type num_layers: int :param emb_dim: Dimension of the node embeddings :type emb_dim: int :param short_cut: Whether to use short cut connections :type short_cut: bool :param concat_hidden: Whether to concatenate hidden representations :type concat_hidden: bool :param batch_norm: Whether to use batch norm :type batch_norm: bool :param num_angle_bin: Number of angle bins for edge message passing. If ``None``, edge message passing is not disabled. :type num_angle_bin: Optional[int] :param activation: Activation function to use, defaults to "relu" :type activation: str, optional :param pool: Pooling operation to use, defaults to "sum" :type pool: str, optional """ super().__init__() # Base parameters self.num_relation = num_relation self.input_dim = input_dim # Edge message passing layers # If not None, this enables Edge Message passing self.num_angle_bin = num_angle_bin self.edge_input_dim = self._get_num_edge_features() self.short_cut = short_cut self.concat_hidden = concat_hidden n_hid = [emb_dim] * num_layers self.dims = [self.input_dim] + n_hid self.activations = [getattr(F, activation) for _ in n_hid] self.batch_norm = batch_norm # Initialise Node layers self.layers = nn.ModuleList() for i in range(len(self.dims) - 1): self.layers.append( gear_net.GeometricRelationalGraphConv( input_dim=self.dims[i], output_dim=self.dims[i + 1], num_relation=self.num_relation, edge_input_dim=self.edge_input_dim, # None, batch_norm=batch_norm, activation=self.activations[i], ) ) if self.num_angle_bin: log.info("Using Edge Message Passing") self.edge_input_dim = self._get_num_edge_features() self.edge_dims = [self.edge_input_dim] + self.dims[:-1] self.spatial_line_graph = gear_net.SpatialLineGraph( self.num_angle_bin ) self.edge_layers = nn.ModuleList() for i in range(len(self.edge_dims) - 1): self.edge_layers.append( gear_net.GeometricRelationalGraphConv( self.edge_dims[i], self.edge_dims[i + 1], self.num_angle_bin, None, batch_norm=self.batch_norm, activation=self.activations[i], ) ) # Batch Norm if self.batch_norm: self.batch_norms = nn.ModuleList() for i in range(len(self.dims) - 1): self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1])) # Readout self.readout = get_aggregation(pool) @property def required_batch_attributes(self) -> Set[str]: """Required batch attributes for this encoder. - ``x`` Positions (shape ``[num_nodes, 3]``) - ``edge_index`` Edge indices (shape ``[2, num_edges]``) - ``edge_type`` Edge types (shape ``[num_edges]``) - ``edge_attr`` Edge attributes (shape ``[num_edges, num_edge_features]``) - ``num_nodes`` Number of nodes (int) - ``batch`` Batch indices (shape ``[num_nodes]``) :return: Set of required batch attributes. :rtype: Set[str] """ return { "x", "edge_index", "edge_type", "edge_attr", "num_nodes", "batch", }
[docs] def forward(self, batch: Union[Batch, ProteinBatch]) -> EncoderOutput: """Implements the forward pass of the GearNet encoder. Returns the node embedding and graph embedding in a dictionary. :param batch: Batch of data to encode. :type batch: Union[Batch, ProteinBatch] :return: Dictionary of node and graph embeddings. Contains ``node_embedding`` and ``graph_embedding`` fields. The node embedding is of shape :math:`(|V|, d)` and the graph embedding is of shape :math:`(n, d)`, where :math:`|V|` is the number of nodes and :math:`n` is the number of graphs in the batch and :math:`d` is the dimension of the embeddings. :rtype: EncoderOutput """ hiddens = [] batch.edge_weight = torch.ones( batch.edge_index.shape[1], dtype=torch.float, device=batch.x.device ) layer_input = batch.x batch.edge_index = torch.cat([batch.edge_index, batch.edge_type]) batch.edge_feature = self.gear_net_edge_features(batch) if self.num_angle_bin: line_graph = self.spatial_line_graph(batch) line_graph.edge_weight = torch.ones( line_graph.edge_index.shape[1], dtype=torch.float, device=batch.x.device, ) edge_input = line_graph.x.float() for i in range(len(self.layers)): hidden = self.layers[i](batch, layer_input) if self.short_cut and hidden.shape == layer_input.shape: hidden = hidden + layer_input if self.num_angle_bin: edge_hidden = self.edge_layers[i](line_graph, edge_input) edge_weight = batch.edge_weight.unsqueeze(-1) # node_out = graph.edge_index[:, 1] * self.num_relation + graph.edge_index[:, 2] node_out = ( batch.edge_index[1, :] * self.num_relation + batch.edge_index[2, :] ) update = scatter_add( edge_hidden * edge_weight, node_out, dim=0, dim_size=batch.num_nodes * self.num_relation, ) update = update.view( batch.num_nodes, self.num_relation * edge_hidden.shape[1] ) update = self.layers[i].linear(update) update = self.layers[i].activation(update) hidden = hidden + update edge_input = edge_hidden if self.batch_norm: hidden = self.batch_norms[i](hidden) hiddens.append(hidden) layer_input = hidden if self.concat_hidden: node_feature = torch.cat(hiddens, dim=-1) else: node_feature = hiddens[-1] return EncoderOutput( { "node_embedding": node_feature, "graph_embedding": self.readout(node_feature, batch.batch), } )
def _get_num_edge_features(self) -> int: """Compute the number of edge features.""" seq_dist = 1 dist = 1 return self.input_dim * 2 + self.num_relation + seq_dist + dist
[docs] def gear_net_edge_features( self, b: Union[Batch, ProteinBatch] ) -> torch.Tensor: """Compute edge features for the gear net encoder. - Concatenate node features of the two nodes in each edge - Concatenate the edge type - Compute the distance between the two nodes in each edge - Compute the sequence distance between the two nodes in each edge :param b: Batch of data to encode. :type b: Union[Batch, ProteinBatch] :return: Edge features :rtype: torch.Tensor """ u = b.x[b.edge_index[0]] v = b.x[b.edge_index[1]] edge_type = F.one_hot(b.edge_type, self.num_relation)[0] dists = torch.pairwise_distance( b.pos[b.edge_index[0]], b.pos[b.edge_index[1]] ).unsqueeze(1) seq_dist = torch.abs(b.edge_index[0] - b.edge_index[1]).unsqueeze(1) return torch.cat([u, v, edge_type, seq_dist, dists], dim=1)
if __name__ == "__main__": import hydra import omegaconf from proteinworkshop import constants cfg = omegaconf.OmegaConf.load( constants.SRC_PATH / "config" / "encoder" / "gear_net.yaml" ) enc = hydra.utils.instantiate(cfg) print(enc)