Skip to content

Commit

Permalink
Fix loading epoch flags.
Browse files Browse the repository at this point in the history
  • Loading branch information
christian-oreilly committed Nov 13, 2024
1 parent 159a148 commit cc88579
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 63 deletions.
127 changes: 67 additions & 60 deletions pylossless/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,53 @@

from .utils._utils import _icalabel_to_data_frame

IC_LABELS = mne_icalabel.config.ICA_LABELS_TO_MNE
CH_LABELS: dict[str, str] = {
"Noisy": "ch_sd",
"Bridged": "bridge",
"Uncorrelated": "low_r",
"Rank": "rank"
}
EPOCH_LABELS: dict[str, str] = {
"Noisy": "noisy",
"Noisy ICs": "noisy_ICs",
"Uncorrelated": "uncorrelated",
}


class _Flagged(dict):

def __init__(self, key_map, kind_str, ll, *args, **kwargs):
"""Initialize class."""
super().__init__(*args, **kwargs)
self.ll = ll
self._key_map = key_map
self._kind_str = kind_str

class FlaggedChs(dict):
@property
def valid_keys(self):
"""Return the valid keys."""
return tuple(self._key_map.values())

def __repr__(self):
"""Return a string representation."""
ret_str = f"Flagged {self._kind_str}s: |\n"
for key, val in self._key_map.items():
ret_str += f" {key}: {self.get(val, None)}\n"
return ret_str

def __eq__(self, other):
for key in self.valid_keys:
if not np.array_equal(self.get(key, np.array([])),
other.get(key, np.array([]))):
return False
return True

def __ne__(self, other):
return not self == other


class FlaggedChs(_Flagged):
"""Object for handling flagged channels in an instance of mne.io.Raw.
Attributes
Expand Down Expand Up @@ -47,32 +92,9 @@ class FlaggedChs(dict):
and methods for python dictionaries.
"""

def __init__(self, ll, *args, **kwargs):
def __init__(self, *args, **kwargs):
"""Initialize class."""
super().__init__(*args, **kwargs)
self.ll = ll

@property
def valid_keys(self):
"""Return the valid keys for FlaggedChs objects."""
return ('ch_sd', 'bridge', 'low_r', 'rank')

def __repr__(self):
"""Return a string representation of the FlaggedChs object."""
return (
f"Flagged channels: |\n"
f" Noisy: {self.get('ch_sd', None)}\n"
f" Bridged: {self.get('bridge', None)}\n"
f" Uncorrelated: {self.get('low_r', None)}\n"
f" Rank: {self.get('rank', None)}\n"
)

def __eq__(self, other):
for key in self.valid_keys:
if not np.array_equal(self.get(key, np.array([])),
other.get(key, np.array([]))):
return False
return True
super().__init__(CH_LABELS, "channel", *args, **kwargs)

def add_flag_cat(self, kind, bad_ch_names, *args):
"""Store channel names that have been flagged by pipeline.
Expand Down Expand Up @@ -152,7 +174,7 @@ def load_tsv(self, fname):
self[label] = grp_df.ch_names.values


class FlaggedEpochs(dict):
class FlaggedEpochs(_Flagged):
"""Object for handling flagged Epochs in an instance of mne.Epochs.
Methods
Expand All @@ -171,7 +193,7 @@ class FlaggedEpochs(dict):
and methods for python dictionaries.
"""

def __init__(self, ll, *args, **kwargs):
def __init__(self, *args, **kwargs):
"""Initialize class.
Parameters
Expand All @@ -183,30 +205,7 @@ def __init__(self, ll, *args, **kwargs):
kwargs : dict
keyword arguments accepted by python's dictionary class.
"""
super().__init__(*args, **kwargs)

self.ll = ll

@property
def valid_keys(self):
"""Return the valid keys for FlaggedEpochs objects."""
return ('noisy', 'uncorrelated', 'noisy_ICs')

def __repr__(self):
"""Return a string representation of the FlaggedEpochs object."""
return (
f"Flagged channels: |\n"
f" Noisy: {self.get('noisy', None)}\n"
f" Noisy ICs: {self.get('noisy_ICs', None)}\n"
f" Uncorrelated: {self.get('uncorrelated', None)}\n"
)

def __eq__(self, other):
for key in self.valid_keys:
if not np.array_equal(self.get(key, np.array([])),
other.get(key, np.array([]))):
return False
return True
super().__init__(EPOCH_LABELS, "epoch", *args, **kwargs)

def add_flag_cat(self, kind, bad_epoch_inds, epochs):
"""Add information on time periods flagged by pyLossless.
Expand All @@ -227,17 +226,25 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs):
self[kind] = bad_epoch_inds
self.ll.add_pylossless_annotations(bad_epoch_inds, kind, epochs)

def load_from_raw(self, raw):
def load_from_raw(self, raw, events, config):
"""Load pylossless annotations from raw object."""
sfreq = raw.info["sfreq"]
tmax = config["epoching"]["epochs_args"]["tmax"]
tmin = config["epoching"]["epochs_args"]["tmin"]
starts = events[:, 0]/sfreq - tmin
stops = events[:, 0]/sfreq + tmax
for annot in raw.annotations:
if annot["description"].upper().startswith("BAD_LL"):
ind_onset = int(np.round(annot["onset"] * sfreq))
ind_dur = int(np.round(annot["duration"] * sfreq))
inds = np.arange(ind_onset, ind_onset + ind_dur)
if annot["description"] not in self:
self[annot["description"]] = list()
self[annot["description"]].append(inds)
if annot["description"].upper().startswith("BAD_LL_"):
onset = annot["onset"]
offset = annot["onset"]+annot["duration"]
mask = ((starts >= onset) & (starts < offset) |
(stops > onset) & (stops <= offset) |
(onset <= starts) & (offset >= stops))
inds = np.where(mask)[0]
desc = annot["description"].lower().replace("bad_ll_", "")
if desc not in self:
self[desc] = np.array([])
self[desc] = np.concatenate((self[desc], inds))


class FlaggedICs(pd.DataFrame):
Expand Down
2 changes: 1 addition & 1 deletion pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def load_ll_derivative(self, derivatives_path):
self.flags["ch"].load_tsv(flagged_chs_fpath.fpath)

# Load Flagged Epochs
self.flags["epoch"].load_from_raw(self.raw)
self.flags["epoch"].load_from_raw(self.raw, self.get_events(), self.config)

return self

Expand Down
3 changes: 1 addition & 2 deletions pylossless/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,11 @@ def test_load_flags(pipeline_fixture, tmp_path):
pipeline_fixture.save(bids_path,
overwrite=False, format="EDF", event_id=None)
pipeline = ll.LosslessPipeline().load_ll_derivative(bids_path)

assert pipeline_fixture.flags['ch'] == pipeline.flags['ch']
pipeline.flags['ch']["bridge"] = ["xx"]
assert pipeline_fixture.flags['ch'] != pipeline.flags['ch']

assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']
pipeline.flags['epoch']["bridge"] = ["noisy"]
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']

assert pipeline_fixture.flags['ic'] == pipeline.flags['ic']

0 comments on commit cc88579

Please sign in to comment.