Skip to content

Commit

Permalink
hook up the smooth lddt loss end2end, thanks to @joseph-c-kim
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 20, 2024
1 parent 4233f89 commit ad23ecc
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Getting a fair number of emails. You can chat with me about this work <a href="h

## Appreciation

- <a href="https://github.com/joseph-c-kim">Joseph</a> for contributing the relative positional encoding module!
- <a href="https://github.com/joseph-c-kim">Joseph</a> for contributing the Relative Positional Encoding and the Smooth LDDT Loss!

## Install

Expand Down
45 changes: 35 additions & 10 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from collections import namedtuple

import torch
from torch import nn
from torch import nn, sigmoid
from torch import Tensor
import torch.nn.functional as F

Expand Down Expand Up @@ -80,15 +80,15 @@ def unpack_one(t, ps, pattern):

# Loss functions

def smoothlddtloss(
@typecheck
def calc_smooth_lddt_loss(
denoised: Float['b m 3'],
ground_truth: Float['b m 3'],
is_rna_per_atom: Float['b m'],
is_dna_per_atom: Float['b m']
) -> Float['b']:
from torch import sigmoid
) -> Float[' b']:

m = is_rna_per_atom.shape[-1]
m, device = is_rna_per_atom.shape[-1], denoised.device

dx_denoised = torch.cdist(denoised, denoised)
dx_gt = torch.cdist(ground_truth, ground_truth)
Expand All @@ -102,10 +102,11 @@ def smoothlddtloss(
mask = einx.multiply('b i, b j -> b i j', is_nuc, is_nuc)
c = (dx_gt < 30) * mask + (dx_gt < 15) * (1 - mask)

num = einx.sum('b [...]', c * eps * (1 - torch.eye(m))) / (m**2 - m)
den = einx.sum('b [...]', c * (1 - torch.eye(m))) / (m**2 - m)

return 1 - num/den
eye = torch.eye(m, device = device)
num = einx.sum('b [...]', c * eps * (1 - eye)) / (m**2 - m)
den = einx.sum('b [...]', c * (1 - eye)) / (m**2 - m)

return 1. - num/den

# linear and outer sum
# for single repr -> pairwise pattern throughout this architecture
Expand Down Expand Up @@ -1699,6 +1700,8 @@ def forward(
normalized_atom_pos: Float['b m 3'],
atom_mask: Bool['b m'],
return_denoised_pos = False,
additional_residue_feats: Float['b n rf'] | None = None,
add_smooth_lddt_loss = False,
**network_condition_kwargs
) -> Float[''] | Tuple[Float[''], Float['b m 3']]:

Expand Down Expand Up @@ -1726,6 +1729,22 @@ def forward(

loss = losses.mean()

if add_smooth_lddt_loss:
assert exists(additional_residue_feats)
w = self.net.atoms_per_window

is_dna, is_rna = additional_residue_feats[..., 7], additional_residue_feats[..., 8]
atom_is_dna, atom_is_rna = tuple(repeat(t, 'b n -> b (n w)', w = w) for t in (is_dna, is_rna))

smooth_lddt_loss = calc_smooth_lddt_loss(
denoised,
normalized_atom_pos,
atom_is_dna,
atom_is_rna
).mean()

loss = loss + smooth_lddt_loss

if not return_denoised_pos:
return loss

Expand Down Expand Up @@ -2354,7 +2373,13 @@ def forward(
# otherwise, noise and make it learn to denoise

if exists(atom_pos):
diffusion_loss, denoised_atom_pos = self.edm(atom_pos, return_denoised_pos = True, **diffusion_cond)
diffusion_loss, denoised_atom_pos = self.edm(
atom_pos,
additional_residue_feats = additional_residue_feats,
add_smooth_lddt_loss = True,
return_denoised_pos = True,
**diffusion_cond
)

# calculate all logits and losses

Expand Down
14 changes: 8 additions & 6 deletions tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,23 @@
Alphafold3,
)

from alphafold3_pytorch.alphafold3 import smoothlddtloss
from alphafold3_pytorch.alphafold3 import (
calc_smooth_lddt_loss
)

def test_smoothlddtloss():
def test_calc_smooth_lddt_loss():
denoised = torch.randn(8, 100, 3)
ground_truth = torch.randn(8, 100, 3)
is_rna_per_atom = torch.randint(0, 2, (8, 100))
is_dna_per_atom = torch.randint(0, 2, (8, 100))
is_rna_per_atom = torch.randint(0, 2, (8, 100)).float()
is_dna_per_atom = torch.randint(0, 2, (8, 100)).float()

loss = smoothlddtloss(
loss = calc_smooth_lddt_loss(
denoised,
ground_truth,
is_rna_per_atom,
is_dna_per_atom
)

assert torch.all(loss <= 1) and torch.all(loss >= 0)

def test_pairformer():
Expand Down

0 comments on commit ad23ecc

Please sign in to comment.