Skip to content

Commit

Permalink
Fix issue #154. (#179)
Browse files Browse the repository at this point in the history
* Fix issue #154.

* codespell

* Fix example bug.
  • Loading branch information
christian-oreilly authored Nov 13, 2024
1 parent ec1cedb commit 35a673d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
21 changes: 19 additions & 2 deletions pylossless/config/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# License: MIT

import numpy as np
from importlib.metadata import version
import warnings

from .config import ConfigMixin

Expand Down Expand Up @@ -90,7 +92,7 @@ def __init__(
)

def __repr__(self):
"""Return a summary of the Calibration object."""
"""Return a summary of the RejectionPolicy object."""
return (
f"RejectionPolicy: |\n"
f" config_fname: {self['config_fname']}\n"
Expand All @@ -101,7 +103,7 @@ def __repr__(self):
f" remove_flagged_ics: {self['remove_flagged_ics']}\n"
)

def apply(self, pipeline, return_ica=False):
def apply(self, pipeline, return_ica=False, version_mismatch="raise"):
"""Return a cleaned new raw object based on the rejection policy.
Parameters
Expand All @@ -119,6 +121,21 @@ def apply(self, pipeline, return_ica=False):
An :class:`~mne.io.Raw` instance with the appropriate channels and ICs
added to mne bads, interpolated, or dropped.
"""
if pipeline.config["version"] != version("pylossless"):
error_message = (
"The output of the pipeline was saved with pylossless version "
f"{pipeline.config['version']} and you are currently using "
f"version {version('pylossless')}. The behavior is undefined."
)
if version_mismatch == "raise":
raise RuntimeError(error_message)
elif version_mismatch == "warning":
warnings.warn(error_message, RuntimeWarning)
elif version_mismatch != "ignore":
raise ValueError("version_mismatch can take values 'raise', "
"'warning', or 'ignore'. Received "
f"{version_mismatch}.")

# Get the raw object
raw = pipeline.raw.copy()

Expand Down
13 changes: 13 additions & 0 deletions pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from copy import deepcopy
from pathlib import Path
from functools import partial
from importlib.metadata import version

# Math and data structures
import numpy as np
Expand Down Expand Up @@ -461,6 +462,8 @@ def __init__(self, config_path=None, config=None):
"epoch": FlaggedEpochs(self),
"ic": FlaggedICs(),
}
self._config = None

if config:
self.config = config
if config_path is None:
Expand Down Expand Up @@ -526,6 +529,15 @@ def _repr_html_(self):

return html

@property
def config(self):
return self._config

@config.setter
def config(self, config):
self._config = config
self._config["version"] = version("pylossless")

@property
def config_fname(self):
warn('config_fname is deprecated and will be removed from future versions.',
Expand Down Expand Up @@ -1094,6 +1106,7 @@ def save(self, derivatives_path, overwrite=False, format="EDF", event_id=None):
config_bidspath = bpath.update(
extension=".yaml", suffix="ll_config", check=False
)

self.config.save(config_bidspath)

# Save flag["ch"]
Expand Down
14 changes: 13 additions & 1 deletion pylossless/tests/test_rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,19 @@ def test_rejection_policy(clean_ch_mode, pipeline_fixture):
want_flags = ["noisy", "uncorrelated", "bridged"]
assert rejection_config["ch_flags_to_reject"] == want_flags

raw, ica = rejection_config.apply(pipeline_fixture, return_ica=True)
pipeline_fixture.config["version"] = "-1"
with pytest.raises(RuntimeError, match="The output of the pipeline was"):
raw, ica = rejection_config.apply(pipeline_fixture,
version_mismatch="raise")
with pytest.raises(RuntimeWarning, match="The output of the pipeline was"):
raw, ica = rejection_config.apply(pipeline_fixture,
version_mismatch="warning")
with pytest.raises(ValueError, match="version_mismatch can take values"):
raw, ica = rejection_config.apply(pipeline_fixture,
version_mismatch="sdfdf")
raw, ica = rejection_config.apply(pipeline_fixture, return_ica=True,
version_mismatch="ignore")

flagged_chs = []
for key in rejection_config["ch_flags_to_reject"]:
flagged_chs.extend(pipeline_fixture.flags["ch"][key].tolist())
Expand Down

0 comments on commit 35a673d

Please sign in to comment.