diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 8b445f2b..b64f9717 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -5288,6 +5288,7 @@ def __init__( contact_mask_threshold: float = 8.0, is_fine_tuning: bool = False, weight_dict_config: dict = None, + fibonacci_sphere_n = 200, # more points equal better approximation at cost of compute ): super().__init__() self.compute_confidence_score = ComputeConfidenceScore(eps=eps) @@ -5301,6 +5302,8 @@ def __init__( self.register_buffer("dist_breaks", dist_breaks) self.register_buffer('lddt_thresholds', torch.tensor([0.5, 1.0, 2.0, 4.0])) + # for rsa calculation + atom_type_radii = tensor([ 1.65, # 0 - nitrogen 1.87, # 1 - carbon alpha @@ -5319,6 +5322,31 @@ def __init__( self.register_buffer('atom_radii', atom_type_radii, persistent = False) + # constitute the fibonacci sphere + + num_surface_dots = fibonacci_sphere_n * 2 + 1 + golden_ratio = 1. + sqrt(5.) / 2 + weight = (4. * pi) / num_surface_dots + + arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1, device = self.device) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3] + + lat = torch.asin((2. * arange) / num_surface_dots) + lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio + + # ein: + # sd - surface dots + # c - coordinate (3) + # i, j - source and target atom + + unit_surface_dots: Float['sd 3'] = torch.stack(( + lon.sin() * lat.cos(), + lon.cos() * lat.cos(), + lat.sin() + ), dim = -1) + + self.register_buffer('unit_surface_dots', unit_surface_dots) + self.surface_weight = weight + @property def device(self): return self.atom_radii.device @@ -5651,7 +5679,6 @@ def calc_atom_access_surface_score( atom_pos: Float['m 3'], atom_type: Int['m'], molecule_atom_lens: Int['n'] | None = None, - fibonacci_sphere_n = 200, # more points equal better approximation at cost of compute atom_distance_min_thres = 1e-4 ) -> Float['m'] | Float['n']: @@ -5666,28 +5693,6 @@ def calc_atom_access_surface_score( # write custom RSA function here - # first constitute the fibonacci sphere - - num_surface_dots = fibonacci_sphere_n * 2 + 1 - golden_ratio = 1. + sqrt(5.) / 2 - weight = (4. * pi) / num_surface_dots - - arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1, device = self.device) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3] - - lat = torch.asin((2. * arange) / num_surface_dots) - lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio - - # ein: - # sd - surface dots - # c - coordinate (3) - # i, j - source and target atom - - unit_surface_dots: Float['sd 3'] = torch.stack(( - lon.sin() * lat.cos(), - lon.cos() * lat.cos(), - lat.sin() - ), dim = -1) - # get atom relative positions + distance # for determining whether to include pairs of atom in calculation for the `free` adjective @@ -5715,7 +5720,7 @@ def calc_atom_access_surface_score( # overall logic - surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots) + surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, self.unit_surface_dots) dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1) @@ -5725,7 +5730,7 @@ def calc_atom_access_surface_score( is_free = reduce(target_atom_close_or_not_included, 'i sd j -> i sd', 'all') # basically the most important line, calculating whether an atom is free by some distance measure - score = reduce(is_free.float() * weight, 'm sd -> m', 'sum') + score = reduce(is_free.float() * self.surface_weight, 'm sd -> m', 'sum') per_atom_access_surface_score = score * atom_radii_sq diff --git a/pyproject.toml b/pyproject.toml index 7ebda9d5..f6814b0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.5.55" +version = "0.5.56" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" },