diff --git a/CHANGES b/CHANGES index 2264e8d6..63161515 100644 --- a/CHANGES +++ b/CHANGES @@ -68,6 +68,7 @@ Fixes * fix rcoulomb in CHARMM energy minimization MDP template file (PR #210) * fix ensemble.EnsembleAnalysis.check_groups_from_common_ensemble (#212) * updated versioneer (#285) +* fix that simulation stages cannot be restarted after error (#272) 2022-01-03 0.8.0 diff --git a/mdpow/restart.py b/mdpow/restart.py index 2b2e362a..25c09133 100644 --- a/mdpow/restart.py +++ b/mdpow/restart.py @@ -116,7 +116,7 @@ def incomplete(self): def incomplete(self, stage): if not stage in self.stages: raise ValueError( - "can only assign a registered stage from %(stages)r" % vars(self) + "Can only assign a registered stage from %(stages)r" % vars(self) ) self.__incomplete = stage @@ -143,7 +143,7 @@ def completed(self, stage): def start(self, stage): """Record that *stage* is starting.""" - if self.current is not None: + if self.current is not None and self.current != stage: errmsg = ( "Cannot start stage %s because previously started stage %s " "has not been completed." % (stage, self.current) @@ -157,7 +157,16 @@ def has_completed(self, stage): return stage in self.history def has_not_completed(self, stage): - """Returns ``True`` if the *stage* had been started but not completed yet.""" + """Returns ``True`` if the *stage* had been started but not completed yet. + + This is subtly different from ``not`` :func:`has_completed` in + that two things have to be true: + + 1. No stage is active (which is the case when a restart is attempted). + 2. The `stage` has not been completed previously (i.e., + :func:`has_completed` returns ``False``) + + """ return self.current is None and not self.has_completed(stage) def clear(self): @@ -190,13 +199,14 @@ def __init__(self, *args, **kwargs): len(self.journal.history) except AttributeError: self.journal = Journal(self.protocols) - super(Journalled, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def get_protocol(self, protocol): """Return method for *protocol*. - If *protocol* is a real method of the class then the method is - returned. + returned. This method should implement its own use of + :meth:`Journal.start` and :meth:`Journal.completed`. - If *protocol* is a registered protocol name but no method of the name exists (i.e. *protocol* is a "dummy protocol") then @@ -205,9 +215,12 @@ def get_protocol(self, protocol): .. function:: dummy_protocol(func, *args, **kwargs) - Runs *func* with the arguments and keywords between calls - to :meth:`Journal.start` and :meth:`Journal.completed`, - with the stage set to *protocol*. + Runs *func* with the arguments and keywords between calls to + :meth:`Journal.start` and :meth:`Journal.completed`, with the + stage set to *protocol*. + + The function should return ``True`` on success and ``False`` on + failure. - Raises a :exc:`ValueError` if the *protocol* is not registered (i.e. not found in :attr:`Journalled.protocols`). diff --git a/mdpow/tests/test_journals.py b/mdpow/tests/test_journals.py new file mode 100644 index 00000000..f324fa7e --- /dev/null +++ b/mdpow/tests/test_journals.py @@ -0,0 +1,175 @@ +import pytest + +from mdpow import restart + + +@pytest.fixture +def journal(): + return restart.Journal(["pre", "main", "post"]) + + +class TestJournal: + def test_full_sequence(self, journal): + journal.start("pre") + assert journal.current == "pre" + journal.completed("pre") + + journal.start("main") + assert journal.current == "main" + journal.completed("main") + + journal.start("post") + assert journal.current == "post" + journal.completed("post") + + def test_set_wrong_stage_ValueError(self, journal): + with pytest.raises(ValueError, match="Can only assign a registered stage"): + journal.start("BEGIN !") + + def test_JournalSequenceError_no_completion(self, journal): + with pytest.raises(restart.JournalSequenceError, match="Cannot start stage"): + journal.start("pre") + assert journal.current == "pre" + + journal.start("main") + + @pytest.mark.xfail + def test_JournalSequenceError_skip_stage(self, journal): + # Currently allows skipping a stage and does not enforce ALL previous + # stages to have completed. + with pytest.raises(restart.JournalSequenceError, match="Cannot start stage"): + journal.start("pre") + assert journal.current == "pre" + journal.completed("pre") + + journal.start("post") + + def test_start_idempotent(self, journal): + # test that start() can be called multiple time (#278) + journal.start("pre") + journal.start("pre") + assert journal.current == "pre" + + def test_incomplete_known_stage(self, journal): + journal.incomplete = "main" + assert journal.incomplete == "main" + + def test_incomplete_unknown_stage_ValueError(self, journal): + with pytest.raises(ValueError, match="Can only assign a registered stage from"): + journal.incomplete = "BEGIN !" + + def test_clear(self, journal): + journal.start("pre") + journal.completed("pre") + journal.start("main") + # manually setting incomplete + journal.incomplete = journal.current + + assert journal.current == "main" + assert journal.incomplete == journal.current + + journal.clear() + assert journal.current is None + assert journal.incomplete is None + + def test_history(self, journal): + journal.start("pre") + journal.completed("pre") + journal.start("main") + journal.completed("main") + journal.start("post") + + # completed stages + assert journal.history == ["pre", "main"] + + def test_history_del(self, journal): + journal.start("pre") + journal.completed("pre") + journal.start("main") + journal.completed("main") + assert journal.history + + del journal.history + assert journal.history == [] + + def test_has_completed(self, journal): + journal.start("pre") + journal.completed("pre") + + assert journal.has_completed("pre") + assert not journal.has_completed("main") + + def test_has_not_completed(self, journal): + journal.start("pre") + journal.completed("pre") + journal.start("main") + # simulate crash/restart + del journal.current + + assert journal.has_not_completed("main") + assert not journal.has_not_completed("pre") + + +# need a real class so that it can be pickled later +class JournalledMemory(restart.Journalled): + # divide is a dummy protocol + protocols = ["divide", "multiply"] + + def __init__(self): + self.memory = 1 + super().__init__() + + def multiply(self, x): + self.journal.start("multiply") + self.memory *= x + self.journal.completed("multiply") + + +@pytest.fixture +def journalled(): + return JournalledMemory() + + +class TestJournalled: + @staticmethod + def divide(m, x): + return m.memory / x + + def test_get_protocol_of_class(self, journalled): + f = journalled.get_protocol("multiply") + f(10) + assert journalled.memory == 10 + assert journalled.journal.has_completed("multiply") + + def test_get_protocol_dummy(self, journalled): + dummy_protocol = journalled.get_protocol("divide") + result = dummy_protocol(self.divide, journalled, 10) + + assert result == 1 / 10 + assert journalled.journal.has_completed("divide") + + def test_get_protocol_dummy_incomplete(self, journalled): + dummy_protocol = journalled.get_protocol("divide") + with pytest.raises(ZeroDivisionError): + result = dummy_protocol(self.divide, journalled, 0) + assert not journalled.journal.has_completed("divide") + + def test_save_load(self, tmp_path): + # instantiate a class that can be pickled (without pytest magic) + journalled = JournalledMemory() + f = journalled.get_protocol("multiply") + f(10) + assert journalled.memory == 10 + + pickle = tmp_path / "memory.pkl" + journalled.save(pickle) + + assert pickle.exists() + + # change instance + f(99) + assert journalled.memory == 10 * 99 + + # reload previous state + journalled.load(pickle) + assert journalled.memory == 10