From bec5043a617271fc045ee18478776d4054e6a25b Mon Sep 17 00:00:00 2001 From: Christian O'Reilly Date: Tue, 12 Nov 2024 21:47:50 -0500 Subject: [PATCH] Fix issue #154. --- pylossless/config/rejection.py | 22 ++++++++++++++++++++-- pylossless/pipeline.py | 3 +++ pylossless/tests/test_rejection.py | 14 +++++++++++++- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pylossless/config/rejection.py b/pylossless/config/rejection.py index 6c094f7..3cd6be8 100644 --- a/pylossless/config/rejection.py +++ b/pylossless/config/rejection.py @@ -4,6 +4,8 @@ # License: MIT import numpy as np +from importlib.metadata import version +import warnings from .config import ConfigMixin @@ -90,7 +92,7 @@ def __init__( ) def __repr__(self): - """Return a summary of the Calibration object.""" + """Return a summary of the RejectionPolicy object.""" return ( f"RejectionPolicy: |\n" f" config_fname: {self['config_fname']}\n" @@ -101,7 +103,7 @@ def __repr__(self): f" remove_flagged_ics: {self['remove_flagged_ics']}\n" ) - def apply(self, pipeline, return_ica=False): + def apply(self, pipeline, return_ica=False, version_mismatch="raise"): """Return a cleaned new raw object based on the rejection policy. Parameters @@ -119,6 +121,22 @@ def apply(self, pipeline, return_ica=False): An :class:`~mne.io.Raw` instance with the appropriate channels and ICs added to mne bads, interpolated, or dropped. """ + + if pipeline.config["version"] != version("pylossless"): + error_message = ( + "The output of the pipeline was saved with pylossless version " + f"{pipeline.config['version']} and you are currently using " + f"version {version('pylossless')}. The behavior is undefined." + ) + if version_mismatch == "raise": + raise RuntimeError(error_message) + elif version_mismatch == "warning": + warnings.warn(error_message, RuntimeWarning) + elif version_mismatch != "ignore": + raise ValueError("version_mismatch can take values 'raise', " + "'warning', or 'ignore'. Received " + f"{version_mismatch}.") + # Get the raw object raw = pipeline.raw.copy() diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index 7956e0a..2b6a2d8 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -10,6 +10,7 @@ from copy import deepcopy from pathlib import Path from functools import partial +from importlib.metadata import version # Math and data structures import numpy as np @@ -1094,6 +1095,8 @@ def save(self, derivatives_path, overwrite=False, format="EDF", event_id=None): config_bidspath = bpath.update( extension=".yaml", suffix="ll_config", check=False ) + + self.config["version"] = version("pylossless") self.config.save(config_bidspath) # Save flag["ch"] diff --git a/pylossless/tests/test_rejection.py b/pylossless/tests/test_rejection.py index 478b71d..88ca889 100644 --- a/pylossless/tests/test_rejection.py +++ b/pylossless/tests/test_rejection.py @@ -16,7 +16,19 @@ def test_rejection_policy(clean_ch_mode, pipeline_fixture): want_flags = ["noisy", "uncorrelated", "bridged"] assert rejection_config["ch_flags_to_reject"] == want_flags - raw, ica = rejection_config.apply(pipeline_fixture, return_ica=True) + pipeline_fixture.config["version"] = "-1" + with pytest.raises(RuntimeError, match="The output of the pipeline was"): + raw, ica = rejection_config.apply(pipeline_fixture, + version_mismatch="raise") + with pytest.raises(RuntimeWarning, match="The output of the pipeline was"): + raw, ica = rejection_config.apply(pipeline_fixture, + version_mismatch="warning") + with pytest.raises(ValueError, match="version_mismatch can take values"): + raw, ica = rejection_config.apply(pipeline_fixture, + version_mismatch="sdfdf") + raw, ica = rejection_config.apply(pipeline_fixture, return_ica=True, + version_mismatch="ignore") + flagged_chs = [] for key in rejection_config["ch_flags_to_reject"]: flagged_chs.extend(pipeline_fixture.flags["ch"][key].tolist())