import abc
from typing import Callable, Dict, List, Literal, Optional, Set, Union
import hydra
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typechecker
from graphein.protein.tensor.angles import dihedrals
from graphein.protein.tensor.data import ProteinBatch, get_random_protein
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch
from proteinworkshop.models.utils import get_loss
from proteinworkshop.types import EncoderOutput, Label, ModelOutput
[docs]
class BaseModel(L.LightningModule, abc.ABC):
config: DictConfig
featuriser: nn.Module
losses: Dict[str, Callable]
task_transform: Optional[Callable]
metric_names: List[str]
[docs]
@abc.abstractmethod
def forward(self, batch: Batch) -> torch.Tensor:
"""Implement forward pass of model.
:param batch: Mini-batch of data.
:type batch: Batch
:return: Model output.
:rtype: torch.Tensor
"""
...
[docs]
@abc.abstractmethod
def training_step(
self, batch: Batch, batch_idx: torch.Tensor
) -> torch.Tensor:
"""Implement training step.
:param batch: Mini-batch of data.
:type batch: Batch
:param batch_idx: Index of batch.
:type batch_idx: torch.Tensor
:return: Return loss.
:rtype: torch.Tensor
"""
...
[docs]
@abc.abstractmethod
def validation_step(
self, batch: Batch, batch_idx: torch.Tensor
) -> torch.Tensor:
"""Implement validation step.
:param batch: Mini-batch of data.
:type batch: Batch
:param batch_idx: Index of batch.
:type batch_idx: torch.Tensor
:return: Return loss.
:rtype: torch.Tensor
"""
...
[docs]
@abc.abstractmethod
def test_step(self, batch: Batch, batch_idx: torch.Tensor) -> torch.Tensor:
"""Implement test step.
:param batch: Mini-batch of data.
:type batch: Batch
:param batch_idx: Index of batch.
:type batch_idx: torch.Tensor
:return: Return loss.
:rtype: torch.Tensor
"""
...
[docs]
def featurise(
self, batch: Union[Batch, ProteinBatch]
) -> Union[Batch, ProteinBatch]:
"""Applies the featuriser (``self.featuriser``) to a batch of data.
.. seealso::
:py:class:proteinworkshop.features.factory.ProteinFeaturiser
:param batch: Batch of data
:type batch: Union[Batch, ProteinBatch]
:return: Featurised batch
:rtype: Union[Batch, ProteinBatch]
"""
out = self.featuriser(batch)
if self.task_transform is not None:
out = self.task_transform(out)
return out
[docs]
def get_labels(self, batch: Union[Batch, ProteinBatch]) -> Label:
"""
Computes or retrieves labels from a batch of data.
Labels are returned as a dictionary of tensors indexed by output name.
:param batch: Batch of data to compute labels for
:type batch: Union[Batch, ProteinBatch]
:return: Dictionary of labels indexed by output name
:rtype: Label
"""
labels: Dict[str, torch.Tensor] = {}
for output in self.config.task.supervise_on:
if output == "node_label":
labels["node_label"] = batch.node_y
if isinstance(
self.losses["node_label"], torch.nn.BCEWithLogitsLoss
):
labels["node_label"] = F.one_hot(
labels["node_label"],
num_classes=self.config.dataset.num_classes,
).float()
elif output == "graph_label":
labels["graph_label"] = batch.graph_y
if (
isinstance(
self.losses["graph_label"], torch.nn.BCEWithLogitsLoss
)
and batch.graph_y.ndim == 1
):
labels["graph_label"] = F.one_hot(
labels["graph_label"],
num_classes=self.config.dataset.num_classes,
).float()
elif output == "dihedrals":
# If we have dihedral labels in the batch, use those
# These will have been stored by the torsional denoising
# transform
if hasattr(batch, "true_dihedrals"):
labels["dihedrals"] = batch.true_dihedrals
# If we have stored uncorrupted coords, use those to compute
elif hasattr(batch, "coords_uncorrupted"):
labels["dihedrals"] = dihedrals(
batch.coords_uncorrupted,
batch.batch,
rad=True,
embed=True,
)
# Otherwise, compute dihedrals from the batch coordinates
else:
labels["dihedrals"] = dihedrals(
batch.coords, batch.batch, rad=True, embed=True
)
elif output == "torsional_noise":
labels["torsional_noise"] = batch.torsional_noise
elif output == "residue_type":
# If we have stored uncorrupted labels, use those
if hasattr(batch, "residue_type_uncorrupted"):
labels["residue_type"] = batch.residue_type_uncorrupted
# Otherwise, use residue types
else:
labels["residue_type"] = batch.residue_type
# If we have stored a mask, apply it
if hasattr(batch, "sequence_corruption_mask"):
labels["residue_type"] = labels["residue_type"][
batch.sequence_corruption_mask
]
elif output == "pos":
labels["pos"] = batch.noise[
:, 1, :
] # TODO this is hardcoded to only handle CA
elif output == "edge_distance":
labels["edge_distance"] = batch.edge_distance_labels
elif output in {"b_factor", "plddt"}:
labels["b_factor"] = batch.b_factor
return Label(labels)
[docs]
@typechecker
def compute_loss(
self, y_hat: ModelOutput, y: Label
) -> Dict[str, torch.Tensor]:
"""
Compute loss by iterating over all outputs.
In the case of multiple losses, the total loss is also included in the
output dictionary of losses.
:param y_hat: Output of model. This should be a dictionary of outputs
(torch.Tensor) indexed by the output name (str)
:type y_hat: ModelOutput
:param y: Labels. This should be a dictionary of labels (torch.Tensor)
indexed by the output name (str)
:type y: Label
:return: Dictionary of losses indexed by output name (str)
:rtype: Dict[str, torch.Tensor]
"""
loss = {k: v(y_hat[k], y[k]) for k, v in self.losses.items()}
# Scale loss terms by coefficient
if self.config.get("task.aux_loss_coefficient"):
for (
output,
coefficient,
) in self.config.task.aux_loss_coefficient.items():
loss[output] = coefficient * loss[output]
loss["total"] = sum(loss.values())
return loss
def _build_output_decoders(self) -> nn.ModuleDict:
"""
Instantiate output decoders.
Decoders are instantiated from their respective config files.
Decoders are stored in :py:class:`nn.ModuleDict`, indexed by output
name.
:return: ModuleDict of decoders indexed by output name
:rtype: nn.ModuleDict
"""
decoders = nn.ModuleDict()
for output_head in self.config.decoder.keys():
cfg = self.config.decoder.get(output_head)
logger.info(
f"Building {output_head} decoder. Output dim {cfg.get('out_dim')}"
)
logger.info(cfg)
decoders[output_head] = hydra.utils.instantiate(cfg)
return decoders
[docs]
def on_after_batch_transfer(
self, batch: Union[Batch, ProteinBatch], dataloader_idx: int
) -> Union[Batch, ProteinBatch]:
"""
Featurise batch **after** it has been transferred to the correct device.
:param batch: Batch of data
:type batch: Batch
:param dataloader_idx: Index of dataloader
:type dataloader_idx: int
:return: Featurised batch
:rtype: Union[Batch, ProteinBatch]
"""
return self.featurise(batch)
[docs]
@typechecker
def log_metrics(
self, loss, y_hat: ModelOutput, y: Label, stage: str, batch: Batch
):
"""
Logs metrics to logger.
:param loss: Dictionary of losses indexed by output name (str)
:type loss: Dict[str, torch.Tensor]
:param y_hat: Output of model. This should be a dictionary of outputs
indexed by the output name (str)
:type y_hat: ModelOutput
:param y: Labels. This should be a dictionary of labels (torch.Tensor)
indexed by the output name (str)
:type y: Label
:param stage: Stage of training (``"train"``, ``"val"``, ``"test"``)
:type stage: str
:param batch: Batch of data
:type batch: Batch
"""
# Log losses
log_dict = {f"{stage}/loss/{k}": v for k, v in loss.items()}
# Log metrics
for m in self.metric_names:
for output in self.config.task.output:
if hasattr(self, f"{stage}_{output}_{m}"):
try:
metric = getattr(self, f"{stage}_{output}_{m}")
pred = y_hat[output]
target = y[output]
if m == "perplexity":
pred = to_dense_batch(pred, batch.batch)[0]
target = to_dense_batch(
target, batch.batch, fill_value=-100
)[0]
# This is a hack for MSE-type metrics which fail on e.g. [4,1] & [4]
try:
val = metric(pred, target)
except RuntimeError:
val = metric(pred, target.unsqueeze(-1))
log_dict[f"{stage}/{output}/{m}"] = val
except (ValueError, RuntimeError):
continue
self.log_dict(log_dict, prog_bar=True)
[docs]
class BenchMarkModel(BaseModel):
def __init__(self, cfg: DictConfig) -> None:
super().__init__()
self.config = cfg
# self.encoder = get_protein_encoder(cfg)
logger.info("Instantiating encoder...")
self.encoder: nn.Module = hydra.utils.instantiate(cfg.encoder)
logger.info(self.encoder)
logger.info("Instantiating decoders...")
self.decoder: nn.ModuleDict = self._build_output_decoders()
logger.info(self.decoder)
logger.info("Instantiating losses...")
self.losses = self.configure_losses(cfg.task.losses)
logger.info(f"Using losses: {self.losses}")
if self.config.get("task.aux_loss_coefficient"):
logger.info(
f"Using aux loss coefficient: {self.config.task.aux_loss_coefficient}"
)
else:
logger.info("Not using aux loss scaling")
logger.info("Configuring metrics...")
self.metrics = self.configure_metrics()
logger.info(self.metric_names)
logger.info("Instantiating featuriser...")
self.featuriser: nn.Module = hydra.utils.instantiate(cfg.features)
logger.info(self.featuriser)
logger.info("Instantiating task transform...")
self.task_transform = hydra.utils.instantiate(
cfg.get("task.transform")
)
logger.info(self.task_transform)
self.save_hyperparameters()
self.example_input_array = self._create_example_batch()
def _create_example_batch(self) -> ProteinBatch:
"""Creates an example batch for model inspection (including
featurisation and transformation as specified by the config).
:return: Example batch of data, featurised and transformed as specified
by the config.
:rtype: ProteinBatch
"""
with torch.no_grad():
proteins = [
get_random_protein()
for _ in range(self.config.dataset.datamodule.batch_size)
]
for p in proteins:
setattr(p, "x", torch.zeros(p.coords.shape[0]))
setattr(
p, "seq_pos", torch.arange(p.coords.shape[0]).unsqueeze(-1)
)
batch = ProteinBatch.from_data_list(proteins)
return self.featurise(batch)
[docs]
@typechecker
def forward(self, batch: Union[Batch, ProteinBatch]) -> ModelOutput:
"""
Implements the forward pass of the model.
1. Apply the model encoder (``self.encoder``) to the batch of data.
2. (Optionally) apply any transformations to the encoder output
(:py:meth:`BaseModel.transform_encoder_output`)
3. Iterate over the decoder heads (``self.decoder``) and apply each
decoder to the relevant part of the encoder output.
4. (Optionally) apply any post-processing to the model output.
(:py:meth:`BaseModel.compute_output`)
:param batch: Mini-batch of data.
:type batch: Union[Batch, ProteinBatch]
:return: Model output.
:rtype: ModelOutput
"""
output: EncoderOutput = self.encoder(batch)
output = self.transform_encoder_output(output, batch)
if self.decoder is not None:
for output_head in self.config.decoder.keys():
if hasattr(self.decoder[output_head], "requires_pos"):
output[output_head] = self.decoder[output_head](
edge_index=batch.edge_index,
scalar_features=output["node_embedding"],
pos=batch.pos,
)
else:
emb_type = self.decoder[
output_head
].input # node_embedding or graph_embedding
output[output_head] = self.decoder[output_head](
output[emb_type]
)
return self.compute_output(output, batch)
[docs]
@typechecker
def compute_output(self, output: ModelOutput, batch: Batch) -> ModelOutput:
"""
Computes output from model output.
- For dihedral angle prediction, this involves normalising the
'sin'/'cos' pairs for each angle such that the have norm 1.
- For sequence denoising, this masks the output such that we only
supervise on the corrupted residues.
:param output: Model output (dictionary mapping output name to the
output tensor)
:type: ModelOutput
:param batch: Batch of data
:type batch: Batch
:return: Model output (dictionary mapping output name to the
transformed output)
:rtype: ModelOutput
"""
if "dihedrals" in output.keys():
# Normalize output so each pair of sin(ang) and cos(ang) sum to 1.
output["dihedrals"] = F.normalize(
output["dihedrals"].view(-1, 3, 2), dim=-1
).view(-1, 6)
# If we have a mask, apply it
if hasattr(batch, "sequence_corruption_mask"):
output["residue_type"] = output["residue_type"][
batch.sequence_corruption_mask
]
return output
@typechecker
def _do_step(
self,
batch: Batch,
batch_idx: int,
stage: Literal["train", "val", "test"],
) -> torch.Tensor:
"""Performs a training/validation/test step.
1. Obtains labels from :py:meth:`get_labels`
2. Computes model output :py:meth:`forward`
3. Computes loss :py:meth:`compute_loss`
4. Logs metrics :py:meth:`log_metrics`
Returns the total loss.
:param batch: Mini-batch of data.
:type batch: Batch
:param batch_idx: Index of batch.
:type batch_idx: int
:param stage: Stage of training (``"train"``, ``"val"``, ``"test"``)
:type stage: Literal["train", "val", "test"]
:return: Loss
:rtype: torch.Tensor
"""
y = self.get_labels(batch)
y_hat = self(batch)
loss = self.compute_loss(y_hat, y)
self.log_metrics(loss, y_hat, y, stage, batch=batch)
return loss["total"]
[docs]
def training_step(
self, batch: Union[Batch, ProteinBatch], batch_idx: int
) -> torch.Tensor:
"""
Perform training step.
1. Obtains labels from :py:meth:`get_labels`
2. Computes model output :py:meth:`forward`
3. Computes loss :py:meth:`compute_loss`
4. Logs metrics :py:meth:`log_metrics`
Returns the total loss.
:param batch: Mini-batch of data.
:type batch: Batch
:param batch_idx: Index of batch.
:type batch_idx: int
:return: Loss
:rtype: torch.Tensor
"""
return self._do_step(batch, batch_idx, "train")
[docs]
def validation_step(
self, batch: Union[Batch, ProteinBatch], batch_idx: int
) -> torch.Tensor:
"""
Perform validation step.
1. Obtains labels from :py:meth:`get_labels`
2. Computes model output :py:meth:`forward`
3. Computes loss :py:meth:`compute_loss`
4. Logs metrics :py:meth:`log_metrics`
Returns the total loss.
:param batch: Mini-batch of data.
:type batch: Batch
:param batch_idx: Index of batch.
:type batch_idx: int
:return: Loss
:rtype: torch.Tensor
"""
return self._do_step(batch, batch_idx, "val")
[docs]
def test_step(
self, batch: Union[Batch, ProteinBatch], batch_idx: int
) -> torch.Tensor:
"""Perform test step.
1. Obtains labels from :py:meth:`get_labels`
2. Computes model output :py:meth:`forward`
3. Computes loss :py:meth:`compute_loss`
4. Logs metrics :py:meth:`log_metrics`
Returns the total loss.
:param batch: Mini-batch of data.
:type batch: Batch
:param batch_idx: Index of batch.
:type batch_idx: int
:return: Loss
:rtype: torch.Tensor
"""
return self._do_step(batch, batch_idx, "test")