Skip to content

Commit

Permalink
calculate fibonacci sphere points only once
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 30, 2024
1 parent bf1d580 commit 0b70ab9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
55 changes: 30 additions & 25 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5288,6 +5288,7 @@ def __init__(
contact_mask_threshold: float = 8.0,
is_fine_tuning: bool = False,
weight_dict_config: dict = None,
fibonacci_sphere_n = 200, # more points equal better approximation at cost of compute
):
super().__init__()
self.compute_confidence_score = ComputeConfidenceScore(eps=eps)
Expand All @@ -5301,6 +5302,8 @@ def __init__(
self.register_buffer("dist_breaks", dist_breaks)
self.register_buffer('lddt_thresholds', torch.tensor([0.5, 1.0, 2.0, 4.0]))

# for rsa calculation

atom_type_radii = tensor([
1.65, # 0 - nitrogen
1.87, # 1 - carbon alpha
Expand All @@ -5319,6 +5322,31 @@ def __init__(

self.register_buffer('atom_radii', atom_type_radii, persistent = False)

# constitute the fibonacci sphere

num_surface_dots = fibonacci_sphere_n * 2 + 1
golden_ratio = 1. + sqrt(5.) / 2
weight = (4. * pi) / num_surface_dots

arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1, device = self.device) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]

lat = torch.asin((2. * arange) / num_surface_dots)
lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio

# ein:
# sd - surface dots
# c - coordinate (3)
# i, j - source and target atom

unit_surface_dots: Float['sd 3'] = torch.stack((
lon.sin() * lat.cos(),
lon.cos() * lat.cos(),
lat.sin()
), dim = -1)

self.register_buffer('unit_surface_dots', unit_surface_dots)
self.surface_weight = weight

@property
def device(self):
return self.atom_radii.device
Expand Down Expand Up @@ -5651,7 +5679,6 @@ def calc_atom_access_surface_score(
atom_pos: Float['m 3'],
atom_type: Int['m'],
molecule_atom_lens: Int['n'] | None = None,
fibonacci_sphere_n = 200, # more points equal better approximation at cost of compute
atom_distance_min_thres = 1e-4
) -> Float['m'] | Float['n']:

Expand All @@ -5666,28 +5693,6 @@ def calc_atom_access_surface_score(

# write custom RSA function here

# first constitute the fibonacci sphere

num_surface_dots = fibonacci_sphere_n * 2 + 1
golden_ratio = 1. + sqrt(5.) / 2
weight = (4. * pi) / num_surface_dots

arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1, device = self.device) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]

lat = torch.asin((2. * arange) / num_surface_dots)
lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio

# ein:
# sd - surface dots
# c - coordinate (3)
# i, j - source and target atom

unit_surface_dots: Float['sd 3'] = torch.stack((
lon.sin() * lat.cos(),
lon.cos() * lat.cos(),
lat.sin()
), dim = -1)

# get atom relative positions + distance
# for determining whether to include pairs of atom in calculation for the `free` adjective

Expand Down Expand Up @@ -5715,7 +5720,7 @@ def calc_atom_access_surface_score(

# overall logic

surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots)
surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, self.unit_surface_dots)

dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1)

Expand All @@ -5725,7 +5730,7 @@ def calc_atom_access_surface_score(

is_free = reduce(target_atom_close_or_not_included, 'i sd j -> i sd', 'all') # basically the most important line, calculating whether an atom is free by some distance measure

score = reduce(is_free.float() * weight, 'm sd -> m', 'sum')
score = reduce(is_free.float() * self.surface_weight, 'm sd -> m', 'sum')

per_atom_access_surface_score = score * atom_radii_sq

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.5.55"
version = "0.5.56"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit 0b70ab9

Please sign in to comment.