A deeper look at AlphaFold2 and its neural architecture
In this series of articles, I will go through protein folding and deep learning models such as AlphaFold, OmegaFold, and ESMFold. We will start with AlphaFold2!
Proteins are molecules that perform most of the biochemical functions in living organisms. They are involved in digestion (enzymes), structural processes (keratin — skin), photosynthesis and are also used extensively in the pharmaceutical industry .
The 3D structure of the protein is fundamental to its function. Proteins are made up of 20 subunits called amino acids (or residues), each with different properties such as charge, polarity, length, and the number of atoms. Amino acids are formed by a backbone, common to all amino acids, and a side-chain, unique to each amino acid. They are connected by a peptide bond .
Protein contain residues oriented at specific torsion angles called φ and ψ, which give rise to a protein 3D shape.
The main problem every biologist faces is obtaining this 3D shape of proteins, usually requires a crystal of the protein and X-Ray Crystallography. Proteins have various properties, for example, membrane proteins tend to be hydrophobic meaning it is hard to identify the conditions at which it crystallizes . Obtaining crystals is therefore a tedious and (arguably) highly random process takes days to years to decades and it can be regarded as more of an art than a science. This means that many biologists may spend the entire duration of their Ph.D. trying to crystallise a protein.
If you are lucky enough to get a crystal of your protein, you can upload it to the Protein Data Bank, a large dataset of proteins:
This begs the question: can we simulate folding to obtain a 3D structure from a sequence? Short answer: Yes, kind of. Long answer: We can use molecular simulations to try to fold proteins which are often heavy in computational use. Hence, projects like Folding@Home try to distribute the problem over many computers to obtain a dynamics simulation of a protein.
Now, a competition, Critical Assessment of Protein Structure Prediction (CASP) was made where some 3D structures of proteins would be holdout so that people could test their protein folding models. In 2020, DeepMind participated with AlphaFold2 beating the state-of-the-art and obtaining outstanding performances.
In this blog post, I will go over AlphaFold2, explain its inner workings, and conclude how it has revolutionized my work as a Ph.D. student on Protein Design and Machine Learning.
Before we start, I would like to give a shoutout to OpenFold by the AQ Laboratory, an open-source implementation of AlphaFold that includes training code through which I double-checked the dimensions of tensors I refer to in this article. Most of this article’s information is in the Supplementary of the original paper.
Let’s begin with an overview. This is what the overall structure of the model looks like:
Typically, you start with a sequence of amino acids of your protein of interest. Note that a crystal is not necessary to obtain the sequence of amino acid : this is usually obtained from DNA sequencing (if you know the gene of the protein) or Protein Sequencing. The proteins can be broken to smaller -mers and analysed in mass spectrometry for example.
The aim is to prepare two key pieces of data the Multiple Sequence Alignment (MSA) representation and a pair representation. For simplicity, I will skip the use of templates.
The MSA representation is obtained by looking for similar sequences in genetic databases. As the picture shows, the sequence may also come from different organisms, e.g., a fish. Here we are trying to get general information about each index position of the protein and understand, in the context of evolution, how the protein has changed in different organisms. Proteins like Rubisco (involved in photosynthesis) are generally highly conserved and therefore have little differences in plants. Others, like the spike protein of a virus, are very variable.
In the pair representation, we are trying to infer relationships between the sequence elements. For example, position 54 of the protein may interact with position 1.
Throughout the network, these representations are updated several times. First, they are embedded to create a representation of the data. Then they pass through the EvoFormer, which extracts information about sequences and pairs, and finally, a structure model which builds the 3D structure of the protein.
The input embedder attempts to create a different representation of the data. For MSA data, AlphaFold uses an arbitrary cluster number rather than the full MSA to reduce the number of possible sequences that go through the transformer, thus decreasing computation. The MSA data input msa_feat (N_clust, N_res, 49) is composed by:
- cluster_msa (N_clust, N_res, 23): a one-hot encoding of the MSA cluster center sequences (20 amino acids + 1 unknown + 1 gap + 1 masked_msa_token)
- cluster_profile (N_clust, N_res, 23): amino acid type distribution for each residue in the MSA (20 amino acids + 1 unknown + 1 gap + 1 masked_msa_token)
- cluster_deletion_mean (N_clust, N_res, 1): average deletions of every residue in every cluster (ranges 0–1)
- cluster_deletion_value (N_clust, N_res, 1): number of deletions in the MSA (ranges 0–1)
- cluster_has_deletion (N_clust, N_res, 1): binary feature indicating whether there are deletions
For pair representations, it encodes each amino acid with a unique index in the sequence with RelPos, which accounts for distance in the sequence. This is represented as a distance matrix of each residue against each other, and the distances clipped to 32, meaning larger distances are capped to 0, meaning the dimension is effectively -32 to 32 + 1 = 65.
Both the MSA representation and the pair representations go through several independent linear layers and are passed to the EvoFormer.
There are then 48 blocks of the EvoFormer, which uses self-attention to allow the MSA and Pairs representations to communicate. We first look at the MSA to then merge it into the pairs.
2.1 MSA Stack
This is composed of row-wise gated self-attention with pair bias, column-wise gated self-attention, transition and outer product mean blocks.
2.1A Row-Wise Gated Self-Attention with Pair Bias
The key point here is to allow MSA and pair representations communicate information with each other.
First, multi-head attention is used to calculate dot-product affinities (N_res, N_res, N_heads) from the MSA representation row, meaning the amino acids in the sequence will learn “conceptual importance” between pairs. In essence, how important one amino acid is for another amino acid.
Then, the pair representation goes through a linear layer without bias, meaning only a weight parameter will be learned. The linear layer outputs N_heads dimensions producing the matrix pair bias matrix (N_res, N_res, N_heads). Remember this matrix was initially capped to 32 maximum distance meaning if an amino acid is more distant than 32 indices, it will have a value of 0
At this point, we have two matrices of shape (N_res, N_res, N_heads) that we can easily add together and softmax to have values between 0 and 1. An attention block with the added matrices as Queries and a row passed through a linear layer as values to obtain the attention weights.
Now we calculate the dot product between:
- the attention weights and
- the Linear + sigmoid of the MSA row as keys (I believe the sigmoid operation here returns a probability-like array ranging from 0–1)
2.1B Column-Wise Gated Self-Attention
The key point here is that MSA is an aligned version of all sequences related to the input sequences. This means that index X will correspond to the same area of the protein for each sequence.
By doing this operation column-wise, we ensure that we have a general understanding of which residues are more likely for each position. This also means the model would be robust should a similar sequence with small differences produce similar 3D shapes.
2.1C MSA Transition
This is a simple 2-layer MLP that first increases the channel dimensions by a factor of 4 and then reduces it down to the original dimensions.
2.1D Outer Product Mean
This operation aims at keeping a continuous flow of information between the MSA and the pair representation. Each column in the MSA is an index position of a protein sequence.
- Here, we select index i and j, which we independently send through a linear layer. This linear layer uses c=32, which is lower than c_m.
- The outer product is then calculated, averaged, flattened, and again through another linear layer.
We now have an updated entry for ij in the pair representation. We repeat this for all the pairs.
2.2 Pairs Stack
Our pair representation can technically be interpreted as a distance matrix. Earlier, we saw how each amino acid starts with 32 neighbors. We can therefore build a triangle graph based on three indices of the pair representation.
For example, nodes i, j, and k will have edges ij, ik, and jk. Each edge is updated with information from the other two edges of all the triangles it is part of.
2.2A Triangular Multiplicative Update
We have two types of updates, one for outgoing edges and one for incoming edges.
For outgoing edges, the full row or pair representations i and j is first independently passed through a linear layer producing a representation of the left edges and right edges.
Then, we compute the dot product between the corresponding representation for the ij pair and the left and right edges independently.
Finally, we take the dot product of the left and right edges representations and a final dot product with the ij pair representation.
For incoming edges, the algorithm is very similar but bear in mind that if previously we were considering the edge as ik, we now go in the opposite direction ki. In the OpenFold code, this is implemented simply as a permute function.
2.2B Triangular Self-Attention
This operation aims at updating the pair representation by using self-attention. The main goal is to update the edge with the most relevant edges, ie. which amino acids in the protein are more likely to interact with the current node.
With self-attention, we learn the best way to update the edge through:
- (query-key) Similarity between edges that contain the node of interest. For instance for node i, all edges that share that node (eg. ij, ik).
- A third edge (eg. jk) which even if it does not directly connect to node i, is part of the triangle.
This last operation is similar in style to a graph message-passing algorithm, where even if nodes are not directly connected, information from other nodes in the graph is weighted and passed on.
2.2C Transition Block
Equivalent to the transition block in the MSA trunk with a 2-Layer MLP where the channel is first expanded by a factor of 4 and then reduced to the original number.
The output of the EvoFormer block is an updated representation of both MSA and pairs (of the same dimensionality).
The structure module is the final part of the model and converts the pairs representations and the input sequence representation (corresponds to a row in the MSA representation) into a 3D structure. It consists of 8 layers with shared weights, and the pair representation is used to bias the attention operations in the Invariant Point Attention (IPA) module.
The outputs are:
- Backbone Frames (r, 3×3): Frames represent a Euclidean transform for atomic positions to go from a local frame of reference to a global one. Free-floating body representation (blue triangles) composed of N-Cα-C; thus, each residue (r_i) has three sets of (x, y, z) coordinates
- χ angles of the sidechains (r , 3): represents the angle of each rotatable atom of the side chain. The angles define the rotational isomer (rotamer) of a residue; therefore, one can derive the exact position of the atoms. Up to χ1, χ2, χ3, χ4.
Note that χ refers to the dihedral angle of each of the rotatable bonds of the side chains. There are shorter amino acids that do not have all four χ angles as shown below:
3.1 Invariant Point Attention (IPA)
Generally, this type of attention is designed to be invariant to Euclidean transformations such as translations and rotations.
- We first update the single representation with self-attention, as explained in previous sections.
- We also feed information about the backbone frames of each residue to produce query points, key points, and value points for the local frame. These are then projected into a global frame where they interact with other residues and then projected back to the local frame.
- The word “invariant” refers to the fact that global and local reference points are enforced to be invariant by using squared distances and coordinate transformation in the 3D space.
3.2 Predict side chain and backbone torsion angles
The single representation goes through a couple of MLPs and outputs the torsion angles ω, φ, ψ, χ1, χ2, χ3, χ4.
3.3 Backbone Update
There are two updates returned by this block: one is the rotation represented by a quaternion (1, a, b, c where the first value is fixed to 1 and a, b, and c correspond to the Euler axis predicted by the network) and a translation represented by a vector matrix.
3.4 All Atom Coordinates
At this point, we have both the backbone frames and the torsion angles, and we would like to obtain the exact atom coordinates of the amino acid. Amino acids have a very specific structure of atoms, and we have the identity as the input sequence. We, therefore, apply the torsion angles to the atoms of the amino acid.
Note that many times you will find many structural violations in the output of AlphaFold, such as the ones depicted below. This is because the model itself does not enforce physical energy constraints. To alleviate this problem, we run an AMBER relaxation force field to minimize the energy of the protein.
The AlphaFold model contains several self-attention layers and large activations due to the sizes of the MSAs. Classical backpropagation is optimized to reduce the number of total computations per node. However, in the case of AlphaFold, it would require more than the available memory in a TPU core (16 GiB). Assuming a protein of 384 residues:
Instead, AlphaFold used gradient checkpointing (also rematerialization). The activations are recomputed and calculated for one layer at the time, thus bringing memory consumption to around 0.4 GiB.
This GIF shows what backpropagation usually looks like:
By checkpointing, we reduce memory usage, though this has the unfortunate side effect of increasing training time by 33%:
What if, rather than a sequence of amino acids, you had the model of a cool protein you designed with a dynamics simulation? Or one that you modeled to bind another protein like a COVID spike protein. Ideally, you would want to predict the sequence necessary to fold to an input 3D shape that may or may not exist in nature (i.e., it could be a completely new protein). Let me introduce you to the world of protein design, which is also my Ph.D. project TIMED (Three-dimensional Inference Method for Efficient Design):
This problem is arguably harder than the folding problem, as multiple sequences can fold to the same shape. This is because there is redundancy in amino acid types, and there are also areas of a protein that are less critical for the actual fold.
The cool aspect about AlphaFold is that we can use it to double-check whether our models work well:
If you would like to know more about this model, have a look at my GitHub repository, which also includes a little UI Demo!
In this article, we saw how AlphaFold (partially) solves a clear problem for biologists, mainly obtaining 3D structures from an amino acid sequence.
We broke down the structure of the model into Input Embedder, EvoFormer, and Structure module. Each of these uses several self-attention layers, including many tricks to optimize the performance.
AlphaFold works well, but is this it for biology? No. AlphaFold is still computationally very expensive, and there isn’t an easy way to use it (No, Google Colab is not easy — it’s clunky). Several alternatives, like OmegaFold and ESMFold, attempt to solve these problems.
These models still do not explain how a protein folds over time. There are also a lot of challenges that involve designing proteins where inverse folding models can use AlphaFold to double-check that designed proteins fold to a specific shape.
In the next series of articles, we will look into OmegaFold and ESMFold!
 Jumper J, Evans R, Pritzel A, Green T, Figurnov M, Ronneberger O, Tunyasuvunakool K, Bates R, Žídek A, Potapenko A, et al. Highly accurate protein structure prediction with AlphaFold. Nature (2021) DOI: 10.1038/s41586–021–03819–2
 Alberts B. Molecular biology of the cell. (2015) Sixth edition. New York, NY: Garland Science, Taylor and Francis Group.
 Ahdritz G, Bouatta N, Kadyan S, Xia Q, Gerecke W, O’Donnell TJ, Berenberg D, Fisk I, Zanichelli N, Zhang B, et al. OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization (2022) Bioinformatics. DOI: 10.1101/2022.11.20.517210
 Callaway E. “It will change everything”: DeepMind’s AI makes gigantic leap in solving protein structures (2020). Nature 588(7837):203–204. DOI: 10.1038/d41586–020–03348–4