Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a vizualitation module #5

Merged
merged 8 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
extensions = [
"numpydoc",
"sphinx.ext.autodoc",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx_design",
"sphinx_gallery.gen_gallery",
Expand All @@ -37,6 +38,18 @@
# Allows us to use the ..todo:: directive
todo_include_todos = True

# -- Options for intersphinx extension ---------------------------------------

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"pandas": ("https://pandas.pydata.org/docs/", None),
"xarray": ("https://docs.xarray.dev/en/stable/", None),
"mne": ("https://mne.tools/dev", None),
"mne_icalabel": ("https://mne.tools/mne-icalabel/dev", None),
"mne_bids": ("https://mne.tools/mne-bids/dev", None),
"eoglearn": ("https://eoglearn.readthedocs.io/en/latest/", None),
}

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
Expand Down
3 changes: 2 additions & 1 deletion eoglearn/datasets/mne.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def read_mne_eyetracking_raw(return_events=False):

logger.debug(f"## EOGLEARN: Reading data from {et_fpath} and {eeg_fpath}")
raw_et = mne.io.read_raw_eyelink(et_fpath, create_annotations=["blinks"])
raw_eeg = mne.io.read_raw_egi(eeg_fpath, preload=True, verbose="warning")
raw_eeg = mne.io.read_raw_egi(eeg_fpath, preload=True)
raw_eeg.filter(1, 30)

logger.debug("## EOGLEARN: Finding events from the raw objects")
# due to a rogue one-shot event, find_events emits a warning
Expand Down
223 changes: 179 additions & 44 deletions eoglearn/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
#
# License: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np
from mne.utils import logger
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.layers import LSTM
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adagrad

from eoglearn.viz import plot_values_topomap


class EOGDenoiser:
"""Use simultaneous EEG and Eyetracking to Denoise EOG from the EEG data.
Expand All @@ -23,9 +27,6 @@ class EOGDenoiser:
Eyetracking channels will be decimated without any filtering. resampling and
decimating will be done on copies of the data, so the original input data will
be preserved.
filter : tuple
The bandpass filter to apply to the EEG data. The filter will only be applied
to a copy of the raw data, and the original data will be preserved.
n_units : int
The number of units to pass into the initial LSTM layer. Defaults to 50.
n_times : int
Expand All @@ -37,8 +38,6 @@ class EOGDenoiser:
The original input ``mne.io.Raw`` instance.
downsample : int
The factor by which the data was downsampled.
filter : tuple
The bandpass filter that was applied to the EEG data.
n_units : int
The number of units in the initial LSTM layer.
n_times : int
Expand Down Expand Up @@ -70,7 +69,6 @@ def __init__(
self,
raw,
downsample=10,
mne_filter=(1, 30),
n_units=50,
n_times=100,
):
Expand All @@ -80,8 +78,8 @@ def __init__(
# MNE Raw object and preprocessing parameters
#############################################
self.raw = raw

self.downsample = downsample
self.filter = mne_filter
#############################
# Set up the Keras LSTM Model
#############################
Expand All @@ -90,6 +88,38 @@ def __init__(
self.model = self.setup_model()
self.train_test_split() # i.e. self.X_train, self.Y_train

@property
def downsampled_sfreq(self):
"""Return the sampling frequency after downsampling."""
return (
self.raw.info["sfreq"] // self.downsample
if self.downsample
else self.raw.info["sfreq"]
)

@property
def X(self):
"""Return an array of the raw eye-tracking data."""
if self.__x is None:
eye_data = self.raw.get_data(picks=["eyetrack"]).T
if self.downsample is not None:
eye_data = eye_data[:: self.downsample, :] # i.e. eye_data[::10, :]
self.scaler_X = StandardScaler().fit(np.nan_to_num(eye_data))
self.__x = self.scaler_X.transform(np.nan_to_num(eye_data))
return self.__x

@property
def Y(self):
"""Return an array of the raw EEG data."""
if self.__y is None:
eeg_data = self.raw.copy()
if self.downsample:
eeg_data.resample(self.downsampled_sfreq)
eeg_data = eeg_data.get_data(picks="eeg").T
self.scaler_Y = StandardScaler().fit(np.nan_to_num(eeg_data))
self.__y = self.scaler_Y.transform(np.nan_to_num(eeg_data))
return self.__y

def setup_model(self):
"""Return a model instance given a raw instance.

