Skip to content

Commit

Permalink
cache the calculation of has_bond in `molecule_lengthed_molecule_inpu…
Browse files Browse the repository at this point in the history
…t_to_atom_input`
  • Loading branch information
lucidrains committed Sep 23, 2024
1 parent fa1a065 commit f585a6a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 34 deletions.
94 changes: 61 additions & 33 deletions alphafold3_pytorch/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,18 @@
# simple caching

ATOMPAIR_IDS_CACHE = dict()
HAS_BOND_CACHE = dict()

@typecheck
def maybe_cache(
fn,
*,
cache: dict,
key: str,
key: str | None,
should_cache: bool = True
) -> Callable:

if not should_cache:
if not should_cache or not exists(key):
return fn

@wraps(fn)
Expand All @@ -246,7 +247,7 @@ def inner(*args, **kwargs):
def get_atompair_ids(
mol: Mol,
directed_bonds: bool
) -> Tensor | None:
) -> Int['m m'] | None:

coordinates = []
updates = []
Expand Down Expand Up @@ -304,6 +305,46 @@ def get_atompair_ids(

return mol_atompair_ids

@typecheck
def get_mol_has_bond(
mol: Mol
) -> Bool['m m'] | None:

coordinates = []

bonds = mol.GetBonds()
num_bonds = len(bonds)

for bond in bonds:
atom_start_index = bond.GetBeginAtomIdx()
atom_end_index = bond.GetEndAtomIdx()

coordinates.extend(
[
[atom_start_index, atom_end_index],
[atom_end_index, atom_start_index],
]
)

if num_bonds == 0:
return None

num_atoms = mol.GetNumAtoms()
has_bond = torch.zeros(num_atoms, num_atoms).bool()

coordinates = tensor(coordinates).long()

# has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)

has_bond_stride = tensor(has_bond.stride())
flattened_coordinates = (coordinates * has_bond_stride).sum(dim = -1)
packed_has_bond, unpack_has_bond = pack_one(has_bond, '*')

packed_has_bond[flattened_coordinates] = True
has_bond = unpack_has_bond(packed_has_bond, '*')

return has_bond

# functions


Expand Down Expand Up @@ -1182,49 +1223,36 @@ def molecule_lengthed_molecule_input_to_atom_input(

for (
mol,
mol_id,
mol_is_chainable_biomolecule,
mol_is_first_mol_in_chain,
mol_is_one_token_per_atom,
) in zip(molecules, is_chainable_biomolecules, is_first_mol_in_chains, one_token_per_atom):
) in zip(
molecules,
molecule_ids,
is_chainable_biomolecules,
is_first_mol_in_chains,
one_token_per_atom
):
num_atoms = mol.GetNumAtoms()

if mol_is_chainable_biomolecule and not mol_is_first_mol_in_chain:
token_bonds[offset, offset - 1] = True
token_bonds[offset - 1, offset] = True

if mol_is_one_token_per_atom:
coordinates = []

bonds = mol.GetBonds()
num_bonds = len(bonds)

for bond in bonds:
atom_start_index = bond.GetBeginAtomIdx()
atom_end_index = bond.GetEndAtomIdx()

coordinates.extend(
[
[atom_start_index, atom_end_index],
[atom_end_index, atom_start_index],
]
)

if num_bonds > 0:
has_bond = torch.zeros(num_atoms, num_atoms).bool()

coordinates = tensor(coordinates).long()

# has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)

has_bond_stride = tensor(has_bond.stride())
flattened_coordinates = (coordinates * has_bond_stride).sum(dim = -1)
packed_has_bond, unpack_has_bond = pack_one(has_bond, '*')

packed_has_bond[flattened_coordinates] = True
has_bond = unpack_has_bond(packed_has_bond, '*')
maybe_cached_get_mol_has_bond = maybe_cache(
get_mol_has_bond,
cache = HAS_BOND_CACHE,
key = str(mol_id),
should_cache = mol_is_chainable_biomolecule.item()
)

# / ein.set_at
has_bond = maybe_cached_get_mol_has_bond(mol)

if exists(has_bond) and has_bond.numel() > 0:
num_atoms = mol.GetNumAtoms()
row_col_slice = slice(offset, offset + num_atoms)
token_bonds[row_col_slice, row_col_slice] = has_bond

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.5.37"
version = "0.5.38"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit f585a6a

Please sign in to comment.