Skip to content

Commit

Permalink
fix test and export all modules crafted by @engelberger
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 21, 2024
1 parent 65a9e84 commit df9c7ad
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 33 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Getting a fair number of emails. You can chat with me about this work <a href="h

- <a href="https://github.com/joseph-c-kim">Joseph</a> for contributing the Relative Positional Encoding and the Smooth LDDT Loss!

- <a href="https://github.com/engelberger">Felipe</a> for contributing Weighted Rigid Align, Express Coordinates In Frame, Compute Alignment Error, and Centre Random Augmentation modules!

## Install

```bash
Expand Down
10 changes: 10 additions & 0 deletions alphafold3_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

from alphafold3_pytorch.alphafold3 import (
RelativePositionEncoding,
SmoothLDDTLoss,
WeightedRigidAlign,
ExpressCoordinatesInFrame,
ComputeAlignmentError,
CentreRandomAugmentation,
TemplateEmbedder,
PreLayerNorm,
AdaptiveLayerNorm,
Expand All @@ -30,6 +35,11 @@
Attention,
Attend,
RelativePositionEncoding,
SmoothLDDTLoss,
WeightedRigidAlign,
ExpressCoordinatesInFrame,
ComputeAlignmentError,
CentreRandomAugmentation,
TemplateEmbedder,
PreLayerNorm,
AdaptiveLayerNorm,
Expand Down
55 changes: 27 additions & 28 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,8 +1753,9 @@ def forward(

# modules todo

class SmoothLDDTLoss(torch.nn.Module):
"""Alg 27"""
class SmoothLDDTLoss(Module):
""" Algorithm 27 """

@typecheck
def __init__(self, nucleic_acid_cutoff: float = 30.0, other_cutoff: float = 15.0):
super().__init__()
Expand Down Expand Up @@ -1807,8 +1808,8 @@ def forward(

return 1 - lddt.mean()

class WeightedRigidAlign(torch.nn.Module):
"""Alg 28"""
class WeightedRigidAlign(Module):
""" Algorithm 28 """
def __init__(self):
super().__init__()

Expand Down Expand Up @@ -1853,45 +1854,45 @@ def forward(

return aligned_coords.detach()

class ExpressCoordinatesInFrame(torch.nn.Module):
"""Alg 29"""
def __init__(self):
class ExpressCoordinatesInFrame(Module):
""" Algorithm 29 """

def __init__(self, eps = 1e-8):
super().__init__()
self.eps = eps

@typecheck
def forward(
self,
coords: Float['b 3'],
frame: Float['b 3 3']
) -> Float['b 3']:
coords: Float['b m 3'],
frame: Float['b m 3 3']
) -> Float['b m 3']:
"""
coords: coordinates to be expressed in the given frame (b, 3)
frame: frame defined by three points (b, 3, 3)
"""

# Extract frame points
a, b, c = frame[:, 0], frame[:, 1], frame[:, 2]
a, b, c = frame.unbind(dim = -1)

# Compute unit vectors of the frame
e1 = self._normalize(a - b)
e2 = self._normalize(c - b)
e3 = torch.cross(e1, e2, dim=-1)
e1 = F.normalize(a - b, dim = -1, eps = self.eps)
e2 = F.normalize(c - b, dim = -1, eps = self.eps)
e3 = torch.cross(e1, e2, dim = -1)

# Express coordinates in the frame basis
v = coords - b

transformed_coords = torch.stack([
torch.einsum('bi,bi->b', v, e1),
torch.einsum('bi,bi->b', v, e2),
torch.einsum('bi,bi->b', v, e3)
], dim=-1)
einsum(v, e1, '... i, ... i -> ...'),
einsum(v, e2, '... i, ... i -> ...'),
einsum(v, e3, '... i, ... i -> ...')
], dim = -1)

return transformed_coords

@typecheck
def _normalize(self, v: Float['b 3'], eps: float = 1e-8) -> Float['b 3']:
return v / (v.norm(dim=-1, keepdim=True) + eps)

class ComputeAlignmentError(torch.nn.Module):
"""Alg 30"""
class ComputeAlignmentError(Module):
""" Algorithm 30 """
@typecheck
def __init__(self, eps: float = 1e-8):
super().__init__()
Expand Down Expand Up @@ -1925,8 +1926,8 @@ def forward(

return alignment_errors

class CentreRandomAugmentation(torch.nn.Module):
"""Alg 19"""
class CentreRandomAugmentation(Module):
""" Algorithm 19 """
@typecheck
def __init__(self, trans_scale: float = 1.0):
super().__init__()
Expand Down Expand Up @@ -1980,8 +1981,6 @@ def _random_translation_vector(self, device: torch.device) -> Float['3']:
translation_vector = torch.randn(3, device=device) * self.trans_scale
return translation_vector



# input embedder

EmbeddedInputs = namedtuple('EmbeddedInputs', [
Expand Down
10 changes: 5 additions & 5 deletions tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import pytest

from alphafold3_pytorch import (
SmoothLDDTLoss,
WeightedRigidAlign,
ExpressCoordinatesInFrame,
ComputeAlignmentError,
CentreRandomAugmentation,
PairformerStack,
MSAModule,
DiffusionTransformer,
Expand All @@ -17,11 +22,6 @@
ConfidenceHead,
DistogramHead,
Alphafold3,
SmoothLDDTLoss,
WeightedRigidAlign,
ExpressCoordinatesInFrame,
ComputeAlignmentError,
CentreRandomAugmentation
)

from alphafold3_pytorch.alphafold3 import (
Expand Down

0 comments on commit df9c7ad

Please sign in to comment.