Skip to content

Commit

Permalink
complete token bonds to spec
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 23, 2024
1 parent 20962b0 commit b63a0d9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
37 changes: 35 additions & 2 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,7 +2187,7 @@ def __init__(
self,
*,
dim_atom_inputs,
dim_additional_residue_feats,
dim_additional_residue_feats = 10,
atoms_per_window = 27,
dim_atom = 128,
dim_atompair = 16,
Expand Down Expand Up @@ -2558,6 +2558,14 @@ def __init__(
**relative_position_encoding_kwargs
)

# token bonds
# Algorithm 1 - line 5

self.token_bond_to_pairwise_feat = nn.Sequential(
Rearrange('... -> ... 1'),
LinearNoBias(1, dim_pairwise)
)

# templates

self.template_embedder = TemplateEmbedder(
Expand Down Expand Up @@ -2654,7 +2662,8 @@ def forward(
atom_inputs: Float['b m dai'],
atom_mask: Bool['b m'],
atompair_feats: Float['b m m dap'],
additional_residue_feats: Float['b n rf'],
additional_residue_feats: Float['b n 10'],
token_bond: Bool['b n n'] | None = None,
msa: Float['b s n d'] | None = None,
msa_mask: Bool['b s'] | None = None,
templates: Float['b t n n dt'] | None = None,
Expand All @@ -2673,7 +2682,13 @@ def forward(
return_loss_breakdown = False
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:

# get atom sequence length and residue sequence length

w = self.atoms_per_window
atom_seq_len = atom_inputs.shape[-2]

assert divisible_by(atom_seq_len, w)
seq_len = atom_inputs.shape[-2] // w

# embed inputs

Expand All @@ -2698,6 +2713,24 @@ def forward(

pairwise_init = pairwise_init + relative_position_encoding

# token bond features

if exists(token_bond):
# well do some precautionary standardization
# (1) mask out diagonal - token to itself does not count as a bond
# (2) symmetrize, in case it is not already symmetrical (could also throw an error)

token_bond = token_bond | rearrange(token_bond, 'b i j -> b j i')
diagonal = torch.eye(seq_len, device = self.device, dtype = torch.bool)
token_bond.masked_fill_(diagonal, False)
else:
seq_arange = torch.arange(seq_len, device = self.device)
token_bond = einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1

token_bond_feats = self.token_bond_to_pairwise_feat(token_bond.float())

pairwise_init = pairwise_init + token_bond_feats

# pairwise mask

mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.22"
version = "0.0.23"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
3 changes: 3 additions & 0 deletions tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ def test_alphafold3():
seq_len = 16
atom_seq_len = seq_len * 27

token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()

atom_inputs = torch.randn(2, atom_seq_len, 77)
atom_mask = torch.ones((2, atom_seq_len)).bool()
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
Expand Down Expand Up @@ -418,6 +420,7 @@ def test_alphafold3():
atom_mask = atom_mask,
atompair_feats = atompair_feats,
additional_residue_feats = additional_residue_feats,
token_bond = token_bond,
msa = msa,
msa_mask = msa_mask,
templates = template_feats,
Expand Down

0 comments on commit b63a0d9

Please sign in to comment.