Source code for proteinworkshop.tasks.remove_missing_ca

"""Implementation of a transform to remove residues with missing CA atoms."""
import numpy as np
import torch
from torch_geometric import transforms as T


[docs] class RemoveMissingCa(T.BaseTransform): """Removes residues with missing CA atoms from a protein structure.""" def __init__(self, fill_value: float = 1e-5, ca_idx: int = 1) -> None: """Initialise the transform. :param fill_value: Value used to denote missing atoms in the ``Protein`` data object. Defaults to ``1e-5``. :type fill_value: float, optional :param ca_idx: Index of the CA atom (in dimension 1) in the coords attribute of the Protein data object. By default this is 1, as the coords attribute is of shape ``(N, 37, 3)`` where ``N`` is the number of residues. :type ca_idx: int, optional """ self.fill_value = fill_value self.ca_idx = ca_idx def __call__(self, data): """Remove residues with missing CA atoms from a protein structure. :param data: Protein data object. :type data: Protein :return: Protein data object with missing residues removed. :rtype: Protein """ # Check for missing CA atoms # If there are no missing CA atoms, return the data mask = data.coords[:, self.ca_idx, 0] != self.fill_value if torch.all(mask): return data data.coords = data.coords[mask] data.residue_type = data.residue_type[mask] data.residues = np.array(data.residues)[mask] data.residue_id = np.array(data.residue_id)[mask] data.chains = data.chains[mask] if data.x is not None: data.x = data.x[mask] if hasattr(data, "amino_acid_one_hot"): data.amino_acid_one_hot = data.amino_acid_one_hot[mask] if hasattr(data, "seq_pos"): data.seq_pos = data.seq_pos[mask] return data
if __name__ == "__main__": from graphein.protein.tensor.data import get_random_protein a = get_random_protein() print(a) t = RemoveMissingCa() print(t(a)) a = torch.load( "../../../protein-workshop/data/FoldClassification/processed/d2vzsa2.pt" ) print(a) t = RemoveMissingCa() print(t(a))