Source code for proteinworkshop.models.graph_encoders.schnet

from typing import Optional, Set, Union

import torch
import torch_scatter
from graphein.protein.tensor.data import ProteinBatch
from torch_geometric.data import Batch
from torch_geometric.nn.models import SchNet

from proteinworkshop.types import EncoderOutput


[docs] class SchNetModel(SchNet): def __init__( self, hidden_channels: int = 128, out_dim: int = 1, num_filters: int = 128, num_layers: int = 6, num_gaussians: int = 50, cutoff: float = 10, max_num_neighbors: int = 32, readout: str = "add", dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: Optional[torch.Tensor] = None, ): """ Initializes an instance of the SchNetModel class with the provided parameters. :param hidden_channels: Number of channels in the hidden layers (default: ``128``) :type hidden_channels: int :param out_dim: Output dimension of the model (default: ``1``) :type out_dim: int :param num_filters: Number of filters used in convolutional layers (default: ``128``) :type num_filters: int :param num_layers: Number of convolutional layers in the model (default: ``6``) :type num_layers: int :param num_gaussians: Number of Gaussian functions used for radial filters (default: ``50``) :type num_gaussians: int :param cutoff: Cutoff distance for interactions (default: ``10``) :type cutoff: float :param max_num_neighbors: Maximum number of neighboring atoms to consider (default: ``32``) :type max_num_neighbors: int :param readout: Global pooling method to be used (default: ``"add"``) :type readout: str """ super().__init__( hidden_channels, num_filters, num_layers, num_gaussians, cutoff, # None, # Interaction graph is not used max_num_neighbors, readout, dipole, mean, std, atomref, ) self.readout = readout # Overwrite embbeding self.embedding = torch.nn.LazyLinear(hidden_channels) # Overwrite atom embedding and final predictor self.lin2 = torch.nn.LazyLinear(out_dim) @property def required_batch_attributes(self) -> Set[str]: """ Required batch attributes for this encoder. - ``x``: Node features (shape: :math:`(n, d)`) - ``pos``: Node positions (shape: :math:`(n, 3)`) - ``edge_index``: Edge indices (shape: :math:`(2, e)`) - ``batch``: Batch indices (shape: :math:`(n,)`) :return: Set of required batch attributes :rtype: Set[str] """ return {"pos", "edge_index", "x", "batch"}
[docs] def forward(self, batch: Union[Batch, ProteinBatch]) -> EncoderOutput: """Implements the forward pass of the SchNet encoder. Returns the node embedding and graph embedding in a dictionary. :param batch: Batch of data to encode. :type batch: Union[Batch, ProteinBatch] :return: Dictionary of node and graph embeddings. Contains ``node_embedding`` and ``graph_embedding`` fields. The node embedding is of shape :math:`(|V|, d)` and the graph embedding is of shape :math:`(n, d)`, where :math:`|V|` is the number of nodes and :math:`n` is the number of graphs in the batch and :math:`d` is the dimension of the embeddings. :rtype: EncoderOutput """ h = self.embedding(batch.x) u, v = batch.edge_index edge_weight = (batch.pos[u] - batch.pos[v]).norm(dim=-1) edge_attr = self.distance_expansion(edge_weight) for interaction in self.interactions: h = h + interaction(h, batch.edge_index, edge_weight, edge_attr) h = self.lin1(h) h = self.act(h) h = self.lin2(h) return EncoderOutput( { "node_embedding": h, "graph_embedding": torch_scatter.scatter( h, batch.batch, dim=0, reduce=self.readout ), } )
if __name__ == "__main__": import hydra import omegaconf import pyrootutils from graphein.protein.tensor.data import get_random_protein root = pyrootutils.setup_root(__file__, pythonpath=True) cfg = omegaconf.OmegaConf.load( root / "configs" / "encoder" / "schnet.yaml" ) print(cfg) encoder = hydra.utils.instantiate(cfg.schnet) print(encoder) batch = ProteinBatch().from_protein_list( [get_random_protein() for _ in range(4)], follow_batch=["coords"] ) batch.batch = batch.coords_batch batch.edges("knn_8", cache="edge_index") batch.pos = batch.coords[:, 1, :] batch.x = batch.residue_type print(batch) out = encoder.forward(batch) print(out)