Skip to content

Commit

Permalink
complete the packed atom representation and add a end2end test
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 24, 2024
1 parent 291f715 commit 8399c29
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 27 deletions.
144 changes: 118 additions & 26 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,13 @@ def mean_pool_with_lens(

@typecheck
def repeat_consecutive_with_lens(
feats: Float['b n d'],
feats: Float['b n ...'] | Bool['b n'],
lens: Int['b n'],
max_length: int | None = None,
return_mask = False
) -> Float['b m d'] | Tuple[Float['b m d'], Bool['b m']]:
) -> Float['b m d'] | Bool['b m'] | Tuple[Float['b m d'] | Bool['b m'], Bool['b m']]:

is_bool = feats.dtype == torch.bool
device = feats.device

# derive arange from the max length
Expand Down Expand Up @@ -165,8 +166,11 @@ def repeat_consecutive_with_lens(

# now broadcast and sum for consecutive features

feats = einx.multiply('b n d, b n m -> b n m d', feats, consecutive_mask.float())
feats = reduce(feats, 'b n m d -> b m d', 'sum')
feats = einx.multiply('b n ..., b n m -> b n m ...', feats, consecutive_mask.float())
feats = reduce(feats, 'b n m ... -> b m ...', 'sum')

if is_bool:
feats = feats.bool()

if not return_mask:
return feats
Expand Down Expand Up @@ -1373,12 +1377,24 @@ def forward(
self,
*,
atom_feats: Float['b m da'],
atom_mask: Bool['b m']
atom_mask: Bool['b m'],
residue_atom_lens: Int['b n'] | None = None
) -> Float['b n ds']:

w = self.atoms_per_window
is_unpacked_repr = exists(w)

assert is_unpacked_repr ^ exists(residue_atom_lens), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)'

atom_feats = self.proj(atom_feats)

# packed atom representation

if exists(residue_atom_lens):
tokens = mean_pool_with_lens(atom_feats, residue_atom_lens)
return tokens

# unpacked atom representation
# masked mean pool the atom feats for each residue, for the token transformer
# this is basically a simple 2-level hierarchical transformer

Expand Down Expand Up @@ -1541,13 +1557,18 @@ def forward(
single_inputs_repr: Float['b n dsi'],
pairwise_trunk: Float['b n n dpt'],
pairwise_rel_pos_feats: Float['b n n dpr'],
residue_atom_lens: Int['b n'] | None = None
):
w = self.atoms_per_window
is_unpacked_repr = exists(w)

assert is_unpacked_repr ^ exists(residue_atom_lens)

# in the paper, it seems they pack the atom feats
# but in this impl, will just use windows for simplicity when communicating between atom and residue resolutions. bit less efficient

assert divisible_by(noised_atom_pos.shape[-2], w)
if is_unpacked_repr:
assert divisible_by(noised_atom_pos.shape[-2], w)

