From 35a673dc47e1fb22b8e23541cf263fd4287908cf Mon Sep 17 00:00:00 2001 From: Christian O'Reilly Date: Wed, 13 Nov 2024 11:44:20 -0500 Subject: [PATCH] Fix issue #154. (#179) * Fix issue #154. * codespell * Fix example bug. --- pylossless/config/rejection.py | 21 +++++++++++++++++++-- pylossless/pipeline.py | 13 +++++++++++++ pylossless/tests/test_rejection.py | 14 +++++++++++++- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/pylossless/config/rejection.py b/pylossless/config/rejection.py index 6c094f7..a7f33b6 100644 --- a/pylossless/config/rejection.py +++ b/pylossless/config/rejection.py @@ -4,6 +4,8 @@ # License: MIT import numpy as np +from importlib.metadata import version +import warnings from .config import ConfigMixin @@ -90,7 +92,7 @@ def __init__( ) def __repr__(self): - """Return a summary of the Calibration object.""" + """Return a summary of the RejectionPolicy object.""" return ( f"RejectionPolicy: |\n" f" config_fname: {self['config_fname']}\n" @@ -101,7 +103,7 @@ def __repr__(self): f" remove_flagged_ics: {self['remove_flagged_ics']}\n" ) - def apply(self, pipeline, return_ica=False): + def apply(self, pipeline, return_ica=False, version_mismatch="raise"): """Return a cleaned new raw object based on the rejection policy. Parameters @@ -119,6 +121,21 @@ def apply(self, pipeline, return_ica=False): An :class:`~mne.io.Raw` instance with the appropriate channels and ICs added to mne bads, interpolated, or dropped. """ + if pipeline.config["version"] != version("pylossless"): + error_message = ( + "The output of the pipeline was saved with pylossless version " + f"{pipeline.config['version']} and you are currently using " + f"version {version('pylossless')}. The behavior is undefined." + ) + if version_mismatch == "raise": + raise RuntimeError(error_message) + elif version_mismatch == "warning": + warnings.warn(error_message, RuntimeWarning) + elif version_mismatch != "ignore": + raise ValueError("version_mismatch can take values 'raise', " + "'warning', or 'ignore'. Received " + f"{version_mismatch}.") + # Get the raw object raw = pipeline.raw.copy() diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index 7956e0a..07cc4b4 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -10,6 +10,7 @@ from copy import deepcopy from pathlib import Path from functools import partial +from importlib.metadata import version # Math and data structures import numpy as np @@ -461,6 +462,8 @@ def __init__(self, config_path=None, config=None): "epoch": FlaggedEpochs(self), "ic": FlaggedICs(), } + self._config = None + if config: self.config = config if config_path is None: @@ -526,6 +529,15 @@ def _repr_html_(self): return html + @property + def config(self): + return self._config + + @config.setter + def config(self, config): + self._config = config + self._config["version"] = version("pylossless") + @property def config_fname(self): warn('config_fname is deprecated and will be removed from future versions.', @@ -1094,6 +1106,7 @@ def save(self, derivatives_path, overwrite=False, format="EDF", event_id=None): config_bidspath = bpath.update( extension=".yaml", suffix="ll_config", check=False ) + self.config.save(config_bidspath) # Save flag["ch"] diff --git a/pylossless/tests/test_rejection.py b/pylossless/tests/test_rejection.py index 478b71d..88ca889 100644 --- a/pylossless/tests/test_rejection.py +++ b/pylossless/tests/test_rejection.py @@ -16,7 +16,19 @@ def test_rejection_policy(clean_ch_mode, pipeline_fixture): want_flags = ["noisy", "uncorrelated", "bridged"] assert rejection_config["ch_flags_to_reject"] == want_flags - raw, ica = rejection_config.apply(pipeline_fixture, return_ica=True) + pipeline_fixture.config["version"] = "-1" + with pytest.raises(RuntimeError, match="The output of the pipeline was"): + raw, ica = rejection_config.apply(pipeline_fixture, + version_mismatch="raise") + with pytest.raises(RuntimeWarning, match="The output of the pipeline was"): + raw, ica = rejection_config.apply(pipeline_fixture, + version_mismatch="warning") + with pytest.raises(ValueError, match="version_mismatch can take values"): + raw, ica = rejection_config.apply(pipeline_fixture, + version_mismatch="sdfdf") + raw, ica = rejection_config.apply(pipeline_fixture, return_ica=True, + version_mismatch="ignore") + flagged_chs = [] for key in rejection_config["ch_flags_to_reject"]: flagged_chs.extend(pipeline_fixture.flags["ch"][key].tolist())