From b63a0d93279ef3a586f0e93ef75f240ca2cb7a60 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 23 May 2024 08:04:38 -0700 Subject: [PATCH] complete token bonds to spec --- alphafold3_pytorch/alphafold3.py | 37 ++++++++++++++++++++++++++++++-- pyproject.toml | 2 +- tests/test_af3.py | 3 +++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 497c7f7a..8f06217c 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -2187,7 +2187,7 @@ def __init__( self, *, dim_atom_inputs, - dim_additional_residue_feats, + dim_additional_residue_feats = 10, atoms_per_window = 27, dim_atom = 128, dim_atompair = 16, @@ -2558,6 +2558,14 @@ def __init__( **relative_position_encoding_kwargs ) + # token bonds + # Algorithm 1 - line 5 + + self.token_bond_to_pairwise_feat = nn.Sequential( + Rearrange('... -> ... 1'), + LinearNoBias(1, dim_pairwise) + ) + # templates self.template_embedder = TemplateEmbedder( @@ -2654,7 +2662,8 @@ def forward( atom_inputs: Float['b m dai'], atom_mask: Bool['b m'], atompair_feats: Float['b m m dap'], - additional_residue_feats: Float['b n rf'], + additional_residue_feats: Float['b n 10'], + token_bond: Bool['b n n'] | None = None, msa: Float['b s n d'] | None = None, msa_mask: Bool['b s'] | None = None, templates: Float['b t n n dt'] | None = None, @@ -2673,7 +2682,13 @@ def forward( return_loss_breakdown = False ) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]: + # get atom sequence length and residue sequence length + w = self.atoms_per_window + atom_seq_len = atom_inputs.shape[-2] + + assert divisible_by(atom_seq_len, w) + seq_len = atom_inputs.shape[-2] // w # embed inputs @@ -2698,6 +2713,24 @@ def forward( pairwise_init = pairwise_init + relative_position_encoding + # token bond features + + if exists(token_bond): + # well do some precautionary standardization + # (1) mask out diagonal - token to itself does not count as a bond + # (2) symmetrize, in case it is not already symmetrical (could also throw an error) + + token_bond = token_bond | rearrange(token_bond, 'b i j -> b j i') + diagonal = torch.eye(seq_len, device = self.device, dtype = torch.bool) + token_bond.masked_fill_(diagonal, False) + else: + seq_arange = torch.arange(seq_len, device = self.device) + token_bond = einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1 + + token_bond_feats = self.token_bond_to_pairwise_feat(token_bond.float()) + + pairwise_init = pairwise_init + token_bond_feats + # pairwise mask mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any') diff --git a/pyproject.toml b/pyproject.toml index a94292d4..de6a96fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.0.22" +version = "0.0.23" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_af3.py b/tests/test_af3.py index 489dfd35..688e53fe 100644 --- a/tests/test_af3.py +++ b/tests/test_af3.py @@ -368,6 +368,8 @@ def test_alphafold3(): seq_len = 16 atom_seq_len = seq_len * 27 + token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool() + atom_inputs = torch.randn(2, atom_seq_len, 77) atom_mask = torch.ones((2, atom_seq_len)).bool() atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16) @@ -418,6 +420,7 @@ def test_alphafold3(): atom_mask = atom_mask, atompair_feats = atompair_feats, additional_residue_feats = additional_residue_feats, + token_bond = token_bond, msa = msa, msa_mask = msa_mask, templates = template_feats,