Skip to content

Commit

Permalink
Merge adjacent epoch flags into a single annotation (#151)
Browse files Browse the repository at this point in the history
* ENH: Merge Close flags

* Remove gap flagging

* FIX: black formatting
  • Loading branch information
scott-huberty authored Nov 7, 2023
1 parent a7842b2 commit 6197698
Showing 1 changed file with 21 additions and 83 deletions.
104 changes: 21 additions & 83 deletions pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,72 +346,6 @@ def chan_neighbour_r(epochs, nneigbr, method):
return m_neigbr_r.rename(ref_chan="ch")


# TODO: check that annot type contains all unique flags
def marks_flag_gap(
raw, min_gap_ms, included_annot_type=None, out_annot_name="bad_pylossless_gap"
):
"""Mark small gaps in time between pylossless annotations.
Parameters
----------
raw : mne.Raw
An instance of mne.Raw
min_gap_ms : int
Time in milleseconds. If the time between two consecutive pylossless
annotations is less than this value, that time period will be
annotated.
included_annot_type : str (Default None)
Descriptions of the `mne.Annotations` in the `mne.Raw` to be included.
If `None`, includes ('bad_pylossless_ch_sd', 'bad_pylossless_low_r',
'bad_pylossless_ic_sd1', 'bad_pylossless_gap').
out_annot_name : str (default 'bad_pylossless_gap')
The description for the `mne.Annotation` That is created for any gaps.
Returns
-------
Annotations : `mne.Annotations`
An instance of `mne.Annotations`
"""
if included_annot_type is None:
included_annot_type = (
"bad_pylossless_ch_sd",
"bad_pylossless_low_r",
"bad_pylossless_ic_sd1",
"bad_pylossless_gap",
)

if len(raw.annotations) == 0:
return mne.Annotations([], [], [], orig_time=raw.annotations.orig_time)

ret_val = np.array(
[
[annot["onset"], annot["duration"]]
for annot in raw.annotations
if annot["description"] in included_annot_type
]
).T

if len(ret_val) == 0:
return mne.Annotations([], [], [], orig_time=raw.annotations.orig_time)

onsets, durations = ret_val
offsets = onsets + durations
gaps = np.array(
[
min(onset - offsets[offsets < onset]) if np.sum(offsets < onset) else np.inf
for onset in onsets[1:]
]
)
gap_mask = gaps < min_gap_ms / 1000

return mne.Annotations(
onset=onsets[1:][gap_mask] - gaps[gap_mask],
duration=gaps[gap_mask],
description=out_annot_name,
orig_time=raw.annotations.orig_time,
)


def coregister(
raw_edf,
fiducials="estimated", # get fiducials from fsaverage
Expand Down Expand Up @@ -645,14 +579,29 @@ def add_pylossless_annotations(self, inds, event_type, epochs):
"""
# Concatenate epoched data back to continuous data
t_onset = epochs.events[inds, 0] / epochs.info["sfreq"]
df = pd.DataFrame(t_onset, columns=["onset"])
# We exclude the last sample from the duration because
# if the annot lasts the whole duration of the epoch
# it's end will coincide with the first sample of the
# next epoch, causing it to erroneously be rejected.
duration = np.ones_like(t_onset) / epochs.info["sfreq"] * len(epochs.times[:-1])
description = [f"bad_pylossless_{event_type}"] * len(t_onset)
df["duration"] = 1 / epochs.info["sfreq"] * len(epochs.times[:-1])
df["description"] = f"bad_pylossless_{event_type}"

# Merge close onsets to prevent a bunch of 1-second annotations of the same name
# find onsets close enough to be considered the same
df["close"] = df.sort_values("onset")["onset"].diff().le(1)
df["group"] = ~df["close"]
df["group"] = df["group"].cumsum()
# group the close onsets and merge them
df["onset"] = df.groupby("group")["onset"].transform("first")
df["duration"] = df.groupby("group")["duration"].transform("sum")
df = df.drop_duplicates(subset=["onset", "duration"])

annotations = mne.Annotations(
t_onset, duration, description, orig_time=self.raw.annotations.orig_time
df["onset"],
df["duration"],
df["description"],
orig_time=self.raw.annotations.orig_time,
)
self.raw.set_annotations(self.raw.annotations + annotations)

Expand Down Expand Up @@ -1025,11 +974,6 @@ def flag_epoch_low_r(self):
logger.info(f"📋 LOSSLESS: Uncorrelated epochs: {bad_epoch_inds}")
self.flags["epoch"].add_flag_cat("low_r", bad_epoch_inds, epochs)

def flag_epoch_gap(self):
"""Flag small time periods between pylossless annotations."""
annots = marks_flag_gap(self.raw, self.config["epoch_gap"]["min_gap_ms"])
self.raw.set_annotations(self.raw.annotations + annots)

@lossless_logger
def run_ica(self, run):
"""Run ICA.
Expand Down Expand Up @@ -1218,21 +1162,15 @@ def _run(self):
# 9. Calculate nearest neighbour R values for epochs
self.flag_epoch_low_r(message="Flagging Uncorrelated epochs")

# 10. Flag very small time periods between flagged time
self.flag_epoch_gap()

# 11. Run ICA
# 10. Run ICA
self.run_ica("run1", message="Running Initial ICA")

# 12. Calculate IC SD
# 11. Calculate IC SD
self.flag_epoch_ic_sd1(message="Flagging time periods with noisy" " IC's.")

# 13. TODO: integrate labels from IClabels to self.flags["ic"]
# 12. TODO: integrate labels from IClabels to self.flags["ic"]
self.run_ica("run2", message="Running Final ICA.")

# 14. Flag very small time periods between flagged time
self.flag_epoch_gap()

def run_dataset(self, paths):
"""Run a full dataset.
Expand Down

0 comments on commit 6197698

Please sign in to comment.