Skip to content

Commit

Permalink
complete the main alphafold2 flow, sans diffusion module and losses +…
Browse files Browse the repository at this point in the history
… sampling
  • Loading branch information
lucidrains committed May 19, 2024
1 parent b0198d4 commit 1936f1e
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 9 deletions.
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,56 @@ Implementation of <a href="https://www.nature.com/articles/s41586-024-07487-w">A

Getting a fair number of emails. You can chat with me about this work <a href="https://discord.gg/x6FuzQPQXY">here</a>

## Install

```bash
$ pip install alphafold3-pytorch
```

## Usage

```python
import torch
from alphafold3_pytorch import Alphafold3

alphafold3 = Alphafold3(
dim_atom_inputs = 77,
dim_additional_residue_feats = 33,
dim_template_feats = 44
)

# mock inputs

seq_len = 16
atom_seq_len = seq_len * 27

atom_inputs = torch.randn(2, atom_seq_len, 77)
atom_mask = torch.ones((2, atom_seq_len)).bool()
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
additional_residue_feats = torch.randn(2, seq_len, 33)

template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
template_mask = torch.ones((2, 2)).bool()

msa = torch.randn(2, 7, seq_len, 64)

# train

loss = alphafold3(
num_recycling_steps = 2,
atom_inputs = atom_inputs,
atom_mask = atom_mask,
atompair_feats = atompair_feats,
additional_residue_feats = additional_residue_feats,
msa = msa,
templates = template_feats,
template_mask = template_mask
)

loss.backward()

```

## Citations

