diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 6424a67e..3456788f 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -131,12 +131,13 @@ def mean_pool_with_lens( @typecheck def repeat_consecutive_with_lens( - feats: Float['b n d'], + feats: Float['b n ...'] | Bool['b n'], lens: Int['b n'], max_length: int | None = None, return_mask = False -) -> Float['b m d'] | Tuple[Float['b m d'], Bool['b m']]: +) -> Float['b m d'] | Bool['b m'] | Tuple[Float['b m d'] | Bool['b m'], Bool['b m']]: + is_bool = feats.dtype == torch.bool device = feats.device # derive arange from the max length @@ -165,8 +166,11 @@ def repeat_consecutive_with_lens( # now broadcast and sum for consecutive features - feats = einx.multiply('b n d, b n m -> b n m d', feats, consecutive_mask.float()) - feats = reduce(feats, 'b n m d -> b m d', 'sum') + feats = einx.multiply('b n ..., b n m -> b n m ...', feats, consecutive_mask.float()) + feats = reduce(feats, 'b n m ... -> b m ...', 'sum') + + if is_bool: + feats = feats.bool() if not return_mask: return feats @@ -1373,12 +1377,24 @@ def forward( self, *, atom_feats: Float['b m da'], - atom_mask: Bool['b m'] + atom_mask: Bool['b m'], + residue_atom_lens: Int['b n'] | None = None ) -> Float['b n ds']: + w = self.atoms_per_window + is_unpacked_repr = exists(w) + + assert is_unpacked_repr ^ exists(residue_atom_lens), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)' atom_feats = self.proj(atom_feats) + # packed atom representation + + if exists(residue_atom_lens): + tokens = mean_pool_with_lens(atom_feats, residue_atom_lens) + return tokens + + # unpacked atom representation # masked mean pool the atom feats for each residue, for the token transformer # this is basically a simple 2-level hierarchical transformer @@ -1541,13 +1557,18 @@ def forward( single_inputs_repr: Float['b n dsi'], pairwise_trunk: Float['b n n dpt'], pairwise_rel_pos_feats: Float['b n n dpr'], + residue_atom_lens: Int['b n'] | None = None ): w = self.atoms_per_window + is_unpacked_repr = exists(w) + + assert is_unpacked_repr ^ exists(residue_atom_lens) # in the paper, it seems they pack the atom feats # but in this impl, will just use windows for simplicity when communicating between atom and residue resolutions. bit less efficient - assert divisible_by(noised_atom_pos.shape[-2], w) + if is_unpacked_repr: + assert divisible_by(noised_atom_pos.shape[-2], w) conditioned_single_repr = self.single_conditioner( times = times, @@ -1572,14 +1593,20 @@ def forward( single_repr_cond = self.single_repr_to_atom_feat_cond(conditioned_single_repr) - single_repr_cond = repeat(single_repr_cond, 'b n ds -> b (n w) ds', w = w) + if is_unpacked_repr: + single_repr_cond = repeat(single_repr_cond, 'b n ds -> b (n w) ds', w = w) + else: + single_repr_cond = repeat_consecutive_with_lens(single_repr_cond, residue_atom_lens) + atom_feats_cond = single_repr_cond + atom_feats_cond # condition atompair feats with pairwise repr pairwise_repr_cond = self.pairwise_repr_to_atompair_feat_cond(conditioned_pairwise_repr) - pairwise_repr_cond = repeat(pairwise_repr_cond, 'b i j dp -> b (i w1) (j w2) dp', w1 = w, w2 = w) - atompair_feats = pairwise_repr_cond + atompair_feats + + if is_unpacked_repr: + pairwise_repr_cond = repeat(pairwise_repr_cond, 'b i j dp -> b (i w1) (j w2) dp', w1 = w, w2 = w) + atompair_feats = pairwise_repr_cond + atompair_feats # condition atompair feats further with single atom repr @@ -1603,8 +1630,10 @@ def forward( tokens = self.atom_feats_to_pooled_token( atom_feats = atom_feats, - atom_mask = atom_mask + atom_mask = atom_mask, + residue_atom_lens = residue_atom_lens ) + # token transformer tokens = self.cond_tokens_with_cond_single(conditioned_single_repr) + tokens @@ -1621,7 +1650,11 @@ def forward( # atom decoder atom_decoder_input = self.tokens_to_atom_decoder_input_cond(tokens) - atom_decoder_input = repeat(atom_decoder_input, 'b n da -> b (n w) da', w = w) + + if is_unpacked_repr: + atom_decoder_input = repeat(atom_decoder_input, 'b n da -> b (n w) da', w = w) + else: + atom_decoder_input = repeat_consecutive_with_lens(atom_decoder_input, residue_atom_lens) atom_decoder_input = atom_decoder_input + atom_feats_skip @@ -1771,7 +1804,7 @@ def sample_schedule(self, num_sample_steps = None): @torch.no_grad() def sample( self, - atom_mask: Bool['b m'], + atom_mask: Bool['b m'] | None = None, num_sample_steps = None, clamp = True, **network_condition_kwargs @@ -1847,6 +1880,7 @@ def forward( pairwise_trunk: Float['b n n dpt'], pairwise_rel_pos_feats: Float['b n n dpr'], return_denoised_pos = False, + residue_atom_lens: Int['b n'] | None = None, additional_residue_feats: Float['b n 10'] | None = None, add_smooth_lddt_loss = False, add_bond_loss = False, @@ -1878,6 +1912,7 @@ def forward( single_inputs_repr = single_inputs_repr, pairwise_trunk = pairwise_trunk, pairwise_rel_pos_feats = pairwise_rel_pos_feats, + residue_atom_lens = residue_atom_lens ) ) @@ -1890,9 +1925,14 @@ def forward( if exists(additional_residue_feats): w = self.net.atoms_per_window + is_unpacked_repr = exists(w) - is_nucleotide_or_ligand_fields = additional_residue_feats[..., 7:] != 0. - atom_is_dna, atom_is_rna, atom_is_ligand = tuple(repeat(t != 0., 'b n -> b (n w)', w = w) for t in is_nucleotide_or_ligand_fields.unbind(dim = -1)) + is_nucleotide_or_ligand_fields = (additional_residue_feats[..., 7:] != 0.).unbind(dim = -1) + + if is_unpacked_repr: + atom_is_dna, atom_is_rna, atom_is_ligand = tuple(repeat(t != 0., 'b n -> b (n w)', w = w) for t in is_nucleotide_or_ligand_fields) + else: + atom_is_dna, atom_is_rna, atom_is_ligand = tuple(repeat_consecutive_with_lens(t, residue_atom_lens) for t in is_nucleotide_or_ligand_fields) # section 3.7.1 equation 4 @@ -2315,6 +2355,8 @@ def forward( atom_mask: Bool['b m'], atompair_feats: Float['b m m dap'], additional_residue_feats: Float['b n rf'], + residue_atom_lens: Int['b n'] | None = None, + ) -> EmbeddedInputs: assert additional_residue_feats.shape[-1] == self.dim_additional_residue_feats @@ -2336,7 +2378,8 @@ def forward( single_inputs = self.atom_feats_to_pooled_token( atom_feats = atom_feats, - atom_mask = atom_mask + atom_mask = atom_mask, + residue_atom_lens = residue_atom_lens ) single_inputs = torch.cat((single_inputs, additional_residue_feats), dim = -1) @@ -2527,6 +2570,7 @@ def __init__( num_pae_bins = 64, sigma_data = 16, diffusion_num_augmentations = 4, + packed_atom_repr = False, loss_confidence_weight = 1e-4, loss_distogram_weight = 1e-2, loss_diffusion_weight = 4., @@ -2593,6 +2637,15 @@ def __init__( ): super().__init__() + # whether a packed atom representation is being used + + self.packed_atom_repr = packed_atom_repr + + # atoms per window if using unpacked representation + + if packed_atom_repr: + atoms_per_window = None + self.atoms_per_window = atoms_per_window # augmentation @@ -2732,9 +2785,10 @@ def forward( self, *, atom_inputs: Float['b m dai'], - atom_mask: Bool['b m'], atompair_feats: Float['b m m dap'], additional_residue_feats: Float['b n 10'], + residue_atom_lens: Int['b n'] | None = None, + atom_mask: Bool['b m'] | None = None, token_bond: Bool['b n n'] | None = None, msa: Float['b s n d'] | None = None, msa_mask: Bool['b s'] | None = None, @@ -2754,13 +2808,33 @@ def forward( return_loss_breakdown = False ) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]: - # get atom sequence length and residue sequence length - - w = self.atoms_per_window atom_seq_len = atom_inputs.shape[-2] - assert divisible_by(atom_seq_len, w) - seq_len = atom_inputs.shape[-2] // w + # determine whether using packed or unpacked atom rep + + assert exists(residue_atom_lens) ^ exists(atom_mask), 'either atom_lens or atom_mask must be given depending on whether packed_atom_repr kwarg is True or False' + + if exists(residue_atom_lens): + assert self.packed_atom_repr, '`packed_atom_repr` kwarg on Alphafold3 must be True when passing in `atom_lens`' + + # handle atom mask + + atom_mask = lens_to_mask(residue_atom_lens) + atom_mask = atom_mask[:, :atom_seq_len] + + # handle offsets for residue atom indices + + if exists(residue_atom_indices): + residue_atom_indices += F.pad(residue_atom_lens, (-1, 1), value = 0) + + # get atom sequence length and residue sequence length depending on whether using packed atomic seq + + if self.packed_atom_repr: + seq_len = residue_atom_lens.shape[-1] + else: + w = self.atoms_per_window + assert divisible_by(atom_seq_len, w) + seq_len = atom_inputs.shape[-2] // w # embed inputs @@ -2774,7 +2848,8 @@ def forward( atom_inputs = atom_inputs, atom_mask = atom_mask, atompair_feats = atompair_feats, - additional_residue_feats = additional_residue_feats + additional_residue_feats = additional_residue_feats, + residue_atom_lens = residue_atom_lens ) # relative positional encoding @@ -2805,7 +2880,12 @@ def forward( # pairwise mask - mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any') + if self.packed_atom_repr: + mask = lens_to_mask(residue_atom_lens) + mask = mask[:, :seq_len] + else: + mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any') + pairwise_mask = einx.logical_and('b i, b j -> b i j', mask, mask) # init recycled single and pairwise @@ -2889,7 +2969,8 @@ def forward( single_trunk_repr = single, single_inputs_repr = single_inputs, pairwise_trunk = pairwise, - pairwise_rel_pos_feats = relative_position_encoding + pairwise_rel_pos_feats = relative_position_encoding, + residue_atom_lens = residue_atom_lens ) # losses default to 0 @@ -2903,7 +2984,12 @@ def forward( # distogram head if not exists(distance_labels) and atom_pos_given and exists(residue_atom_indices): - residue_pos = einx.get_at('b (n [w]) c, b n -> b n c', atom_pos, residue_atom_indices) + + if self.packed_atom_repr: + residue_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, residue_atom_indices) + else: + residue_pos = einx.get_at('b (n [w]) c, b n -> b n c', atom_pos, residue_atom_indices) + residue_dist = torch.cdist(residue_pos, residue_pos, p = 2) dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', residue_dist, self.distance_bins).abs() distance_labels = dist_from_dist_bins.argmin(dim = -1) @@ -2938,6 +3024,7 @@ def forward( relative_position_encoding, additional_residue_feats, residue_atom_indices, + residue_atom_lens, pae_labels, pde_labels, plddt_labels, @@ -2958,6 +3045,7 @@ def forward( relative_position_encoding, additional_residue_feats, residue_atom_indices, + residue_atom_lens, pae_labels, pde_labels, plddt_labels, @@ -2980,6 +3068,7 @@ def forward( single_inputs_repr = single_inputs, pairwise_trunk = pairwise, pairwise_rel_pos_feats = relative_position_encoding, + residue_atom_lens = residue_atom_lens, return_denoised_pos = True, ) @@ -2990,7 +3079,10 @@ def forward( if calc_diffusion_loss and should_call_confidence_head: - pred_atom_pos = einx.get_at('b (n [w]) c, b n -> b n c', denoised_atom_pos, residue_atom_indices) + if self.packed_atom_repr: + pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', denoised_atom_pos, residue_atom_indices) + else: + pred_atom_pos = einx.get_at('b (n [w]) c, b n -> b n c', denoised_atom_pos, residue_atom_indices) logits = self.confidence_head( single_repr = single, diff --git a/pyproject.toml b/pyproject.toml index cba87336..092d9c2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.0.25" +version = "0.0.26" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_af3.py b/tests/test_af3.py index f4ff0e0e..de033e7f 100644 --- a/tests/test_af3.py +++ b/tests/test_af3.py @@ -505,3 +505,91 @@ def test_alphafold3_without_msa_and_templates(): ) loss.backward() + +def test_alphafold3_with_packed_atom_repr(): + seq_len = 16 + residue_atom_lens = torch.randint(1, 3, (2, seq_len)) + + atom_seq_len = residue_atom_lens.sum(dim = -1).amax() + + token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool() + + atom_inputs = torch.randn(2, atom_seq_len, 77) + + atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16) + additional_residue_feats = torch.randn(2, seq_len, 10) + + template_feats = torch.randn(2, 2, seq_len, seq_len, 44) + template_mask = torch.ones((2, 2)).bool() + + msa = torch.randn(2, 7, seq_len, 64) + msa_mask = torch.ones((2, 7)).bool() + + atom_pos = torch.randn(2, atom_seq_len, 3) + residue_atom_indices = torch.randint(0, 2, (2, seq_len)) + + pae_labels = torch.randint(0, 64, (2, seq_len, seq_len)) + pde_labels = torch.randint(0, 64, (2, seq_len, seq_len)) + plddt_labels = torch.randint(0, 50, (2, seq_len)) + resolved_labels = torch.randint(0, 2, (2, seq_len)) + + alphafold3 = Alphafold3( + dim_atom_inputs = 77, + dim_additional_residue_feats = 10, + dim_template_feats = 44, + num_dist_bins = 38, + packed_atom_repr = True, + confidence_head_kwargs = dict( + pairformer_depth = 1 + ), + template_embedder_kwargs = dict( + pairformer_stack_depth = 1 + ), + msa_module_kwargs = dict( + depth = 1 + ), + pairformer_stack = dict( + depth = 2 + ), + diffusion_module_kwargs = dict( + atom_encoder_depth = 1, + token_transformer_depth = 1, + atom_decoder_depth = 1, + ), + ) + + loss, breakdown = alphafold3( + num_recycling_steps = 2, + atom_inputs = atom_inputs, + residue_atom_lens = residue_atom_lens, + atompair_feats = atompair_feats, + additional_residue_feats = additional_residue_feats, + token_bond = token_bond, + msa = msa, + msa_mask = msa_mask, + templates = template_feats, + template_mask = template_mask, + atom_pos = atom_pos, + residue_atom_indices = residue_atom_indices, + pae_labels = pae_labels, + pde_labels = pde_labels, + plddt_labels = plddt_labels, + resolved_labels = resolved_labels, + return_loss_breakdown = True + ) + + loss.backward() + + print(residue_atom_lens) + sampled_atom_pos = alphafold3( + num_sample_steps = 16, + atom_inputs = atom_inputs, + residue_atom_lens = residue_atom_lens, + atompair_feats = atompair_feats, + additional_residue_feats = additional_residue_feats, + msa = msa, + templates = template_feats, + template_mask = template_mask, + ) + + assert sampled_atom_pos.ndim == 3