Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporate Christian revisions & fix a bug with creating events from annotations #8

Merged
merged 4 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions eoglearn/datasets/mne.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from mne.utils import logger


def read_mne_eyetracking_raw(return_events=False):
def read_mne_eyetracking_raw(return_events=False, bandpass=True):
"""Return an MNE Raw object containing the EyeLink dataset.

Parameters
----------
return_events : bool
If ``True``, return the events for the eyetracking and EEG data.
bandpass: bool
If ``True``, applied a [1, 30]Hz bandpass.

Returns
-------
Expand All @@ -40,7 +42,8 @@ def read_mne_eyetracking_raw(return_events=False):
logger.debug(f"## EOGLEARN: Reading data from {et_fpath} and {eeg_fpath}")
raw_et = mne.io.read_raw_eyelink(et_fpath, create_annotations=["blinks"])
raw_eeg = mne.io.read_raw_egi(eeg_fpath, preload=True)
raw_eeg.filter(1, 30)
if bandpass:
raw_eeg.filter(1, 30)

logger.debug("## EOGLEARN: Finding events from the raw objects")
# due to a rogue one-shot event, find_events emits a warning
Expand All @@ -64,6 +67,15 @@ def read_mne_eyetracking_raw(return_events=False):
)
# Add EEG channels to the eye-tracking raw object
raw_et.add_channels([raw_eeg], force_update_info=True)

annots = mne.annotations_from_events(
et_events,
raw_et.info["sfreq"],
event_desc={2: "Flash"},
orig_time=raw_et.info["meas_date"],
)
raw_et.set_annotations(raw_et.annotations + annots)

if return_events:
return raw_et, dict(eyetrack=et_events, eeg=eeg_events)
return raw_et
98 changes: 64 additions & 34 deletions eoglearn/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# License: BSD-3-Clause

import matplotlib.pyplot as plt
import mne
import numpy as np
from mne.utils import logger
from sklearn.preprocessing import StandardScaler
Expand Down Expand Up @@ -31,6 +32,8 @@ class EOGDenoiser:
The number of units to pass into the initial LSTM layer. Defaults to 50.
n_times : int
The number of timepoints to pass into the LSTM model at once. Defaults to 100.
noise_picks: list | None
Channels that contain the noise channels.

Attributes
----------
Expand Down Expand Up @@ -58,36 +61,53 @@ class EOGDenoiser:
The StandardScaler instance used to scale the eyetracking data.
scaler_Y : sklearn.preprocessing.StandardScaler
The StandardScaler instance used to scale the EEG data.
noise_picks: list
Channels that contain the noise channels.
denoised_neural : np.ndarray
The denoised neural data, scaled back to the original units and shape,
i.e. ``(n_samples, n_meeg_channels)`` of the input :class:`~mne.io.Raw`
object.

Notes
-----
See the MNE-Python tutorial on aligning EEG and eyetracking data for information
on how to create a raw object with both EEG and eyetracking channels.
"""

def __init__(
self,
raw,
downsample=10,
n_units=50,
n_times=100,
):
def __init__(self, raw, downsample=10, n_units=50, n_times=100, noise_picks=None):
self.__x = None
self.__y = None
#############################################
# MNE Raw object and preprocessing parameters
#############################################
self.raw = raw
self.__denoised_neural = None

if noise_picks is None:
self.noise_picks = ["eyetrack"]
else:
self.noise_picks = noise_picks

self.downsample = downsample
#############################
# Set up the Keras LSTM Model
#############################
self.n_units = n_units
self.n_times = n_times
self.model = self.setup_model()
self.model = None
self.setup_model()
self.train_test_split() # i.e. self.X_train, self.Y_train

@property
def denoised_neural(self):
"""Return the MEEG signal without EOG artifact."""
if self.__denoised_neural is None:
logger.info(
"Denoising neural data, saving to ``denoised_neural_`` attribute."
)
self.compute_denoised_neural()
return self.__denoised_neural

@property
def downsampled_sfreq(self):
"""Return the sampling frequency after downsampling."""
Expand All @@ -101,7 +121,7 @@ def downsampled_sfreq(self):
def X(self):
"""Return an array of the raw eye-tracking data."""
if self.__x is None:
eye_data = self.raw.get_data(picks=["eyetrack"]).T
eye_data = self.raw.get_data(picks=self.noise_picks).T
if self.downsample is not None:
eye_data = eye_data[:: self.downsample, :] # i.e. eye_data[::10, :]
self.scaler_X = StandardScaler().fit(np.nan_to_num(eye_data))
Expand All @@ -112,14 +132,18 @@ def X(self):
def Y(self):
"""Return an array of the raw EEG data."""
if self.__y is None:
eeg_data = self.raw.copy()
if self.downsample:
eeg_data.resample(self.downsampled_sfreq)
eeg_data = eeg_data.get_data(picks="eeg").T
eeg_data = self._get_y_raw().get_data(picks="eeg").T
self.scaler_Y = StandardScaler().fit(np.nan_to_num(eeg_data))
self.__y = self.scaler_Y.transform(np.nan_to_num(eeg_data))
return self.__y

def _get_y_raw(self):
eeg_data = self.raw.copy()
if self.downsample:
eeg_data.resample(self.downsampled_sfreq)
eeg_data.pick("eeg")
return eeg_data

