Source code for proteinworkshop.tasks.structural_denoising

"""Implements a transform for corrupting the Cartesian coordinates of a protein structure."""
import copy
from typing import Literal, Set, Union

import torch
from beartype import beartype as typechecker
from graphein.protein.tensor.data import Protein
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform


[docs] class StructuralNoiseTransform(BaseTransform): """Adds noise to the coordinates of a protein structure. Sets the following attributes on the protein data object: - ``coords_uncorrupted``: The original coordinates of the protein. - ``noise``: The noise added to the coordinates. - ``coords``: The original coordinates + noise. :param corruption_rate: Magnitude of corruption to apply to the coordinates. :type corruption_rate: float :param corruption_strategy: Noise strategy to use for corruption. :type corruption_strategy: Literal["uniform", "gaussian"] """ def __init__( self, corruption_rate: float, corruption_strategy: Literal["uniform", "gaussian"], ): self.corruption_rate = corruption_rate self.corruption_strategy = corruption_strategy @property def required_attributes(self) -> Set[str]: return {"coords"} @typechecker def __call__(self, x: Union[Data, Protein]) -> Union[Data, Protein]: """Adds noise to the coordinates of a protein structure. :param x: Protein data object :type x: Union[Data, Protein] :raises ValueError: If the corruption strategy is not supported. :return: Protein data object with corrupted coordinates. :rtype: Union[Data, Protein] """ x.coords_uncorrupted = copy.deepcopy(x.coords) with torch.no_grad(): if self.corruption_strategy == "uniform": noise = torch.rand_like(x.coords, device=x.coords.device) noise = (noise - 0.5) * 2 * self.corruption_rate elif self.corruption_strategy == "gaussian": noise = ( torch.randn_like(x.coords, device=x.coords.device) * self.corruption_rate ) else: raise ValueError( f"Corruption strategy: {self.corruption_strategy} not supported." ) pad_indices = torch.where(x.coords == 1e-5) x.noise = noise x.coords += noise x.coords[pad_indices] = 1e-5 return x def __repr__(self) -> str: return f"{self.__class__}(corruption_strategy: {self.corruption_strategy} corruption_rate: {self.corruption_rate})"
if __name__ == "__main__": from graphein.protein.tensor.data import get_random_protein p = get_random_protein() task = StructuralNoiseTransform( corruption_rate=5, corruption_strategy="uniform" ) def rmsd(x, y): return torch.sqrt(torch.mean((x - y) ** 2)) p = task(p) print(rmsd(p.coords, p.coords_uncorrupted))