Tasks#

../_images/box_aux_tasks.png
../_images/box_downstream_tasks.png

Task configs are ‘high-level’ configs which configure the transforms outputs, losses and metrics of the model.

Specific training objects are achieved through the use of Transforms.

Note

To change the task, use a command with a format like:

workshop train task=<TASK_NAME> dataset=cath encoder=gvp trainer=cpu ...
# or
python proteinworkshop/train.py task=<TASK_NAME> dataset=cath encoder=gvp trainer=cpu ... # or trainer=gpu

Where <TASK_NAME> is one of the tasks listed below.

See also

Transforms

Denoising Tasks#

Sequence Denoising (sequence_denoising)#

This config trains a model to predict the original identies of the corrupted residues in a protein sequence.

The corruption is configured by the sequence noise transform <proteinworkshop.tasks.sequence_denoising.SequenceNoiseTransform>, which can be configured to apply masking or mutation corruptions in various amounts to the input sequence.

config/task/sequence_denoising.yaml
config/task/sequence_denoising.yaml#
# @package _global_

defaults:
  - override /metrics:
      - accuracy
      - f1_score
      - perplexity
  - override /decoder:
      - residue_type
  - override /transforms:
      - remove_missing_ca
      - sequence_denoising

dataset:
  num_classes: 23

callbacks:
  early_stopping:
    monitor: val/residue_type/accuracy
    mode: "max"
  model_checkpoint:
    monitor: val/residue_type/accuracy
    mode: "max"

task:
  task: "sequence_denoising"
  classification_type: "multiclass"
  metric_average: "micro"

  losses:
    residue_type: cross_entropy
  label_smoothing: 0.0

  output:
    - residue_type
  supervise_on:
    - residue_type

Structure Denoising (structure_denoising)#

This config trains a model to predict the original Cartesian coordinates of the corrupted residues in a protein structure.

Noise is applied to the Cartesian coordinates and the model is tasked with predicting either the per-residue Cartesian noise or the original Cartesian coordinates.