def setup_model(self):
"""Return a model instance given a raw instance.

Expand All @@ -143,7 +167,7 @@ def setup_model(self):

adagrad = Adagrad(learning_rate=1)
model.compile(loss="mean_squared_error", optimizer=adagrad)
return model
self.model = model

def train_test_split(self):
"""Split Eyetrack and EEG data into training and testing sets.
Expand Down Expand Up @@ -234,7 +258,7 @@ def predict_eog(self):
"""
if not len(self.X_train):
logger.info("setting up train/test split")
_ = self.train_test_split()
self.train_test_split()
predictions = self.predict(self.X_train)
# reshape to back to 2D array matching the original raw data shape
predicted_eog = predictions.reshape(
Expand All @@ -244,17 +268,8 @@ def predict_eog(self):
self.predicted_eog_ = self.scaler_Y.inverse_transform(predicted_eog)
return self.predicted_eog_

def get_denoised_neural(self):
"""Return the denoised M/EEG neural data.

Returns
-------
denoised_neural : np.ndarray
The denoised neural data, scaled back to the original units and shape,
i.e. ``(n_samples, n_meeg_channels)`` of the input :class:`~mne.io.Raw`
object. The denoised neural data are saved to the ``denoised_neural_``
attribute.
"""
def compute_denoised_neural(self):
"""Compute the denoised M/EEG neural data."""
if not len(self.Y_train):
logger.info("setting up train/test split")
_ = self.train_test_split()
Expand All @@ -269,8 +284,28 @@ def get_denoised_neural(self):
# scale back to original units of the Raw data
Y_train = self.scaler_Y.inverse_transform(Y_train)
assert Y_train.shape == predicted_eog.shape
self.denoised_neural_ = Y_train - predicted_eog
return self.denoised_neural_
self.__denoised_neural = Y_train - predicted_eog

def get_denoised_neural_raw(self):
"""Return an mne.io.Raw object of the MEEG signal without EOG artifact."""
raw_y = self._get_y_raw()
raw_denoised = mne.io.RawArray(self.denoised_neural.T, raw_y.info)
raw_denoised.set_annotations(raw_y.annotations)
return raw_denoised

def plot_loss(self):
"""Plot the training and validation loss.

Returns
-------
fig : matplotlib.figure.Figure
The resulting figure showing the loss functions.
"""
fig, ax = plt.subplots(constrained_layout=True)
for key in ["loss", "val_loss"]:
if key in self.model.history.history:
ax.plot(self.model.history.history[key], label=key)
return fig

def plot_eog_topo(self, montage, show=True):
"""Plot the topography of the eyetracking data.
Expand All @@ -290,12 +325,7 @@ def plot_eog_topo(self, montage, show=True):
fig : matplotlib.figure.Figure
The resulting figure object for the topomap plot.
"""
if not hasattr(self, "denoised_neural_"):
logger.info(
"Denoising neural data, saving to ``denoised_neural_`` attribute."
)
_ = self.get_denoised_neural()
squared_errors = np.square(self.denoised_neural_)
squared_errors = np.square(self.denoised_neural)
mean_squared_erros = np.mean(squared_errors, axis=0)
noise = np.sqrt(mean_squared_erros)
# reshape to back to 2D array matching the original raw data shape
Expand Down
2 changes: 1 addition & 1 deletion eoglearn/viz/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,5 @@ def plot_values_topomap(

if colorbar:
fig.colorbar(im[0], ax=axes, shrink=0.6, label="Percentage of EOG in signal")
plt_show(show)
plt_show(show, fig)
return fig
36 changes: 36 additions & 0 deletions examples/plot_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# %%
# Import the necessary packages
import mne
import matplotlib.pyplot as plt
from eoglearn.datasets import read_mne_eyetracking_raw
from eoglearn.models import EOGDenoiser

Expand Down Expand Up @@ -38,24 +39,59 @@
# %%
# Fit the model
# We will only use 10 epochs to speed up the example

# %%
eog_denoiser.fit_model(epochs=10)
history = eog_denoiser.model.history

# %%
# display the training history
print(history.history["loss"])
print(history.history["val_loss"])
eog_denoiser.plot_loss()

# %%
# Plot a topomap of the predicted EOG artifact.
# ---------------------------------------------
# The plot below displays the predicted amount of EOG artifact for each EEG sensor.
# The output is as we would expect, with frontal sensors containing the most EOG
# artifact.

# %%
montage = mne.channels.make_standard_montage("GSN-HydroCel-129")
eog_denoiser.plot_eog_topo(montage=montage)

# %%
# .. todo::
# Add a plot of the predicted EOG artifact for each EEG sensor over time.
# Add plots of the denoised EEG data.

# %%
# Compare ERP between the original and "EOG-denoised" signals
# -----------------------------------------------------------
#
# Let's create an averaged evoked response to the flash stimuli for both the original
# data and the "EOG-denoised" data. We'll focus on the frontal EEG channels, since it is
# these will contain the most EOG in the original signal.

# %%
pred_raw = eog_denoiser.get_denoised_neural_raw()
events, event_id = mne.events_from_annotations(pred_raw, regexp="Flash")
pred_epochs = mne.Epochs(
pred_raw, events=events, event_id=event_id, tmin=-0.3, tmax=3, preload=True
)

events, event_id = mne.events_from_annotations(eog_denoiser.raw, regexp="Flash")
original_epochs = mne.Epochs(
eog_denoiser.raw, events=events, event_id=event_id, tmin=-0.3, tmax=3, preload=True
)

frontal = ["E19", "E11", "E4", "E12", "E5"]
pred_avg_frontal = pred_epochs.average().get_data(picks=frontal).mean(0)
original_avg_frontal = original_epochs.average().get_data(picks=frontal).mean(0)

ax = plt.subplot()
ax.plot(pred_epochs.times, pred_avg_frontal, label="predicted")
ax.plot(original_epochs.times, original_avg_frontal, label="original")
ax.set_xlim(-0.3, 1)
ax.legend()
Loading
Loading