diff --git a/doc/sphinx/source/analysis/ensemble.txt b/doc/sphinx/source/analysis/ensemble.txt index 50f6122f..5e9a907a 100644 --- a/doc/sphinx/source/analysis/ensemble.txt +++ b/doc/sphinx/source/analysis/ensemble.txt @@ -43,3 +43,9 @@ object when the :meth:`~mdpow.analysis.ensemble.EnsembleAtomGroup.ensemble` is r .. autoclass:: mdpow.analysis.ensemble.EnsembleAtomGroup :members: + + +ensemble_wrapper Decorator +__________________________ + +.. autofunction:: mdpow.analysis.ensemble_wrapper diff --git a/mdpow/analysis/ensemble.py b/mdpow/analysis/ensemble.py index 00dfbd27..9cacdd73 100644 --- a/mdpow/analysis/ensemble.py +++ b/mdpow/analysis/ensemble.py @@ -3,7 +3,7 @@ import os import errno -from typing import Optional, List +from typing import Optional, List, Union import numpy as np @@ -550,3 +550,48 @@ def check_groups_from_common_ensemble(groups: List[EnsembleAtomGroup]): from the same Ensemble.''' logger.error(msg) raise ValueError(msg) + + +def ensemble_wrapper(cls): + """A decorator for :class:`MDAnalysis.Universe ` subclasses modifying + them to accept an :class:`~mdpow.analysis.ensemble.Ensemble` or + :class:`~mdpow.analysis.ensemble.EnsembleAtomGroup`. + + .. rubric:: Example Analysis + + @ensemble_wrapper + class Example(AnalysisBase): + pass + + Ens = Ensemble(dirname='mol_dir) + ExRun = Example(Ens) + + """ + class EnsembleWrapper: + def __init__(self, ensemble: Union[Ensemble, EnsembleAtomGroup], *args, **kwargs): + self._ensemble = ensemble + self._args = args + self._kwargs = kwargs + self._Analysis = cls + + def _prepare_ensemble(self): + # Defined separately so user can modify behavior + self._results_dict = {x: None for x in self._ensemble.keys()} + + def _conclude_system(self): + # Defined separately so user can modify behavior + self._results_dict[self._key] = self._SystemRun.results + + def _conclude_ensemble(self): + self.results = self._results_dict + + def run(self, start=0, stop=0, step=1): + self._prepare_ensemble() + for self._key in self._ensemble.keys(): + self._SystemRun = self._Analysis(self._ensemble[self._key], *self._args, **self._kwargs) + self._SystemRun.run(start=start, step=step, stop=stop) + self._conclude_system() + self._conclude_ensemble() + return self + + return EnsembleWrapper diff --git a/mdpow/tests/test_ensemble.py b/mdpow/tests/test_ensemble.py index 9d8b2462..1d9c0caf 100644 --- a/mdpow/tests/test_ensemble.py +++ b/mdpow/tests/test_ensemble.py @@ -13,10 +13,11 @@ import MDAnalysis as mda from MDAnalysis.exceptions import NoDataError, SelectionError +from MDAnalysis.analysis.base import AnalysisBase from gromacs.utilities import in_dir -from ..analysis.ensemble import Ensemble, EnsembleAnalysis, EnsembleAtomGroup +from ..analysis.ensemble import Ensemble, EnsembleAnalysis, EnsembleAtomGroup, ensemble_wrapper from ..analysis.dihedral import DihedralAnalysis from pkg_resources import resource_filename @@ -161,3 +162,29 @@ def test_value_error(self): dh4 = ens.select_atoms('name C4 or name C17 or name S2 or name N3') with pytest.raises(ValueError): dh_run = DihedralAnalysis([dh1, dh2, dh4, dh3]).run(start=0, stop=4, step=1) + + def test_ensemble_wrapper1(self): + + class BaseTest(AnalysisBase): + def __init__(self, system: mda.Universe): + super(BaseTest, self).__init__(system.trajectory) + self.system = system + + def _prepare(self): + self._res_arr = [] + + def _single_frame(self): + self._res_arr.append(len(self.system.select_atoms('not resname SOL'))) + assert self._res_arr[-1] == 42 + + def _conclude(self): + self.results = self._res_arr + + @ensemble_wrapper + class EnsembleTest(BaseTest): + pass + + Sim = Ensemble(dirname=self.tmpdir.name, solvents=['water']) + SolvCount = EnsembleTest(Sim).run(stop=10) + assert isinstance(SolvCount, EnsembleTest) + assert SolvCount.results[('water', 'VDW', '0000')] == ([42] * 10)