conditioned_single_repr = self.single_conditioner(
times = times,
Expand All @@ -1572,14 +1593,20 @@ def forward(

single_repr_cond = self.single_repr_to_atom_feat_cond(conditioned_single_repr)

single_repr_cond = repeat(single_repr_cond, 'b n ds -> b (n w) ds', w = w)
if is_unpacked_repr:
single_repr_cond = repeat(single_repr_cond, 'b n ds -> b (n w) ds', w = w)
else:
single_repr_cond = repeat_consecutive_with_lens(single_repr_cond, residue_atom_lens)

atom_feats_cond = single_repr_cond + atom_feats_cond

# condition atompair feats with pairwise repr

pairwise_repr_cond = self.pairwise_repr_to_atompair_feat_cond(conditioned_pairwise_repr)
pairwise_repr_cond = repeat(pairwise_repr_cond, 'b i j dp -> b (i w1) (j w2) dp', w1 = w, w2 = w)
atompair_feats = pairwise_repr_cond + atompair_feats

if is_unpacked_repr:
pairwise_repr_cond = repeat(pairwise_repr_cond, 'b i j dp -> b (i w1) (j w2) dp', w1 = w, w2 = w)
atompair_feats = pairwise_repr_cond + atompair_feats

# condition atompair feats further with single atom repr

Expand All @@ -1603,8 +1630,10 @@ def forward(

tokens = self.atom_feats_to_pooled_token(
atom_feats = atom_feats,
atom_mask = atom_mask
atom_mask = atom_mask,
residue_atom_lens = residue_atom_lens
)

# token transformer

tokens = self.cond_tokens_with_cond_single(conditioned_single_repr) + tokens
Expand All @@ -1621,7 +1650,11 @@ def forward(
# atom decoder

atom_decoder_input = self.tokens_to_atom_decoder_input_cond(tokens)
atom_decoder_input = repeat(atom_decoder_input, 'b n da -> b (n w) da', w = w)

if is_unpacked_repr:
atom_decoder_input = repeat(atom_decoder_input, 'b n da -> b (n w) da', w = w)
else:
atom_decoder_input = repeat_consecutive_with_lens(atom_decoder_input, residue_atom_lens)

atom_decoder_input = atom_decoder_input + atom_feats_skip

Expand Down Expand Up @@ -1771,7 +1804,7 @@ def sample_schedule(self, num_sample_steps = None):
@torch.no_grad()
def sample(
self,
atom_mask: Bool['b m'],
atom_mask: Bool['b m'] | None = None,
num_sample_steps = None,
clamp = True,
**network_condition_kwargs
Expand Down Expand Up @@ -1847,6 +1880,7 @@ def forward(
pairwise_trunk: Float['b n n dpt'],
pairwise_rel_pos_feats: Float['b n n dpr'],
return_denoised_pos = False,
residue_atom_lens: Int['b n'] | None = None,
additional_residue_feats: Float['b n 10'] | None = None,
add_smooth_lddt_loss = False,
add_bond_loss = False,
Expand Down Expand Up @@ -1878,6 +1912,7 @@ def forward(
single_inputs_repr = single_inputs_repr,
pairwise_trunk = pairwise_trunk,
pairwise_rel_pos_feats = pairwise_rel_pos_feats,
residue_atom_lens = residue_atom_lens
)
)

Expand All @@ -1890,9 +1925,14 @@ def forward(

if exists(additional_residue_feats):
w = self.net.atoms_per_window
is_unpacked_repr = exists(w)

is_nucleotide_or_ligand_fields = additional_residue_feats[..., 7:] != 0.
atom_is_dna, atom_is_rna, atom_is_ligand = tuple(repeat(t != 0., 'b n -> b (n w)', w = w) for t in is_nucleotide_or_ligand_fields.unbind(dim = -1))
is_nucleotide_or_ligand_fields = (additional_residue_feats[..., 7:] != 0.).unbind(dim = -1)

if is_unpacked_repr:
atom_is_dna, atom_is_rna, atom_is_ligand = tuple(repeat(t != 0., 'b n -> b (n w)', w = w) for t in is_nucleotide_or_ligand_fields)
else:
atom_is_dna, atom_is_rna, atom_is_ligand = tuple(repeat_consecutive_with_lens(t, residue_atom_lens) for t in is_nucleotide_or_ligand_fields)

# section 3.7.1 equation 4

Expand Down Expand Up @@ -2315,6 +2355,8 @@ def forward(
atom_mask: Bool['b m'],
atompair_feats: Float['b m m dap'],
additional_residue_feats: Float['b n rf'],
residue_atom_lens: Int['b n'] | None = None,

) -> EmbeddedInputs:

assert additional_residue_feats.shape[-1] == self.dim_additional_residue_feats
Expand All @@ -2336,7 +2378,8 @@ def forward(

single_inputs = self.atom_feats_to_pooled_token(
atom_feats = atom_feats,
atom_mask = atom_mask
atom_mask = atom_mask,
residue_atom_lens = residue_atom_lens
)

single_inputs = torch.cat((single_inputs, additional_residue_feats), dim = -1)
Expand Down Expand Up @@ -2527,6 +2570,7 @@ def __init__(
num_pae_bins = 64,
sigma_data = 16,
diffusion_num_augmentations = 4,
packed_atom_repr = False,
loss_confidence_weight = 1e-4,
loss_distogram_weight = 1e-2,
loss_diffusion_weight = 4.,
Expand Down Expand Up @@ -2593,6 +2637,15 @@ def __init__(
):
super().__init__()

# whether a packed atom representation is being used

self.packed_atom_repr = packed_atom_repr

# atoms per window if using unpacked representation

if packed_atom_repr:
atoms_per_window = None

self.atoms_per_window = atoms_per_window

# augmentation
Expand Down Expand Up @@ -2732,9 +2785,10 @@ def forward(
self,
*,
atom_inputs: Float['b m dai'],
atom_mask: Bool['b m'],
atompair_feats: Float['b m m dap'],
additional_residue_feats: Float['b n 10'],
residue_atom_lens: Int['b n'] | None = None,
atom_mask: Bool['b m'] | None = None,
token_bond: Bool['b n n'] | None = None,
msa: Float['b s n d'] | None = None,
msa_mask: Bool['b s'] | None = None,
Expand All @@ -2754,13 +2808,33 @@ def forward(
return_loss_breakdown = False
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:

# get atom sequence length and residue sequence length

w = self.atoms_per_window
atom_seq_len = atom_inputs.shape[-2]

assert divisible_by(atom_seq_len, w)
seq_len = atom_inputs.shape[-2] // w
# determine whether using packed or unpacked atom rep

assert exists(residue_atom_lens) ^ exists(atom_mask), 'either atom_lens or atom_mask must be given depending on whether packed_atom_repr kwarg is True or False'

if exists(residue_atom_lens):
assert self.packed_atom_repr, '`packed_atom_repr` kwarg on Alphafold3 must be True when passing in `atom_lens`'

# handle atom mask

atom_mask = lens_to_mask(residue_atom_lens)
atom_mask = atom_mask[:, :atom_seq_len]

# handle offsets for residue atom indices

if exists(residue_atom_indices):
residue_atom_indices += F.pad(residue_atom_lens, (-1, 1), value = 0)

# get atom sequence length and residue sequence length depending on whether using packed atomic seq

if self.packed_atom_repr:
seq_len = residue_atom_lens.shape[-1]
else:
w = self.atoms_per_window
assert divisible_by(atom_seq_len, w)
seq_len = atom_inputs.shape[-2] // w

# embed inputs

Expand All @@ -2774,7 +2848,8 @@ def forward(
atom_inputs = atom_inputs,
atom_mask = atom_mask,
atompair_feats = atompair_feats,
additional_residue_feats = additional_residue_feats
additional_residue_feats = additional_residue_feats,
residue_atom_lens = residue_atom_lens
)

# relative positional encoding
Expand Down Expand Up @@ -2805,7 +2880,12 @@ def forward(

# pairwise mask

mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')
if self.packed_atom_repr:
mask = lens_to_mask(residue_atom_lens)
mask = mask[:, :seq_len]
else:
mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')

pairwise_mask = einx.logical_and('b i, b j -> b i j', mask, mask)

# init recycled single and pairwise
Expand Down Expand Up @@ -2889,7 +2969,8 @@ def forward(
single_trunk_repr = single,
single_inputs_repr = single_inputs,
pairwise_trunk = pairwise,
pairwise_rel_pos_feats = relative_position_encoding
pairwise_rel_pos_feats = relative_position_encoding,
residue_atom_lens = residue_atom_lens
)

# losses default to 0
Expand All @@ -2903,7 +2984,12 @@ def forward(
# distogram head

if not exists(distance_labels) and atom_pos_given and exists(residue_atom_indices):
residue_pos = einx.get_at('b (n [w]) c, b n -> b n c', atom_pos, residue_atom_indices)

if self.packed_atom_repr:
residue_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, residue_atom_indices)
else:
residue_pos = einx.get_at('b (n [w]) c, b n -> b n c', atom_pos, residue_atom_indices)

residue_dist = torch.cdist(residue_pos, residue_pos, p = 2)
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', residue_dist, self.distance_bins).abs()
distance_labels = dist_from_dist_bins.argmin(dim = -1)
Expand Down Expand Up @@ -2938,6 +3024,7 @@ def forward(
relative_position_encoding,
additional_residue_feats,
residue_atom_indices,
residue_atom_lens,
pae_labels,
pde_labels,
plddt_labels,
Expand All @@ -2958,6 +3045,7 @@ def forward(
relative_position_encoding,
additional_residue_feats,
residue_atom_indices,
residue_atom_lens,
pae_labels,
pde_labels,
plddt_labels,
Expand All @@ -2980,6 +3068,7 @@ def forward(
single_inputs_repr = single_inputs,
pairwise_trunk = pairwise,
pairwise_rel_pos_feats = relative_position_encoding,
residue_atom_lens = residue_atom_lens,
return_denoised_pos = True,
)

Expand All @@ -2990,7 +3079,10 @@ def forward(

if calc_diffusion_loss and should_call_confidence_head:

pred_atom_pos = einx.get_at('b (n [w]) c, b n -> b n c', denoised_atom_pos, residue_atom_indices)
if self.packed_atom_repr:
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', denoised_atom_pos, residue_atom_indices)
else:
pred_atom_pos = einx.get_at('b (n [w]) c, b n -> b n c', denoised_atom_pos, residue_atom_indices)

logits = self.confidence_head(
single_repr = single,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.25"
version = "0.0.26"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
Loading

0 comments on commit 8399c29

Please sign in to comment.