Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add function to create a minimal network with redundancy for all nodes #559

Merged
merged 24 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b43c153
add empty function with adjusted docstring
JenkeScheen Sep 15, 2023
a91f92d
remove bulk of function, will add back gradually.
JenkeScheen Sep 15, 2023
372d219
drop in redundant network fn
JenkeScheen Sep 18, 2023
a1cd502
add tests
JenkeScheen Sep 18, 2023
a2dd148
Merge branch 'main' into main
dwhswenson Sep 20, 2023
2ef3289
Ah, missed this one (thanks mypy): Suggested change [mapping…
JenkeScheen Sep 21, 2023
7863505
more descriptive test fixture
JenkeScheen Sep 21, 2023
8f92cde
set 'test_' for actual test function..
JenkeScheen Sep 21, 2023
a106976
Merge branch 'main' of https://github.com/JenkeScheen/openfe_redundan…
JenkeScheen Sep 21, 2023
532b910
add test for number of edges in redundant networks
JenkeScheen Sep 21, 2023
8caec4f
add mst_num = 2 by default
JenkeScheen Sep 21, 2023
229537a
fix function names in redundant ntwk tests
JenkeScheen Sep 21, 2023
8457a3d
add `mst_num` to default test
JenkeScheen Sep 21, 2023
a0bf736
test that redundant(n=1) is same as spanning
JenkeScheen Sep 21, 2023
f0fbdcc
pep8 shenanigans
JenkeScheen Sep 21, 2023
a494fdf
redundant network into original MST test scope
JenkeScheen Sep 21, 2023
548568c
revert scope back to `session` so we can pull data
JenkeScheen Sep 21, 2023
deac790
test fix (had wrong network refs)
JenkeScheen Sep 21, 2023
68a46a1
more pep8
JenkeScheen Sep 21, 2023
36a9f39
simplify test
JenkeScheen Sep 22, 2023
0a97327
rename duped fn name
JenkeScheen Sep 25, 2023
0334030
Merge branch 'main' into main
mikemhenry Sep 25, 2023
578b7a1
Merge branch 'main' into main
IAlibay Oct 4, 2023
633837f
Update ligand_network_planning.py
richardjgowers Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions openfe/setup/ligand_network_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def generate_maximal_network(
total = len(nodes) * (len(nodes) - 1) // 2
progress = functools.partial(tqdm, total=total, delay=1.5)
elif progress is False:
progress = lambda x: x
def progress(x): return x
# otherwise, it should be a user-defined callable

mapping_generator = itertools.chain.from_iterable(
Expand Down Expand Up @@ -229,6 +229,74 @@ def generate_minimal_spanning_network(
return min_network


def generate_minimal_redundant_network(
ligands: Iterable[SmallMoleculeComponent],
mappers: Union[AtomMapper, Iterable[AtomMapper]],
scorer: Callable[[LigandAtomMapping], float],
progress: Union[bool, Callable[[Iterable], Iterable]] = True,
mst_num: int = 2,
) -> LigandNetwork:
"""
Plan a network with a specified amount of redundancy for each node

Creates a network with as few edges as possible with maximum total score,
ensuring that every node is connected to two edges to introduce
statistical redundancy.

Parameters
----------
ligands : Iterable[SmallMoleculeComponent]
the ligands to include in the LigandNetwork
mappers : AtomMapper or Iterable[AtomMapper]
the AtomMapper(s) to use to propose mappings. At least 1 required,
but many can be given, in which case all will be tried to find the
highest score edges
scorer : Scoring function
any callable which takes a LigandAtomMapping and returns a float
progress : Union[bool, Callable[Iterable], Iterable]
progress bar: if False, no progress bar will be shown. If True, use a
tqdm progress bar that only appears after 1.5 seconds. You can also
provide a custom progress bar wrapper as a callable.
mst_num: int
Minimum Spanning Tree number: the number of minimum spanning trees to
generate. If two, the second-best edges are included in the returned
network. If three, the third-best edges are also included, etc.
"""
if isinstance(mappers, AtomMapper):
mappers = [mappers]
mappers = [_hasten_lomap(m, ligands) if isinstance(m, LomapAtomMapper)
else m for m in mappers]

# First create a network with all the proposed mappings (scored)
network = generate_maximal_network(ligands, mappers, scorer, progress)

# Flip network scores so we can use minimal algorithm
g2 = nx.MultiGraph()
for e1, e2, d in network.graph.edges(data=True):
g2.add_edge(e1, e2, weight=-d['score'], object=d['object'])

# As in .generate_minimal_spanning_network(), use nx to get the minimal
# network. But now also remove those edges from the fully-connected
# network, then get the minimal network again. Add mappings from all
# minimal networks together.
mappings = []
for _ in range(mst_num): # can increase range here for more redundancy
# get list from generator so that we don't adjust network by calling it:
current_best_edges = list(nx.minimum_spanning_edges(g2))

g2.remove_edges_from(current_best_edges)
for _, _, _, edge_data in current_best_edges:
mappings.append(edge_data['object'])

redund_network = LigandNetwork(mappings)
missing_nodes = set(network.nodes) - set(redund_network.nodes)
if missing_nodes:
raise RuntimeError("Unable to create edges to some nodes: "
f"{list(missing_nodes)}")

return redund_network


def generate_network_from_names(
ligands: list[SmallMoleculeComponent],
mapper: AtomMapper,
Expand Down Expand Up @@ -353,7 +421,7 @@ def load_orion_network(
KeyError
If an unexpected line format is encountered.
"""

with open(network_file, 'r') as f:
network_lines = [l.strip().split(' ') for l in f
if not l.startswith('#')]
Expand Down
113 changes: 108 additions & 5 deletions openfe/tests/setup/test_network_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def scorer(mapping):
assert len(network.edges) == len(others)

for edge in network.edges:
assert len(edge.componentA_to_componentB) > 1 # we didn't take the bad mapper
# we didn't take the bad mapper
assert len(edge.componentA_to_componentB) > 1
assert 'score' in edge.annotations
assert edge.annotations['score'] == len(edge.componentA_to_componentB)

Expand Down Expand Up @@ -196,7 +197,8 @@ def test_minimal_spanning_network(minimal_spanning_network, toluene_vs_others):
tol, others = toluene_vs_others
assert len(minimal_spanning_network.nodes) == len(others) + 1
for edge in minimal_spanning_network.edges:
assert edge.componentA_to_componentB != {0: 0} # lomap should find something
assert edge.componentA_to_componentB != {
0: 0} # lomap should find something


def test_minimal_spanning_network_connectedness(minimal_spanning_network):
Expand Down Expand Up @@ -245,6 +247,106 @@ def scorer(mapping):
)


@pytest.fixture(scope='session')
def minimal_redundant_network(toluene_vs_others):
toluene, others = toluene_vs_others
mappers = [BadMapper(), openfe.setup.atom_mapping.LomapAtomMapper()]

def scorer(mapping):
return len(mapping.componentA_to_componentB)

network = openfe.setup.ligand_network_planning.generate_minimal_redundant_network(
ligands=others + [toluene],
mappers=mappers,
scorer=scorer,
mst_num=2
)
return network


def test_minimal_redundant_network(minimal_redundant_network, toluene_vs_others):
tol, others = toluene_vs_others

# test for correct number of nodes
assert len(minimal_redundant_network.nodes) == len(others) + 1

# test for correct number of edges
assert len(minimal_redundant_network.edges) == 2 * \
(len(minimal_redundant_network.nodes) - 1)

for edge in minimal_redundant_network.edges:
assert edge.componentA_to_componentB != {
0: 0} # lomap should find something


def test_minimal_redundant_network_connectedness(minimal_redundant_network):
found_pairs = set()
for edge in minimal_redundant_network.edges:
pair = frozenset([edge.componentA, edge.componentB])
assert pair not in found_pairs
found_pairs.add(pair)

assert nx.is_connected(nx.MultiGraph(minimal_redundant_network.graph))


def test_redundant_vs_spanning_network(minimal_redundant_network, minimal_spanning_network):
# when setting minimal redundant network to only take one MST, it should have as many
# edges as the regular minimum spanning network
assert 2 * len(minimal_spanning_network.edges) == len(
minimal_redundant_network.edges)


def test_minimal_redundant_network_edges(minimal_redundant_network):
# issue #244, this was previously giving non-reproducible (yet valid)
# networks when scores were tied.
edge_ids = sorted(
(edge.componentA.name, edge.componentB.name)
for edge in minimal_redundant_network.edges
)
ref = sorted([
('1,3,7-trimethylnaphthalene', '2,6-dimethylnaphthalene'),
('1,3,7-trimethylnaphthalene', '2-methyl-6-propylnaphthalene'),
('1-butyl-4-methylbenzene', '2,6-dimethylnaphthalene'),
('1-butyl-4-methylbenzene', '2-methyl-6-propylnaphthalene'),
('1-butyl-4-methylbenzene', 'toluene'),
('2,6-dimethylnaphthalene', '2-methyl-6-propylnaphthalene'),
('2,6-dimethylnaphthalene', '2-methylnaphthalene'),
('2,6-dimethylnaphthalene', '2-naftanol'),
('2,6-dimethylnaphthalene', 'methylcyclohexane'),
('2,6-dimethylnaphthalene', 'toluene'),
('2-methyl-6-propylnaphthalene', '2-methylnaphthalene'),
('2-methylnaphthalene', '2-naftanol'),
('2-methylnaphthalene', 'methylcyclohexane'),
('2-methylnaphthalene', 'toluene')
])

assert len(edge_ids) == len(ref)
assert edge_ids == ref


def test_minimal_redundant_network_redundant(minimal_redundant_network):
# test that each node is connected to 2 edges.
network = minimal_redundant_network
for node in network.nodes:
assert len(network.graph.in_edges(node)) + \
len(network.graph.out_edges(node)) >= 2


def test_minimal_redundant_network_unreachable(toluene_vs_others):
toluene, others = toluene_vs_others
nimrod = openfe.SmallMoleculeComponent(mol_from_smiles("N"))

def scorer(mapping):
return len(mapping.componentA_to_componentB)

with pytest.raises(RuntimeError, match="Unable to create edges"):
network = openfe.setup.ligand_network_planning.generate_minimal_redundant_network(
ligands=others + [toluene, nimrod],
mappers=[openfe.setup.atom_mapping.LomapAtomMapper()],
scorer=scorer
)


def test_network_from_names(atom_mapping_basic_test_files):
ligs = list(atom_mapping_basic_test_files.values())

Expand Down Expand Up @@ -366,10 +468,12 @@ def test_network_from_external(file_fixture, loader, request,
expected_edges = {
(benzene_modifications['benzene'], benzene_modifications['toluene']),
(benzene_modifications['benzene'], benzene_modifications['phenol']),
(benzene_modifications['benzene'], benzene_modifications['benzonitrile']),
(benzene_modifications['benzene'],
benzene_modifications['benzonitrile']),
(benzene_modifications['benzene'], benzene_modifications['anisole']),
(benzene_modifications['benzene'], benzene_modifications['styrene']),
(benzene_modifications['benzene'], benzene_modifications['benzaldehyde']),
(benzene_modifications['benzene'],
benzene_modifications['benzaldehyde']),
}

actual_edges = {(e.componentA, e.componentB) for e in list(network.edges)}
Expand Down Expand Up @@ -423,7 +527,6 @@ def test_bad_orion_network(benzene_modifications, tmpdir):
)



BAD_EDGES = """\
1c91235:9c91235 benzene -> toluene
1c91235:7876633 benzene -> phenol
Expand Down