Skip to content

Commit

Permalink
Merge pull request #34 from AminAlam/dev
Browse files Browse the repository at this point in the history
Power of Freq bands over time added
  • Loading branch information
AminAlam authored Jan 17, 2024
2 parents 1e4794c + e258fff commit 7aef5b5
Show file tree
Hide file tree
Showing 10 changed files with 492 additions and 51 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,5 @@ docs/source/_templates
docs/*.bat
docs/Makefile

paradigm.sh
paradigm.sh
*.csv
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,25 @@ Where each COMMNAD can be one of the supported commands such as convert_rhd_to_m
To learn more about the commnads, you can use the following command:
```console
elecphys --help
Usage: main.py COMMAND1 [ARGS]... [COMMAND2 [ARGS]...]...
Usage: main.py [OPTIONS] COMMAND1 [ARGS]... [COMMAND2 [ARGS]...]...

ElecPhys is a Python package for electrophysiology data analysis. It
provides tools for data loading, conversion, preprocessing, and
visualization.

Options:
--help Show this message and exit.
-v, --verbose Verbose mode
-d, --debug Debug mode
--help Show this message and exit.

Commands:
convert_mat_to_npz Converts MAT files to NPZ files using MAT...
convert_rhd_to_mat Converts RHD files to mat files using RHD...
dft_numeric_output_from_npz Computes DFT and saves results as NPZ files
freq_bands_power_over_time Computes signal's power in given...
frequncy_domain_filter Filtering in frequency domain using...
normalize_npz Normalizes NPZ files
pca_from_npz Computes PCA from NPZ files
plot_avg_stft Plots average STFT from NPZ files
plot_dft Plots DFT from NPZ file
plot_filter_freq_response Plots filter frequency response
Expand Down
38 changes: 34 additions & 4 deletions elecphys/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_npz(npz_file) -> [np.ndarray, int]:


def load_all_npz_files(npz_folder: str, ignore_channels: [
list, str] = None) -> [np.ndarray, int]:
list, str] = None, channels_list: [list, str] = None) -> [np.ndarray, int, list]:
""" Function that Loads all NPZ files in a folder
Parameters
Expand All @@ -53,21 +53,51 @@ def load_all_npz_files(npz_folder: str, ignore_channels: [
path to npz folder containing NPZ files
ignore_channels: list, str
list of channels to be ignored and not loaded. If None, all channels will be loaded. Either a list of channel names or a string of channel names separated by commas.
channels_list: list, str
list of channels to be loaded. If None, all channels will be loaded. Either a list of channel names or a string of channel names separated by commas.
Returns
--------
data_all: np.ndarray
data from all NPZ files. Shape: (num_channels, num_samples)
fs: int
sampling frequency (Hz)
channels_map: list
list of channel indices corresponding to the order of channels in data_all
"""
files_list = os.listdir(npz_folder)
for file_name in files_list:
if not file_name.endswith('.npz'):
files_list.remove(file_name)
files_list = utils.sort_file_names(files_list)
all_channels_in_folder = list(range(0, len(files_list)))
if channels_list is None:
channels_list = all_channels_in_folder
else:
channels_list = utils.convert_string_to_list(channels_list)
if ignore_channels is not None:
ignore_channels = utils.convert_string_to_list(ignore_channels)
for ch_indx in ignore_channels:
files_list.remove(ch_indx)
ignore_channels = [i - 1 for i in ignore_channels]
else:
ignore_channels = []
# all elements of channels_list that are not in all_channels_in_folder
invalid_channels = [channel for channel in all_channels_in_folder if channel not in channels_list]
if len(invalid_channels) > 0:
ignore_channels.extend(invalid_channels)
channels_map = all_channels_in_folder

channels_map_new = []
for channel in channels_map:
if channel not in ignore_channels:
channels_map_new.append(channel)
channels_map = channels_map_new

files_list_new = []
for indx, file_name in enumerate(files_list):
if indx not in ignore_channels:
files_list_new.append(file_name)
files_list = files_list_new

num_channels = len(files_list)
ch_indx = 0
for npz_file in files_list:
Expand All @@ -79,7 +109,7 @@ def load_all_npz_files(npz_folder: str, ignore_channels: [
data, _ = load_npz(npz_file_path)
data_all[ch_indx, :] = data
ch_indx += 1
return data_all, fs
return data_all, fs, channels_map


def load_npz_stft(npz_file) -> [np.ndarray, np.ndarray, np.ndarray]:
Expand Down
110 changes: 110 additions & 0 deletions elecphys/fourier_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from scipy import signal
from tqdm import tqdm
import json
import csv

import utils
import cfc
import data_io
import visualization


def stft_numeric_output_from_npz(input_npz_folder: str, output_npz_folder: str,
Expand Down Expand Up @@ -378,3 +381,110 @@ def calc_cfc_from_npz(input_npz_folder: str, output_npz_folder: str,
freqs_amp=freqs_amp,
freqs_phase=freqs_phase,
time_interval=time_interval)


def freq_bands_power_over_time(
input_npz_folder: str,
freq_bands: [
tuple,
list] = None,
channels_list: str = None,
ignore_channels: str = None,
window_size: float = 1,
overlap: float = 0.5,
t_min: float = None,
t_max: float = None,
output_csv_file: str = None,
output_plot_file: str = None,
plot_type: str = 'average_of_channels') -> None:
""" Calculates power over time for given frequency bands
Parameters
----------
input_npz_folder: str
path to input npz folder containing signal npz files (in time domain)
freq_bands: tuple
tuple or list of frequency bands to calculate power over time for. It should be a tuple or list of lists, where each list contains two elements: the lower and upper frequency bounds of the band (in Hz). For example, freq_bands = [[1, 4], [4, 8], [8, 12]] would calculate power over time for the delta, theta, and alpha bands.
channels_list: str
list of channels to include in analysis
ignore_channels: str
list of channels to ignore in analysis
window_size: float
window size in seconds to calculate power over time
overlap: float
windows overlap in seconds to calculate power over time
t_min: float
start of time interval to calculate power over time. Default is None which means start from beginning of signal.
t_max: float
end of time interval to calculate power over time. Default is None which means end at end of signal.
output_csv_file: str
path to output csv file to save power over time results
output_plot_file: str
path to output plot file to save power over time results
plot_type: str
type of plot to generate. Options are 'avg' or 'all'. Default is 'avg' which plots average power over time for all channels with an erros cloud. 'all' plots power over time for all channels individually.
Returns
----------
"""
if channels_list is not None:
channels_list = utils.convert_string_to_list(channels_list)
if ignore_channels is not None:
ignore_channels = utils.convert_string_to_list(ignore_channels)

data_all, fs, channels_map = data_io.load_all_npz_files(input_npz_folder, ignore_channels, channels_list)
# if freq_bands only has one list, we should make sure it is a list of lists
if len(freq_bands) == 2 and isinstance(freq_bands[0], int) and isinstance(freq_bands[1], int):
freq_bands = [freq_bands]
freq_bands = [utils.convert_string_to_list(freq_band) for freq_band in freq_bands]
freq_bands = utils.check_freq_bands(freq_bands, fs)

if t_min is None:
t_min = 0
if t_max is None:
t_max = data_all.shape[1] / fs

if t_max <= t_min:
raise ValueError(
f'Invalid time interval: [{t_min}, {t_max}]. t_max must be larger than t_min.')

for freq_band in freq_bands:
for ch_indx in range(data_all.shape[0]):
data = data_all[ch_indx, :]
f, t, Zxx = stft_from_array(data, fs, window_size, overlap)
Zxx = np.abs(Zxx)
if ch_indx == 0:
spectrum_all = np.zeros((data_all.shape[0], len(t), len(f)))
spectrum_all[ch_indx, :, :] = Zxx.T

t0 = np.where(t >= t_min)[0][0]
t1 = np.where(t <= t_max)[0][-1]
t = t[t0:t1 + 1]
f0 = np.where(f >= freq_band[0])[0][0]
f1 = np.where(f <= freq_band[1])[0][-1]
spectrum_all = spectrum_all[:, :, f0:f1 + 1]
spectrum_all = spectrum_all[:, t0:t1 + 1, :]
power_all = np.sum(spectrum_all**2, axis=2)
avg_power = np.mean(power_all, axis=0)
avg_power = 10 * np.log10(avg_power)
power_all = 10 * np.log10(power_all)
if output_csv_file is not None:
if 'csv' in output_csv_file:
output_csv_file = output_csv_file.replace('.csv', '')
# save to csv file
with open(f'{output_csv_file}_{freq_band[0]}_{freq_band[1]}.csv', 'w', newline='') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(['Channel', 'Time', 'Power'])
for ch_indx in range(data_all.shape[0]):
for t_indx in range(len(t)):
csvwriter.writerow([channels_map[ch_indx] + 1, t[t_indx], power_all[ch_indx, t_indx]])
for t_indx in range(len(t)):
csvwriter.writerow(['Avg_channels', t[t_indx], avg_power[t_indx]])
if output_plot_file is not None:
output_plot_file_format = output_plot_file.split('.')[-1]
if output_plot_file_format == '':
output_plot_file_format = 'pdf'
output_plot_file_band = f"{output_plot_file.split('.')[0]}_{freq_band[0]}_{freq_band[1]}.{output_plot_file_format}"
else:
output_plot_file_band = None
visualization.plot_power_over_time_from_array(power_all, t, channels_map, plot_type, output_plot_file_band)
Loading

0 comments on commit 7aef5b5

Please sign in to comment.