Expand Down Expand Up @@ -138,7 +168,14 @@ def train_test_split(self):
(-1, self.n_times, self.Y.shape[-1]), order="C"
)

def fit_model(self, fitting_kwargs=None):
def fit_model(
self,
epochs=10,
validation_split=0.2,
batch_size=1,
verbose=2,
fitting_kwargs=None,
):
"""Fit the EOGDenoiser model using the input Raw object.

Parameters
Expand All @@ -149,48 +186,146 @@ def fit_model(self, fitting_kwargs=None):
``dict(epochs=50, validation_split=0.2, batch_size=1, verbose=2)``.
"""
if fitting_kwargs is None:
fitting_kwargs = dict(
epochs=50,
validation_split=0.2,
batch_size=1,
verbose=2,
)
fitting_kwargs = dict()
self.model.fit(
self.X_train,
self.Y_train,
epochs=epochs,
validation_split=validation_split,
batch_size=batch_size,
verbose=verbose,
**fitting_kwargs,
)

@property
def downsampled_sfreq(self):
"""Return the sampling frequency after downsampling."""
return (
self.raw.info["sfreq"] // self.downsample
if self.downsample
else self.raw.info["sfreq"]
def predict(self, data, predict_kwargs=None):
"""Return Model predictions.

Parameters
----------
data : np.ndarray
The data to predict. Must be the same shape as ``self.X_train``.
predict_kwargs : dict
A dictionary of keyword arguments to pass into the ``predict`` method of
the Keras ``Sequential`` model. Defaults to ``None``.

Returns
-------
predictions : np.ndarray
The predicted data.

Notes
-----
This method is a wrapper for the ``predict`` method of the Keras
:class:`~tensorflow.keras.models.Sequential` model.
"""
if not predict_kwargs:
predict_kwargs = dict()
return self.model.predict(data, **predict_kwargs)

def predict_eog(self):
"""Return the predicted EOG data.

Returns
-------
predicted_eog : np.ndarray
The predicted EOG data, scaled back to the original units and shape, i.e.
``(n_samples, n_meeg_channels)`` of the input :class:`~mne.io.Raw` object.
The predictions are saved to the ``predicted_eog_`` attribute.
"""
if not len(self.X_train):
logger.info("setting up train/test split")
_ = self.train_test_split()
predictions = self.predict(self.X_train)
# reshape to back to 2D array matching the original raw data shape
predicted_eog = predictions.reshape(
(np.prod(predictions.shape[:-1]), predictions.shape[-1]), order="C"
)
# inverse transform to get back to original raw data units
self.predicted_eog_ = self.scaler_Y.inverse_transform(predicted_eog)
return self.predicted_eog_

@property
def X(self):
"""Return an array of the raw eye-tracking data."""
if self.__x is None:
eye_data = self.raw.get_data(picks=["eyetrack"]).T
if self.downsample is not None:
eye_data = eye_data[:: self.downsample, :] # i.e. eye_data[::10, :]
self.scaler_X = StandardScaler().fit(np.nan_to_num(eye_data))
self.__x = self.scaler_X.transform(np.nan_to_num(eye_data))
return self.__x
def get_denoised_neural(self):
"""Return the denoised M/EEG neural data.

@property
def Y(self):
"""Return an array of the raw EEG data."""
if self.__y is None:
eeg_data = self.raw.copy()
if self.filter:
eeg_data.filter(*self.filter)
if self.downsample:
eeg_data.resample(self.downsampled_sfreq)
eeg_data = eeg_data.get_data(picks="eeg").T
self.scaler_Y = StandardScaler().fit(np.nan_to_num(eeg_data))
self.__y = self.scaler_Y.transform(np.nan_to_num(eeg_data))
return self.__y
Returns
-------
denoised_neural : np.ndarray
The denoised neural data, scaled back to the original units and shape,
i.e. ``(n_samples, n_meeg_channels)`` of the input :class:`~mne.io.Raw`
object. The denoised neural data are saved to the ``denoised_neural_``
attribute.
"""
if not len(self.Y_train):
logger.info("setting up train/test split")
_ = self.train_test_split()
if not hasattr(self, "predicted_eog_"):
logger.info("Predicting EOG data, saving to ``predicted_eog_`` attribute.")
_ = self.predict_eog()
predicted_eog = self.predicted_eog_
# reshape to back to 2D array matching the original raw data shape
Y_train = self.Y_train.reshape(
(np.prod(self.Y_train.shape[:-1]), self.Y_train.shape[-1]), order="C"
)
# scale back to original units of the Raw data
Y_train = self.scaler_Y.inverse_transform(Y_train)
assert Y_train.shape == predicted_eog.shape
self.denoised_neural_ = Y_train - predicted_eog
return self.denoised_neural_

def plot_eog_topo(self, montage, show=True):
"""Plot the topography of the eyetracking data.

Parameters
----------
montage : mne.channels.DigMontage | str
Montage for digitized electrode and headshape position data.
See mne.channels.make_standard_montage(), and
mne.channels.get_builtin_montages() for more information
on making montage objects in MNE.
show : bool
Whether to show the plot or not. Defaults to True.

Returns
-------
fig : matplotlib.figure.Figure
The resulting figure object for the topomap plot.
"""
if not hasattr(self, "denoised_neural_"):
logger.info(
"Denoising neural data, saving to ``denoised_neural_`` attribute."
)
_ = self.get_denoised_neural()
squared_errors = np.square(self.denoised_neural_)
mean_squared_erros = np.mean(squared_errors, axis=0)
noise = np.sqrt(mean_squared_erros)
# reshape to back to 2D array matching the original raw data shape
Y_train = self.Y_train.reshape(
(np.prod(self.Y_train.shape[:-1]), self.Y_train.shape[-1]), order="C"
)
# scale back to original units of the Raw data
Y_train = self.scaler_Y.inverse_transform(Y_train)
signal = np.sqrt((Y_train**2).mean(axis=0))
assert noise.shape == signal.shape
snr = (noise / signal)[:-1]
percent_noise = 1 - snr
percent_noise *= 100
eeg_names = self.raw.copy().pick("eeg").ch_names[:-1]
data_dict = dict(list(zip(eeg_names, percent_noise)))
montage = montage
fig, ax = plt.subplots(constrained_layout=True)

plot_values_topomap(
data_dict,
montage,
axes=ax,
vmin=percent_noise.min(),
vmax=percent_noise.max(),
names=None,
image_interp="linear",
sensors=True,
show=show,
)
ax.set_title(
"Percentage of EEG signal that is accounted for by Ocular Artifact"
)
return fig
15 changes: 11 additions & 4 deletions eoglearn/models/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# License: BSD-3-Clause

import mne

from eoglearn.models import EOGDenoiser


Expand Down Expand Up @@ -30,8 +32,13 @@ def test_build_model(mne_fixture):
assert eog_denoiser.model.layers[1].input_shape == (None, 100, 50)

# test model training
fitting_kwargs = dict(epochs=1, validation_split=0.2, batch_size=1, verbose=2)
eog_denoiser.fit_model(fitting_kwargs=fitting_kwargs)
eog_denoiser.fit_model(epochs=3)
history = eog_denoiser.model.history
# For now, just check that the final loss is somewhat reasonable
history.history["loss"][-1] < 0.05
# For now, just check that the loss isn't any higher than what we've seen so far.
assert history.history["loss"][-1] < 0.05
assert history.history["val_loss"][-1] < 0.07

# test viz
montage = mne.channels.make_standard_montage("GSN-HydroCel-129")
fig = eog_denoiser.plot_eog_topo(montage=montage, show=False)
del fig
2 changes: 2 additions & 0 deletions eoglearn/viz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._training import plot_training
from .topo import plot_values_topomap
27 changes: 27 additions & 0 deletions eoglearn/viz/_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import matplotlib.pyplot as plt


def plot_training(model, axes=None):
"""Plot the training history of a model.

Parameters
----------
model : keras.Model
A compiled Keras model.

Returns
-------
None
"""
if axes:
ax = axes
fig = ax.get_figure()
else:
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(model.history.history["loss"], label="Training Loss")
ax.plot(model.history.history["val_loss"], label="Validation Loss")
ax.legend()
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Training History")
return fig.show()
Loading
Loading