Skip to content

Commit

Permalink
able to compute rsa on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 30, 2024
1 parent 5f97fcb commit bf1d580
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5319,6 +5319,10 @@ def __init__(

self.register_buffer('atom_radii', atom_type_radii, persistent = False)

@property
def device(self):
return self.atom_radii.device

@typecheck
def compute_gpde(
self,
Expand Down Expand Up @@ -5629,10 +5633,10 @@ def calc_atom_access_surface_score_from_structure(
structure_atom_pos.append(one_atom_pos)
structure_atom_type_for_radii.append(one_atom_type)

structure_atom_pos: Float['m 3'] = tensor(structure_atom_pos)
structure_atom_type_for_radii: Int['m'] = tensor(structure_atom_type_for_radii)
structure_atom_pos: Float['m 3'] = tensor(structure_atom_pos, device = self.device)
structure_atom_type_for_radii: Int['m'] = tensor(structure_atom_type_for_radii, device = self.device)

structure_atoms_per_residue: Int['n'] = tensor([len([*residue.get_atoms()]) for residue in structure.get_residues()]).long()
structure_atoms_per_residue: Int['n'] = tensor([len([*residue.get_atoms()]) for residue in structure.get_residues()], device = self.device).long()

return self.calc_atom_access_surface_score(
atom_pos = structure_atom_pos,
Expand Down Expand Up @@ -5668,7 +5672,7 @@ def calc_atom_access_surface_score(
golden_ratio = 1. + sqrt(5.) / 2
weight = (4. * pi) / num_surface_dots

arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1, device = self.device) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]

lat = torch.asin((2. * arange) / num_surface_dots)
lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.5.54"
version = "0.5.55"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit bf1d580

Please sign in to comment.