[docs]classProteinFeaturiser(nn.Module):""" Initialise a protein featuriser. :param representation: Representation to use for the protein. One of ``"ca", "ca_bb", "full_atom"``. :type representation: StructureRepresentation :param scalar_node_features: List of scalar-values node features to compute. Options: ``"amino_acid_one_hot", "sequence_positional_encoding", "alpha", "kappa", "dihedrals" "sidechain_torsions"``. :type scalar_node_features: List[ScalarNodeFeature] :param vector_node_features: List of vector-valued node features to compute. # TODO types :type vector_node_features: List[VectorNodeFeature] :param edge_types: List of edge types to compute. Options: # TODO types :type edge_types: List[str] :param scalar_edge_features: List of scalar-valued edge features to compute. # TODO types :type scalar_edge_features: List[ScalarEdgeFeature] :param vector_edge_features: List of vector-valued edge features to compute. # TODO types :type vector_edge_features: List[VectorEdgeFeature] """def__init__(self,representation:StructureRepresentation,scalar_node_features:List[ScalarNodeFeature],vector_node_features:List[VectorNodeFeature],edge_types:List[str],scalar_edge_features:List[ScalarEdgeFeature],vector_edge_features:List[VectorEdgeFeature],):super(ProteinFeaturiser,self).__init__()self.representation=representationself.scalar_node_features=scalar_node_featuresself.vector_node_features=vector_node_featuresself.edge_types=edge_typesself.scalar_edge_features=scalar_edge_featuresself.vector_edge_features=vector_edge_featuresif"sequence_positional_encoding"inself.scalar_node_features:self.positional_encoding=PositionalEncoding(16)
[docs]@jaxtyped(typechecker=typechecker)defforward(self,batch:Union[Batch,ProteinBatch])->Union[Batch,ProteinBatch]:# Scalar node featuresifself.scalar_node_features:concat_nf=Falseifhasattr(self,"positional_encoding"):batch.x=self.positional_encoding(batch.seq_pos)# This is necessary to concat node features with the positional encodingconcat_nf=Trueifself.scalar_node_features!=["sequence_positional_encoding"]:scalar_features=compute_scalar_node_features(batch,self.scalar_node_features)ifconcat_nf:batch.x=torch.cat([batch.x,scalar_features],dim=-1)else:batch.x=scalar_featuresbatch.x=torch.nan_to_num(batch.x,nan=0.0,posinf=0.0,neginf=0.0)# Representationbatch=transform_representation(batch,self.representation)# Vector node featuresifself.vector_node_features:batch=compute_vector_node_features(batch,self.vector_node_features)# Edgesifself.edge_types:batch.edge_index,batch.edge_type=compute_edges(batch,self.edge_types)batch.num_relation=len(self.edge_types)# Scalar edge featuresifself.scalar_edge_features:batch.edge_attr=compute_scalar_edge_features(batch,self.scalar_edge_features)# Vector edge featuresifself.vector_edge_features:batch=compute_vector_edge_features(batch,self.vector_edge_features)returnbatch