diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ec2547b70..ae3c37032 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -147,8 +147,6 @@ jobs: channel-priority: true activate-environment: elephant environment-file: requirements/environment-tests.yml - auto-activate-base: false - use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! - name: Install dependencies shell: bash -l {0} diff --git a/elephant/online.py b/elephant/online.py new file mode 100644 index 000000000..acdb4986e --- /dev/null +++ b/elephant/online.py @@ -0,0 +1,253 @@ +from copy import deepcopy + +import numpy as np +import quantities as pq + +from elephant.statistics import isi + +msg_same_units = "Each batch must have the same units." + + +class MeanOnline(object): + def __init__(self, batch_mode=False): + self.mean = None + self.count = 0 + self.units = None + self.batch_mode = batch_mode + + def update(self, new_val): + units = None + if isinstance(new_val, pq.Quantity): + units = new_val.units + new_val = new_val.magnitude + if self.batch_mode: + batch_size = new_val.shape[0] + new_val_sum = new_val.sum(axis=0) + else: + batch_size = 1 + new_val_sum = new_val + self.count += batch_size + if self.mean is None: + self.mean = deepcopy(new_val_sum / batch_size) + self.units = units + else: + if units != self.units: + raise ValueError(msg_same_units) + self.mean += (new_val_sum - self.mean * batch_size) / self.count + + def as_units(self, val): + if self.units is None: + return val + return pq.Quantity(val, units=self.units, copy=False) + + def get_mean(self): + return self.as_units(deepcopy(self.mean)) + + def reset(self): + self.mean = None + self.count = 0 + self.units = None + + +class VarianceOnline(MeanOnline): + def __init__(self, batch_mode=False): + super(VarianceOnline, self).__init__(batch_mode=batch_mode) + self.variance_sum = 0. + + def update(self, new_val): + units = None + if isinstance(new_val, pq.Quantity): + units = new_val.units + new_val = new_val.magnitude + if self.mean is None: + self.mean = 0. + self.variance_sum = 0. + self.units = units + elif units != self.units: + raise ValueError(msg_same_units) + delta_var = new_val - self.mean + if self.batch_mode: + batch_size = new_val.shape[0] + self.count += batch_size + delta_mean = new_val.sum(axis=0) - self.mean * batch_size + self.mean += delta_mean / self.count + delta_var *= new_val - self.mean + delta_var = delta_var.sum(axis=0) + else: + self.count += 1 + self.mean += delta_var / self.count + delta_var *= new_val - self.mean + self.variance_sum += delta_var + + def get_mean_std(self, unbiased=False): + if self.mean is None: + return None, None + if self.count > 1: + count = self.count - 1 if unbiased else self.count + std = np.sqrt(self.variance_sum / count) + else: + # with 1 update biased & unbiased sample variance is zero + std = 0. + mean = self.as_units(deepcopy(self.mean)) + std = self.as_units(std) + return mean, std + + def reset(self): + super(VarianceOnline, self).reset() + self.variance_sum = 0. + + +class InterSpikeIntervalOnline(object): + def __init__(self, bin_size=0.0005, max_isi_value=1, batch_mode=False): + self.max_isi_value = max_isi_value # in sec + self.last_spike_time = None + self.bin_size = bin_size # in sec + self.num_bins = int(self.max_isi_value / self.bin_size) + self.bin_edges = np.linspace(start=0, stop=self.max_isi_value, + num=self.num_bins + 1) + self.current_isi_histogram = np.zeros(shape=self.num_bins) + self.bach_mode = batch_mode + self.units = None + + def update(self, new_val): + units = None + if isinstance(new_val, pq.Quantity): + units = new_val.units + new_val = new_val.magnitude + if self.last_spike_time is None: # for first batch + if self.bach_mode: + new_isi = isi(new_val) + self.last_spike_time = new_val[-1] + else: + new_isi = np.array([]) + self.last_spike_time = new_val + self.units = units + else: # for second to last batch + if units != self.units: + raise ValueError(msg_same_units) + if self.bach_mode: + new_isi = isi(np.append(self.last_spike_time, new_val)) + self.last_spike_time = new_val[-1] + else: + new_isi = np.array([new_val - self.last_spike_time]) + self.last_spike_time = new_val + isi_hist, _ = np.histogram(new_isi, bins=self.bin_edges) + self.current_isi_histogram += isi_hist + + def as_units(self, val): + if self.units is None: + return val + return pq.Quantity(val, units=self.units, copy=False) + + def get_isi(self): + return self.as_units(deepcopy(self.current_isi_histogram)) + + def reset(self): + self.last_spike_time = None + self.units = None + self.current_isi_histogram = np.zeros(shape=self.num_bins) + + +class CovarianceOnline(object): + def __init__(self, batch_mode=False): + self.batch_mode = batch_mode + self.var_x = VarianceOnline(batch_mode=batch_mode) + self.var_y = VarianceOnline(batch_mode=batch_mode) + self.units = None + self.covariance_sum = 0. + self.count = 0 + + def update(self, new_val_pair): + units = None + if isinstance(new_val_pair, pq.Quantity): + units = new_val_pair.units + new_val_pair = new_val_pair.magnitude + if self.count == 0: + self.var_x.mean = 0. + self.var_y.mean = 0. + self.covariance_sum = 0. + self.units = units + elif units != self.units: + raise ValueError(msg_same_units) + if self.batch_mode: + self.var_x.update(new_val_pair[0]) + self.var_y.update(new_val_pair[1]) + delta_var_x = new_val_pair[0] - self.var_x.mean + delta_var_y = new_val_pair[1] - self.var_y.mean + delta_covar = delta_var_x * delta_var_y + batch_size = len(new_val_pair[0]) + self.count += batch_size + delta_covar = delta_covar.sum(axis=0) + self.covariance_sum += delta_covar + else: + delta_var_x = new_val_pair[0] - self.var_x.mean + delta_var_y = new_val_pair[1] - self.var_y.mean + delta_covar = delta_var_x * delta_var_y + self.var_x.update(new_val_pair[0]) + self.var_y.update(new_val_pair[1]) + self.count += 1 + self.covariance_sum += \ + ((self.count - 1) / self.count) * delta_covar + + def get_cov(self, unbiased=False): + if self.var_x.mean is None and self.var_y.mean is None: + return None + if self.count > 1: + count = self.count - 1 if unbiased else self.count + cov = self.covariance_sum / count + else: + cov = 0. + return cov + + def reset(self): + self.var_x.reset() + self.var_y.reset() + self.units = None + self.covariance_sum = 0. + self.count = 0 + + +class PearsonCorrelationCoefficientOnline(object): + def __init__(self, batch_mode=False): + self.batch_mode = batch_mode + self.covariance_xy = CovarianceOnline(batch_mode=batch_mode) + self.units = None + self.R_xy = 0. + self.count = 0 + + def update(self, new_val_pair): + units = None + if isinstance(new_val_pair, pq.Quantity): + units = new_val_pair.units + new_val_pair = new_val_pair.magnitude + if self.count == 0: + self.covariance_xy.var_y.mean = 0. + self.covariance_xy.var_y.mean = 0. + self.units = units + elif units != self.units: + raise ValueError(msg_same_units) + self.covariance_xy.update(new_val_pair) + if self.batch_mode: + batch_size = len(new_val_pair[0]) + self.count += batch_size + else: + self.count += 1 + if self.count > 1: + self.R_xy = np.divide( + self.covariance_xy.covariance_sum, + (np.sqrt(self.covariance_xy.var_x.variance_sum * + self.covariance_xy.var_y.variance_sum))) + + def get_pcc(self): + if self.count == 0: + return None + elif self.count == 1: + return 0. + else: + return self.R_xy + + def reset(self): + self.count = 0 + self.units = None + self.R_xy = 0. + self.covariance_xy.reset() diff --git a/elephant/signal_processing.py b/elephant/signal_processing.py index 65c66e2a4..e2b06e07b 100644 --- a/elephant/signal_processing.py +++ b/elephant/signal_processing.py @@ -25,7 +25,8 @@ import quantities as pq import scipy.signal -from elephant.utils import check_same_units +from elephant.online import VarianceOnline +from elephant.utils import check_neo_consistency __all__ = [ "zscore", @@ -65,7 +66,7 @@ def zscore(signal, inplace=True): Signals for which to calculate the z-score. inplace : bool, optional If True, the contents of the input `signal` is replaced by the - z-transformed signal, if possible, i.e when the signal type is float. + z-transformed signal, if possible, i.e. when the signal type is float. If the signal type is not float, an error is raised. If False, a copy of the original `signal` is returned. Default: True @@ -154,17 +155,19 @@ def zscore(signal, inplace=True): # Transform input to a list if isinstance(signal, neo.AnalogSignal): signal = [signal] - check_same_units(signal, object_type=neo.AnalogSignal) + check_neo_consistency(signal, object_type=neo.AnalogSignal) - # Calculate mean and standard deviation - signal_stacked = np.vstack(signal).magnitude - mean = signal_stacked.mean(axis=0) - std = signal_stacked.std(axis=0) + # Calculate mean and standard deviation vectors + online = VarianceOnline(batch_mode=True) + for sig in signal: + online.update(sig.magnitude) + mean, std = online.get_mean_std(unbiased=False) signal_ztransformed = [] for sig in signal: # Perform inplace operation only if array is of dtype float. # Otherwise, raise an error. + if inplace and not np.issubdtype(sig.dtype, np.floating): raise ValueError(f"Cannot perform inplace operation as the " f"signal dtype is not float. Source: {sig.name}") @@ -290,6 +293,9 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False, If `scaleopt` is not one of the predefined above keywords. + .. bibliography:: + :keyprefix: signal- + Examples -------- .. plot:: @@ -335,9 +341,8 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False, "indices. Cannot define pairs for cross-correlation.") if not isinstance(hilbert_envelope, bool): raise ValueError("'hilbert_envelope' must be a boolean value") - if n_lags is not None: - if not isinstance(n_lags, int) or n_lags <= 0: - raise ValueError('n_lags must be a non-negative integer') + if n_lags is not None and (not isinstance(n_lags, int) or n_lags <= 0): + raise ValueError('n_lags must be a non-negative integer') # z-score analog signal and store channel time series in different arrays # Cross-correlation will be calculated between xsig and ysig @@ -568,7 +573,7 @@ def wavelet_transform(signal, frequency, n_cycles=6.0, sampling_frequency=1.0, Parameters ---------- signal : (Nt, Nch) neo.AnalogSignal or np.ndarray or list - Time series data to be wavelet-transformed. When multi-dimensional + Time series data to be wavelet-transformed. When multidimensional `np.ndarray` or list is given, the time axis must be the last dimension. If `neo.AnalogSignal`, `Nt` is the number of time points and `Nch` is the number of channels. @@ -921,6 +926,7 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None): raise ValueError('Input signal is not a neo.AnalogSignal!') if baseline is None: + # do nothing pass elif baseline == 'mean': # subtract mean from each channel diff --git a/elephant/spike_train_correlation.py b/elephant/spike_train_correlation.py index 03b4d9acb..b6e467098 100644 --- a/elephant/spike_train_correlation.py +++ b/elephant/spike_train_correlation.py @@ -376,7 +376,8 @@ def covariance(binned_spiketrain, binary=False, fast=True): binned_spiketrain, corrcoef_norm=False) -def correlation_coefficient(binned_spiketrain, binary=False, fast=True): +def correlation_coefficient(binned_spiketrain, binary=False, fast=True, + zero_diag=False): r""" Calculate the NxN matrix of pairwise Pearson's correlation coefficients between all combinations of N binned spike trains. @@ -415,6 +416,9 @@ def correlation_coefficient(binned_spiketrain, binary=False, fast=True): are counted as 1, resulting in binary binned vectors :math:`b_i`. If False, the binned vectors :math:`b_i` contain the spike counts per bin. Default: False + zero_diag : bool, optional + Zero-out the diagonal of a correlation matrix (True) or not (False). + Default: False fast : bool, optional If `fast=True` and the sparsity of `binned_spiketrain` is `> 0.1`, use `np.corrcoef()`. Otherwise, use memory efficient implementation. @@ -478,10 +482,13 @@ def correlation_coefficient(binned_spiketrain, binary=False, fast=True): if fast and binned_spiketrain.sparsity > _SPARSITY_MEMORY_EFFICIENT_THR: array = binned_spiketrain.to_array() - return np.corrcoef(array) + corr_mat = np.corrcoef(array) + else: + corr_mat = _covariance_sparse(binned_spiketrain, corrcoef_norm=True) - return _covariance_sparse( - binned_spiketrain, corrcoef_norm=True) + if zero_diag: + np.fill_diagonal(corr_mat, 0) + return corr_mat def corrcoef(*args, **kwargs): diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py new file mode 100644 index 000000000..8e27bf57f --- /dev/null +++ b/elephant/test/test_online.py @@ -0,0 +1,520 @@ +import unittest + +import numpy as np +import quantities as pq +from numpy.testing import assert_array_almost_equal + +from elephant import statistics + +from elephant.online import MeanOnline, VarianceOnline, CovarianceOnline, \ + PearsonCorrelationCoefficientOnline, InterSpikeIntervalOnline + +from elephant.spike_train_generation import StationaryPoissonProcess +from elephant.spike_train_synchrony import spike_contrast + + +class TestSlidingWindowGeneric(unittest.TestCase): + def test_fanofactor(self): + """ + This test computes the Fano factor in a sliding window fashion + without the use of MeanOnline or VarianceOnline. + """ + np.random.seed(0) + t_stop = 10 * pq.s + spiketrains = [StationaryPoissonProcess( + rate=20 * pq.Hz, t_stop=t_stop).generate_spiketrain() + for _ in range(10)] + fanofactor_target = statistics.fanofactor(spiketrains) + spike_counts = np.zeros(len(spiketrains)) + checkpoints = np.linspace(0 * pq.s, t_stop, num=10) + for t_start, t_stop in zip(checkpoints[:-1], checkpoints[1:]): + sts_chunk = [st.time_slice(t_start=t_start, t_stop=t_stop) + for st in spiketrains] + spike_counts += list(map(len, sts_chunk)) + fanofactor = spike_counts.var() / spike_counts.mean() + self.assertAlmostEqual(fanofactor, fanofactor_target) + + +class TestMeanOnline(unittest.TestCase): + def test_floats(self): + np.random.seed(0) + arr = np.random.rand(100) + online = MeanOnline() + for val in arr: + online.update(val) + self.assertIsNone(online.units) + self.assertIsInstance(online.get_mean(), float) + self.assertAlmostEqual(online.get_mean(), arr.mean()) + + def test_numpy_array(self): + np.random.seed(1) + arr = np.random.rand(10, 100) + online = MeanOnline() + for arr_vec in arr: + online.update(arr_vec) + self.assertIsInstance(online.get_mean(), np.ndarray) + self.assertIsNone(online.units) + self.assertEqual(online.get_mean().shape, (arr.shape[1],)) + assert_array_almost_equal(online.get_mean(), arr.mean(axis=0)) + + def test_quantity_scalar(self): + np.random.seed(2) + arr = np.random.rand(100) * pq.Hz + online = MeanOnline() + for val in arr: + online.update(val) + self.assertEqual(online.units, arr.units) + self.assertAlmostEqual(online.get_mean(), arr.mean()) + + def test_quantities_vector(self): + np.random.seed(3) + arr = np.random.rand(10, 100) * pq.ms + online = MeanOnline() + for arr_vec in arr: + online.update(arr_vec) + self.assertEqual(online.units, arr.units) + self.assertEqual(online.get_mean().shape, (arr.shape[1],)) + assert_array_almost_equal(online.get_mean(), arr.mean(axis=0)) + + def test_reset(self): + target_value = 2.5 + online = MeanOnline() + online.update(target_value) + self.assertEqual(online.get_mean(), target_value) + online.reset() + self.assertIsNone(online.mean) + self.assertIsNone(online.units) + self.assertEqual(online.count, 0) + + def test_mean_firing_rate(self): + np.random.seed(4) + spiketrain = np.random.rand(10000).cumsum() + rate_target = statistics.mean_firing_rate(spiketrain) + online = MeanOnline() + n_batches = 10 + t_start = None + for spiketrain_chunk in np.array_split(spiketrain, n_batches): + rate_batch = statistics.mean_firing_rate(spiketrain_chunk, + t_start=t_start) + online.update(rate_batch) + t_start = spiketrain_chunk[-1] + self.assertAlmostEqual(online.get_mean(), rate_target, places=3) + + def test_cv2(self): + np.random.seed(5) + spiketrain = np.random.rand(10000).cumsum() + cv2_target = statistics.cv2(statistics.isi(spiketrain)) + online = MeanOnline() + n_batches = 10 + for spiketrain_chunk in np.array_split(spiketrain, n_batches): + isi_batch = statistics.isi(spiketrain_chunk) + cv2_batch = statistics.cv2(isi_batch) + online.update(cv2_batch) + self.assertAlmostEqual(online.get_mean(), cv2_target, places=3) + + def test_lv(self): + np.random.seed(6) + spiketrain = np.random.rand(10000).cumsum() + lv_target = statistics.lv(statistics.isi(spiketrain)) + online = MeanOnline() + n_batches = 10 + for spiketrain_chunk in np.array_split(spiketrain, n_batches): + isi_batch = statistics.isi(spiketrain_chunk) + lv_batch = statistics.lv(isi_batch) + online.update(lv_batch) + self.assertAlmostEqual(online.get_mean(), lv_target, places=3) + + def test_lvr(self): + np.random.seed(6) + spiketrain = np.random.rand(10000).cumsum() + lvr_target = statistics.lvr(statistics.isi(spiketrain)) + online = MeanOnline() + n_batches = 10 + for spiketrain_chunk in np.array_split(spiketrain, n_batches): + isi_batch = statistics.isi(spiketrain_chunk) + lvr_batch = statistics.lvr(isi_batch) + online.update(lvr_batch) + self.assertAlmostEqual(online.get_mean(), lvr_target, places=1) + + def test_spike_contrast(self): + np.random.seed(1) + t_stop = 100 * pq.s + spiketrains = [StationaryPoissonProcess( + rate=20 * pq.Hz, t_stop=t_stop).generate_spiketrain() + for _ in range(10)] + synchrony_target = spike_contrast(spiketrains) + checkpoints = np.linspace(0 * pq.s, t_stop, num=10) + online = MeanOnline() + for t_start, t_stop in zip(checkpoints[:-1], checkpoints[1:]): + synchrony_batch = spike_contrast(spiketrains, + t_start=t_start, + t_stop=t_stop) + online.update(synchrony_batch) + synchrony = online.get_mean() + self.assertAlmostEqual(synchrony, synchrony_target, places=1) + + +class TestVarianceOnline(unittest.TestCase): + def test_floats(self): + np.random.seed(0) + arr = np.random.rand(100) + online = VarianceOnline() + for val in arr: + online.update(val) + self.assertIsNone(online.units) + self.assertIsInstance(online.get_mean(), float) + self.assertAlmostEqual(online.get_mean(), arr.mean()) + for unbiased in (False, True): + mean, std = online.get_mean_std(unbiased=unbiased) + self.assertAlmostEqual(mean, arr.mean()) + self.assertAlmostEqual(std, arr.std(ddof=unbiased)) + + def test_numpy_array(self): + np.random.seed(1) + arr = np.random.rand(10, 100) + online = VarianceOnline() + for arr_vec in arr: + online.update(arr_vec) + self.assertIsNone(online.units) + self.assertIsInstance(online.get_mean(), np.ndarray) + self.assertEqual(online.get_mean().shape, (arr.shape[1],)) + assert_array_almost_equal(online.get_mean(), arr.mean(axis=0)) + for unbiased in (False, True): + mean, std = online.get_mean_std(unbiased=unbiased) + assert_array_almost_equal(mean, arr.mean(axis=0)) + assert_array_almost_equal(std, arr.std(axis=0, ddof=unbiased)) + + def test_quantity_scalar(self): + np.random.seed(2) + arr = np.random.rand(100) * pq.Hz + online = VarianceOnline() + for val in arr: + online.update(val) + self.assertEqual(online.units, arr.units) + self.assertAlmostEqual(online.get_mean(), arr.mean()) + for unbiased in (False, True): + mean, std = online.get_mean_std(unbiased=unbiased) + self.assertAlmostEqual(mean, arr.mean()) + self.assertAlmostEqual(std, arr.std(ddof=unbiased)) + + def test_quantities_vector(self): + np.random.seed(3) + arr = np.random.rand(10, 100) * pq.ms + online = VarianceOnline() + for arr_vec in arr: + online.update(arr_vec) + self.assertEqual(online.units, arr.units) + self.assertEqual(online.get_mean().shape, (arr.shape[1],)) + assert_array_almost_equal(online.get_mean(), arr.mean(axis=0)) + for unbiased in (False, True): + mean, std = online.get_mean_std(unbiased=unbiased) + assert_array_almost_equal(mean, arr.mean(axis=0)) + assert_array_almost_equal(std, arr.std(axis=0, ddof=unbiased)) + + def test_reset(self): + target_value = 2.5 + online = VarianceOnline() + online.update(target_value) + self.assertEqual(online.get_mean(), target_value) + online.reset() + self.assertIsNone(online.mean) + self.assertIsNone(online.units) + self.assertEqual(online.count, 0) + self.assertEqual(online.variance_sum, 0.) + + def test_cv(self): + np.random.seed(4) + spiketrain = np.random.rand(10000).cumsum() + isi_all = statistics.isi(spiketrain) + cv_target = statistics.cv(isi_all) + online = VarianceOnline(batch_mode=True) + n_batches = 10 + for spiketrain_chunk in np.array_split(spiketrain, n_batches): + isi_batch = statistics.isi(spiketrain_chunk) + online.update(isi_batch) + isi_mean, isi_std = online.get_mean_std(unbiased=False) + cv_online = isi_std / isi_mean + self.assertAlmostEqual(cv_online, cv_target, places=3) + + +class TestInterSpikeIntervalOnline(unittest.TestCase): + def test_single_floats(self): + np.random.seed(0) + arr = np.sort(np.random.rand(100)) + online_isi = InterSpikeIntervalOnline() + for val in arr: + online_isi.update(val) + self.assertIsNone(online_isi.units) + standard_isi_histo, _ = np.histogram(statistics.isi(arr), + bins=online_isi.bin_edges) + np.testing.assert_allclose(online_isi.get_isi(), standard_isi_histo, + rtol=1e-15, atol=1e-15) + + def test_numpy_array(self): + np.random.seed(1) + arr = np.sort(np.random.rand(10 * 100)).reshape(10, 100) + online_isi = InterSpikeIntervalOnline(batch_mode=True) + for arr_vec in arr: + online_isi.update(arr_vec) + self.assertIsNone(online_isi.units) + standard_isi_histo, _ = np.histogram( + statistics.isi(arr.reshape(1, 10*100)), bins=online_isi.bin_edges) + np.testing.assert_allclose(online_isi.get_isi(), standard_isi_histo, + rtol=1e-15, atol=1e-15) + + def test_quantity_scalar(self): + np.random.seed(2) + arr = np.sort(np.random.rand(100)) * pq.s + online_isi = InterSpikeIntervalOnline() + for val in arr: + online_isi.update(val) + self.assertEqual(online_isi.units, arr.units) + standard_isi_histo, _ = np.histogram(statistics.isi(arr), + bins=online_isi.bin_edges) + np.testing.assert_allclose(online_isi.get_isi().magnitude, + standard_isi_histo, rtol=1e-15, atol=1e-15) + + def test_quantities_vector(self): + np.random.seed(3) + arr = np.sort(np.random.rand(10 * 100)).reshape(10, 100) * pq.s + online_isi = InterSpikeIntervalOnline(batch_mode=True) + for arr_vec in arr: + online_isi.update(arr_vec) + self.assertEqual(online_isi.units, arr.units) + standard_isi_histo, _ = np.histogram( + statistics.isi(arr.reshape(1, 10 * 100)), + bins=online_isi.bin_edges) + np.testing.assert_allclose(online_isi.get_isi().magnitude, + standard_isi_histo, rtol=1e-15, atol=1e-15) + + def test_reset(self): + np.random.seed(4) + arr = np.sort(np.random.rand(100)) * pq.s + online_isi = InterSpikeIntervalOnline() + for val in arr: + online_isi.update(val) + self.assertEqual(online_isi.units, arr.units) + standard_isi_histo, _ = np.histogram(statistics.isi(arr), + bins=online_isi.bin_edges) + np.testing.assert_allclose(online_isi.get_isi().magnitude, + standard_isi_histo, rtol=1e-15, atol=1e-15) + online_isi.reset() + self.assertIsNone(online_isi.units) + self.assertIsNone(online_isi.last_spike_time) + np.testing.assert_allclose(online_isi.current_isi_histogram, + np.zeros(shape=online_isi.num_bins)) + + +class TestCovarianceOnline(unittest.TestCase): + def test_simple_small_sets_XY_unbatched(self): + X = np.array([1, 2, 3, 2, 3]) + Y = np.array([4, 4, 3, 3, 4]) + online_cov = CovarianceOnline(batch_mode=False) + for x_i, y_i in zip(X, Y): + online_cov.update([x_i, y_i]) + self.assertIsNone(online_cov.units) + for unbiased in (False, True): # -0.12 (biased) / -0.15 (unbiased) + self.assertAlmostEqual(online_cov.get_cov(unbiased=unbiased), + np.cov([X, Y], bias=not unbiased)[0][1]) + + def test_simple_small_sets_XY_batched(self): + X = np.array([[1, 2, 3, 2, 3], [1, 2, 3, 2, 3], [1, 2, 3, 2, 3]]) + Y = np.array([[4, 4, 3, 3, 4], [4, 4, 3, 3, 4], [4, 4, 3, 3, 4]]) + online_cov = CovarianceOnline(batch_mode=True) + for x_i, y_i in zip(X, Y): + online_cov.update([x_i, y_i]) + self.assertIsNone(online_cov.units) + for unbiased in (False, True): # -0.12 (biased) / -0.15 (unbiased) + self.assertAlmostEqual(online_cov.get_cov(unbiased=unbiased), + np.cov([X.reshape(15), Y.reshape(15)], + bias=not unbiased)[0][1]) + + def test_floats(self): + np.random.seed(0) + X = np.random.rand(100) + Y = np.random.rand(100) + online_cov = CovarianceOnline() + for x_i, y_i in zip(X, Y): + online_cov.update([x_i, y_i]) + self.assertIsNone(online_cov.units) + self.assertIsInstance(online_cov.get_cov(), float) + for unbiased in (False, True): + self.assertAlmostEqual(online_cov.get_cov(unbiased=unbiased), + np.cov([X, Y], bias=not unbiased)[0][1]) + + def test_numpy_array(self): + np.random.seed(1) + X = np.random.rand(10, 100) + Y = np.random.rand(10, 100) + online_cov = CovarianceOnline(batch_mode=True) + for x_i, y_i in zip(X, Y): + online_cov.update([x_i, y_i]) + self.assertIsNone(online_cov.units) + self.assertIsInstance(online_cov.get_cov(), float) + for unbiased in (False, True): + self.assertAlmostEqual(online_cov.get_cov(unbiased=unbiased), + np.cov(X.reshape(10*100), Y.reshape(10*100), + bias=not unbiased)[0][1], places=5) + + def test_quantity_scaler(self): + np.random.seed(2) + X = np.random.rand(100) * pq.Hz + Y = np.random.rand(100) * pq.Hz + online_cov = CovarianceOnline() + for x_i, y_i in zip(X, Y): + online_cov.update([x_i, y_i] * X.units) + self.assertEqual(online_cov.units, X.units) + self.assertIsInstance(online_cov.get_cov(), float) + for unbiased in (False, True): + self.assertAlmostEqual( + online_cov.get_cov(unbiased=unbiased), + np.cov([X, Y], bias=not unbiased)[0][1]) + + def test_quantities_vector(self): + np.random.seed(3) + X = np.random.rand(10, 100) * pq.ms + Y = np.random.rand(10, 100) * pq.ms + online_cov = CovarianceOnline(batch_mode=True) + for x_i, y_i in zip(X, Y): + online_cov.update([x_i, y_i] * X.units) + self.assertEqual(online_cov.units, X.units) + self.assertIsInstance(online_cov.get_cov(), float) + for unbiased in (False, True): + self.assertAlmostEqual( + online_cov.get_cov(unbiased=unbiased), + np.cov(X.reshape(10*100), Y.reshape(10*100), + bias=not unbiased)[0][1], places=4) + + def test_reset(self): + X = np.array([1, 2, 3, 2, 3]) + Y = np.array([4, 4, 3, 3, 4]) + online_cov = CovarianceOnline() + for x_i, y_i in zip(X, Y): + online_cov.update([x_i, y_i]) + self.assertEqual(online_cov.get_cov(), -0.12) + online_cov.reset() + self.assertIsNone(online_cov.var_x.mean) + self.assertIsNone(online_cov.var_y.mean) + self.assertIsNone(online_cov.units) + self.assertEqual(online_cov.count, 0) + self.assertEqual(online_cov.covariance_sum, 0.) + + def test_units(self): + X = [[1, 2, 3, 2, 3], [1, 2, 3, 2, 3]] + Y = [[4, 4, 3, 3, 4], [4, 4, 3, 3, 4]] + online_cov = CovarianceOnline(batch_mode=True) + for x_i, y_i in zip(X, Y): + if online_cov.count == 5: + np.testing.assert_raises(ValueError, online_cov.update, + [x_i, y_i]*pq.ms) + else: + online_cov.update([x_i, y_i]*pq.s) + + +class TestPearsonCorrelationCoefficientOnline(unittest.TestCase): + def test_simple_small_sets_XY_unbatched(self): + X = np.array([1, 2, 3, 2, 3]) + Y = np.array([4, 4, 3, 3, 4]) + online_pcc = PearsonCorrelationCoefficientOnline(batch_mode=False) + for x_i, y_i in zip(X, Y): + online_pcc.update([x_i, y_i]) + self.assertIsNone(online_pcc.units) + self.assertAlmostEqual(online_pcc.get_pcc(), np.corrcoef([X, Y])[0][1]) + + def test_simple_small_sets_XY_batched(self): + X = np.array([[1, 2, 3, 2, 3], [1, 2, 3, 2, 3], [1, 2, 3, 2, 3]]) + Y = np.array([[4, 4, 3, 3, 4], [4, 4, 3, 3, 4], [4, 4, 3, 3, 4]]) + online_pcc = PearsonCorrelationCoefficientOnline(batch_mode=True) + for x_i, y_i in zip(X, Y): + online_pcc.update([x_i, y_i]) + self.assertIsNone(online_pcc.units) + self.assertAlmostEqual(online_pcc.get_pcc(), + np.corrcoef([X.reshape(15), + Y.reshape(15)])[0][1]) + + def test_floats(self): + np.random.seed(0) + X = np.random.rand(100) + Y = np.random.rand(100) + online_pcc = PearsonCorrelationCoefficientOnline() + for x_i, y_i in zip(X, Y): + online_pcc.update([x_i, y_i]) + self.assertIsNone(online_pcc.units) + self.assertIsInstance(online_pcc.get_pcc(), float) + self.assertAlmostEqual(online_pcc.get_pcc(), np.corrcoef([X, Y])[0][1]) + + def test_numpy_array(self): + np.random.seed(1) + X = np.random.rand(10, 100) + Y = np.random.rand(10, 100) + online_pcc = PearsonCorrelationCoefficientOnline(batch_mode=True) + for x_i, y_i in zip(X, Y): + online_pcc.update([x_i, y_i]) + self.assertIsNone(online_pcc.units) + self.assertIsInstance(online_pcc.get_pcc(), float) + self.assertAlmostEqual(online_pcc.get_pcc(), + np.corrcoef( + [X.reshape(10*100), + Y.reshape(10*100)])[0][1], places=3) + + def test_quantity_scaler(self): + np.random.seed(2) + X = np.random.rand(100) * pq.Hz + Y = np.random.rand(100) * pq.Hz + online_pcc = PearsonCorrelationCoefficientOnline() + for x_i, y_i in zip(X, Y): + online_pcc.update([x_i, y_i] * X.units) + self.assertEqual(online_pcc.units, X.units) + self.assertIsInstance(online_pcc.get_pcc(), float) + self.assertAlmostEqual(online_pcc.get_pcc(), np.corrcoef([X, Y])[0][1]) + + def test_quantities_vector(self): + np.random.seed(3) + X = np.random.rand(10, 100) * pq.ms + Y = np.random.rand(10, 100) * pq.ms + online_pcc = PearsonCorrelationCoefficientOnline(batch_mode=True) + for x_i, y_i in zip(X, Y): + online_pcc.update([x_i, y_i] * X.units) + self.assertEqual(online_pcc.units, X.units) + self.assertIsInstance(online_pcc.get_pcc(), float) + self.assertAlmostEqual( + online_pcc.get_pcc(), + np.corrcoef(X.reshape(10*100), Y.reshape(10*100))[0][1], places=2) + + def test_reset(self): + X = np.array([1, 2, 3, 2, 3]) + Y = X + online_pcc = PearsonCorrelationCoefficientOnline() + for x_i, y_i in zip(X, Y): + online_pcc.update([x_i, y_i]) + self.assertEqual(online_pcc.get_pcc(), 1) + online_pcc.reset() + self.assertIsNone(online_pcc.covariance_xy.var_x.mean) + self.assertEqual(online_pcc.covariance_xy.var_x.variance_sum, 0.) + self.assertEqual(online_pcc.covariance_xy.var_x.count, 0) + self.assertIsNone(online_pcc.covariance_xy.var_x.units) + self.assertIsNone(online_pcc.covariance_xy.var_y.mean) + self.assertEqual(online_pcc.covariance_xy.var_y.variance_sum, 0.) + self.assertEqual(online_pcc.covariance_xy.var_y.count, 0) + self.assertIsNone(online_pcc.covariance_xy.var_y.units) + self.assertIsNone(online_pcc.covariance_xy.units) + self.assertEqual(online_pcc.covariance_xy.covariance_sum, 0.) + self.assertEqual(online_pcc.covariance_xy.count, 0) + self.assertIsNone(online_pcc.units) + self.assertEqual(online_pcc.count, 0) + self.assertEqual(online_pcc.R_xy, 0.) + + def test_units(self): + X = [[1, 2, 3, 2, 3], [1, 2, 3, 2, 3]] + Y = [[4, 4, 3, 3, 4], [4, 4, 3, 3, 4]] + online_pcc = PearsonCorrelationCoefficientOnline(batch_mode=True) + for x_i, y_i in zip(X, Y): + if online_pcc.count == 5: + np.testing.assert_raises(ValueError, online_pcc.update, + [x_i, y_i]*pq.ms) + else: + online_pcc.update([x_i, y_i]*pq.s) + + +if __name__ == '__main__': + unittest.main()