Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Online batched statistics measures #405

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
31cfded
MeanOnline & VarianceOnline
dizcza Oct 23, 2020
e6fa691
mean firing rate and cv
dizcza Oct 23, 2020
a8da7f7
test cv2, lv, lvr online
dizcza Oct 24, 2020
977979f
test fanofactor
dizcza Oct 24, 2020
5e5b57b
fixed bug in spike_contrast; test spike_contrast online
dizcza Oct 24, 2020
60e79c3
waveform_snr works directly with waveforms
dizcza Oct 25, 2020
eb70d27
easier use-case of online spike_contrast
dizcza Oct 25, 2020
080fcb3
fast conversion
dizcza Oct 25, 2020
7c886c6
zero-out corrcoef diag
dizcza Oct 25, 2020
9442955
fixed bug in n_discarded calculation
dizcza Oct 25, 2020
abf43ce
rewrote BinnedSpikeTrain
dizcza Oct 26, 2020
73fa2d1
don't cut spiketrains while binning: raise an error if t_start or t_s…
dizcza Oct 26, 2020
7bc09a4
efficient to_array()
dizcza Oct 26, 2020
97cfba9
faster binning in BinnedSpikeTrain
dizcza Oct 26, 2020
61cbfee
utils doc
dizcza Oct 26, 2020
8526e05
Merge branch 'master' into feature/online_measures
dizcza Oct 26, 2020
cc4aca0
faster BinnedSpikeTrain with a bugfix for the incorrectly estimated n…
dizcza Oct 26, 2020
f57c763
don't wrap spiketrain.data in array in python2
dizcza Oct 27, 2020
54fdd55
fixed tests style
dizcza Oct 27, 2020
3d09266
using spiketrain.magnitude
dizcza Oct 28, 2020
00150af
Merge branch 'opt/binned_spiketrain' into feature/online_measures
dizcza Oct 29, 2020
0378385
Merge branch 'master' into feature/online_measures
dizcza Nov 18, 2020
481fa9c
in-place mean, std in zscore
dizcza Nov 18, 2020
2b0b1c6
Merge branch 'master' into feature/online_measures
dizcza Feb 8, 2021
a7d24ad
Merge branch 'master' into feature/online_measures
dizcza Feb 22, 2021
96e3eb8
Batched-version of Covariance and Pearson Correlation Coefficient (#90)
ojoenlanuca Feb 7, 2022
1dcaa98
created batched-version of InterSpikeInterval with UnitTests (#89)
ojoenlanuca Feb 7, 2022
fd71e12
Merge branch 'master' into feature/online_measures
Jun 23, 2022
112ede2
removed doubled Neo import
Jun 23, 2022
7420a8f
fixed indents and replaced deprecated homogeneous_poisson_process
Jun 24, 2022
fca26d0
adjusted test precision for test_spike_contrast in test_online.py
Jun 24, 2022
efab69d
fix pep8
Sep 2, 2022
2b0ecbb
Merge branch 'master' into feature/online_measures
Nov 7, 2022
dd18b3c
Merge branch 'master' into feature/online_measures
Nov 17, 2022
305a8ac
update citation
Nov 17, 2022
9985314
defined constant for warning message
Nov 17, 2022
a15669c
fix typos, simplify if statement
Nov 17, 2022
e60149d
add comment to empty if statement
Nov 17, 2022
f305583
fix if statement
Nov 17, 2022
9ca854b
fix deprecations
Nov 17, 2022
3f6c5f0
fix deprecations
Nov 17, 2022
5ff8803
Merge branch 'master' into feature/online_measures
Moritz-Alexander-Kern Apr 4, 2024
ef5155d
Fix neo consistency import
Moritz-Alexander-Kern Apr 4, 2024
f157bb5
fix missing parameter in correlation_coefficient
Moritz-Alexander-Kern Apr 4, 2024
d512848
fix pep8
Moritz-Alexander-Kern Apr 4, 2024
bf80d48
fix macOS CI runner
Moritz-Alexander-Kern Apr 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 253 additions & 0 deletions elephant/online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
from copy import deepcopy

import numpy as np
import quantities as pq

from elephant.statistics import isi

msg_same_units = "Each batch must have the same units."


class MeanOnline(object):
def __init__(self, batch_mode=False):
self.mean = None
self.count = 0
self.units = None
self.batch_mode = batch_mode

def update(self, new_val):
units = None
if isinstance(new_val, pq.Quantity):
units = new_val.units
new_val = new_val.magnitude
if self.batch_mode:
batch_size = new_val.shape[0]
new_val_sum = new_val.sum(axis=0)
else:
batch_size = 1
new_val_sum = new_val
self.count += batch_size
if self.mean is None:
self.mean = deepcopy(new_val_sum / batch_size)
self.units = units
else:
if units != self.units:
raise ValueError(msg_same_units)
self.mean += (new_val_sum - self.mean * batch_size) / self.count

def as_units(self, val):
if self.units is None:
return val
return pq.Quantity(val, units=self.units, copy=False)

def get_mean(self):
return self.as_units(deepcopy(self.mean))

def reset(self):
self.mean = None
self.count = 0
self.units = None


class VarianceOnline(MeanOnline):
def __init__(self, batch_mode=False):
super(VarianceOnline, self).__init__(batch_mode=batch_mode)
self.variance_sum = 0.

def update(self, new_val):
units = None
if isinstance(new_val, pq.Quantity):
units = new_val.units
new_val = new_val.magnitude
if self.mean is None:
self.mean = 0.
self.variance_sum = 0.
self.units = units
elif units != self.units:
raise ValueError(msg_same_units)
delta_var = new_val - self.mean
if self.batch_mode:
batch_size = new_val.shape[0]
self.count += batch_size
delta_mean = new_val.sum(axis=0) - self.mean * batch_size
self.mean += delta_mean / self.count
delta_var *= new_val - self.mean
delta_var = delta_var.sum(axis=0)
else:
self.count += 1
self.mean += delta_var / self.count
delta_var *= new_val - self.mean
self.variance_sum += delta_var

def get_mean_std(self, unbiased=False):
if self.mean is None:
return None, None
if self.count > 1:
count = self.count - 1 if unbiased else self.count
std = np.sqrt(self.variance_sum / count)
else:
# with 1 update biased & unbiased sample variance is zero
std = 0.
mean = self.as_units(deepcopy(self.mean))
std = self.as_units(std)
return mean, std

def reset(self):
super(VarianceOnline, self).reset()
self.variance_sum = 0.


class InterSpikeIntervalOnline(object):
def __init__(self, bin_size=0.0005, max_isi_value=1, batch_mode=False):
self.max_isi_value = max_isi_value # in sec
self.last_spike_time = None
self.bin_size = bin_size # in sec
self.num_bins = int(self.max_isi_value / self.bin_size)
self.bin_edges = np.linspace(start=0, stop=self.max_isi_value,
num=self.num_bins + 1)
self.current_isi_histogram = np.zeros(shape=self.num_bins)
self.bach_mode = batch_mode
self.units = None

def update(self, new_val):
units = None
if isinstance(new_val, pq.Quantity):
units = new_val.units
new_val = new_val.magnitude
if self.last_spike_time is None: # for first batch
if self.bach_mode:
new_isi = isi(new_val)
self.last_spike_time = new_val[-1]
else:
new_isi = np.array([])
self.last_spike_time = new_val
self.units = units
else: # for second to last batch
if units != self.units:
raise ValueError(msg_same_units)
if self.bach_mode:
new_isi = isi(np.append(self.last_spike_time, new_val))
self.last_spike_time = new_val[-1]
else:
new_isi = np.array([new_val - self.last_spike_time])
self.last_spike_time = new_val
isi_hist, _ = np.histogram(new_isi, bins=self.bin_edges)
self.current_isi_histogram += isi_hist

def as_units(self, val):
if self.units is None:
return val
return pq.Quantity(val, units=self.units, copy=False)

def get_isi(self):
return self.as_units(deepcopy(self.current_isi_histogram))

def reset(self):
self.last_spike_time = None
self.units = None
self.current_isi_histogram = np.zeros(shape=self.num_bins)


class CovarianceOnline(object):
def __init__(self, batch_mode=False):
self.batch_mode = batch_mode
self.var_x = VarianceOnline(batch_mode=batch_mode)
self.var_y = VarianceOnline(batch_mode=batch_mode)
self.units = None
self.covariance_sum = 0.
self.count = 0

def update(self, new_val_pair):
units = None
if isinstance(new_val_pair, pq.Quantity):
units = new_val_pair.units
new_val_pair = new_val_pair.magnitude
if self.count == 0:
self.var_x.mean = 0.
self.var_y.mean = 0.
self.covariance_sum = 0.
self.units = units
elif units != self.units:
raise ValueError(msg_same_units)
if self.batch_mode:
self.var_x.update(new_val_pair[0])
self.var_y.update(new_val_pair[1])
delta_var_x = new_val_pair[0] - self.var_x.mean
delta_var_y = new_val_pair[1] - self.var_y.mean
delta_covar = delta_var_x * delta_var_y
batch_size = len(new_val_pair[0])
self.count += batch_size
delta_covar = delta_covar.sum(axis=0)
self.covariance_sum += delta_covar
else:
delta_var_x = new_val_pair[0] - self.var_x.mean
delta_var_y = new_val_pair[1] - self.var_y.mean
delta_covar = delta_var_x * delta_var_y
self.var_x.update(new_val_pair[0])
self.var_y.update(new_val_pair[1])
self.count += 1
self.covariance_sum += \
((self.count - 1) / self.count) * delta_covar

def get_cov(self, unbiased=False):
if self.var_x.mean is None and self.var_y.mean is None:
return None
if self.count > 1:
count = self.count - 1 if unbiased else self.count
cov = self.covariance_sum / count
else:
cov = 0.
return cov

def reset(self):
self.var_x.reset()
self.var_y.reset()
self.units = None
self.covariance_sum = 0.
self.count = 0


class PearsonCorrelationCoefficientOnline(object):
def __init__(self, batch_mode=False):
self.batch_mode = batch_mode
self.covariance_xy = CovarianceOnline(batch_mode=batch_mode)
self.units = None
self.R_xy = 0.
self.count = 0

def update(self, new_val_pair):
units = None
if isinstance(new_val_pair, pq.Quantity):
units = new_val_pair.units
new_val_pair = new_val_pair.magnitude
if self.count == 0:
self.covariance_xy.var_y.mean = 0.
self.covariance_xy.var_y.mean = 0.
self.units = units
elif units != self.units:
raise ValueError(msg_same_units)
self.covariance_xy.update(new_val_pair)
if self.batch_mode:
batch_size = len(new_val_pair[0])
self.count += batch_size
else:
self.count += 1
if self.count > 1:
self.R_xy = np.divide(
self.covariance_xy.covariance_sum,
(np.sqrt(self.covariance_xy.var_x.variance_sum *
self.covariance_xy.var_y.variance_sum)))

def get_pcc(self):
if self.count == 0:
return None
elif self.count == 1:
return 0.
else:
return self.R_xy

def reset(self):
self.count = 0
self.units = None
self.R_xy = 0.
self.covariance_xy.reset()
35 changes: 19 additions & 16 deletions elephant/signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
import quantities as pq
import scipy.signal

from elephant.utils import deprecated_alias, check_same_units

import warnings
from elephant.online import VarianceOnline
from elephant.utils import deprecated_alias, check_neo_consistency

__all__ = [
"zscore",
Expand Down Expand Up @@ -67,7 +66,7 @@ def zscore(signal, inplace=True):
Signals for which to calculate the z-score.
inplace : bool, optional
If True, the contents of the input `signal` is replaced by the
z-transformed signal, if possible, i.e when the signal type is float.
z-transformed signal, if possible, i.e. when the signal type is float.
If the signal type is not float, an error is raised.
If False, a copy of the original `signal` is returned.
Default: True
Expand Down Expand Up @@ -156,18 +155,19 @@ def zscore(signal, inplace=True):
# Transform input to a list
if isinstance(signal, neo.AnalogSignal):
signal = [signal]
check_same_units(signal, object_type=neo.AnalogSignal)
check_neo_consistency(signal, object_type=neo.AnalogSignal)

# Calculate mean and standard deviation
signal_stacked = np.vstack(signal).magnitude
mean = signal_stacked.mean(axis=0)
std = signal_stacked.std(axis=0)
# Calculate mean and standard deviation vectors
online = VarianceOnline(batch_mode=True)
for sig in signal:
online.update(sig.magnitude)
mean, std = online.get_mean_std(unbiased=False)

signal_ztransformed = []
for sig in signal:
# Perform inplace operation only if array is of dtype float.
# Otherwise, raise an error.
if inplace and not np.issubdtype(np.float, sig.dtype):
if inplace and not np.issubdtype(float, sig.dtype):
raise ValueError(f"Cannot perform inplace operation as the "
f"signal dtype is not float. Source: {sig.name}")

Expand Down Expand Up @@ -294,6 +294,9 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False,

If `scaleopt` is not one of the predefined above keywords.

.. bibliography::
:keyprefix: signal-

Examples
--------
.. plot::
Expand Down Expand Up @@ -339,9 +342,8 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False,
"indices. Cannot define pairs for cross-correlation.")
if not isinstance(hilbert_envelope, bool):
raise ValueError("'hilbert_envelope' must be a boolean value")
if n_lags is not None:
if not isinstance(n_lags, int) or n_lags <= 0:
raise ValueError('n_lags must be a non-negative integer')
if n_lags is not None and (not isinstance(n_lags, int) or n_lags <= 0):
raise ValueError('n_lags must be a non-negative integer')

# z-score analog signal and store channel time series in different arrays
# Cross-correlation will be calculated between xsig and ysig
Expand Down Expand Up @@ -581,7 +583,7 @@ def wavelet_transform(signal, frequency, n_cycles=6.0, sampling_frequency=1.0,
Parameters
----------
signal : (Nt, Nch) neo.AnalogSignal or np.ndarray or list
Time series data to be wavelet-transformed. When multi-dimensional
Time series data to be wavelet-transformed. When multidimensional
`np.ndarray` or list is given, the time axis must be the last
dimension. If `neo.AnalogSignal`, `Nt` is the number of time points
and `Nch` is the number of channels.
Expand Down Expand Up @@ -673,7 +675,7 @@ def _morlet_wavelet_ft(freq, n_cycles, fs, n):
# in Le van Quyen et al. J Neurosci Meth 111:83-98 (2001).
sigma = n_cycles / (6. * freq)
freqs = np.fft.fftfreq(n, 1.0 / fs)
heaviside = np.array(freqs > 0., dtype=np.float)
heaviside = np.array(freqs > 0., dtype=float)
ft_real = np.sqrt(2 * np.pi * freq) * sigma * np.exp(
-2 * (np.pi * sigma * (freqs - freq)) ** 2) * heaviside * fs
ft_imag = np.zeros_like(ft_real)
Expand Down Expand Up @@ -717,7 +719,7 @@ def _morlet_wavelet_ft(freq, n_cycles, fs, n):
n = n_orig

# generate Morlet wavelets (in the frequency domain)
wavelet_fts = np.empty([len(freqs), n], dtype=np.complex)
wavelet_fts = np.empty([len(freqs), n], dtype=complex)
for i, f in enumerate(freqs):
wavelet_fts[i] = _morlet_wavelet_ft(f, n_cycles, sampling_frequency, n)

Expand Down Expand Up @@ -935,6 +937,7 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None):
raise ValueError('Input signal is not a neo.AnalogSignal!')

if baseline is None:
# do nothing
pass
elif baseline == 'mean':
# subtract mean from each channel
Expand Down
15 changes: 11 additions & 4 deletions elephant/spike_train_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ def covariance(binned_spiketrain, binary=False, fast=True):


@deprecated_alias(binned_sts='binned_spiketrain')
def correlation_coefficient(binned_spiketrain, binary=False, fast=True):
def correlation_coefficient(binned_spiketrain, binary=False, zero_diag=False,
fast=True):
r"""
Calculate the NxN matrix of pairwise Pearson's correlation coefficients
between all combinations of N binned spike trains.
Expand Down Expand Up @@ -418,6 +419,9 @@ def correlation_coefficient(binned_spiketrain, binary=False, fast=True):
are counted as 1, resulting in binary binned vectors :math:`b_i`. If
False, the binned vectors :math:`b_i` contain the spike counts per bin.
Default: False
zero_diag : bool, optional
Zero-out the diagonal of a correlation matrix (True) or not (False).
Default: False
fast : bool, optional
If `fast=True` and the sparsity of `binned_spiketrain` is `> 0.1`, use
`np.corrcoef()`. Otherwise, use memory efficient implementation.
Expand Down Expand Up @@ -481,10 +485,13 @@ def correlation_coefficient(binned_spiketrain, binary=False, fast=True):

if fast and binned_spiketrain.sparsity > _SPARSITY_MEMORY_EFFICIENT_THR:
array = binned_spiketrain.to_array()
return np.corrcoef(array)
corr_mat = np.corrcoef(array)
else:
corr_mat = _covariance_sparse(binned_spiketrain, corrcoef_norm=True)

return _covariance_sparse(
binned_spiketrain, corrcoef_norm=True)
if zero_diag:
np.fill_diagonal(corr_mat, 0)
return corr_mat


def corrcoef(*args, **kwargs):
Expand Down
Loading
Loading