From cc885795b662275fa83e1bc80cc2238ad2f67424 Mon Sep 17 00:00:00 2001 From: Christian O'Reilly Date: Tue, 12 Nov 2024 19:29:13 -0500 Subject: [PATCH] Fix loading epoch flags. --- pylossless/flagging.py | 127 ++++++++++++++++-------------- pylossless/pipeline.py | 2 +- pylossless/tests/test_pipeline.py | 3 +- 3 files changed, 69 insertions(+), 63 deletions(-) diff --git a/pylossless/flagging.py b/pylossless/flagging.py index ce14727..cbd3f0f 100644 --- a/pylossless/flagging.py +++ b/pylossless/flagging.py @@ -17,8 +17,53 @@ from .utils._utils import _icalabel_to_data_frame +IC_LABELS = mne_icalabel.config.ICA_LABELS_TO_MNE +CH_LABELS: dict[str, str] = { + "Noisy": "ch_sd", + "Bridged": "bridge", + "Uncorrelated": "low_r", + "Rank": "rank" +} +EPOCH_LABELS: dict[str, str] = { + "Noisy": "noisy", + "Noisy ICs": "noisy_ICs", + "Uncorrelated": "uncorrelated", +} + + +class _Flagged(dict): + + def __init__(self, key_map, kind_str, ll, *args, **kwargs): + """Initialize class.""" + super().__init__(*args, **kwargs) + self.ll = ll + self._key_map = key_map + self._kind_str = kind_str -class FlaggedChs(dict): + @property + def valid_keys(self): + """Return the valid keys.""" + return tuple(self._key_map.values()) + + def __repr__(self): + """Return a string representation.""" + ret_str = f"Flagged {self._kind_str}s: |\n" + for key, val in self._key_map.items(): + ret_str += f" {key}: {self.get(val, None)}\n" + return ret_str + + def __eq__(self, other): + for key in self.valid_keys: + if not np.array_equal(self.get(key, np.array([])), + other.get(key, np.array([]))): + return False + return True + + def __ne__(self, other): + return not self == other + + +class FlaggedChs(_Flagged): """Object for handling flagged channels in an instance of mne.io.Raw. Attributes @@ -47,32 +92,9 @@ class FlaggedChs(dict): and methods for python dictionaries. """ - def __init__(self, ll, *args, **kwargs): + def __init__(self, *args, **kwargs): """Initialize class.""" - super().__init__(*args, **kwargs) - self.ll = ll - - @property - def valid_keys(self): - """Return the valid keys for FlaggedChs objects.""" - return ('ch_sd', 'bridge', 'low_r', 'rank') - - def __repr__(self): - """Return a string representation of the FlaggedChs object.""" - return ( - f"Flagged channels: |\n" - f" Noisy: {self.get('ch_sd', None)}\n" - f" Bridged: {self.get('bridge', None)}\n" - f" Uncorrelated: {self.get('low_r', None)}\n" - f" Rank: {self.get('rank', None)}\n" - ) - - def __eq__(self, other): - for key in self.valid_keys: - if not np.array_equal(self.get(key, np.array([])), - other.get(key, np.array([]))): - return False - return True + super().__init__(CH_LABELS, "channel", *args, **kwargs) def add_flag_cat(self, kind, bad_ch_names, *args): """Store channel names that have been flagged by pipeline. @@ -152,7 +174,7 @@ def load_tsv(self, fname): self[label] = grp_df.ch_names.values -class FlaggedEpochs(dict): +class FlaggedEpochs(_Flagged): """Object for handling flagged Epochs in an instance of mne.Epochs. Methods @@ -171,7 +193,7 @@ class FlaggedEpochs(dict): and methods for python dictionaries. """ - def __init__(self, ll, *args, **kwargs): + def __init__(self, *args, **kwargs): """Initialize class. Parameters @@ -183,30 +205,7 @@ def __init__(self, ll, *args, **kwargs): kwargs : dict keyword arguments accepted by python's dictionary class. """ - super().__init__(*args, **kwargs) - - self.ll = ll - - @property - def valid_keys(self): - """Return the valid keys for FlaggedEpochs objects.""" - return ('noisy', 'uncorrelated', 'noisy_ICs') - - def __repr__(self): - """Return a string representation of the FlaggedEpochs object.""" - return ( - f"Flagged channels: |\n" - f" Noisy: {self.get('noisy', None)}\n" - f" Noisy ICs: {self.get('noisy_ICs', None)}\n" - f" Uncorrelated: {self.get('uncorrelated', None)}\n" - ) - - def __eq__(self, other): - for key in self.valid_keys: - if not np.array_equal(self.get(key, np.array([])), - other.get(key, np.array([]))): - return False - return True + super().__init__(EPOCH_LABELS, "epoch", *args, **kwargs) def add_flag_cat(self, kind, bad_epoch_inds, epochs): """Add information on time periods flagged by pyLossless. @@ -227,17 +226,25 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs): self[kind] = bad_epoch_inds self.ll.add_pylossless_annotations(bad_epoch_inds, kind, epochs) - def load_from_raw(self, raw): + def load_from_raw(self, raw, events, config): """Load pylossless annotations from raw object.""" sfreq = raw.info["sfreq"] + tmax = config["epoching"]["epochs_args"]["tmax"] + tmin = config["epoching"]["epochs_args"]["tmin"] + starts = events[:, 0]/sfreq - tmin + stops = events[:, 0]/sfreq + tmax for annot in raw.annotations: - if annot["description"].upper().startswith("BAD_LL"): - ind_onset = int(np.round(annot["onset"] * sfreq)) - ind_dur = int(np.round(annot["duration"] * sfreq)) - inds = np.arange(ind_onset, ind_onset + ind_dur) - if annot["description"] not in self: - self[annot["description"]] = list() - self[annot["description"]].append(inds) + if annot["description"].upper().startswith("BAD_LL_"): + onset = annot["onset"] + offset = annot["onset"]+annot["duration"] + mask = ((starts >= onset) & (starts < offset) | + (stops > onset) & (stops <= offset) | + (onset <= starts) & (offset >= stops)) + inds = np.where(mask)[0] + desc = annot["description"].lower().replace("bad_ll_", "") + if desc not in self: + self[desc] = np.array([]) + self[desc] = np.concatenate((self[desc], inds)) class FlaggedICs(pd.DataFrame): diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index fa64d16..3cb6953 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -1258,7 +1258,7 @@ def load_ll_derivative(self, derivatives_path): self.flags["ch"].load_tsv(flagged_chs_fpath.fpath) # Load Flagged Epochs - self.flags["epoch"].load_from_raw(self.raw) + self.flags["epoch"].load_from_raw(self.raw, self.get_events(), self.config) return self diff --git a/pylossless/tests/test_pipeline.py b/pylossless/tests/test_pipeline.py index 74059c2..cc76d21 100644 --- a/pylossless/tests/test_pipeline.py +++ b/pylossless/tests/test_pipeline.py @@ -79,6 +79,7 @@ def test_load_flags(pipeline_fixture, tmp_path): pipeline_fixture.save(bids_path, overwrite=False, format="EDF", event_id=None) pipeline = ll.LosslessPipeline().load_ll_derivative(bids_path) + assert pipeline_fixture.flags['ch'] == pipeline.flags['ch'] pipeline.flags['ch']["bridge"] = ["xx"] assert pipeline_fixture.flags['ch'] != pipeline.flags['ch'] @@ -86,5 +87,3 @@ def test_load_flags(pipeline_fixture, tmp_path): assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch'] pipeline.flags['epoch']["bridge"] = ["noisy"] assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch'] - - assert pipeline_fixture.flags['ic'] == pipeline.flags['ic']