Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge additions from work leading to Conference submission #18

Closed
wants to merge 16 commits into from
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,8 @@ docs/auto_examples/
*-bem.fif
*-fwd.fif
*-src.fif

# data files generated by our scripts that are too big for git
data/paper/*.netcdf
data/paper/*.edf

Empty file added data/paper/.gitkeep
Empty file.
Binary file added images/Evoked.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Overlaid.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/eyepos.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/montage.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/mosaic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/raw.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/sensitivity_specificity.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/topo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/topo_Model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/topo_clean.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/topo_ica.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/topodots_ICA.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/topodots_Model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,082 changes: 1,082 additions & 0 deletions notebooks/paper/LSTM_compute_xr.ipynb

Large diffs are not rendered by default.

5,085 changes: 5,085 additions & 0 deletions notebooks/paper/Pytorch_EOG_LSTM_analysis.ipynb

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions scripts/run_eog_lstm_ica_mp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import multiprocessing
from pathlib import Path

# Scientific Stack
import numpy as np

# File I/O, Signal Processing
import mne
import eoglearn # This is my package for this project
import mne_icalabel


def process(*args, tmax=None):
try:
subject, run = args[0]

# reload raw and bandpass 1-100 to be fair to ICLabel
fpath = eoglearn.datasets.fetch_eegeyenet(subject=subject, run=run)
raw_ica = eoglearn.io.read_raw_eegeyenet(fpath)

raw_ica.set_montage("GSN-HydroCel-129")
raw_ica.set_eeg_reference("average")
raw_ica.set_annotations(None) # get rid of BAD_blinks annots
raw_ica.pick("eeg").filter(1, 100)

tmax = int(raw_ica.times[-1])
raw_ica.crop(tmax=tmax, include_tmax=False)

ica = mne.preprocessing.ICA(method="infomax", fit_params=dict(extended=True))
ica.fit(raw_ica)
component_dict = mne_icalabel.label_components(raw_ica, ica, "iclabel")

exclude_idx = [idx for idx, label in enumerate(component_dict["labels"]) if label in ["eye blink"]]

# Now apply the ICA to raw, lowpass to 30Hz to match our DL Raw, and plot.
ica.apply(raw_ica, exclude=exclude_idx)
raw_ica.filter(1, 30).resample(100)
raw_ica.export(root / f"{subject}_{run}_ica.edf")
except:
pass


root = Path(__file__).parent.parent / "data" / "paper" / "processed"

if __name__ == "__main__":

nb_processes = 5
if not root.exists():
root.mkdir()

runs_dict = eoglearn.datasets.eegeyenet.get_subjects_runs()
subject_run = np.concatenate([[(subject, run)
for run in runs_dict[subject]]
for subject in runs_dict])
subject_run = [(subject, run)
for subject, run in subject_run
if not (root / f"{subject}_{run}_ica.edf").exists()]

p = multiprocessing.Pool(nb_processes)
p.map(process, subject_run)
167 changes: 167 additions & 0 deletions scripts/run_eog_lstm_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#!/work/co20/eog_lstm/venv_lstm/bin/python

from pathlib import Path
import sys

# Scientific Stack
import numpy as np

# ML/DL Stack
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.nn.functional as F # for easy use of relu


# File I/O, Signal Processing
import mne
import eoglearn # This is my package for this project


class EOGRegressor(nn.Module):
def __init__(self, n_input_features, n_output_features, hidden_size=64, num_layers=1, dropout=0.5):
super(EOGRegressor, self).__init__()
self.input_size = n_input_features
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = nn.Dropout(dropout)

self.rnn = nn.LSTM(n_input_features, hidden_size, num_layers=num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, n_output_features)

def forward(self, input):
# input shape: (batch_size, seq_len, input_size)
batch_size = input.size(0) # same as input.shape[0]

# Initialize hidden state & cell states
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)

# Forward propagate RNN
out, (h0, c0) = self.rnn(input, (h0, c0))

# Decode the hidden state of the last time step
out = self.dropout(out)
out = self.fc(out)

return out


def train_the_model(X, Y, num_epochs=1000, hidden_size=64, num_layers=1, dropout=0.5):
""" Train the Pytorch model."""

# Instantiate the model
if X.ndim == 3:
assert Y.ndim == 3
input_features = X.shape[2] # Assuming (batch_size, seq_len, input_size)
output_features = Y.shape[2]
else:
raise ValueError("Input data must have 3 dimensions: (batch_size, seq_len, input_size)")

model = EOGRegressor(input_features, output_features, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)

# Loss function (Mean Squared Error)
criterion = nn.MSELoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

losses = np.zeros(num_epochs)
# Training loop
model.train()
for i, epoch in enumerate(range(num_epochs)):
# Forward pass
outputs = model(X)

# Compute loss
loss = criterion(outputs, Y)
losses[i] = loss.detach().numpy()

# Zero gradients, backward pass, and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Print the loss every 100 iterations
if i % 100 == 0:
print(f'Epoch: {epoch} Loss: {loss.item():.4f}')

# Set model to eval mode to turn off dropout
model.eval()
with torch.no_grad():
predicted_noise = model(X)
denoised_output = (Y - predicted_noise).numpy()

return losses, predicted_noise, denoised_output


def prep_data(subject="EP10", run=1):
fpath = eoglearn.datasets.fetch_eegeyenet(subject=subject, run=run)
raw = eoglearn.io.read_raw_eegeyenet(fpath)

raw.set_montage("GSN-HydroCel-129")
raw.filter(1, 30, picks="eeg").resample(100) # DO NOT filter eyetrack channels
raw.set_eeg_reference("average")
return raw


def format_data_for_ml(raw, tmax):
# normalize the dataset
X = raw.get_data(picks=["eyetrack"]).T #[::5] # decimate the eyetracking data

Y = raw.get_data(picks="eeg").T

scaler = StandardScaler()
X = scaler.fit_transform(X)
# For Y we need to split the fit and transform into 2 steps
# Because we will need to inverse transform the model output later during evaluation
scaler = StandardScaler()
scaler_y = scaler.fit(Y)
Y = scaler_y.transform(Y)

# 1s epochs
X = X.reshape(tmax, int(raw.info["sfreq"]), 3)
Y = Y.reshape(tmax, int(raw.info["sfreq"]), 129)

# Convert data to tensors
X_tensor = torch.from_numpy(X).float()
Y_tensor = torch.from_numpy(Y).float()

return X_tensor, Y_tensor, scaler_y


def clean_data(subject, run, tmax=None):

raw = prep_data(subject="EP10", run=1)

if tmax is None:
tmax = int(raw.times[-1])
raw.crop(tmax=tmax, include_tmax=False)

X_tensor, Y_tensor, scaler_y = format_data_for_ml(raw, tmax)
losses, predicted_noise, denoised_output = train_the_model(X_tensor, Y_tensor, dropout=.5, num_layers=2)

# Reshape back to 2D and inverse transform to original units (Volts)
predicted_noise = scaler_y.inverse_transform(predicted_noise.reshape(tmax*int(raw.info['sfreq']), 129)).T
denoised_output = scaler_y.inverse_transform(denoised_output.reshape(tmax*int(raw.info['sfreq']), 129)).T

raw_clean = mne.io.RawArray(denoised_output, raw.copy().pick("eeg").info)
raw_noise = mne.io.RawArray(predicted_noise, raw.copy().pick("eeg").info)
return raw, raw_clean, raw_noise


if __name__ == "__main__":

no = int(sys.argv[1])

runs_dict = eoglearn.datasets.eegeyenet.get_subjects_runs()
subject_run = np.concatenate([[(subject, run) for run in runs_dict[subject]]
for subject in runs_dict])

tmax = None
subject, run = subject_run[no]
raw, raw_clean, raw_noise = clean_data(subject=subject, run=run, tmax=tmax)
data_dir = Path(__file__).parent.parent / "data" / "paper"
raw.export(data_dir / f"{subject}_{run}_original.edf")
raw_clean.export(data_dir / f"{subject}_{run}_clean.edf")
raw_noise.export(data_dir / f"{subject}_{run}_noise.edf")
Loading
Loading