from typing import Optional, Set, Union
import torch
import torch_scatter
from graphein.protein.tensor.data import ProteinBatch
from torch_geometric.data import Batch
from torch_geometric.nn.models import SchNet
from proteinworkshop.types import EncoderOutput
[docs]
class SchNetModel(SchNet):
def __init__(
self,
hidden_channels: int = 128,
out_dim: int = 1,
num_filters: int = 128,
num_layers: int = 6,
num_gaussians: int = 50,
cutoff: float = 10,
max_num_neighbors: int = 32,
readout: str = "add",
dipole: bool = False,
mean: Optional[float] = None,
std: Optional[float] = None,
atomref: Optional[torch.Tensor] = None,
):
"""
Initializes an instance of the SchNetModel class with the provided
parameters.
:param hidden_channels: Number of channels in the hidden layers
(default: ``128``)
:type hidden_channels: int
:param out_dim: Output dimension of the model (default: ``1``)
:type out_dim: int
:param num_filters: Number of filters used in convolutional layers
(default: ``128``)
:type num_filters: int
:param num_layers: Number of convolutional layers in the model
(default: ``6``)
:type num_layers: int
:param num_gaussians: Number of Gaussian functions used for radial
filters (default: ``50``)
:type num_gaussians: int
:param cutoff: Cutoff distance for interactions (default: ``10``)
:type cutoff: float
:param max_num_neighbors: Maximum number of neighboring atoms to
consider (default: ``32``)
:type max_num_neighbors: int
:param readout: Global pooling method to be used (default: ``"add"``)
:type readout: str
"""
super().__init__(
hidden_channels,
num_filters,
num_layers,
num_gaussians,
cutoff, # None, # Interaction graph is not used
max_num_neighbors,
readout,
dipole,
mean,
std,
atomref,
)
self.readout = readout
# Overwrite embbeding
self.embedding = torch.nn.LazyLinear(hidden_channels)
# Overwrite atom embedding and final predictor
self.lin2 = torch.nn.LazyLinear(out_dim)
@property
def required_batch_attributes(self) -> Set[str]:
"""
Required batch attributes for this encoder.
- ``x``: Node features (shape: :math:`(n, d)`)
- ``pos``: Node positions (shape: :math:`(n, 3)`)
- ``edge_index``: Edge indices (shape: :math:`(2, e)`)
- ``batch``: Batch indices (shape: :math:`(n,)`)
:return: Set of required batch attributes
:rtype: Set[str]
"""
return {"pos", "edge_index", "x", "batch"}
[docs]
def forward(self, batch: Union[Batch, ProteinBatch]) -> EncoderOutput:
"""Implements the forward pass of the SchNet encoder.
Returns the node embedding and graph embedding in a dictionary.
:param batch: Batch of data to encode.
:type batch: Union[Batch, ProteinBatch]
:return: Dictionary of node and graph embeddings. Contains
``node_embedding`` and ``graph_embedding`` fields. The node
embedding is of shape :math:`(|V|, d)` and the graph embedding is
of shape :math:`(n, d)`, where :math:`|V|` is the number of nodes
and :math:`n` is the number of graphs in the batch and :math:`d` is
the dimension of the embeddings.
:rtype: EncoderOutput
"""
h = self.embedding(batch.x)
u, v = batch.edge_index
edge_weight = (batch.pos[u] - batch.pos[v]).norm(dim=-1)
edge_attr = self.distance_expansion(edge_weight)
for interaction in self.interactions:
h = h + interaction(h, batch.edge_index, edge_weight, edge_attr)
h = self.lin1(h)
h = self.act(h)
h = self.lin2(h)
return EncoderOutput(
{
"node_embedding": h,
"graph_embedding": torch_scatter.scatter(
h, batch.batch, dim=0, reduce=self.readout
),
}
)
if __name__ == "__main__":
import hydra
import omegaconf
import pyrootutils
from graphein.protein.tensor.data import get_random_protein
root = pyrootutils.setup_root(__file__, pythonpath=True)
cfg = omegaconf.OmegaConf.load(
root / "configs" / "encoder" / "schnet.yaml"
)
print(cfg)
encoder = hydra.utils.instantiate(cfg.schnet)
print(encoder)
batch = ProteinBatch().from_protein_list(
[get_random_protein() for _ in range(4)], follow_batch=["coords"]
)
batch.batch = batch.coords_batch
batch.edges("knn_8", cache="edge_index")
batch.pos = batch.coords[:, 1, :]
batch.x = batch.residue_type
print(batch)
out = encoder.forward(batch)
print(out)