The corruption is configured by the structure noise transform <proteinworkshop.tasks.structure_denoising.StructureNoiseTransform>`, which can be configured to apply uniform or gaussian random noise in various amounts to the input structure.

See also

proteinworkshop.tasks.structure_denoising.StructureNoiseTransform

config/task/structure_denoising.yaml#
# @package _global_

defaults:
  - override /metrics:
      - rmse
  - override /decoder:
      - pos_equiv
  - override /transforms:
      - remove_missing_ca
      - structure_denoising

dataset:
  num_classes: null

callbacks:
  early_stopping:
    monitor: val/loss/pos
    mode: "min"
  model_checkpoint:
    monitor: val/loss/pos
    mode: "min"

task:
  task: "structure_denoising"
  losses:
    pos: mse_loss
  label_smoothing: 0.0
  output:
    - pos
  supervise_on:
    - pos

Sequence & Structure Denoising (sequence_structure_denoising)#

This config trains a model to predict the original identies of the corrupted residues in a protein sequence and the original Cartesian coordinates of the corrupted residues in a protein structure.

This config demonstrates how we can compose transforms in a modular fashion to create new training regimes.

See also

proteinworkshop.tasks.sequence_denoising.SequenceNoiseTransform proteinworkshop.tasks.structure_denoising.StructureNoiseTransform

config/task/sequence_structure_denoising.yaml#
# @package _global_

defaults:
  - override /metrics:
      - rmse
      - accuracy
  - override /decoder:
      - pos_equiv
      - residue_type
  - override /transforms:
      - remove_missing_ca
      - structure_denoising
      - sequence_denoising

dataset:
  num_classes: 23

callbacks:
  early_stopping:
    monitor: val/loss/total
    mode: "min"
  model_checkpoint:
    monitor: val/loss/total
    mode: "min"

task:
  task: "sequence_structure_denoising"
  classification_type: "multiclass"
  metric_average: "micro"

  losses:
    pos: mse_loss
    residue_type: cross_entropy

  label_smoothing: 0.0
  output:
    - pos
    - residue_type
  supervise_on:
    - pos
    - residue_type

Torsional Denoising (torsional_denoising)#

This config trains a model to predict the original dihedral angles of the corrupted residues in a protein structure.

The torsional noise transform applies noise in dihedral angle space. The cartesian coordinates are the recomputed using pNeRF to enable structure-based featurisation.

config/task/torsional_denoising.yaml#
# @package _global_

defaults:
  - override /metrics:
      - rmse
  - override /decoder:
      #- dihedrals
      - torsional_noise
  - override /transforms:
      - remove_missing_ca
      - torsional_denoising

dataset:
  num_classes: null

callbacks:
  early_stopping:
    #monitor: val/loss/dihedrals
    monitor: val/loss/torsional_noise
    mode: "min"
  model_checkpoint:
    #monitor: val/loss/dihedrals
    monitor: val/loss/torsional_noise
    mode: "min"

task:
  task: "torsional_denoising"
  losses:
    #dihedrals: mse_loss
    torsional_noise: mse_loss
  label_smoothing: 0.0
  output:
    #- dihedrals # Or torsional_noise
    - torsional_noise # or dihedrals
  supervise_on:
    # - dihedrals
    - torsional_noise

Node-level Tasks#

Protein-Protein Interaction Site Prediction (ppi_site_prediction)#

config/task/ppi_site_prediction.yaml#
# @package _global_

defaults:
  - _self_
  - override /metrics:
      - accuracy
      - f1_score
      - f1_max
      - auprc
      - rocauc
  - override /decoder:
      - node_label
  - override /transforms:
      - remove_missing_ca
      - ppi_site_prediction

dataset:
  num_classes: 2

callbacks:
  early_stopping:
    monitor: val/node_label/accuracy
    mode: "max"
  model_checkpoint:
    monitor: val/node_label/accuracy
    mode: "max"

task:
  task: "classification"
  classification_type: "binary"
  metric_average: "micro"

  losses:
    node_label: bce
  label_smoothing: 0.0

  output:
    - "node_label"
  supervise_on:
    - "node_label"

Ligand Binding Site Prediction (binding_site_identification)#

See also

proteinworkshop.tasks.binding_site_identification.BindingSiteIdentificationTransform

config/task/binding_site_identification.yaml#
# @package _global_

defaults:
  - override /metrics:
      - accuracy
      - f1_score
      - auprc
      - rocauc
  - override /decoder:
      - node_label
  - override /transforms:
      - remove_missing_ca
      - binding_site_prediction

transforms:
  binding_site_prediction:
    hetatms: [HOH, SO4, PEG] # Types of hetatms to be considered as binding sites
    threshold: 3.5 # Threshold for binding site prediction
    ca_only: False # Whether to use only CA atoms for assigning labels
    multilabel: True # Whether to use multilabel or binary labels

dataset:
  num_classes: 3 # This needs to match the number of hetatms above

callbacks:
  early_stopping:
    monitor: val/node_label/accuracy
    mode: "max"
  model_checkpoint:
    monitor: val/node_label/accuracy
    mode: "max"

task:
  task: "binding_site_identification"
  classification_type: "multilabel" # Check this aligns with binding site config above
  metric_average: "macro"

  losses:
    node_label: cross_entropy # Check this aligns with binding site config above
  label_smoothing: 0.0

  output:
    - "node_label"
  supervise_on:
    - "node_label"

Multiclass Node Classification (multiclass_node_classification)#

config/task/multiclass_node_classification.yaml#
# @package _global_

defaults:
  - override /metrics:
      - accuracy
      - f1_score
      - f1_max
  - override /decoder:
      - node_label

callbacks:
  early_stopping:
    monitor: val/node_label/accuracy
    mode: "max"
  model_checkpoint:
    monitor: val/node_label/accuracy
    mode: "max"

task:
  task: "classification"
  classification_type: "multiclass"
  metric_average: "micro"

  losses:
    node_label: cross_entropy
  label_smoothing: 0.0

  output:
    - "node_label"
  supervise_on:
    - "node_label"

pLDDT Prediction (plddt_prediction)#

This config specifies a self-supervision task to predict the per-residue pLDDT score of each node.

Warning

This task requires the input data to have a b_factor attribute.

If the input structure are not predicted structures, this task will be a B factor prediction task.

config/task/plddt_prediction.yaml#
# @package _global_

defaults:
  - override /metrics:
      - rmse
  - override /decoder:
      - b_factor

dataset:
  num_classes: 1

callbacks:
  early_stopping:
    monitor: val/loss/b_factor
    mode: "min"
  model_checkpoint:
    monitor: val/loss/b_factor
    mode: "min"

task:
  task: "plddt_prediction"
  losses:
    b_factor: mse_loss
  label_smoothing: 0.0
  output:
    - b_factor
  supervise_on:
    - b_factor

Edge-level Tasks#

Edge Distance Prediction (edge_distance_prediction)#

This config specifies a self-supervision task to predict the pairwise distance between two nodes.

We first sample num_samples edges randomly from the input batch. We then construct a mask to remove the sampled edges from the batch. We store the masked node indices and their pairwise distance as batch.node_mask and batch.edge_distance_labels, respectively. Finally, it masks the edges (and their attributes) using the constructed mask and returns the modified batch. The distance is then predicted from the concantenated node embeddings of the two nodes.

config/task/edge_distance_prediction.yaml#
# @package _global_

defaults:
  - override /metrics:
      - rmse
  - override /decoder:
      - edge_distance

dataset:
  num_classes: 1

callbacks:
  early_stopping:
    monitor: val/edge_distance/rmse
    mode: "min"
  model_checkpoint:
    monitor: val/edge_distance/rmse
    mode: "min"

task:
  task: "edge_distance_prediction"
  transform:
    _target_: proteinworkshop.tasks.edge_distance_prediction.EdgeDistancePredictionTransform
    num_samples: 256
  losses:
    edge_distance: mse_loss
  label_smoothing: 0.0
  output:
    - edge_distance
  supervise_on:
    - edge_distance

Auxiliary Tasks#

Auxiliary tasks define training objectives additional to the main training task. For instance, these can define auxiliary denoising objectives over sequence, coordinate or angle space.

Auxiliary tasks are implemented by modifying the experiment config to inject the additional metrics, decoders and losses.

Note

Aux tasks are specified with the following syntax:

python proteinworkshop/train.py ... +aux_task=nn_sequence

Sequence Denoising (nn_sequence)#

# Example:
python proteinworkshop/train.py dataset=cath encoder=gvp task=plddt_prediction +aux_task=nn_sequence
config/aux_task/nn_sequence.yaml#
# @package _global_

defaults:
  - _self_
  - /decoder:
      - residue_type
  - /metrics@residue_type:
  - /task@losses@residue_type: null
  - /transforms:
      - remove_missing_ca
      - sequence_denoising

task:
  losses:
    residue_type: cross_entropy

  output: ${oc.dict.keys:task.losses}
  supervise_on: ${oc.dict.keys:task.losses}
  aux_loss_coefficient:
    residue_type: 0.1

Structure Denoising (nn_structure_r3)#

# Example:
python proteinworkshop/train.py dataset=cath encoder=gvp task=plddt_prediction +aux_task=nn_structure_r3
config/aux_task/nn_structure_r3.yaml#
# @package _global_

defaults:
  - _self_
  - /decoder:
      - pos_equiv
  - /metrics: null
  - /task@losses@pos: null
  - /transforms:
      - remove_missing_ca
      - structure_denoising

task:
  losses:
    pos: mse_loss

  output: ${oc.dict.keys:task.losses}
  supervise_on: ${oc.dict.keys:task.losses}
aux_loss_coefficient:
  pos: 0.1

Torsional Denoising (nn_structure_torsion)#

Noise is applied to the backbone torsion angles and Cartesian coordinates are recomputed using pNeRF and the uncorrupted bond lengths and angles prior to feature computation. Similarly to the coordinate denoising task, the model is then tasked with predicting either the per-residue angular noise or the original dihedral angles

Warning

This will subset the data to only include the backbone atoms (N, Ca, C). The backbone oxygen can be placed with: graphein.protein.tensor.reconstruction.place_fourth_coord().

This will break, for example, sidechain torsion angle computation for the first few chi angles that are partially defined by backbone atoms.

# Example:
python proteinworkshop/train.py dataset=cath encoder=gvp task=plddt_prediction +aux_task=nn_structure_torsion
config/aux_task/nn_structure_torsion.yaml#
# @package _global_

defaults:
  - _self_
  - /decoder:
      - dihedrals
  - /metrics: null
  - /task@losses@pos: null
  - /transforms:
      - remove_missing_ca
      - torsional_denoising
task:
  losses:
    dihedrals: mse_loss

  output: ${oc.dict.keys:task.losses}
  supervise_on: ${oc.dict.keys:task.losses}
aux_loss_coefficient:
  pos: 0.1

Inverse Folding (inverse_folding)#

This adds an additional inverse folding objective to the model. I.e. the model is trained to predict the sequence of the input structure.

Warning

This will remove the residue-type node feature and sidechain torsion angles to avoid leakage of information from the target structure.

# Example:
python proteinworkshop/train.py dataset=cath encoder=gvp task=plddt_prediction +aux_task=inverse_folding
config/aux_task/inverse_folding.yaml#
# @package _global_

defaults:
  - _self_
  - /decoder:
      - residue_type
  - /metrics@pos: null
  - /task@losses@pos: null
  - /transforms:
      - remove_missing_ca
      - inverse_folding

task:
  losses:
    residue_type: cross_entropy

  output: ${oc.dict.keys:task.losses}
  supervise_on: ${oc.dict.keys:task.losses}

None (none)#

# Example:
python proteinworkshop/train.py dataset=cath encoder=gvp task=plddt_prediction