Skip to content

Commit

Permalink
tests for new network planner functions
Browse files Browse the repository at this point in the history
  • Loading branch information
richardjgowers committed Jul 12, 2023
1 parent cf3ea3a commit 103704a
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 3 deletions.
36 changes: 33 additions & 3 deletions openfe/setup/ligand_network_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
from typing import Iterable, Callable, Optional, Union
import itertools
from collections import Counter
import functools

import networkx as nx
Expand Down Expand Up @@ -206,17 +207,37 @@ def generate_network_from_names(
Returns
-------
LigandNetwork
Raises
------
KeyError
if an invalid name is requested
ValueError
if multiple molecules have the same name (this would otherwise be
problematic)
"""
nm2idx = {l.name: i for i, l in enumerate(ligands)}

ids = [(nm2idx[nm1], nm2idx[nm2]) for nm1, nm2 in names]
if len(nm2idx) < len(ligands):
dupes = Counter((l.name for l in ligands))
dupe_names = [k for k, v in dupes.items() if v > 1]
raise ValueError(f"Duplicate names: {dupe_names}")

try:
ids = [(nm2idx[nm1], nm2idx[nm2]) for nm1, nm2 in names]
except KeyError:
badnames = [nm for nm in itertools.chain.from_iterable(names)
if nm not in nm2idx]
available = [ligand.name for ligand in ligands]
raise KeyError(f"Invalid name(s) requested {badnames}. "
f"Available: {available}")

return generate_network_from_indices(ligands, mapper, ids)


def generate_network_from_indices(
ligands: list[SmallMoleculeComponent],
mapper: Union[AtomMapper, Iterable[AtomMapper]],
mapper: AtomMapper,
indices: list[tuple[int, int]],
) -> LigandNetwork:
"""Generate a LigandNetwork
Expand All @@ -235,11 +256,20 @@ def generate_network_from_indices(
Returns
-------
LigandNetwork
Raises
------
IndexError
if an invalid ligand index is requested
"""
edges = []

for i, j in indices:
m1, m2 = ligands[i], ligands[j]
try:
m1, m2 = ligands[i], ligands[j]
except IndexError:
raise IndexError(f"Invalid ligand id, requested {i} {j} "
f"with {len(ligands)} available")

mapping = next(mapper.suggest_mappings(m1, m2))

Expand Down
89 changes: 89 additions & 0 deletions openfe/tests/setup/test_network_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,92 @@ def scorer(mapping):
mappers=[openfe.setup.atom_mapping.LomapAtomMapper()],
scorer=scorer
)


def test_network_from_names(atom_mapping_basic_test_files):
ligs = list(atom_mapping_basic_test_files.values())

requested = [
('toluene', '2-naftanol'),
('2-methylnaphthalene', '2-naftanol'),
]

network = openfe.setup.ligand_network_planning.generate_network_from_names(
ligands=ligs,
names=requested,
mapper=openfe.LomapAtomMapper(),
)

assert len(network.nodes) == len(ligs)
assert len(network.edges) == 2
actual_edges = [(e.componentA.name, e.componentB.name)
for e in network.edges]
assert set(requested) == set(actual_edges)


def test_network_from_names_bad_name(atom_mapping_basic_test_files):
ligs = list(atom_mapping_basic_test_files.values())

requested = [
('hank', '2-naftanol'),
('2-methylnaphthalene', '2-naftanol'),
]

with pytest.raises(KeyError, match="Invalid name"):
_ = openfe.setup.ligand_network_planning.generate_network_from_names(
ligands=ligs,
names=requested,
mapper=openfe.LomapAtomMapper(),
)


def test_network_from_names_duplicate_name(atom_mapping_basic_test_files):
ligs = list(atom_mapping_basic_test_files.values())
ligs = ligs + [ligs[0]]

requested = [
('toluene', '2-naftanol'),
('2-methylnaphthalene', '2-naftanol'),
]

with pytest.raises(ValueError, match="Duplicate names"):
_ = openfe.setup.ligand_network_planning.generate_network_from_names(
ligands=ligs,
names=requested,
mapper=openfe.LomapAtomMapper(),
)


def test_network_from_indices(atom_mapping_basic_test_files):
ligs = list(atom_mapping_basic_test_files.values())

requested = [(0, 1), (2, 3)]

network = openfe.setup.ligand_network_planning.generate_network_from_indices(
ligands=ligs,
indices=requested,
mapper=openfe.LomapAtomMapper(),
)

assert len(network.nodes) == len(ligs)
assert len(network.edges) == 2

edges = list(network.edges)
expected_edges = {(ligs[0], ligs[1]), (ligs[2], ligs[3])}
actual_edges = {(edges[0].componentA, edges[0].componentB),
(edges[1].componentA, edges[1].componentB)}

assert actual_edges == expected_edges


def test_network_from_indices_indexerror(atom_mapping_basic_test_files):
ligs = list(atom_mapping_basic_test_files.values())

requested = [(20, 1), (2, 3)]

with pytest.raises(IndexError, match="Invalid ligand id"):
network = openfe.setup.ligand_network_planning.generate_network_from_indices(
ligands=ligs,
indices=requested,
mapper=openfe.LomapAtomMapper(),
)

0 comments on commit 103704a

Please sign in to comment.