from itertools import chain
from typing import Literal, Tuple
import torch
from beartype import beartype as typechecker
from graphein.protein.tensor.types import AtomTensor, CoordTensor
from jaxtyping import jaxtyped
from torch_geometric.data import Batch, Data
from torch_geometric.utils import unbatch
from proteinworkshop.configs.config import ExperimentConfigurationError
[docs]
@jaxtyped(typechecker=typechecker)
def get_full_atom_coords(
atom_tensor: AtomTensor, fill_value: float = 1e-5
) -> Tuple[CoordTensor, torch.Tensor, torch.Tensor]:
"""Converts an AtomTensor to a full atom representation
(e.g. dense to sparse).
:param atom_tensor: AtomTensor of shape (``N_residues x 37 x 3``)
:type atom_tensor: AtomTensor
:param fill_value: Value indicating missing atoms, defaults to ``1e-5``
:type fill_value: float, optional
:return: Tuple of coords (``N_atoms x 3``), residue_index (``N_atoms``),
atom_type (``N_atoms`` (``[0-36]``))
:rtype: Tuple[CoordTensor, torch.Tensor, torch.Tensor]
"""
# Get number of atoms per residue
filled = atom_tensor[:, :, 0] != fill_value
nz = filled.nonzero()
residue_index = nz[:, 0]
atom_type = nz[:, 1]
coords = atom_tensor.reshape(-1, 3)
coords = coords[coords != fill_value].reshape(-1, 3)
return coords, residue_index, atom_type
@typechecker
def _ca_to_fa_repr(x: Data) -> Data:
"""Converts CA representation to full atom representation."""
coords, residue_index, atom_type = get_full_atom_coords(x.coords)
x.amino_acid_one_hot = x.amino_acid_one_hot[residue_index]
x.dihedrals = x.dihedrals[residue_index]
x.pos = coords
x.residue_index = residue_index
x.atom_type = atom_type
x.num_nodes = x.pos.shape[0]
return x
@typechecker
def _ca_to_bb_repr(x: Data) -> Data:
"""Converts CA representation to backbone representation."""
x.pos = x.coords[:, :4, :].reshape(-1, 3)
x.dihedrals = x.dihedrals.repeat_interleave(4, 0)
x.amino_acid_one_hot = x.amino_acid_one_hot.repeat_interleave(4, 0)
x.num_nodes = x.num_nodes * 4
x.atom_type = torch.tensor([0.0, 1.0, 2.0]).repeat(x.num_nodes)
n_id = [f"{n}:N" for n in x.node_id]
ca_id = [f"{n}:Ca" for n in x.node_id]
c_id = [f"{n}:C" for n in x.node_id]
x.node_id = list(chain.from_iterable(zip(n_id, ca_id, c_id)))
return x
[docs]
@typechecker
def ca_to_bb_repr(batch: Batch) -> Batch: # sourcery skip: assign-if-exp
"""
Converts a batch of CA representations to backbone representations. I.e.
1 node per residue -> 4 nodes per residue (N, CA, C, O)
This function tiles any existing node features on the CA atoms over the
additional nodes in the backbone representation.
"""
if "sidechain_torsions" in batch.keys:
sidechain_torsions = batch.sidechain_torsions.repeat_interleave(4, 0)
else:
sidechain_torsions = None
if "chi1" in batch.keys:
chi1 = batch.chi1.repeat_interleave(4, 0)
else:
chi1 = None
if "positional_encoding" in batch.keys:
positional_encoding = batch.positional_encoding.repeat_interleave(4, 0)
else:
positional_encoding = None
if "true_dihedrals" in batch.keys:
true_dihedrals = batch.true_dihedrals.repeat_interleave(4, 0)
else:
true_dihedrals = None
if "mask" in batch.keys:
mask = batch.mask.repeat_interleave(4, 0)
else:
mask = None
batch_idx = batch.batch.repeat_interleave(4, 0)
x = batch.x.repeat_interleave(4, 0) if "x" in batch.keys else None
batch = Batch.from_data_list(
[_ca_to_bb_repr(x) for x in batch.to_data_list()]
)
batch.batch = batch_idx
if sidechain_torsions is not None:
batch.sidechain_torsions = sidechain_torsions
del sidechain_torsions
if chi1 is not None:
batch.chi1 = chi1
del chi1
if positional_encoding is not None:
batch.positional_encoding = positional_encoding
del positional_encoding
if true_dihedrals is not None:
batch.true_dihedrals = true_dihedrals
del true_dihedrals
if mask is not None:
batch.mask = mask
del mask
if x is not None:
batch.x = x
del x
return batch
[docs]
@typechecker
def ca_to_bb_sc_repr(batch: Batch) -> Batch:
"""Converts a batch of CA representations to backbone + sidechain representations."""
# Get centroids
batch.coords[:, 3:, :] = 1e-5
batch.coords[:, 4, :] = coarsen_sidechain(batch, aggr="mean")
batch.coords = batch.coords[:, :4, :]
return ca_to_fa_repr(batch)
[docs]
@typechecker
def ca_to_ca_sc_repr(batch: Batch) -> Batch:
"""Converts a batch of CA representations to C + sidechain representations."""
# Get centroids
batch.coords[:, 2:, :] = 1e-5
batch.coords[:, 0, :] = coarsen_sidechain(batch, aggr="mean")
batch.coords = batch.coords[:, :2, :]
return batch
[docs]
@typechecker
def coarsen_sidechain(x: Data, aggr: str = "mean") -> CoordTensor:
"""Returns tensor of sidechain centroids: L x 3"""
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
# Compute mean sidechain position
sc_points = x.coords[:, 4:]
if aggr == "mean":
sc_points = torch.mean(sc_points, dim=1)
else:
raise NotImplementedError(
f"Aggregation method {aggr} not implemented."
)
return sc_points
[docs]
@typechecker
def ca_to_fa_repr(batch: Batch) -> Batch: # sourcery skip: assign-if-exp
"""Converts a batch of CA representations to full atom representations."""
if "sidechain_torsion" in batch.keys:
sidechain_torsions = unbatch(batch.sidechain_torsion, batch.batch)
else:
sidechain_torsions = None
if "chi1" in batch.keys:
chi1 = unbatch(batch.chi1, batch.batch)
else:
chi1 = None
if "mask" in batch.keys:
mask = unbatch(batch.mask, batch.batch)
else:
mask = None
if "true_dihedrals" in batch.keys:
true_dihedrals = unbatch(batch.true_dihedrals, batch.batch)
else:
true_dihedrals = None
if "true_amino_acid_one_hot" in batch.keys:
true_amino_acid_one_hot = unbatch(
batch.true_amino_acid_one_hot, batch.batch
)
else:
true_amino_acid_one_hot = None
if "positional_encoding" in batch.keys:
positional_encoding = unbatch(batch.positional_encoding, batch.batch)
else:
positional_encoding = None
batch = Batch.from_data_list(
[_ca_to_fa_repr(x) for x in batch.to_data_list()]
)
residue_idxs = unbatch(batch.residue_index, batch.batch)
if sidechain_torsions is not None:
batch.sidechain_torsion = torch.cat(
[
sidechain_torsion[res_idx - torch.min(res_idx)]
for sidechain_torsion, res_idx in zip(
sidechain_torsions, residue_idxs
)
]
)
del sidechain_torsions
if chi1 is not None:
batch.chi1 = torch.cat(
[
chi1[res_idx - torch.min(res_idx)]
for chi1, res_idx in zip(chi1, residue_idxs)
]
)
del chi1
if true_dihedrals is not None:
batch.true_dihedrals = torch.cat(
[
true_dihedrals[res_idx - torch.min(res_idx)]
for true_dihedrals, res_idx in zip(
true_dihedrals, residue_idxs
)
]
)
del true_dihedrals
if true_amino_acid_one_hot is not None:
batch.true_amino_acid_one_hot = torch.cat(
[
true_amino_acid_one_hot[res_idx - torch.min(res_idx)]
for true_amino_acid_one_hot, res_idx in zip(
true_amino_acid_one_hot, residue_idxs
)
]
)
del true_amino_acid_one_hot
if mask is not None:
batch.mask = torch.cat(
[
mask[res_idx - torch.min(res_idx)]
for mask, res_idx in zip(mask, residue_idxs)
]
)
del mask
if positional_encoding is not None:
batch.positional_encoding = torch.cat(
[
pos_encoding[res_idx - torch.min(res_idx)]
for pos_encoding, res_idx in zip(
positional_encoding, residue_idxs
)
]
)
del positional_encoding
return batch