Skip to content

Commit

Permalink
complete the average pooling by atom lengths function needed for pack…
Browse files Browse the repository at this point in the history
…ed repr of atoms in diffusion module when going from atoms -> tokens
  • Loading branch information
lucidrains committed May 23, 2024
1 parent 54c4cf7 commit b93e278
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 40 deletions.
50 changes: 24 additions & 26 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,35 +87,33 @@ def inner(t, *args, **kwargs):
return fn(t, *args, **kwargs)
return inner

# Loss functions
# packed atom representation functions

@typecheck
def calc_smooth_lddt_loss(
denoised: Float['b m 3'],
ground_truth: Float['b m 3'],
is_rna_per_atom: Float['b m'],
is_dna_per_atom: Float['b m']
) -> Float[' b']:

m, device = is_rna_per_atom.shape[-1], denoised.device

dx_denoised = torch.cdist(denoised, denoised)
dx_gt = torch.cdist(ground_truth, ground_truth)

ddx = torch.abs(dx_gt - dx_denoised)
eps = 0.25 * (
sigmoid(0.5 - ddx) + sigmoid(1 - ddx) + sigmoid(2 - ddx) + sigmoid(4 - ddx)
)

is_nuc = is_rna_per_atom + is_dna_per_atom
mask = einx.multiply('b i, b j -> b i j', is_nuc, is_nuc)
c = (dx_gt < 30) * mask + (dx_gt < 15) * (1 - mask)

eye = torch.eye(m, device = device)
num = einx.sum('b [...]', c * eps * (1 - eye)) / (m**2 - m)
den = einx.sum('b [...]', c * (1 - eye)) / (m**2 - m)
def mean_pool_with_lens(
feats: Float['b m d'],
lens: Int['b n']
) -> Float['b n d']:

seq_len = feats.shape[1]

mask = lens > 0
assert (lens.sum(dim = -1) <= seq_len).all(), 'one of the lengths given exceeds the total sequence length of the features passed in'

cumsum_feats = feats.cumsum(dim = 1)
cumsum_feats = F.pad(cumsum_feats, (0, 0, 1, 0), value = 0.)

cumsum_indices = lens.cumsum(dim = 1)
cumsum_indices = F.pad(cumsum_indices, (1, 0), value = 0)

sel_cumsum = einx.get_at('b [m] d, b n -> b n d', cumsum_feats, cumsum_indices)

# subtract cumsum at one index from the previous one
summed = sel_cumsum[:, 1:] - sel_cumsum[:, :-1]

return 1. - num/den
avg = einx.divide('b n d, b n', summed, lens.clamp(min = 1))
avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.)
return avg

# linear and outer sum
# for single repr -> pairwise pattern throughout this architecture
Expand Down
20 changes: 6 additions & 14 deletions tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,15 @@
)

from alphafold3_pytorch.alphafold3 import (
calc_smooth_lddt_loss
mean_pool_with_lens
)

def test_calc_smooth_lddt_loss():
denoised = torch.randn(8, 100, 3)
ground_truth = torch.randn(8, 100, 3)
is_rna_per_atom = torch.randint(0, 2, (8, 100)).float()
is_dna_per_atom = torch.randint(0, 2, (8, 100)).float()

loss = calc_smooth_lddt_loss(
denoised,
ground_truth,
is_rna_per_atom,
is_dna_per_atom
)
def test_mean_pool_with_lens():
seq = torch.tensor([[[1.], [1.], [1.], [2.], [2.], [2.], [2.], [1.], [1.]]])
lens = torch.tensor([[3, 4, 2]]).long()
pooled = mean_pool_with_lens(seq, lens)

assert torch.all(loss <= 1) and torch.all(loss >= 0)
assert torch.allclose(pooled, torch.tensor([[[1.], [2.], [1.]]]))

def test_smooth_lddt_loss():
pred_coords = torch.randn(2, 100, 3)
Expand Down

0 comments on commit b93e278

Please sign in to comment.