Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
mikemhenry authored Sep 25, 2023
2 parents 0a97327 + f403763 commit 0334030
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 11 deletions.
16 changes: 11 additions & 5 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,15 @@ def get_individual_estimates(self) -> list[tuple[unit.Quantity, unit.Quantity]]:
def get_forward_and_reverse_energy_analysis(self) -> list[dict[str, Union[npt.NDArray, unit.Quantity]]]:
"""
Get a list of forward and reverse analysis of the free energies
for each repeat using uncorrolated production samples.
for each repeat using uncorrelated production samples.
The returned dicts have keys:
'fractions' - the fraction of data used for this estimate
'forward_DGs', 'reverse_DGs' - for each fraction of data, the estimate
'forward_dDGs', 'reverse_dDGs' - for each estimate, the uncertainty
The 'fractions' values are a numpy array, while the other arrays are
Quantity arrays, with units attached.
Returns
-------
Expand Down Expand Up @@ -231,9 +239,7 @@ def get_overlap_matrices(self) -> list[dict[str, npt.NDArray]]:
return overlap_stats

def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]:
"""
Returns the replica lambda state transition statistics for each
repeat.
"""The replica lambda state transition statistics for each repeat.
Note
----
Expand All @@ -246,7 +252,7 @@ def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]:
A list of dictionaries containing the following:
* ``eigenvalues``: The sorted (descending) eigenvalues of the
lambda state transition matrix
* ``matrix``: The transition matrix estimate of a replica switchin
* ``matrix``: The transition matrix estimate of a replica switching
from state i to state j.
"""
try:
Expand Down
Binary file not shown.
Binary file removed openfe/tests/data/openmm_rfe/vac_results.json.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion openfe/tests/protocols/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,5 +147,5 @@ def transformation_json() -> str:
"""string of a result of quickrun"""
d = resources.files('openfe.tests.data.openmm_rfe')

with gzip.open((d / 'vac_results.json.gz').as_posix(), 'r') as f: # type: ignore
with gzip.open((d / 'Transformation-e1702a3efc0fa735d5c14fc7572b5278_results.json.gz').as_posix(), 'r') as f: # type: ignore
return f.read().decode() # type: ignore
103 changes: 98 additions & 5 deletions openfe/tests/protocols/test_openmm_equil_rfe_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,10 +1241,103 @@ def test_constraints(tyk2_xml, tyk2_reference_xml):
assert float(a.get('d')) == pytest.approx(float(b.get('d')))


def test_reload_protocol_result(transformation_json):
d = json.loads(transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)
class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, transformation_json):
d = json.loads(transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)

pr = openmm_rfe.RelativeHybridTopologyProtocolResult.from_dict(d['protocol_result'])
pr = openfe.ProtocolResult.from_dict(d['protocol_result'])

assert pr
return pr

def test_reload_protocol_result(self, transformation_json):
d = json.loads(transformation_json,
cls=gufe.tokenization.JSON_HANDLER.decoder)

pr = openmm_rfe.RelativeHybridTopologyProtocolResult.from_dict(d['protocol_result'])

assert pr

def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()

assert est
assert est.m == pytest.approx(-15.768768285032115)
assert isinstance(est, unit.Quantity)
assert est.is_compatible_with(unit.kilojoule_per_mole)

def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()

assert est
assert est.m == pytest.approx(0.03662634237353985)
assert isinstance(est, unit.Quantity)
assert est.is_compatible_with(unit.kilojoule_per_mole)

def test_get_individual(self, protocolresult):
inds = protocolresult.get_individual_estimates()

assert isinstance(inds, list)
assert len(inds) == 3
for e, u in inds:
assert e.is_compatible_with(unit.kilojoule_per_mole)
assert u.is_compatible_with(unit.kilojoule_per_mole)

def test_get_forwards_etc(self, protocolresult):
far = protocolresult.get_forward_and_reverse_energy_analysis()

assert isinstance(far, list)
far1 = far[0]
assert isinstance(far1, dict)
for k in ['fractions', 'forward_DGs', 'forward_dDGs',
'reverse_DGs', 'reverse_dDGs']:
assert k in far1

if k == 'fractions':
assert isinstance(far1[k], np.ndarray)
else:
assert isinstance(far1[k], unit.Quantity)
assert far1[k].is_compatible_with(unit.kilojoule_per_mole)

def test_get_overlap_matrices(self, protocolresult):
ovp = protocolresult.get_overlap_matrices()

assert isinstance(ovp, list)
assert len(ovp) == 3

ovp1 = ovp[0]
assert isinstance(ovp1['matrix'], np.ndarray)
assert ovp1['matrix'].shape == (11,11)

def test_get_replica_transition_statistics(self, protocolresult):
rpx = protocolresult.get_replica_transition_statistics()

assert isinstance(rpx, list)
assert len(rpx) == 3
rpx1 = rpx[0]
assert 'eigenvalues' in rpx1
assert 'matrix' in rpx1
assert rpx1['eigenvalues'].shape == (11,)
assert rpx1['matrix'].shape == (11, 11)

def test_get_replica_states(self, protocolresult):
rep = protocolresult.get_replica_states()

assert isinstance(rep, list)
assert len(rep) == 3
assert rep[0].shape == (6, 11)

def test_equilibration_iterations(self, protocolresult):
eq = protocolresult.equilibration_iterations()

assert isinstance(eq, list)
assert len(eq) == 3
assert all(isinstance(v, float) for v in eq)

def test_production_iterations(self, protocolresult):
prod = protocolresult.production_iterations()

assert isinstance(prod, list)
assert len(prod) == 3
assert all(isinstance(v, float) for v in prod)

0 comments on commit 0334030

Please sign in to comment.