From bf1d58052e4ba6835203855a591f9f70a43236e0 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 29 Sep 2024 17:05:10 -0700 Subject: [PATCH] able to compute rsa on gpu --- alphafold3_pytorch/alphafold3.py | 12 ++++++++---- pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index 077181d4..8b445f2b 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -5319,6 +5319,10 @@ def __init__( self.register_buffer('atom_radii', atom_type_radii, persistent = False) + @property + def device(self): + return self.atom_radii.device + @typecheck def compute_gpde( self, @@ -5629,10 +5633,10 @@ def calc_atom_access_surface_score_from_structure( structure_atom_pos.append(one_atom_pos) structure_atom_type_for_radii.append(one_atom_type) - structure_atom_pos: Float['m 3'] = tensor(structure_atom_pos) - structure_atom_type_for_radii: Int['m'] = tensor(structure_atom_type_for_radii) + structure_atom_pos: Float['m 3'] = tensor(structure_atom_pos, device = self.device) + structure_atom_type_for_radii: Int['m'] = tensor(structure_atom_type_for_radii, device = self.device) - structure_atoms_per_residue: Int['n'] = tensor([len([*residue.get_atoms()]) for residue in structure.get_residues()]).long() + structure_atoms_per_residue: Int['n'] = tensor([len([*residue.get_atoms()]) for residue in structure.get_residues()], device = self.device).long() return self.calc_atom_access_surface_score( atom_pos = structure_atom_pos, @@ -5668,7 +5672,7 @@ def calc_atom_access_surface_score( golden_ratio = 1. + sqrt(5.) / 2 weight = (4. * pi) / num_surface_dots - arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3] + arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1, device = self.device) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3] lat = torch.asin((2. * arange) / num_surface_dots) lon = torch.fmod(arange, golden_ratio) * 2 * pi / golden_ratio diff --git a/pyproject.toml b/pyproject.toml index f43cc23b..7ebda9d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.5.54" +version = "0.5.55" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" },