```bibtex
Expand Down
247 changes: 238 additions & 9 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
i - residue sequence length (source)
j - residue sequence length (target)
m - atom sequence length
c - coordinates (3 for spatial)
d - feature dimension
ds - feature dimension (single)
dp - feature dimension (pairwise)
Expand Down Expand Up @@ -861,7 +862,7 @@ def __init__(
# final projection of mean pooled repr -> out

self.to_out = nn.Sequential(
LinearNoBias(dim, dim),
LinearNoBias(dim, dim_pairwise),
nn.ReLU()
)

Expand All @@ -873,7 +874,7 @@ def forward(
template_mask: Bool['b t'],
pairwise_repr: Float['b n n dp'],
mask: Bool['b n'] | None = None,
) -> Float['b n n d']:
) -> Float['b n n dp']:

num_templates = templates.shape[1]

Expand All @@ -884,7 +885,8 @@ def forward(

v, merged_batch_ps = pack_one(v, '* i j d')

mask = repeat(mask, 'b n -> (b t) n', t = num_templates)
if exists(mask):
mask = repeat(mask, 'b n -> (b t) n', t = num_templates)

for block in self.pairformer_stack:
v = block(
Expand Down Expand Up @@ -1815,7 +1817,7 @@ def forward(
pairwise_repr: Float['b n n dp'],
pred_atom_pos: Float['b n c'],
mask: Bool['b n'] | None = None,
calc_pae_logits_and_loss = True
return_pae_logits = True

) -> ConfidenceHeadLogits[
Float['b pae n n'] | None,
Expand Down Expand Up @@ -1854,7 +1856,7 @@ def forward(

pae_logits = None

if calc_pae_logits_and_loss:
if return_pae_logits:
pae_logits = self.to_pae_logits(pairwise_repr)

# return all logits
Expand All @@ -1863,21 +1865,248 @@ def forward(

# main class

LossBreakdown = namedtuple('LossBreakdown', [
'distogram',
'pae',
'pdt',
'plddt',
'resolved'
])

class Alphafold3(Module):
""" Algorithm 1 """

@typecheck
def __init__(
self,
*,
dim_atom_inputs,
dim_additional_residue_feats,
dim_template_feats,
dim_template_model = 64,
atoms_per_window = 27,
dim_atom = 128,
dim_atompair = 16,
dim_input_embedder_token = 384,
dim_single = 384,
dim_pairwise = 128,
atompair_dist_bins: Float[' dist_bins'] = torch.linspace(3, 20, 37),
ignore_index = -1,
num_dist_bins = 38,
num_plddt_bins = 50,
num_pde_bins = 64,
num_pae_bins = 64,
loss_confidence_weight = 1e-4,
loss_distogram_weight = 1e-2,
loss_diffusion = 4.
loss_diffusion_weight = 4.,
input_embedder_kwargs: dict = dict(
atom_transformer_blocks = 3,
atom_transformer_heads = 4,
atom_transformer_kwargs = dict()
),
confidence_head_kwargs: dict = dict(
pairformer_depth = 4
),
template_embedder_kwargs: dict = dict(
pairformer_stack_depth = 2,
pairwise_block_kwargs = dict(),
),
msa_module_kwargs: dict = dict(
depth = 4,
dim_msa = 64,
dim_msa_input = None,
outer_product_mean_dim_hidden = 32,
msa_pwa_dropout_row_prob = 0.15,
msa_pwa_heads = 8,
msa_pwa_dim_head = 32,
pairwise_block_kwargs = dict()
),
pairformer_stack: dict = dict(
depth = 48,
pair_bias_attn_dim_head = 64,
pair_bias_attn_heads = 16,
dropout_row_prob = 0.25,
pairwise_block_kwargs = dict()
)
):
super().__init__()

self.atoms_per_window = atoms_per_window

# input feature embedder

self.input_embedder = InputFeatureEmbedder(
dim_atom_inputs = dim_atom_inputs,
dim_additional_residue_feats = dim_additional_residue_feats,
atoms_per_window = atoms_per_window,
dim_atom = dim_atom,
dim_atompair = dim_atompair,
dim_token = dim_input_embedder_token,
dim_single = dim_single,
dim_pairwise = dim_pairwise,
**input_embedder_kwargs
)

dim_single_inputs = dim_input_embedder_token + dim_additional_residue_feats

# templates

self.template_embedder = TemplateEmbedder(
dim_template_feats = dim_template_feats,
dim = dim_template_model,
dim_pairwise = dim_pairwise,
**template_embedder_kwargs
)

# msa

self.msa_module = MSAModule(
dim_single = dim_single,
dim_pairwise = dim_pairwise,
**msa_module_kwargs
)

# main pairformer trunk, 48 layers

self.pairformer = PairformerStack(
dim_single = dim_single,
dim_pairwise = dim_pairwise,
**pairformer_stack
)

# recycling related

self.recycle_single = nn.Sequential(
nn.LayerNorm(dim_single),
LinearNoBias(dim_single, dim_single)
)

self.recycle_pairwise = nn.Sequential(
nn.LayerNorm(dim_pairwise),
LinearNoBias(dim_pairwise, dim_pairwise)
)

# logit heads

self.distogram_head = DistogramHead(
dim_pairwise = dim_pairwise,
num_dist_bins = num_dist_bins
)

self.confidence_head = ConfidenceHead(
dim_single_inputs = dim_single_inputs,
atompair_dist_bins = atompair_dist_bins,
dim_single = dim_single,
dim_pairwise = dim_pairwise,
num_plddt_bins = num_plddt_bins,
num_pde_bins = num_pde_bins,
num_pae_bins = num_pae_bins,
**confidence_head_kwargs
)

# loss related

self.ignore_index = ignore_index
self.loss_distogram_weight = loss_distogram_weight
self.loss_confidence_weight = loss_confidence_weight
self.loss_diffusion_weight = loss_diffusion_weight

@typecheck
def forward(
self,
*,
include_pae_loss = False # turned on in latter part of training
):
return
atom_inputs: Float['b m dai'],
atom_mask: Bool['b m'],
atompair_feats: Float['b m m dap'],
additional_residue_feats: Float['b n rf'],
msa: Float['b s n d'],
templates: Float['b t n n dt'],
template_mask: Bool['b t'],
num_recycling_steps: int = 1,
distance_labels: Int['b n n'] | None = None,
pae_labels: Int['b n n'] | None = None,
pde_labels: Int['b n n'] | None = None,
plddt_labels: Int['b n'] | None = None,
resolved_labels: Int['b n'] | None = None,
) -> Float['b m c'] | Float['']:

# embed inputs

(
single_inputs,
single_init,
pairwise_init,
atom_feats,
atompair_feats
) = self.input_embedder(
atom_inputs = atom_inputs,
atom_mask = atom_mask,
atompair_feats = atompair_feats,
additional_residue_feats = additional_residue_feats
)

w = self.atoms_per_window

mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')

# init recycled single and pairwise

recycled_pairwise = recycled_single = None
single = pairwise = None

# for each recycling step

for _ in range(num_recycling_steps):

# handle recycled single and pairwise if not first step

recycled_single = recycled_pairwise = 0.

if exists(single):
recycled_single = self.recycle_single(single)

if exists(pairwise):
recycled_pairwise = self.recycle_pairwise(pairwise)

single = single_init + recycled_single
pairwise = pairwise_init + recycled_pairwise

# else go through main transformer trunk from alphafold2

# templates

embedded_template = self.template_embedder(
templates = templates,
template_mask = template_mask,
pairwise_repr = pairwise,
mask = mask
)

pairwise = embedded_template + pairwise

# msa

embedded_msa = self.msa_module(
msa = msa,
single_repr = single,
pairwise_repr = pairwise,
mask = mask
)

pairwise = embedded_msa + pairwise

# main attention trunk (pairformer)

single, pairwise = self.pairformer(
single_repr = single,
pairwise_repr = pairwise,
mask = mask
)

# determine whether to return loss if any labels were to be passed in
# otherwise will sample the atomic coordinates

labels = (distance_labels, pae_labels, pde_labels, plddt_labels, resolved_labels)
return_loss = any([*filter(exists, labels)])

return torch.tensor(0.)
Loading

0 comments on commit 1936f1e

Please sign in to comment.