Source code for proteinworkshop.tasks.sequence_denoising

"""Implements sequence denoising task."""
import copy
from typing import Literal, Set, Union

import torch
from beartype import beartype as typechecker
from graphein.protein.tensor.data import Protein
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform


[docs] class SequenceNoiseTransform(BaseTransform): def __init__( self, corruption_rate: float, corruption_strategy: Literal["mutate", "mask"], ): """Corrupts the sequence of a protein by randomly flipping residues to another type or masking them. .. note:: - The data object this is called on must have a ``residue_type`` attribute. See: :py:meth:`required_attributes` - The original sequence is stored as ``data.residue_type_uncorrupted``. - The indices of the corrupted residues are stored as ``data.sequence_corruption_mask``. :param corruption_rate: Fraction of residues to corrupt. :type corruption_rate: float :param corruption_strategy: Strategy to use for corruption. Either: ``mutate`` or ``mask``. :type corruption_strategy: Literal[mutate, mask] """ self.corruption_rate = corruption_rate self.corruption_strategy = corruption_strategy @property def required_attributes(self) -> Set[str]: return {"residue_type"} @typechecker def __call__(self, x: Union[Data, Protein]) -> Union[Data, Protein]: x.residue_type_uncorrupted = copy.deepcopy(x.residue_type) # Get indices of residues to corrupt indices = torch.randint( 0, x.residue_type.shape[0], (int(x.residue_type.shape[0] * self.corruption_rate),), device=x.residue_type.device, ).long() # Apply corruption if self.corruption_strategy == "mutate": # Set indices to random residue type x.residue_type[indices] = torch.randint( 0, 23, # TODO: probably best to not hardcode this (indices.shape[0],), device=x.residue_type.device, ) elif self.corruption_strategy == "mask": # Set indices to 23 -> "UNK" x.residue_type[ indices ] = 23 # TODO: probably best to not hardcode this else: raise NotImplementedError( f"Corruption strategy: {self.corruption_strategy} not supported." ) # Get indices of applied corruptions index = torch.zeros(x.residue_type.shape[0]) index[indices] = 1 x.sequence_corruption_mask = index.bool() return x def __repr__(self) -> str: return f"{self.__class__}(corruption_strategy: {self.corruption_strategy} corruption_rate: {self.corruption_rate})"
if __name__ == "__main__": from graphein.protein.tensor.data import get_random_protein p = get_random_protein() orig_residues = p.residue_type task = SequenceNoiseTransform( corruption_rate=0.99, corruption_strategy="mutate" ) p = task(p) print(p.residue_type) print(p.residue_type_uncorrupted) task = SequenceNoiseTransform( corruption_rate=0.99, corruption_strategy="mask" ) p = task(p) print(p.residue_type) print(p.residue_type_uncorrupted)