diff --git a/examples/plot_model.py b/examples/plot_model.py index 1570a55..c4b9b2b 100644 --- a/examples/plot_model.py +++ b/examples/plot_model.py @@ -11,6 +11,7 @@ # %% # Import the necessary packages import mne +import matplotlib.pyplot as plt from eoglearn.datasets import read_mne_eyetracking_raw from eoglearn.models import EOGDenoiser @@ -38,6 +39,8 @@ # %% # Fit the model # We will only use 10 epochs to speed up the example + +# %% eog_denoiser.fit_model(epochs=10) history = eog_denoiser.model.history @@ -45,6 +48,7 @@ # display the training history print(history.history["loss"]) print(history.history["val_loss"]) +eog_denoiser.plot_loss() # %% # Plot a topomap of the predicted EOG artifact. @@ -52,6 +56,8 @@ # The plot below displays the predicted amount of EOG artifact for each EEG sensor. # The output is as we would expect, with frontal sensors containing the most EOG # artifact. + +# %% montage = mne.channels.make_standard_montage("GSN-HydroCel-129") eog_denoiser.plot_eog_topo(montage=montage) @@ -59,3 +65,33 @@ # .. todo:: # Add a plot of the predicted EOG artifact for each EEG sensor over time. # Add plots of the denoised EEG data. + +# %% +# Compare ERP between the original and "EOG-denoised" signals +# ----------------------------------------------------------- +# +# Let's create an averaged evoked response to the flash stimuli for both the original +# data and the "EOG-denoised" data. We'll focus on the frontal EEG channels, since it is +# these will contain the most EOG in the original signal. + +# %% +pred_raw = eog_denoiser.get_denoised_neural_raw() +events, event_id = mne.events_from_annotations(pred_raw, regexp="Flash") +pred_epochs = mne.Epochs( + pred_raw, events=events, event_id=event_id, tmin=-0.3, tmax=3, preload=True +) + +events, event_id = mne.events_from_annotations(eog_denoiser.raw, regexp="Flash") +original_epochs = mne.Epochs( + eog_denoiser.raw, events=events, event_id=event_id, tmin=-0.3, tmax=3, preload=True +) + +frontal = ["E19", "E11", "E4", "E12", "E5"] +pred_avg_frontal = pred_epochs.average().get_data(picks=frontal).mean(0) +original_avg_frontal = original_epochs.average().get_data(picks=frontal).mean(0) + +ax = plt.subplot() +ax.plot(pred_epochs.times, pred_avg_frontal, label="predicted") +ax.plot(original_epochs.times, original_avg_frontal, label="original") +ax.set_xlim(-0.3, 1) +ax.legend()