Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix RTHP results serialization and test results load #550

Merged
merged 6 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ recursive-include openfe/tests/data/ *.xml
recursive-include openfe/tests/data/ *.graphml
recursive-include openfe/tests/data/ *.edge
recursive-include openfe/tests/data/ *.dat
recursive-include openfe/tests/data/ *json.gz
include openfecli/tests/data/*.json
include openfecli/tests/data/*.tar.gz
recursive-include openfecli/tests/ *.sdf
Expand Down
5 changes: 3 additions & 2 deletions openfe/protocols/openmm_utils/multistate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@
"""
# Do things that get badly cached later
self._replica_states = self.analyzer.reporter.read_replica_thermodynamic_states()
self._equil_iters = self.analyzer.n_equilibration_iterations
self._prod_iters = self.analyzer._equilibration_data[2]
# float conversions to avoid having to deal with numpy dtype serialization
self._equil_iters = float(self.analyzer.n_equilibration_iterations)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these be int instead of float? (Variable name looks to me like an int)

Copy link
Member Author

@IAlibay IAlibay Sep 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've already raised that issue in openmmtools, currently they return float and it's been agreed that they will eventually change.

self._prod_iters = float(self.analyzer._equilibration_data[2])

# Gather estimate of free energy
self._free_energy, self._free_energy_err = self.get_equil_free_energy()
Expand Down Expand Up @@ -185,10 +186,10 @@
try:
# pymbar 3
DF_ij, dDF_ij = mbar.getFreeEnergyDifferences()
except AttributeError:
r = mbar.compute_free_energy_differences()
DF_ij = r['Delta_f']
dDF_ij = r['dDelta_f']

Check warning on line 192 in openfe/protocols/openmm_utils/multistate_analysis.py

View check run for this annotation

Codecov / codecov/patch

openfe/protocols/openmm_utils/multistate_analysis.py#L189-L192

Added lines #L189 - L192 were not covered by tests

DG = DF_ij[0, -1] * analyzer.kT
dDG = dDF_ij[0, -1] * analyzer.kT
Expand Down Expand Up @@ -246,9 +247,9 @@

# Check that the N_l is the same across all states
if not np.all(N_l == N_l[0]):
errmsg = ("The number of samples is not equivalent across all "

Check warning on line 250 in openfe/protocols/openmm_utils/multistate_analysis.py

View check run for this annotation

Codecov / codecov/patch

openfe/protocols/openmm_utils/multistate_analysis.py#L250

Added line #L250 was not covered by tests
f"states {N_l}")
raise ValueError(errmsg)

Check warning on line 252 in openfe/protocols/openmm_utils/multistate_analysis.py

View check run for this annotation

Codecov / codecov/patch

openfe/protocols/openmm_utils/multistate_analysis.py#L252

Added line #L252 was not covered by tests

# Get the chunks of N_l going from 10% to ~ 100%
# Note: you always lose out a few data points but it's fine
Expand Down Expand Up @@ -311,8 +312,8 @@
try:
# pymbar 3
overlap_matrix = self.analyzer.mbar.computeOverlap()
except AttributeError:
overlap_matrix = self.analyzer.mbar.compute_overlap()

Check warning on line 316 in openfe/protocols/openmm_utils/multistate_analysis.py

View check run for this annotation

Codecov / codecov/patch

openfe/protocols/openmm_utils/multistate_analysis.py#L315-L316

Added lines #L315 - L316 were not covered by tests

return overlap_matrix

Expand Down Expand Up @@ -417,4 +418,4 @@
return results_dict

def close(self):
self.analyzer.clear()

Check warning on line 421 in openfe/protocols/openmm_utils/multistate_analysis.py

View check run for this annotation

Codecov / codecov/patch

openfe/protocols/openmm_utils/multistate_analysis.py#L421

Added line #L421 was not covered by tests
Binary file added openfe/tests/data/openmm_rfe/vac_results.json.gz
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm missing something really obvious. I'm trying to load in the results json back into a RelativeHybridTopologyProtocolResult but my unit results end up without the .outputs attribute.

Current testing looks something like this:

 r = json.load(gzip.open('vac_results.json.gz', 'r'), cls=JSON_HANDLER.decoder)
 rhtpr = RelativeHybridTopologyProtocolResult._from_dict(r['protocol_result']['data'])

Looks like it hasn't sufficiently deserialized the protocol unit results, is this even something we're meant to be able to do?

Loads fine, has the

Binary file not shown.
12 changes: 11 additions & 1 deletion openfe/tests/protocols/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import gzip
import pytest
import importlib
from importlib import resources
from rdkit import Chem
from rdkit.Geometry import Point3D
import openfe
Expand Down Expand Up @@ -139,3 +140,12 @@ def toluene_many_solv_system(benzene_modifications):
"bar": benzo,
"solvent": openfe.SolventComponent()},
)


@pytest.fixture
def transformation_json() -> str:
"""string of a result of quickrun"""
d = resources.files('openfe.tests.data.openmm_rfe')

with gzip.open((d / 'vac_results.json.gz').as_posix(), 'r') as f: # type: ignore
return f.read().decode() # type: ignore
10 changes: 10 additions & 0 deletions openfe/tests/protocols/test_openmm_equil_rfe_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import gufe
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
import json
import pytest
from unittest import mock
from openff.units import unit
Expand Down Expand Up @@ -1164,11 +1165,11 @@

@pytest.fixture(scope='session')
def tyk2_xml(tmp_path_factory):
with resources.files('openfe.tests.data.openmm_rfe') as d:
fn1 = str(d / 'ligand_23.sdf')
fn2 = str(d / 'ligand_55.sdf')
lig23 = openfe.SmallMoleculeComponent.from_sdf_file(fn1)
lig55 = openfe.SmallMoleculeComponent.from_sdf_file(fn2)

Check warning on line 1172 in openfe/tests/protocols/test_openmm_equil_rfe_protocols.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/test_openmm_equil_rfe_protocols.py#L1168-L1172

Added lines #L1168 - L1172 were not covered by tests

mapping = setup.LigandAtomMapping(
componentA=lig23, componentB=lig55,
Expand Down Expand Up @@ -1206,8 +1207,8 @@

@pytest.fixture(scope='session')
def tyk2_reference_xml():
with resources.files('openfe.tests.data.openmm_rfe') as d:
f = d / 'reference.xml'

Check warning on line 1211 in openfe/tests/protocols/test_openmm_equil_rfe_protocols.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/test_openmm_equil_rfe_protocols.py#L1210-L1211

Added lines #L1210 - L1211 were not covered by tests
with open(f, 'r') as i:
xmldata = i.read()
return ET.fromstring(xmldata)
Expand Down Expand Up @@ -1238,3 +1239,12 @@
assert a.get('p1') == b.get('p1')
assert a.get('p2') == b.get('p2')
assert float(a.get('d')) == pytest.approx(float(b.get('d')))


def test_reload_protocol_result(transformation_json):
d = json.loads(transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)

pr = openmm_rfe.RelativeHybridTopologyProtocolResult.from_dict(d['protocol_result'])

assert pr
Loading