Skip to content

Commit

Permalink
DOC: add new plotting features to MNE tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
scott-huberty committed Dec 21, 2023
1 parent a693189 commit 906e9db
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions examples/plot_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# %%
# Import the necessary packages
import mne
import matplotlib.pyplot as plt
from eoglearn.datasets import read_mne_eyetracking_raw
from eoglearn.models import EOGDenoiser

Expand Down Expand Up @@ -38,24 +39,59 @@
# %%
# Fit the model
# We will only use 10 epochs to speed up the example

# %%
eog_denoiser.fit_model(epochs=10)
history = eog_denoiser.model.history

# %%
# display the training history
print(history.history["loss"])
print(history.history["val_loss"])
eog_denoiser.plot_loss()

# %%
# Plot a topomap of the predicted EOG artifact.
# ---------------------------------------------
# The plot below displays the predicted amount of EOG artifact for each EEG sensor.
# The output is as we would expect, with frontal sensors containing the most EOG
# artifact.

# %%
montage = mne.channels.make_standard_montage("GSN-HydroCel-129")
eog_denoiser.plot_eog_topo(montage=montage)

# %%
# .. todo::
# Add a plot of the predicted EOG artifact for each EEG sensor over time.
# Add plots of the denoised EEG data.

# %%
# Compare ERP between the original and "EOG-denoised" signals
# -----------------------------------------------------------
#
# Let's create an averaged evoked response to the flash stimuli for both the original
# data and the "EOG-denoised" data. We'll focus on the frontal EEG channels, since it is
# these will contain the most EOG in the original signal.

# %%
pred_raw = eog_denoiser.get_denoised_neural_raw()
events, event_id = mne.events_from_annotations(pred_raw, regexp="Flash")
pred_epochs = mne.Epochs(
pred_raw, events=events, event_id=event_id, tmin=-0.3, tmax=3, preload=True
)

events, event_id = mne.events_from_annotations(eog_denoiser.raw, regexp="Flash")
original_epochs = mne.Epochs(
eog_denoiser.raw, events=events, event_id=event_id, tmin=-0.3, tmax=3, preload=True
)

frontal = ["E19", "E11", "E4", "E12", "E5"]
pred_avg_frontal = pred_epochs.average().get_data(picks=frontal).mean(0)
original_avg_frontal = original_epochs.average().get_data(picks=frontal).mean(0)

ax = plt.subplot()
ax.plot(pred_epochs.times, pred_avg_frontal, label="predicted")
ax.plot(original_epochs.times, original_avg_frontal, label="original")
ax.set_xlim(-0.3, 1)
ax.legend()

0 comments on commit 906e9db

Please sign in to comment.