From 14e2da2bcd55373cb1f49b337e82e5174aef6f58 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 10:29:41 +0200 Subject: [PATCH 01/16] use assert_allclose instead of assert np.allclose --- pycrostates/cluster/tests/test_aahc.py | 9 +++++---- pycrostates/cluster/tests/test_kmeans.py | 9 +++++---- pycrostates/cluster/utils/tests/test_utils.py | 5 +++-- pycrostates/io/tests/test_ch_data.py | 11 ++++++----- pycrostates/io/tests/test_fiff.py | 8 ++++---- pycrostates/io/tests/test_meas_info.py | 5 +++-- .../preprocessing/tests/test_resample.py | 3 ++- .../segmentation/tests/test_segmentation.py | 17 +++++++++-------- .../segmentation/tests/test_transitions.py | 7 ++++--- 9 files changed, 41 insertions(+), 33 deletions(-) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 35a2a0e3..f9846a4f 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -14,6 +14,7 @@ from mne.datasets import testing from mne.io import RawArray, read_raw_fif from mne.io.pick import _picks_to_idx +from numpy.testing import assert_allclose from pycrostates import __version__ from pycrostates.cluster import AAHCluster @@ -870,7 +871,7 @@ def test_refit(): with pytest.raises(RuntimeError, match="must be unfitted"): aahCluster_.fit(raw, picks="mag") # works assert eeg_ch_names == aahCluster_.info["ch_names"] - assert np.allclose(eeg_cluster_centers, aahCluster_.cluster_centers_) + assert_allclose(eeg_cluster_centers, aahCluster_.cluster_centers_) # pylint: disable=too-many-statements @@ -1281,9 +1282,9 @@ def test_save(tmp_path, caplog): segmentation1 = aahCluster1.predict(raw_eeg, picks="eeg") segmentation2 = aahCluster2.predict(raw_eeg, picks="eeg") - assert np.allclose(segmentation._labels, segmentation1._labels) - assert np.allclose(segmentation._labels, segmentation2._labels) - assert np.allclose(segmentation1._labels, segmentation2._labels) + assert_allclose(segmentation._labels, segmentation1._labels) + assert_allclose(segmentation._labels, segmentation2._labels) + assert_allclose(segmentation1._labels, segmentation2._labels) def test_comparison(caplog): diff --git a/pycrostates/cluster/tests/test_kmeans.py b/pycrostates/cluster/tests/test_kmeans.py index 7cc95be3..9dff11ea 100644 --- a/pycrostates/cluster/tests/test_kmeans.py +++ b/pycrostates/cluster/tests/test_kmeans.py @@ -14,6 +14,7 @@ from mne.datasets import testing from mne.io import RawArray, read_raw_fif from mne.io.pick import _picks_to_idx +from numpy.testing import assert_allclose from pycrostates import __version__ from pycrostates.cluster import ModKMeans @@ -434,7 +435,7 @@ def test_reorder(caplog): ModK__ = ModK_.copy() ModK_.reorder_clusters(order=np.array([1, 0, 2, 3])) ModK_.reorder_clusters(template=ModK__) - assert np.allclose(ModK_.cluster_centers_, ModK_.cluster_centers_) + assert_allclose(ModK_.cluster_centers_, ModK_.cluster_centers_) def test_properties(caplog): @@ -1260,9 +1261,9 @@ def test_save(tmp_path, caplog): segmentation1 = ModK1.predict(raw_eeg, picks="eeg") segmentation2 = ModK2.predict(raw_eeg, picks="eeg") - assert np.allclose(segmentation._labels, segmentation1._labels) - assert np.allclose(segmentation._labels, segmentation2._labels) - assert np.allclose(segmentation1._labels, segmentation2._labels) + assert_allclose(segmentation._labels, segmentation1._labels) + assert_allclose(segmentation._labels, segmentation2._labels) + assert_allclose(segmentation1._labels, segmentation2._labels) def test_comparison(caplog): diff --git a/pycrostates/cluster/utils/tests/test_utils.py b/pycrostates/cluster/utils/tests/test_utils.py index c85fa995..0a10b999 100644 --- a/pycrostates/cluster/utils/tests/test_utils.py +++ b/pycrostates/cluster/utils/tests/test_utils.py @@ -3,6 +3,7 @@ import numpy as np from mne.datasets import testing from mne.io import read_raw_fif +from numpy.testing import assert_allclose from pycrostates.cluster import ModKMeans from pycrostates.cluster.utils.utils import _optimize_order, optimize_order @@ -46,13 +47,13 @@ def test__optimize_order(): current = random_template ignore_polarity = False order = _optimize_order(current, template, ignore_polarity=ignore_polarity) - assert np.allclose(current[order], template) + assert_allclose(current[order], template) # Shuffle + ignore_polarity current = random_template ignore_polarity = True order = _optimize_order(current, template, ignore_polarity=ignore_polarity) - assert np.allclose(current[order], template) + assert_allclose(current[order], template) # Shuffle + sign + ignore_polarity current = random_pol_template diff --git a/pycrostates/io/tests/test_ch_data.py b/pycrostates/io/tests/test_ch_data.py index b3c698fe..977ef61f 100644 --- a/pycrostates/io/tests/test_ch_data.py +++ b/pycrostates/io/tests/test_ch_data.py @@ -1,6 +1,7 @@ import numpy as np import pytest from mne import create_info, pick_types +from numpy.testing import assert_allclose from pycrostates.io import ChData, ChInfo @@ -27,13 +28,13 @@ def test_ChData(): """Test basic ChData functionalities.""" # create from info ch_data = ChData(data, info.copy()) - assert np.allclose(ch_data._data, data) + assert_allclose(ch_data._data, data) assert isinstance(ch_data.info, ChInfo) assert info.ch_names == ch_data.info.ch_names # create from chinfo ch_data = ChData(data, ch_info.copy()) - assert np.allclose(ch_data._data, data) + assert_allclose(ch_data._data, data) assert isinstance(ch_data.info, ChInfo) assert ch_info.ch_names == ch_data.info.ch_names @@ -51,14 +52,14 @@ def test_ChData(): data_ = ch_data.get_data() data_[0, :] = 0.0 assert not np.allclose(data_, data) - assert np.allclose(ch_data._data, data) + assert_allclose(ch_data._data, data) # test get_data() with picks data_ = ch_data.get_data(picks="eeg") - assert np.allclose(data_, data) + assert_allclose(data_, data) ch_data.info["bads"] = [ch_data.info["ch_names"][0]] data_ = ch_data.get_data(picks="eeg") - assert np.allclose(data_, data[1:, :]) + assert_allclose(data_, data[1:, :]) # test repr assert isinstance(ch_data.__repr__(), str) diff --git a/pycrostates/io/tests/test_fiff.py b/pycrostates/io/tests/test_fiff.py index 67ee3685..8dcad705 100644 --- a/pycrostates/io/tests/test_fiff.py +++ b/pycrostates/io/tests/test_fiff.py @@ -2,11 +2,11 @@ import os -import numpy as np import pytest from mne.datasets import testing from mne.io import read_raw_fif from mne.preprocessing import ICA +from numpy.testing import assert_allclose from pycrostates import __version__ from pycrostates.cluster import ModKMeans @@ -90,9 +90,9 @@ def test_write_and_read(tmp_path, caplog): segmentation1 = ModK1.predict(raw2, picks="eeg") segmentation2 = ModK2.predict(raw2, picks="eeg") - assert np.allclose(segmentation._labels, segmentation1._labels) - assert np.allclose(segmentation._labels, segmentation2._labels) - assert np.allclose(segmentation1._labels, segmentation2._labels) + assert_allclose(segmentation._labels, segmentation1._labels) + assert_allclose(segmentation._labels, segmentation2._labels) + assert_allclose(segmentation1._labels, segmentation2._labels) def test_invalid_write(tmp_path): diff --git a/pycrostates/io/tests/test_meas_info.py b/pycrostates/io/tests/test_meas_info.py index a699fb30..ba7422c9 100644 --- a/pycrostates/io/tests/test_meas_info.py +++ b/pycrostates/io/tests/test_meas_info.py @@ -11,6 +11,7 @@ from mne.io.constants import FIFF from mne.transforms import Transform from mne.utils import check_version +from numpy.testing import assert_allclose from pycrostates.io import ChInfo from pycrostates.utils._logs import logger, set_log_level @@ -265,12 +266,12 @@ def test_montage(): montage.get_positions()[key] == montage2.get_positions()[key] ) elif isinstance(montage.get_positions()[key], np.ndarray): - assert np.allclose( + assert_allclose( montage.get_positions()[key], montage2.get_positions()[key] ) elif isinstance(montage.get_positions()[key], OrderedDict): for k, v in montage.get_positions()[key].items(): - assert np.allclose( + assert_allclose( montage.get_positions()[key][k], montage2.get_positions()[key][k], ) diff --git a/pycrostates/preprocessing/tests/test_resample.py b/pycrostates/preprocessing/tests/test_resample.py index 2932101e..2592a6da 100644 --- a/pycrostates/preprocessing/tests/test_resample.py +++ b/pycrostates/preprocessing/tests/test_resample.py @@ -6,6 +6,7 @@ from mne import BaseEpochs from mne.datasets import testing from mne.io.pick import _picks_to_idx +from numpy.testing import assert_allclose from pycrostates.io import ChData from pycrostates.preprocessing import resample @@ -151,4 +152,4 @@ def test_resample_random_state(): resamples_1 = resample(raw, n_resamples=1, n_samples=500, random_state=42)[ 0 ] - assert np.allclose(resamples_0._data, resamples_1._data) + assert_allclose(resamples_0._data, resamples_1._data) diff --git a/pycrostates/segmentation/tests/test_segmentation.py b/pycrostates/segmentation/tests/test_segmentation.py index 597a1652..95c9a37a 100644 --- a/pycrostates/segmentation/tests/test_segmentation.py +++ b/pycrostates/segmentation/tests/test_segmentation.py @@ -4,6 +4,7 @@ from mne import BaseEpochs, Epochs, make_fixed_length_events from mne.datasets import testing from mne.io import BaseRaw, read_raw_fif +from numpy.testing import assert_allclose from pycrostates.cluster import ModKMeans from pycrostates.segmentation import EpochsSegmentation, RawSegmentation @@ -66,9 +67,9 @@ def test_properties(ModK, inst, caplog): assert isinstance(predict_parameters, dict) cluster_centers_ -= 10 - assert np.allclose(cluster_centers_, segmentation._cluster_centers_ - 10) + assert_allclose(cluster_centers_, segmentation._cluster_centers_ - 10) labels -= 10 - assert np.allclose(labels, segmentation._labels - 10) + assert_allclose(labels, segmentation._labels - 10) predict_parameters["test"] = 10 assert "test" not in segmentation._predict_parameters @@ -308,11 +309,11 @@ def test_compute_transition_matrix_stat(ModK, inst): segmentation.compute_transition_matrix(stat="wrong") T = segmentation.compute_transition_matrix(stat="count") T = segmentation.compute_transition_matrix(stat="probability") - assert np.allclose(np.sum(T, axis=1), 1) + assert_allclose(np.sum(T, axis=1), 1) T = segmentation.compute_transition_matrix(stat="proportion") - assert np.allclose(np.sum(T, axis=1), 1) + assert_allclose(np.sum(T, axis=1), 1) T = segmentation.compute_transition_matrix(stat="percent") - assert np.allclose(np.sum(T, axis=1), 100) + assert_allclose(np.sum(T, axis=1), 100) def test_compute_expected_transition_matrix_Raw(): @@ -341,8 +342,8 @@ def test_compute_expected_transition_matrix_stat(ModK, inst): ): segmentation.compute_expected_transition_matrix(stat="count") T = segmentation.compute_expected_transition_matrix(stat="probability") - assert np.allclose(np.sum(T, axis=1), 1) + assert_allclose(np.sum(T, axis=1), 1) T = segmentation.compute_expected_transition_matrix(stat="proportion") - assert np.allclose(np.sum(T, axis=1), 1) + assert_allclose(np.sum(T, axis=1), 1) T = segmentation.compute_expected_transition_matrix(stat="percent") - assert np.allclose(np.sum(T, axis=1), 100) + assert_allclose(np.sum(T, axis=1), 100) diff --git a/pycrostates/segmentation/tests/test_transitions.py b/pycrostates/segmentation/tests/test_transitions.py index 73c93893..db71e1c2 100644 --- a/pycrostates/segmentation/tests/test_transitions.py +++ b/pycrostates/segmentation/tests/test_transitions.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from numpy.testing import assert_allclose from pycrostates.segmentation.transitions import ( _check_labels_n_clusters, @@ -82,7 +83,7 @@ def test_compute_transition_matrix(labels, ignore_self, T): ) assert isinstance(T, np.ndarray) assert t.shape == (n_clusters, n_clusters) - assert np.allclose(t, T) + assert_allclose(t, T) def test_compute_expected_transition_matrix(): @@ -101,7 +102,7 @@ def test_compute_expected_transition_matrix(): expected_T = _compute_expected_transition_matrix( labels, n_clusters, ignore_self=True, stat="probability" ) - assert np.allclose(boostrap_T, expected_T, atol=1e-2) + assert_allclose(boostrap_T, expected_T, atol=1e-2) # case where 1 state is missing labels = np.random.randint(-1, 3, 500) @@ -118,7 +119,7 @@ def test_compute_expected_transition_matrix(): expected_T = _compute_expected_transition_matrix( labels, n_clusters, ignore_self=True, stat="probability" ) - assert np.allclose(boostrap_T, expected_T, atol=1e-2) + assert_allclose(boostrap_T, expected_T, atol=1e-2) def test_check_labels_n_clusters(): From 135c0b5e4d05fc57facb2d786f15b04e7703f1c4 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 10:43:34 +0200 Subject: [PATCH 02/16] replace assert np.isclose(...).all() with assert_allclose and use sane rtol=1e-7, atol=0 in remaining np.allclose tests --- pycrostates/cluster/tests/test_aahc.py | 129 +++++++++++---------- pycrostates/cluster/tests/test_kmeans.py | 137 ++++++++++++----------- pycrostates/io/tests/test_ch_data.py | 2 +- 3 files changed, 143 insertions(+), 125 deletions(-) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index f9846a4f..224daa82 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -267,22 +267,22 @@ def test_aahClusterMeans(): # Test copy aahCluster2 = aahCluster1.copy() _check_fitted(aahCluster2) - assert np.isclose( + assert_allclose( aahCluster2._cluster_centers_, aahCluster1._cluster_centers_ - ).all() + ) assert np.isclose(aahCluster2.GEV_, aahCluster1.GEV_) - assert np.isclose(aahCluster2._labels_, aahCluster1._labels_).all() + assert_allclose(aahCluster2._labels_, aahCluster1._labels_) aahCluster2.fitted = False _check_fitted(aahCluster1) _check_unfitted(aahCluster2) aahCluster3 = aahCluster1.copy(deep=False) _check_fitted(aahCluster3) - assert np.isclose( + assert_allclose( aahCluster3._cluster_centers_, aahCluster1._cluster_centers_ - ).all() + ) assert np.isclose(aahCluster3.GEV_, aahCluster1.GEV_) - assert np.isclose(aahCluster3._labels_, aahCluster1._labels_).all() + assert_allclose(aahCluster3._labels_, aahCluster1._labels_) aahCluster3.fitted = False _check_fitted(aahCluster1) _check_unfitted(aahCluster3) @@ -314,52 +314,52 @@ def test_invert_polarity(): aahCluster_ = aah_cluster.copy() cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) aahCluster_.invert_polarity([True, False, True, False]) - assert np.isclose( + assert_allclose( aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :] - ).all() - assert np.isclose( + ) + assert_allclose( aahCluster_._cluster_centers_[1, :], cluster_centers_[1, :] - ).all() - assert np.isclose( + ) + assert_allclose( aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :] - ).all() - assert np.isclose( + ) + assert_allclose( aahCluster_._cluster_centers_[3, :], cluster_centers_[3, :] - ).all() + ) # bool aahCluster_ = aah_cluster.copy() cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) aahCluster_.invert_polarity(True) - assert np.isclose( + assert_allclose( aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :] - ).all() - assert np.isclose( + ) + assert_allclose( aahCluster_._cluster_centers_[1, :], -cluster_centers_[1, :] - ).all() - assert np.isclose( + ) + assert_allclose( aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :] - ).all() - assert np.isclose( + ) + assert_allclose( aahCluster_._cluster_centers_[3, :], -cluster_centers_[3, :] - ).all() + ) # np.array aahCluster_ = aah_cluster.copy() cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) aahCluster_.invert_polarity(np.array([True, False, True, False])) - assert np.isclose( + assert_allclose( aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :] - ).all() - assert np.isclose( + ) + assert_allclose( aahCluster_._cluster_centers_[1, :], cluster_centers_[1, :] - ).all() - assert np.isclose( + ) + assert_allclose( aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :] - ).all() - assert np.isclose( + ) + assert_allclose( aahCluster_._cluster_centers_[3, :], cluster_centers_[3, :] - ).all() + ) # Test invalid arguments with pytest.raises(ValueError, match="not a 2D iterable"): @@ -450,37 +450,37 @@ def test_reorder(caplog): # Test mapping aahCluster_ = aah_cluster.copy() aahCluster_.reorder_clusters(mapping={0: 1}) - assert np.isclose( + assert_allclose( aah_cluster._cluster_centers_[0, :], aahCluster_._cluster_centers_[1, :], - ).all() - assert np.isclose( + ) + assert_allclose( aah_cluster._cluster_centers_[1, :], aahCluster_._cluster_centers_[0, :], - ).all() + ) assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] # Test order aahCluster_ = aah_cluster.copy() aahCluster_.reorder_clusters(order=[1, 0, 2, 3]) - assert np.isclose( + assert_allclose( aah_cluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] - ).all() - assert np.isclose( + ) + assert_allclose( aah_cluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] - ).all() + ) assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] aahCluster_ = aah_cluster.copy() aahCluster_.reorder_clusters(order=np.array([1, 0, 2, 3])) - assert np.isclose( + assert_allclose( aah_cluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] - ).all() - assert np.isclose( + ) + assert_allclose( aah_cluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] - ).all() + ) assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] @@ -783,14 +783,14 @@ def test_fit_data_shapes(): aahCluster_reject_omit.fit(raw_, reject_by_annotation="omit") # Compare 'omit' and True - assert np.isclose( + assert_allclose( aahCluster_reject_omit._fitted_data, aahCluster_reject_True._fitted_data, - ).all() + ) assert np.isclose(aahCluster_reject_omit.GEV_, aahCluster_reject_True.GEV_) - assert np.isclose( + assert_allclose( aahCluster_reject_omit._labels_, aahCluster_reject_True._labels_ - ).all() + ) # due to internal randomness, the sign can be flipped sgn = np.sign( np.sum( @@ -800,10 +800,10 @@ def test_fit_data_shapes(): ) ) aahCluster_reject_True._cluster_centers_ *= sgn[:, None] - assert np.isclose( + assert_allclose( aahCluster_reject_omit._cluster_centers_, aahCluster_reject_True._cluster_centers_, - ).all() + ) # Make sure there is a shape diff between True and False assert ( @@ -835,9 +835,9 @@ def test_fit_data_shapes(): aahCluster_rej_5_end._fitted_data, raw_, "eeg", 5, None, "omit" ) assert aahCluster_rej_0_5._fitted_data.shape != fitted_data_0_5.shape - assert np.isclose( + assert_allclose( fitted_data_5_end, aahCluster_rej_5_end._fitted_data - ).all() + ) def test_refit(): @@ -966,15 +966,18 @@ def test_predict_default(caplog): segmentation_no_annot = aah_cluster.predict( raw_eeg, factor=0, reject_edges=True, reject_by_annotation="omit" ) - assert not np.isclose( - segmentation_rej_True._labels, segmentation_rej_False._labels - ).all() - assert np.isclose( + assert not np.allclose( + segmentation_rej_True._labels, + segmentation_rej_False._labels, + rtol=1e-7, + atol=0, + ) + assert_allclose( segmentation_no_annot._labels, segmentation_rej_False._labels - ).all() - assert np.isclose( + ) + assert_allclose( segmentation_rej_None._labels, segmentation_rej_False._labels - ).all() + ) # test different half_window_size segmentation1 = aah_cluster.predict( @@ -986,9 +989,15 @@ def test_predict_default(caplog): segmentation3 = aah_cluster.predict( raw_eeg, factor=0, reject_edges=False, half_window_size=3 ) - assert not np.isclose(segmentation1._labels, segmentation2._labels).all() - assert not np.isclose(segmentation1._labels, segmentation3._labels).all() - assert not np.isclose(segmentation2._labels, segmentation3._labels).all() + assert not np.allclose( + segmentation1._labels, segmentation2._labels, rtol=1e-7, atol=0 + ) + assert not np.allclose( + segmentation1._labels, segmentation3._labels, rtol=1e-7, atol=0 + ) + assert not np.allclose( + segmentation2._labels, segmentation3._labels, rtol=1e-7, atol=0 + ) # pylint: enable=too-many-statements diff --git a/pycrostates/cluster/tests/test_kmeans.py b/pycrostates/cluster/tests/test_kmeans.py index 9dff11ea..d8480dea 100644 --- a/pycrostates/cluster/tests/test_kmeans.py +++ b/pycrostates/cluster/tests/test_kmeans.py @@ -160,18 +160,18 @@ def test_ModKMeans(): # Test copy ModK2 = ModK1.copy() _check_fitted(ModK2) - assert np.isclose(ModK2._cluster_centers_, ModK1._cluster_centers_).all() + assert_allclose(ModK2._cluster_centers_, ModK1._cluster_centers_) assert np.isclose(ModK2.GEV_, ModK1.GEV_) - assert np.isclose(ModK2._labels_, ModK1._labels_).all() + assert_allclose(ModK2._labels_, ModK1._labels_) ModK2.fitted = False _check_fitted(ModK1) _check_unfitted(ModK2) ModK3 = ModK1.copy(deep=False) _check_fitted(ModK3) - assert np.isclose(ModK3._cluster_centers_, ModK1._cluster_centers_).all() + assert_allclose(ModK3._cluster_centers_, ModK1._cluster_centers_) assert np.isclose(ModK3.GEV_, ModK1.GEV_) - assert np.isclose(ModK3._labels_, ModK1._labels_).all() + assert_allclose(ModK3._labels_, ModK1._labels_) ModK3.fitted = False _check_fitted(ModK1) _check_unfitted(ModK3) @@ -203,52 +203,52 @@ def test_invert_polarity(): ModK_ = ModK.copy() cluster_centers_ = deepcopy(ModK_._cluster_centers_) ModK_.invert_polarity([True, False, True, False]) - assert np.isclose( + assert_allclose( ModK_._cluster_centers_[0, :], -cluster_centers_[0, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK_._cluster_centers_[1, :], cluster_centers_[1, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK_._cluster_centers_[2, :], -cluster_centers_[2, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK_._cluster_centers_[3, :], cluster_centers_[3, :] - ).all() + ) # bool ModK_ = ModK.copy() cluster_centers_ = deepcopy(ModK_._cluster_centers_) ModK_.invert_polarity(True) - assert np.isclose( + assert_allclose( ModK_._cluster_centers_[0, :], -cluster_centers_[0, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK_._cluster_centers_[1, :], -cluster_centers_[1, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK_._cluster_centers_[2, :], -cluster_centers_[2, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK_._cluster_centers_[3, :], -cluster_centers_[3, :] - ).all() + ) # np.array ModK_ = ModK.copy() cluster_centers_ = deepcopy(ModK_._cluster_centers_) ModK_.invert_polarity(np.array([True, False, True, False])) - assert np.isclose( + assert_allclose( ModK_._cluster_centers_[0, :], -cluster_centers_[0, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK_._cluster_centers_[1, :], cluster_centers_[1, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK_._cluster_centers_[2, :], -cluster_centers_[2, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK_._cluster_centers_[3, :], cluster_centers_[3, :] - ).all() + ) # Test invalid arguments with pytest.raises(ValueError, match="not a 2D iterable"): @@ -335,35 +335,35 @@ def test_reorder(caplog): # Test mapping ModK_ = ModK.copy() ModK_.reorder_clusters(mapping={0: 1}) - assert np.isclose( + assert_allclose( ModK._cluster_centers_[0, :], ModK_._cluster_centers_[1, :] - ).all() - assert np.isclose( + ) + assert_allclose( ModK._cluster_centers_[1, :], ModK_._cluster_centers_[0, :] - ).all() + ) assert ModK._cluster_names[0] == ModK_._cluster_names[1] assert ModK._cluster_names[0] == ModK_._cluster_names[1] # Test order ModK_ = ModK.copy() ModK_.reorder_clusters(order=[1, 0, 2, 3]) - assert np.isclose( + assert_allclose( ModK._cluster_centers_[0], ModK_._cluster_centers_[1] - ).all() - assert np.isclose( + ) + assert_allclose( ModK._cluster_centers_[1], ModK_._cluster_centers_[0] - ).all() + ) assert ModK._cluster_names[0] == ModK_._cluster_names[1] assert ModK._cluster_names[0] == ModK_._cluster_names[1] ModK_ = ModK.copy() ModK_.reorder_clusters(order=np.array([1, 0, 2, 3])) - assert np.isclose( + assert_allclose( ModK._cluster_centers_[0], ModK_._cluster_centers_[1] - ).all() - assert np.isclose( + ) + assert_allclose( ModK._cluster_centers_[1], ModK_._cluster_centers_[0] - ).all() + ) assert ModK._cluster_names[0] == ModK_._cluster_names[1] assert ModK._cluster_names[0] == ModK_._cluster_names[1] @@ -699,16 +699,16 @@ def test_fit_data_shapes(): ModK_reject_omit.fit(raw_, n_jobs=1, reject_by_annotation="omit") # Compare 'omit' and True - assert np.isclose( + assert_allclose( ModK_reject_omit._fitted_data, ModK_reject_True._fitted_data - ).all() + ) assert np.isclose(ModK_reject_omit.GEV_, ModK_reject_True.GEV_) - assert np.isclose( + assert_allclose( ModK_reject_omit._labels_, ModK_reject_True._labels_ - ).all() - assert np.isclose( + ) + assert_allclose( ModK_reject_omit._cluster_centers_, ModK_reject_True._cluster_centers_ - ).all() + ) # Make sure there is a shape diff between True and False assert ( @@ -740,7 +740,7 @@ def test_fit_data_shapes(): ModK_rej_5_end._fitted_data, raw_, "eeg", 5, None, "omit" ) assert ModK_rej_0_5._fitted_data.shape != fitted_data_0_5.shape - assert np.isclose(fitted_data_5_end, ModK_rej_5_end._fitted_data).all() + assert_allclose(fitted_data_5_end, ModK_rej_5_end._fitted_data) def test_refit(): @@ -778,7 +778,7 @@ def test_refit(): with pytest.raises(RuntimeError, match="must be unfitted"): ModK_.fit(raw, picks="mag") # works assert eeg_ch_names == ModK_.info["ch_names"] - assert np.isclose(eeg_cluster_centers, ModK_.cluster_centers_).all() + assert_allclose(eeg_cluster_centers, ModK_.cluster_centers_) def test_predict_default(caplog): @@ -870,15 +870,18 @@ def test_predict_default(caplog): segmentation_no_annot = ModK.predict( raw_eeg, factor=0, reject_edges=True, reject_by_annotation="omit" ) - assert not np.isclose( - segmentation_rej_True._labels, segmentation_rej_False._labels - ).all() - assert np.isclose( + assert not np.allclose( + segmentation_rej_True._labels, + segmentation_rej_False._labels, + rtol=1e-7, + atol=0, + ) + assert_allclose( segmentation_no_annot._labels, segmentation_rej_False._labels - ).all() - assert np.isclose( + ) + assert_allclose( segmentation_rej_None._labels, segmentation_rej_False._labels - ).all() + ) # test different half_window_size segmentation1 = ModK.predict( @@ -890,9 +893,15 @@ def test_predict_default(caplog): segmentation3 = ModK.predict( raw_eeg, factor=0, reject_edges=False, half_window_size=3 ) - assert not np.isclose(segmentation1._labels, segmentation2._labels).all() - assert not np.isclose(segmentation1._labels, segmentation3._labels).all() - assert not np.isclose(segmentation2._labels, segmentation3._labels).all() + assert not np.allclose( + segmentation1._labels, segmentation2._labels, rtol=1e-7, atol=0 + ) + assert not np.allclose( + segmentation1._labels, segmentation3._labels, rtol=1e-7, atol=0 + ) + assert not np.allclose( + segmentation2._labels, segmentation3._labels, rtol=1e-7, atol=0 + ) def test_picks_fit_predict(caplog): @@ -1106,9 +1115,9 @@ def test_n_jobs(): ) ModK_.fit(raw_eeg, n_jobs=2) _check_fitted(ModK_) - assert np.isclose(ModK_._cluster_centers_, ModK._cluster_centers_).all() + assert_allclose(ModK_._cluster_centers_, ModK._cluster_centers_) assert np.isclose(ModK_.GEV_, ModK.GEV_) - assert np.isclose(ModK_._labels_, ModK._labels_).all() + assert_allclose(ModK_._labels_, ModK._labels_) def test_fit_not_converged(caplog): @@ -1176,10 +1185,10 @@ def test_randomseed(): ) ModK3.fit(raw_eeg, n_jobs=1) - assert np.isclose(ModK1._cluster_centers_, ModK2._cluster_centers_).all() - assert not np.isclose( - ModK1._cluster_centers_, ModK3._cluster_centers_ - ).all() + assert_allclose(ModK1._cluster_centers_, ModK2._cluster_centers_) + assert not np.allclose( + ModK1._cluster_centers_, ModK3._cluster_centers_, rtol=1e-7, atol=0 + ) def test_contains_mixin(): diff --git a/pycrostates/io/tests/test_ch_data.py b/pycrostates/io/tests/test_ch_data.py index 977ef61f..5bc03217 100644 --- a/pycrostates/io/tests/test_ch_data.py +++ b/pycrostates/io/tests/test_ch_data.py @@ -51,7 +51,7 @@ def test_ChData(): # test that data is copied data_ = ch_data.get_data() data_[0, :] = 0.0 - assert not np.allclose(data_, data) + assert not np.allclose(data_, data, rtol=1e-7, atol=0) assert_allclose(ch_data._data, data) # test get_data() with picks From 8430fe1f0d3db91c896f2ffb98e5777806fd6942 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 10:48:11 +0200 Subject: [PATCH 03/16] fix warning about gradient_kwargs --- pycrostates/viz/cluster_centers.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pycrostates/viz/cluster_centers.py b/pycrostates/viz/cluster_centers.py index b05404cb..cb8a278b 100644 --- a/pycrostates/viz/cluster_centers.py +++ b/pycrostates/viz/cluster_centers.py @@ -16,6 +16,13 @@ from ..utils._logs import logger, verbose +_GRADIENT_KWARGS_DEFAULTS: Dict[str, str] = { + "color": "black", + "linestyle": "-", + "marker": "P", +} + + @fill_doc @verbose def plot_cluster_centers( @@ -24,11 +31,7 @@ def plot_cluster_centers( cluster_names: List[str] = None, axes: Optional[Union[Axes, NDArray[Axes]]] = None, show_gradient: Optional[bool] = False, - gradient_kwargs: Dict[str, Any] = { - "color": "black", - "linestyle": "-", - "marker": "P", - }, + gradient_kwargs: Dict[str, Any] = _GRADIENT_KWARGS_DEFAULTS, *, block: bool = False, verbose: Optional[str] = None, @@ -74,7 +77,7 @@ def plot_cluster_centers( (dict,), "gradient_kwargs", ) - if gradient_kwargs is not None and not show_gradient: + if gradient_kwargs != _GRADIENT_KWARGS_DEFAULTS and not show_gradient: logger.warning( "The argument 'gradient_kwargs' has not effect when " "the argument 'show_gradient' is set to False." From 025280388361b0edb978b22cdd84cdf06c9132f4 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 10:48:41 +0200 Subject: [PATCH 04/16] fix missing f-string --- pycrostates/viz/cluster_centers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycrostates/viz/cluster_centers.py b/pycrostates/viz/cluster_centers.py index cb8a278b..3db4fbbd 100644 --- a/pycrostates/viz/cluster_centers.py +++ b/pycrostates/viz/cluster_centers.py @@ -118,7 +118,7 @@ def plot_cluster_centers( raise ValueError( "Argument 'cluster_centers' and 'axes' must contain the same " f"number of clusters and Axes. Provided: {n_clusters} " - "microstates maps and {axes.size} axes." + f"microstates maps and {axes.size} axes." ) figs = [ax.get_figure() for ax in axes.flatten()] if len(set(figs)) == 1: From 48447765f0bacbf07c5811b66c9266fdf11ea186 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 11:08:39 +0200 Subject: [PATCH 05/16] update style configuration --- .codespellignore | 0 .flake8 | 24 ------------------------ .github/workflows/code-style.yml | 10 +++------- pyproject.toml | 25 ++++++++++++++++++------- 4 files changed, 21 insertions(+), 38 deletions(-) create mode 100644 .codespellignore delete mode 100644 .flake8 diff --git a/.codespellignore b/.codespellignore new file mode 100644 index 00000000..e69de29b diff --git a/.flake8 b/.flake8 deleted file mode 100644 index bb010fd5..00000000 --- a/.flake8 +++ /dev/null @@ -1,24 +0,0 @@ -[flake8] -max-line-length = 79 - -ignore = - # these rules don't play well with black - # whitespace before ':' - E203, - # line break before binary operator - W503, - E241,E305,W504,W605,E731 - -exclude = - .git, - .github, - .pytest_cache, - pycrostates.egg-info, - setup.py, - docs/source/conf.py, - tutorials/* - -per-file-ignores = - # __init__.py files are allowed to have unused imports and lines-too-long - */__init__.py:F401 - */**/__init__.py:F401,E501 diff --git a/.github/workflows/code-style.yml b/.github/workflows/code-style.yml index a5c38838..d5181af0 100644 --- a/.github/workflows/code-style.yml +++ b/.github/workflows/code-style.yml @@ -1,7 +1,4 @@ name: style -# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#concurrency -# https://docs.github.com/en/developers/webhooks-and-events/events/github-event-types#pullrequestevent -# workflow name, PR number (empty on push), push ref (empty on PR) concurrency: group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} cancel-in-progress: true @@ -27,10 +24,8 @@ jobs: run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel python -m pip install --progress-bar off .[style] - - name: Run flake8 - uses: py-actions/flake8@v2 - with: - path: "pycrostates" + - name: Run Ruff + run: ruff check pycrostates - name: Run isort uses: isort/isort-action@master - name: Run black @@ -43,6 +38,7 @@ jobs: check_filenames: true check_hidden: true skip: ./.git,./build,./.github,*.bib + ignore_words_file: ./.codespellignore - name: Run pydocstyle run: pydocstyle . - name: Run bibclean diff --git a/pyproject.toml b/pyproject.toml index 5de0951a..6bea1255 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,8 +73,8 @@ style = [ 'black', 'codespell', 'isort', - 'flake8', 'pydocstyle[toml]', + 'ruff', ] test = [ 'pymatreader', @@ -107,8 +107,8 @@ exclude = ['pycrostates*tests'] "pycrostates.html_templates" = ["repr/*.jinja"] [tool.black] -line-length = 79 -target-version = ['py37'] +line-length = 88 +target-version = ['py39'] include = '\.pyi?$' extend-exclude = ''' ( @@ -124,8 +124,8 @@ extend-exclude = ''' [tool.isort] profile = 'black' multi_line_output = 3 -line_length = 79 -py_version = 37 +line_length = 88 +py_version = 39 extend_skip_glob = [ 'setup.py', 'docs/*', @@ -135,11 +135,21 @@ extend_skip_glob = [ [tool.pydocstyle] convention = 'numpy' -ignore-decorators= '(copy_doc|property|.*setter|.*getter)' +ignore-decorators= '(copy_doc|property|.*setter|.*getter|pyqtSlot|Slot)' match = '^(?!setup|__init__|test_).*\.py' -match-dir = '^(?!docs|tutorials|build|dist|\.).*' +match-dir = '^pycrostates.*' add_ignore = 'D100,D104,D107' +[tool.ruff] +line-length = 88 +extend-exclude = [ + "doc", + "setup.py", +] + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] + [tool.pytest.ini_options] minversion = '6.0' addopts = '--durations 20 --junit-xml=junit-results.xml --verbose' @@ -161,5 +171,6 @@ omit = [ exclude_lines = [ 'pragma: no cover', 'if __name__ == .__main__.:', + 'if TYPE_CHECKING:', ] precision = 2 From 801c0b5b8c2cdf821ba98fa768ff2c1bacdeb2cf Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 11:08:46 +0200 Subject: [PATCH 06/16] run ruff, isort, black --- pycrostates/cluster/_base.py | 62 ++--- pycrostates/cluster/aahc.py | 12 +- pycrostates/cluster/kmeans.py | 15 +- pycrostates/cluster/tests/test_aahc.py | 256 +++++------------- pycrostates/cluster/tests/test_kmeans.py | 213 ++++----------- pycrostates/cluster/utils/tests/test_utils.py | 4 +- pycrostates/cluster/utils/utils.py | 3 +- pycrostates/io/ch_data.py | 14 +- pycrostates/io/fiff.py | 27 +- pycrostates/io/meas_info.py | 15 +- pycrostates/io/reader.py | 4 +- pycrostates/io/tests/test_ch_data.py | 20 +- pycrostates/io/tests/test_fiff.py | 60 +--- pycrostates/io/tests/test_meas_info.py | 13 +- pycrostates/io/tests/test_reader.py | 4 +- pycrostates/metrics/calinski_harabasz.py | 4 +- pycrostates/metrics/dunn.py | 4 +- .../preprocessing/extract_gfp_peaks.py | 11 +- pycrostates/preprocessing/resample.py | 4 +- pycrostates/preprocessing/spatial_filter.py | 17 +- .../tests/test_extract_gfp_peaks.py | 8 +- .../preprocessing/tests/test_resample.py | 32 +-- .../tests/test_spatial_filter.py | 16 +- pycrostates/segmentation/_base.py | 41 +-- .../segmentation/tests/test_segmentation.py | 52 +--- .../segmentation/tests/test_transitions.py | 12 +- pycrostates/segmentation/transitions.py | 4 +- pycrostates/utils/__init__.py | 6 +- pycrostates/utils/_checks.py | 8 +- pycrostates/utils/_docs.py | 4 +- pycrostates/utils/_imports.py | 4 +- pycrostates/utils/_logs.py | 4 +- pycrostates/utils/tests/test_checks.py | 28 +- pycrostates/utils/tests/test_logs.py | 11 +- pycrostates/utils/tests/test_mixin.py | 20 +- pycrostates/utils/utils.py | 20 +- pycrostates/viz/cluster_centers.py | 5 +- pycrostates/viz/segmentation.py | 6 +- pycrostates/viz/tests/test_cluster_centers.py | 4 +- pycrostates/viz/tests/test_segmentation.py | 12 +- 40 files changed, 271 insertions(+), 788 deletions(-) diff --git a/pycrostates/cluster/_base.py b/pycrostates/cluster/_base.py index eedfd6c5..bfd44f90 100644 --- a/pycrostates/cluster/_base.py +++ b/pycrostates/cluster/_base.py @@ -225,14 +225,10 @@ def fit( if isinstance(inst, (BaseRaw, BaseEpochs)): tmin, tmax = _check_tmin_tmax(inst, tmin, tmax) if isinstance(inst, BaseRaw): - reject_by_annotation = _check_reject_by_annotation( - reject_by_annotation - ) + reject_by_annotation = _check_reject_by_annotation(reject_by_annotation) # picks - picks_bads_inc = _picks_to_idx( - inst.info, picks, none="all", exclude=[] - ) + picks_bads_inc = _picks_to_idx(inst.info, picks, none="all", exclude=[]) picks = _picks_to_idx(inst.info, picks, none="all", exclude="bads") _check_picks_uniqueness(inst.info, picks) ch_not_used = set(picks_bads_inc) - set(picks) @@ -256,9 +252,7 @@ def fit( del msg # retrieve numpy array - kwargs = ( - dict() if isinstance(inst, ChData) else dict(tmin=tmin, tmax=tmax) - ) + kwargs = dict() if isinstance(inst, ChData) else dict(tmin=tmin, tmax=tmax) if isinstance(inst, BaseRaw): kwargs["reject_by_annotation"] = reject_by_annotation data = inst.get_data(picks=picks, **kwargs) @@ -271,15 +265,9 @@ def fit( info = pick_info(inst.info, picks, copy=True) if info["bads"] != []: if len(info["bads"]) == 1: - msg = ( - "The channel %s is set as bad and will be used for " - "fitting." - ) + msg = "The channel %s is set as bad and will be used for fitting." else: - msg = ( - "The channels %s are set as bad and will be used for " - "fitting." - ) + msg = "The channels %s are set as bad and will be used for fitting." logger.warning(msg, ", ".join(ch_name for ch_name in info["bads"])) del msg self._info = ChInfo(info=info) @@ -314,9 +302,7 @@ def rename_clusters( self._check_fit() if mapping is not None and new_names is not None: - raise ValueError( - "Only one of 'mapping' or 'new_names' must be provided." - ) + raise ValueError("Only one of 'mapping' or 'new_names' must be provided.") if mapping is not None: _check_type(mapping, (dict,), item_name="mapping") @@ -339,8 +325,7 @@ def rename_clusters( # convert to dict mapping = { - old_name: new_names[k] - for k, old_name in enumerate(self._cluster_names) + old_name: new_names[k] for k, old_name in enumerate(self._cluster_names) } else: @@ -351,8 +336,7 @@ def rename_clusters( return self._cluster_names = [ - mapping[name] if name in mapping else name - for name in self._cluster_names + mapping[name] if name in mapping else name for name in self._cluster_names ] def reorder_clusters( @@ -403,8 +387,7 @@ def reorder_clusters( if sum(x is not None for x in (mapping, order, template)) > 1: raise ValueError( - "Only one of 'mapping', 'order' or 'template' " - "must be provided." + "Only one of 'mapping', 'order' or 'template' must be provided." ) # Mapping @@ -420,9 +403,7 @@ def reorder_clusters( # check uniqueness if len(set(mapping.values())) != len(mapping.values()): - raise ValueError( - "Position in the new order can not be repeated." - ) + raise ValueError("Position in the new order can not be repeated.") # check that a cluster is not moved twice for key in mapping: if key in mapping.values(): @@ -677,9 +658,7 @@ def predict( _check_type(factor, ("int",), item_name="factor") _check_type(half_window_size, ("int",), item_name="half_window_size") _check_type(tol, ("numeric",), item_name="tol") - _check_type( - min_segment_length, ("int",), item_name="min_segment_length" - ) + _check_type(min_segment_length, ("int",), item_name="min_segment_length") _check_type(reject_edges, (bool,), item_name="reject_edges") _check_type( reject_by_annotation, @@ -710,9 +689,7 @@ def predict( "The current fit contains bad channels %s" + " which will be used for prediction." ) - logger.warning( - msg, ", ".join(ch_name for ch_name in self._info["bads"]) - ) + logger.warning(msg, ", ".join(ch_name for ch_name in self._info["bads"])) del msg # check that the instance as the required channels (good + bads) @@ -1019,9 +996,9 @@ def _smooth_segmentation( w = np.zeros((Nu, Nt)) w[(rmat == labels)] = 1 - e = np.sum( - Vvar - np.sum(np.dot(w.T, states).T * data, axis=0) ** 2 - ) / (Nt * (Ne - 1)) + e = np.sum(Vvar - np.sum(np.dot(w.T, states).T * data, axis=0) ** 2) / ( + Nt * (Ne - 1) + ) window = np.ones((1, 2 * half_window_size + 1)) S0 = 0 @@ -1035,9 +1012,9 @@ def _smooth_segmentation( labels = dlt w = np.zeros((Nu, Nt)) w[(rmat == labels)] = 1 - Su = np.sum( - Vvar - np.sum(np.dot(w.T, states).T * data, axis=0) ** 2 - ) / (Nt * (Ne - 1)) + Su = np.sum(Vvar - np.sum(np.dot(w.T, states).T * data, axis=0) ** 2) / ( + Nt * (Ne - 1) + ) if np.abs(Su - S0) <= np.abs(tol * Su): break S0 = Su @@ -1064,8 +1041,7 @@ def _reject_short_segments( skip_condition = [ k in (0, len(segments) - 1), # ignore edge segments segment[0] == -1, # ignore segments labelled with 0 - min_segment_length - <= len(segment), # ignore large segments + min_segment_length <= len(segment), # ignore large segments ] if any(skip_condition): idx += len(segment) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 99c12d2b..b9c3558a 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -48,9 +48,7 @@ def __init__( # make the parameter an argument # https://github.com/vferat/pycrostates/pull/93#issue-1431122168 self._ignore_polarity = True - self._normalize_input = AAHCluster._check_normalize_input( - normalize_input - ) + self._normalize_input = AAHCluster._check_normalize_input(normalize_input) # fit variables self._GEV_ = None @@ -225,9 +223,7 @@ def _compute_maps( cluster = np.delete(cluster, to_remove, axis=1) GEV = np.delete(GEV, to_remove, axis=0) - assignment[assignment > to_remove] = ( - assignment[assignment > to_remove] - 1 - ) + assignment[assignment > to_remove] = assignment[assignment > to_remove] - 1 fit = data[:, orphans].T @ cluster if ignore_polarity: @@ -239,9 +235,7 @@ def _compute_maps( for c in cluster_to_update: members = assignment == c if ignore_polarity: - evecs, _, _ = np.linalg.svd( - data[:, members], full_matrices=False - ) + evecs, _, _ = np.linalg.svd(data[:, members], full_matrices=False) cluster[:, c] = evecs[:, 0] else: cluster[:, c] = np.mean(data[:, members], axis=1) diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index 6d841cd0..e6d0ac00 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -207,9 +207,7 @@ def fit( for init in inits ) try: - best_run = np.nanargmax( - [run[0] if run[3] else np.nan for run in runs] - ) + best_run = np.nanargmax([run[0] if run[3] else np.nan for run in runs]) best_gev, best_maps, best_segmentation, _ = runs[best_run] count_converged = sum(run[3] for run in runs) except ValueError: @@ -306,9 +304,7 @@ def _compute_maps( data_sum_sq = np.sum(data**2) # Select random time points for our initial topographic maps - init_times = random_state.choice( - n_samples, size=n_clusters, replace=False - ) + init_times = random_state.choice(n_samples, size=n_clusters, replace=False) maps = data[:, init_times].T # Normalize the maps maps /= np.linalg.norm(maps, axis=1, keepdims=True) @@ -332,9 +328,7 @@ def _compute_maps( maps[state] /= np.linalg.norm(maps[state]) # Estimate residual noise - act_sum_sq = np.sum( - np.sum(maps[segmentation].T * data, axis=0) ** 2 - ) + act_sum_sq = np.sum(np.sum(maps[segmentation].T * data, axis=0) ** 2) residual = abs(data_sum_sq - act_sum_sq) residual /= float(n_samples * (n_channels - 1)) @@ -433,7 +427,6 @@ def _check_tol(tol: Union[int, float]) -> Union[int, float]: _check_type(tol, ("numeric",), item_name="tol") if tol <= 0: raise ValueError( - "The tolerance must be a positive number. " - f"Provided: '{tol}'." + "The tolerance must be a positive number. " f"Provided: '{tol}'." ) return tol diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 224daa82..cb1e71e3 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -37,12 +37,8 @@ raw_meg.load_data().apply_proj() # epochs events = make_fixed_length_events(raw_meg, duration=1) -epochs_meg = Epochs( - raw_meg, events, tmin=0, tmax=0.5, baseline=None, preload=True -) -epochs_eeg = Epochs( - raw_eeg, events, tmin=0, tmax=0.5, baseline=None, preload=True -) +epochs_meg = Epochs(raw_meg, events, tmin=0, tmax=0.5, baseline=None, preload=True) +epochs_eeg = Epochs(raw_eeg, events, tmin=0, tmax=0.5, baseline=None, preload=True) # ch_data ch_data = ChData(raw_eeg.get_data(), raw_eeg.info) # Fit one for general purposes @@ -70,17 +66,12 @@ sim_n_frames = 250 # number of samples to generate sim_n_chans = pos.shape[0] # number of channels # compute forward model -A = np.sum( - (pos[None, ...] - sources[:, None, :3]) * sources[:, None, 3:], axis=2 -) +A = np.sum((pos[None, ...] - sources[:, None, :3]) * sources[:, None, 3:], axis=2) A /= np.linalg.norm(A, axis=1, keepdims=True) # simulate source actvities for 4 sources # with positive and negative polarity mapping = np.arange(sim_n_frames) % (sim_n_ms * 2) -s = ( - np.sign(mapping - sim_n_ms + 0.01) - * np.eye(sim_n_ms)[:, mapping % sim_n_ms] -) +s = np.sign(mapping - sim_n_ms + 0.01) * np.eye(sim_n_ms)[:, mapping % sim_n_ms] # apply forward model X = A.T @ s # add i.i.d. noise @@ -101,13 +92,9 @@ def test_default_algorithm(): # compute Euclidean distances (using the sign that minimizes the distance) sgn = np.sign(A @ A_hat.T) - dists = np.linalg.norm( - (A_hat[None, ...] - A[:, None] * sgn[..., None]), axis=2 - ) + dists = np.linalg.norm((A_hat[None, ...] - A[:, None] * sgn[..., None]), axis=2) # compute tolerance (2 times the expected noise level) - tol = ( - 2 * sim_sigma / np.sqrt(sim_n_frames / sim_n_ms) * np.sqrt(sim_n_chans) - ) + tol = 2 * sim_sigma / np.sqrt(sim_n_frames / sim_n_ms) * np.sqrt(sim_n_chans) # check if there is a cluster center whose distance # is within the tolerance assert (dists.min(axis=0) < tol).all() @@ -127,12 +114,7 @@ def test_ignore_polarity_false(): # compute Euclidean distances dists = np.linalg.norm((A_hat[None, ...] - A_[:, None]), axis=2) # compute tolerance (2 times the expected noise level) - tol = ( - 2 - * sim_sigma - / np.sqrt(sim_n_frames / sim_n_ms / 2) - * np.sqrt(sim_n_chans) - ) + tol = 2 * sim_sigma / np.sqrt(sim_n_frames / sim_n_ms / 2) * np.sqrt(sim_n_chans) # check if there is a cluster center whose distance # is within the tolerance assert (dists.min(axis=0) < tol).all() @@ -154,13 +136,9 @@ def test_normalize_input_true(): # compute Euclidean distances (using the sign that minimizes the distance) sgn = np.sign(A @ A_hat.T) - dists = np.linalg.norm( - (A_hat[None, ...] - A[:, None] * sgn[..., None]), axis=2 - ) + dists = np.linalg.norm((A_hat[None, ...] - A[:, None] * sgn[..., None]), axis=2) # compute tolerance (2 times the expected noise level) - tol = ( - 2 * sim_sigma / np.sqrt(sim_n_frames / sim_n_ms) * np.sqrt(sim_n_chans) - ) + tol = 2 * sim_sigma / np.sqrt(sim_n_frames / sim_n_ms) * np.sqrt(sim_n_chans) # check if there is a cluster center whose distance # is within the tolerance assert (dists.min(axis=0) < tol).all() @@ -193,9 +171,7 @@ def _check_unfitted(aah_cluster): assert aah_cluster._labels_ is None -def _check_fitted_data_raw( - fitted_data, raw, picks, tmin, tmax, reject_by_annotation -): +def _check_fitted_data_raw(fitted_data, raw, picks, tmin, tmax, reject_by_annotation): """Check the fitted data array for a raw instance.""" # Trust MNE .get_data() to correctly select data picks = _picks_to_idx(raw.info, picks) @@ -267,9 +243,7 @@ def test_aahClusterMeans(): # Test copy aahCluster2 = aahCluster1.copy() _check_fitted(aahCluster2) - assert_allclose( - aahCluster2._cluster_centers_, aahCluster1._cluster_centers_ - ) + assert_allclose(aahCluster2._cluster_centers_, aahCluster1._cluster_centers_) assert np.isclose(aahCluster2.GEV_, aahCluster1.GEV_) assert_allclose(aahCluster2._labels_, aahCluster1._labels_) aahCluster2.fitted = False @@ -278,9 +252,7 @@ def test_aahClusterMeans(): aahCluster3 = aahCluster1.copy(deep=False) _check_fitted(aahCluster3) - assert_allclose( - aahCluster3._cluster_centers_, aahCluster1._cluster_centers_ - ) + assert_allclose(aahCluster3._cluster_centers_, aahCluster1._cluster_centers_) assert np.isclose(aahCluster3.GEV_, aahCluster1.GEV_) assert_allclose(aahCluster3._labels_, aahCluster1._labels_) aahCluster3.fitted = False @@ -314,52 +286,28 @@ def test_invert_polarity(): aahCluster_ = aah_cluster.copy() cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) aahCluster_.invert_polarity([True, False, True, False]) - assert_allclose( - aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :] - ) - assert_allclose( - aahCluster_._cluster_centers_[1, :], cluster_centers_[1, :] - ) - assert_allclose( - aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :] - ) - assert_allclose( - aahCluster_._cluster_centers_[3, :], cluster_centers_[3, :] - ) + assert_allclose(aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :]) + assert_allclose(aahCluster_._cluster_centers_[1, :], cluster_centers_[1, :]) + assert_allclose(aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :]) + assert_allclose(aahCluster_._cluster_centers_[3, :], cluster_centers_[3, :]) # bool aahCluster_ = aah_cluster.copy() cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) aahCluster_.invert_polarity(True) - assert_allclose( - aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :] - ) - assert_allclose( - aahCluster_._cluster_centers_[1, :], -cluster_centers_[1, :] - ) - assert_allclose( - aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :] - ) - assert_allclose( - aahCluster_._cluster_centers_[3, :], -cluster_centers_[3, :] - ) + assert_allclose(aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :]) + assert_allclose(aahCluster_._cluster_centers_[1, :], -cluster_centers_[1, :]) + assert_allclose(aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :]) + assert_allclose(aahCluster_._cluster_centers_[3, :], -cluster_centers_[3, :]) # np.array aahCluster_ = aah_cluster.copy() cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) aahCluster_.invert_polarity(np.array([True, False, True, False])) - assert_allclose( - aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :] - ) - assert_allclose( - aahCluster_._cluster_centers_[1, :], cluster_centers_[1, :] - ) - assert_allclose( - aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :] - ) - assert_allclose( - aahCluster_._cluster_centers_[3, :], cluster_centers_[3, :] - ) + assert_allclose(aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :]) + assert_allclose(aahCluster_._cluster_centers_[1, :], cluster_centers_[1, :]) + assert_allclose(aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :]) + assert_allclose(aahCluster_._cluster_centers_[3, :], cluster_centers_[3, :]) # Test invalid arguments with pytest.raises(ValueError, match="not a 2D iterable"): @@ -384,9 +332,7 @@ def test_rename(caplog): # Test mapping aahCluster_ = aah_cluster.copy() - mapping = { - old: alphabet[k] for k, old in enumerate(aah_cluster._cluster_names) - } + mapping = {old: alphabet[k] for k, old in enumerate(aah_cluster._cluster_names)} for key, value in mapping.items(): assert isinstance(key, str) assert isinstance(value, str) @@ -407,28 +353,20 @@ def test_rename(caplog): aahCluster_.rename_clusters(mapping=101) with pytest.raises(ValueError, match="Invalid value for the 'old name'"): mapping = { - old + "101": alphabet[k] - for k, old in enumerate(aah_cluster._cluster_names) + old + "101": alphabet[k] for k, old in enumerate(aah_cluster._cluster_names) } aahCluster_.rename_clusters(mapping=mapping) with pytest.raises(TypeError, match="'new name' must be an instance of "): mapping = {old: k for k, old in enumerate(aah_cluster._cluster_names)} aahCluster_.rename_clusters(mapping=mapping) - with pytest.raises( - ValueError, match="Argument 'new_names' should contain" - ): + with pytest.raises(ValueError, match="Argument 'new_names' should contain"): aahCluster_.rename_clusters(new_names=alphabet + ["E"]) aahCluster_.rename_clusters() assert "Either 'mapping' or 'new_names' should not be" in caplog.text - with pytest.raises( - ValueError, match="Only one of 'mapping' or 'new_names'" - ): - mapping = { - old: alphabet[k] - for k, old in enumerate(aah_cluster._cluster_names) - } + with pytest.raises(ValueError, match="Only one of 'mapping' or 'new_names'"): + mapping = {old: alphabet[k] for k, old in enumerate(aah_cluster._cluster_names)} aahCluster_.rename_clusters(mapping=mapping, new_names=alphabet) # Test unfitted @@ -436,10 +374,7 @@ def test_rename(caplog): aahCluster_.fitted = False _check_unfitted(aahCluster_) with pytest.raises(RuntimeError, match="must be fitted before"): - mapping = { - old: alphabet[k] - for k, old in enumerate(aah_cluster._cluster_names) - } + mapping = {old: alphabet[k] for k, old in enumerate(aah_cluster._cluster_names)} aahCluster_.rename_clusters(mapping=mapping) with pytest.raises(RuntimeError, match="must be fitted before"): aahCluster_.rename_clusters(new_names=alphabet) @@ -464,23 +399,15 @@ def test_reorder(caplog): # Test order aahCluster_ = aah_cluster.copy() aahCluster_.reorder_clusters(order=[1, 0, 2, 3]) - assert_allclose( - aah_cluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] - ) - assert_allclose( - aah_cluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] - ) + assert_allclose(aah_cluster._cluster_centers_[0], aahCluster_._cluster_centers_[1]) + assert_allclose(aah_cluster._cluster_centers_[1], aahCluster_._cluster_centers_[0]) assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] aahCluster_ = aah_cluster.copy() aahCluster_.reorder_clusters(order=np.array([1, 0, 2, 3])) - assert_allclose( - aah_cluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] - ) - assert_allclose( - aah_cluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] - ) + assert_allclose(aah_cluster._cluster_centers_[0], aahCluster_._cluster_centers_[1]) + assert_allclose(aah_cluster._cluster_centers_[1], aahCluster_._cluster_centers_[0]) assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] @@ -496,21 +423,15 @@ def test_reorder(caplog): aahCluster_ = aah_cluster.copy() with pytest.raises(TypeError, match="'mapping' must be an instance of "): aahCluster_.reorder_clusters(mapping=101) - with pytest.raises( - ValueError, match="Invalid value for the 'old position'" - ): + with pytest.raises(ValueError, match="Invalid value for the 'old position'"): aahCluster_.reorder_clusters(mapping={4: 1}) - with pytest.raises( - ValueError, match="Invalid value for the 'new position'" - ): + with pytest.raises(ValueError, match="Invalid value for the 'new position'"): aahCluster_.reorder_clusters(mapping={0: 4}) with pytest.raises( ValueError, match="Position in the new order can not be repeated." ): aahCluster_.reorder_clusters(mapping={0: 1, 2: 1}) - with pytest.raises( - ValueError, match="A position can not be present in both" - ): + with pytest.raises(ValueError, match="A position can not be present in both"): aahCluster_.reorder_clusters(mapping={0: 1, 1: 2}) with pytest.raises(TypeError, match="'order' must be an instance of "): @@ -521,17 +442,12 @@ def test_reorder(caplog): ValueError, match="Argument 'order' should contain 'n_clusters'" ): aahCluster_.reorder_clusters(order=[0, 3, 1, 2, 0]) - with pytest.raises( - ValueError, match="Argument 'order' should be a 1D iterable" - ): - aahCluster_.reorder_clusters( - order=np.array([[0, 1, 2, 3], [0, 1, 2, 3]]) - ) + with pytest.raises(ValueError, match="Argument 'order' should be a 1D iterable"): + aahCluster_.reorder_clusters(order=np.array([[0, 1, 2, 3], [0, 1, 2, 3]])) aahCluster_.reorder_clusters() assert ( - "Either 'mapping', 'order' or 'template' should not be 'None' " - in caplog.text + "Either 'mapping', 'order' or 'template' should not be 'None' " in caplog.text ) with pytest.raises( @@ -610,9 +526,7 @@ def test_properties(caplog): def test_invalid_arguments(): """Test invalid arguments for init and for fit.""" # n_clusters - with pytest.raises( - TypeError, match="'n_clusters' must be an instance of " - ): + with pytest.raises(TypeError, match="'n_clusters' must be an instance of "): aahCluster_ = AAHCluster(n_clusters="4") with pytest.raises(ValueError, match="The number of clusters must be a"): aahCluster_ = AAHCluster(n_clusters=0) @@ -690,9 +604,7 @@ def test_fit_data_shapes(): tmax=None, reject_by_annotation=False, ) - _check_fitted_data_raw( - aahCluster_._fitted_data, raw_eeg, "eeg", 5, None, None - ) + _check_fitted_data_raw(aahCluster_._fitted_data, raw_eeg, "eeg", 5, None, None) # save for later fitted_data_5_end = deepcopy(aahCluster_._fitted_data) @@ -705,9 +617,7 @@ def test_fit_data_shapes(): tmax=None, reject_by_annotation=False, ) - _check_fitted_data_epochs( - aahCluster_._fitted_data, epochs_eeg, "eeg", 0.2, None - ) + _check_fitted_data_epochs(aahCluster_._fitted_data, epochs_eeg, "eeg", 0.2, None) # tmax aahCluster_.fitted = False @@ -719,9 +629,7 @@ def test_fit_data_shapes(): tmax=5, reject_by_annotation=False, ) - _check_fitted_data_raw( - aahCluster_._fitted_data, raw_eeg, "eeg", None, 5, None - ) + _check_fitted_data_raw(aahCluster_._fitted_data, raw_eeg, "eeg", None, 5, None) # save for later fitted_data_0_5 = deepcopy(aahCluster_._fitted_data) @@ -734,9 +642,7 @@ def test_fit_data_shapes(): tmax=0.3, reject_by_annotation=False, ) - _check_fitted_data_epochs( - aahCluster_._fitted_data, epochs_eeg, "eeg", None, 0.3 - ) + _check_fitted_data_epochs(aahCluster_._fitted_data, epochs_eeg, "eeg", None, 0.3) # tmin, tmax aahCluster_.fitted = False @@ -748,9 +654,7 @@ def test_fit_data_shapes(): tmax=8, reject_by_annotation=False, ) - _check_fitted_data_raw( - aahCluster_._fitted_data, raw_eeg, "eeg", 2, 8, None - ) + _check_fitted_data_raw(aahCluster_._fitted_data, raw_eeg, "eeg", 2, 8, None) aahCluster_.fitted = False _check_unfitted(aahCluster_) @@ -761,9 +665,7 @@ def test_fit_data_shapes(): tmax=0.4, reject_by_annotation=False, ) - _check_fitted_data_epochs( - aahCluster_._fitted_data, epochs_eeg, "eeg", 0.1, 0.4 - ) + _check_fitted_data_epochs(aahCluster_._fitted_data, epochs_eeg, "eeg", 0.1, 0.4) # --------------------- # Reject by annotations @@ -788,9 +690,7 @@ def test_fit_data_shapes(): aahCluster_reject_True._fitted_data, ) assert np.isclose(aahCluster_reject_omit.GEV_, aahCluster_reject_True.GEV_) - assert_allclose( - aahCluster_reject_omit._labels_, aahCluster_reject_True._labels_ - ) + assert_allclose(aahCluster_reject_omit._labels_, aahCluster_reject_True._labels_) # due to internal randomness, the sign can be flipped sgn = np.sign( np.sum( @@ -823,9 +723,7 @@ def test_fit_data_shapes(): aahCluster_rej_0_5 = aahCluster_.copy() aahCluster_rej_0_5.fit(raw_, tmin=0, tmax=5, reject_by_annotation=True) aahCluster_rej_5_end = aahCluster_.copy() - aahCluster_rej_5_end.fit( - raw_, tmin=5, tmax=None, reject_by_annotation=True - ) + aahCluster_rej_5_end.fit(raw_, tmin=5, tmax=None, reject_by_annotation=True) _check_fitted(aahCluster_rej_0_5) _check_fitted(aahCluster_rej_5_end) _check_fitted_data_raw( @@ -835,9 +733,7 @@ def test_fit_data_shapes(): aahCluster_rej_5_end._fitted_data, raw_, "eeg", 5, None, "omit" ) assert aahCluster_rej_0_5._fitted_data.shape != fitted_data_0_5.shape - assert_allclose( - fitted_data_5_end, aahCluster_rej_5_end._fitted_data - ) + assert_allclose(fitted_data_5_end, aahCluster_rej_5_end._fitted_data) def test_refit(): @@ -904,17 +800,13 @@ def test_predict_default(caplog): raw_eeg, factor=0, reject_edges=False, min_segment_length=5 ) assert isinstance(segmentation, RawSegmentation) - segment_lengths = [ - len(list(group)) for _, group in groupby(segmentation._labels) - ] + segment_lengths = [len(list(group)) for _, group in groupby(segmentation._labels)] assert all(5 <= size for size in segment_lengths[1:-1]) assert "Rejecting segments shorter than" in caplog.text caplog.clear() # epochs, no smoothing, no_edge - segmentation = aah_cluster.predict( - epochs_eeg, factor=0, reject_edges=False - ) + segmentation = aah_cluster.predict(epochs_eeg, factor=0, reject_edges=False) assert isinstance(segmentation, EpochsSegmentation) assert "Segmenting data without smoothing" in caplog.text caplog.clear() @@ -943,9 +835,7 @@ def test_predict_default(caplog): ) assert isinstance(segmentation, EpochsSegmentation) for epoch_labels in segmentation._labels: - segment_lengths = [ - len(list(group)) for _, group in groupby(epoch_labels) - ] + segment_lengths = [len(list(group)) for _, group in groupby(epoch_labels)] assert all(5 <= size for size in segment_lengths[1:-1]) assert "Rejecting segments shorter than" in caplog.text caplog.clear() @@ -972,12 +862,8 @@ def test_predict_default(caplog): rtol=1e-7, atol=0, ) - assert_allclose( - segmentation_no_annot._labels, segmentation_rej_False._labels - ) - assert_allclose( - segmentation_rej_None._labels, segmentation_rej_False._labels - ) + assert_allclose(segmentation_no_annot._labels, segmentation_rej_False._labels) + assert_allclose(segmentation_rej_None._labels, segmentation_rej_False._labels) # test different half_window_size segmentation1 = aah_cluster.predict( @@ -1028,9 +914,7 @@ def test_picks_fit_predict(caplog): aahCluster_.fitted = False # create mock raw for fitting - info_ = create_info( - ["Fp1", "Fp2", "CP1", "CP2"], sfreq=1024, ch_types="eeg" - ) + info_ = create_info(["Fp1", "Fp2", "CP1", "CP2"], sfreq=1024, ch_types="eeg") info_.set_montage("standard_1020") data = np.random.randn(4, 1024 * 10) @@ -1079,9 +963,7 @@ def test_picks_fit_predict(caplog): aahCluster_.predict(raw_predict, picks=["CP2", "CP1"]) # Try with one additional channel in the instance used for prediction. - info_ = create_info( - ["Fp1", "Fp2", "Fpz", "CP2", "CP1"], sfreq=1024, ch_types="eeg" - ) + info_ = create_info(["Fp1", "Fp2", "Fpz", "CP2", "CP1"], sfreq=1024, ch_types="eeg") info_.set_montage("standard_1020") data = np.random.randn(5, 1024 * 10) raw_predict = RawArray(data, info_) @@ -1109,9 +991,7 @@ def test_picks_fit_predict(caplog): # try with a missing channel from the prediction instance # fails, because Fp1 is used in aah_cluster.info raw_predict.drop_channels(["Fp1"]) - with pytest.raises( - ValueError, match="Fp1 was used during fitting but is missing" - ): + with pytest.raises(ValueError, match="Fp1 was used during fitting but is missing"): aahCluster_.predict(raw_predict, picks="eeg") # set a bad channel during fitting @@ -1186,17 +1066,11 @@ def test_predict_invalid_arguments(): aah_cluster.predict(epochs_eeg.average()) with pytest.raises(TypeError, match="'factor' must be an instance of "): aah_cluster.predict(raw_eeg, factor="0") - with pytest.raises( - TypeError, match="'reject_edges' must be an instance of " - ): + with pytest.raises(TypeError, match="'reject_edges' must be an instance of "): aah_cluster.predict(raw_eeg, reject_edges=1) - with pytest.raises( - TypeError, match="'half_window_size' must be an instance of " - ): + with pytest.raises(TypeError, match="'half_window_size' must be an instance of "): aah_cluster.predict(raw_eeg, half_window_size="1") - with pytest.raises( - TypeError, match="'min_segment_length' must be an instance of " - ): + with pytest.raises(TypeError, match="'min_segment_length' must be an instance of "): aah_cluster.predict(raw_eeg, min_segment_length="0") with pytest.raises( TypeError, match="'reject_by_annotation' must be an instance of " @@ -1210,9 +1084,7 @@ def test_contains_mixin(): """Test contains mixin class.""" assert "eeg" in aah_cluster assert aah_cluster.compensation_grade is None - assert ( - aah_cluster.get_channel_types() == ["eeg"] * aah_cluster._info["nchan"] - ) + assert aah_cluster.get_channel_types() == ["eeg"] * aah_cluster._info["nchan"] # test raise with non-fitted instance aahCluster_ = AAHCluster( @@ -1327,9 +1199,7 @@ def test_comparison(caplog): assert aahCluster1 != aahCluster2 aahCluster1 = aah_cluster.copy() aahCluster1._info = ChInfo( - ch_names=[ - str(k) for k in range(aahCluster1._cluster_centers_.shape[1]) - ], + ch_names=[str(k) for k in range(aahCluster1._cluster_centers_.shape[1])], ch_types=["eeg"] * aahCluster1._cluster_centers_.shape[1], ) assert aahCluster1 != aahCluster2 diff --git a/pycrostates/cluster/tests/test_kmeans.py b/pycrostates/cluster/tests/test_kmeans.py index d8480dea..db4be0c3 100644 --- a/pycrostates/cluster/tests/test_kmeans.py +++ b/pycrostates/cluster/tests/test_kmeans.py @@ -37,12 +37,8 @@ raw_meg.load_data().apply_proj() # epochs events = make_fixed_length_events(raw_meg, duration=1) -epochs_meg = Epochs( - raw_meg, events, tmin=0, tmax=0.5, baseline=None, preload=True -) -epochs_eeg = Epochs( - raw_eeg, events, tmin=0, tmax=0.5, baseline=None, preload=True -) +epochs_meg = Epochs(raw_meg, events, tmin=0, tmax=0.5, baseline=None, preload=True) +epochs_eeg = Epochs(raw_eeg, events, tmin=0, tmax=0.5, baseline=None, preload=True) # ch_data ch_data = ChData(raw_eeg.get_data(), raw_eeg.info) # Fit one for general purposes @@ -82,9 +78,7 @@ def _check_unfitted(ModK): assert ModK._labels_ is None -def _check_fitted_data_raw( - fitted_data, raw, picks, tmin, tmax, reject_by_annotation -): +def _check_fitted_data_raw(fitted_data, raw, picks, tmin, tmax, reject_by_annotation): """Check the fitted data array for a raw instance.""" # Trust MNE .get_data() to correctly select data picks = _picks_to_idx(raw.info, picks) @@ -203,52 +197,28 @@ def test_invert_polarity(): ModK_ = ModK.copy() cluster_centers_ = deepcopy(ModK_._cluster_centers_) ModK_.invert_polarity([True, False, True, False]) - assert_allclose( - ModK_._cluster_centers_[0, :], -cluster_centers_[0, :] - ) - assert_allclose( - ModK_._cluster_centers_[1, :], cluster_centers_[1, :] - ) - assert_allclose( - ModK_._cluster_centers_[2, :], -cluster_centers_[2, :] - ) - assert_allclose( - ModK_._cluster_centers_[3, :], cluster_centers_[3, :] - ) + assert_allclose(ModK_._cluster_centers_[0, :], -cluster_centers_[0, :]) + assert_allclose(ModK_._cluster_centers_[1, :], cluster_centers_[1, :]) + assert_allclose(ModK_._cluster_centers_[2, :], -cluster_centers_[2, :]) + assert_allclose(ModK_._cluster_centers_[3, :], cluster_centers_[3, :]) # bool ModK_ = ModK.copy() cluster_centers_ = deepcopy(ModK_._cluster_centers_) ModK_.invert_polarity(True) - assert_allclose( - ModK_._cluster_centers_[0, :], -cluster_centers_[0, :] - ) - assert_allclose( - ModK_._cluster_centers_[1, :], -cluster_centers_[1, :] - ) - assert_allclose( - ModK_._cluster_centers_[2, :], -cluster_centers_[2, :] - ) - assert_allclose( - ModK_._cluster_centers_[3, :], -cluster_centers_[3, :] - ) + assert_allclose(ModK_._cluster_centers_[0, :], -cluster_centers_[0, :]) + assert_allclose(ModK_._cluster_centers_[1, :], -cluster_centers_[1, :]) + assert_allclose(ModK_._cluster_centers_[2, :], -cluster_centers_[2, :]) + assert_allclose(ModK_._cluster_centers_[3, :], -cluster_centers_[3, :]) # np.array ModK_ = ModK.copy() cluster_centers_ = deepcopy(ModK_._cluster_centers_) ModK_.invert_polarity(np.array([True, False, True, False])) - assert_allclose( - ModK_._cluster_centers_[0, :], -cluster_centers_[0, :] - ) - assert_allclose( - ModK_._cluster_centers_[1, :], cluster_centers_[1, :] - ) - assert_allclose( - ModK_._cluster_centers_[2, :], -cluster_centers_[2, :] - ) - assert_allclose( - ModK_._cluster_centers_[3, :], cluster_centers_[3, :] - ) + assert_allclose(ModK_._cluster_centers_[0, :], -cluster_centers_[0, :]) + assert_allclose(ModK_._cluster_centers_[1, :], cluster_centers_[1, :]) + assert_allclose(ModK_._cluster_centers_[2, :], -cluster_centers_[2, :]) + assert_allclose(ModK_._cluster_centers_[3, :], cluster_centers_[3, :]) # Test invalid arguments with pytest.raises(ValueError, match="not a 2D iterable"): @@ -294,27 +264,20 @@ def test_rename(caplog): ModK_.rename_clusters(mapping=101) with pytest.raises(ValueError, match="Invalid value for the 'old name'"): mapping = { - old + "101": alphabet[k] - for k, old in enumerate(ModK._cluster_names) + old + "101": alphabet[k] for k, old in enumerate(ModK._cluster_names) } ModK_.rename_clusters(mapping=mapping) with pytest.raises(TypeError, match="'new name' must be an instance of "): mapping = {old: k for k, old in enumerate(ModK._cluster_names)} ModK_.rename_clusters(mapping=mapping) - with pytest.raises( - ValueError, match="Argument 'new_names' should contain" - ): + with pytest.raises(ValueError, match="Argument 'new_names' should contain"): ModK_.rename_clusters(new_names=alphabet + ["E"]) ModK_.rename_clusters() assert "Either 'mapping' or 'new_names' should not be" in caplog.text - with pytest.raises( - ValueError, match="Only one of 'mapping' or 'new_names'" - ): - mapping = { - old: alphabet[k] for k, old in enumerate(ModK._cluster_names) - } + with pytest.raises(ValueError, match="Only one of 'mapping' or 'new_names'"): + mapping = {old: alphabet[k] for k, old in enumerate(ModK._cluster_names)} ModK_.rename_clusters(mapping=mapping, new_names=alphabet) # Test unfitted @@ -322,9 +285,7 @@ def test_rename(caplog): ModK_.fitted = False _check_unfitted(ModK_) with pytest.raises(RuntimeError, match="must be fitted before"): - mapping = { - old: alphabet[k] for k, old in enumerate(ModK._cluster_names) - } + mapping = {old: alphabet[k] for k, old in enumerate(ModK._cluster_names)} ModK_.rename_clusters(mapping=mapping) with pytest.raises(RuntimeError, match="must be fitted before"): ModK_.rename_clusters(new_names=alphabet) @@ -335,35 +296,23 @@ def test_reorder(caplog): # Test mapping ModK_ = ModK.copy() ModK_.reorder_clusters(mapping={0: 1}) - assert_allclose( - ModK._cluster_centers_[0, :], ModK_._cluster_centers_[1, :] - ) - assert_allclose( - ModK._cluster_centers_[1, :], ModK_._cluster_centers_[0, :] - ) + assert_allclose(ModK._cluster_centers_[0, :], ModK_._cluster_centers_[1, :]) + assert_allclose(ModK._cluster_centers_[1, :], ModK_._cluster_centers_[0, :]) assert ModK._cluster_names[0] == ModK_._cluster_names[1] assert ModK._cluster_names[0] == ModK_._cluster_names[1] # Test order ModK_ = ModK.copy() ModK_.reorder_clusters(order=[1, 0, 2, 3]) - assert_allclose( - ModK._cluster_centers_[0], ModK_._cluster_centers_[1] - ) - assert_allclose( - ModK._cluster_centers_[1], ModK_._cluster_centers_[0] - ) + assert_allclose(ModK._cluster_centers_[0], ModK_._cluster_centers_[1]) + assert_allclose(ModK._cluster_centers_[1], ModK_._cluster_centers_[0]) assert ModK._cluster_names[0] == ModK_._cluster_names[1] assert ModK._cluster_names[0] == ModK_._cluster_names[1] ModK_ = ModK.copy() ModK_.reorder_clusters(order=np.array([1, 0, 2, 3])) - assert_allclose( - ModK._cluster_centers_[0], ModK_._cluster_centers_[1] - ) - assert_allclose( - ModK._cluster_centers_[1], ModK_._cluster_centers_[0] - ) + assert_allclose(ModK._cluster_centers_[0], ModK_._cluster_centers_[1]) + assert_allclose(ModK._cluster_centers_[1], ModK_._cluster_centers_[0]) assert ModK._cluster_names[0] == ModK_._cluster_names[1] assert ModK._cluster_names[0] == ModK_._cluster_names[1] @@ -379,21 +328,15 @@ def test_reorder(caplog): ModK_ = ModK.copy() with pytest.raises(TypeError, match="'mapping' must be an instance of "): ModK_.reorder_clusters(mapping=101) - with pytest.raises( - ValueError, match="Invalid value for the 'old position'" - ): + with pytest.raises(ValueError, match="Invalid value for the 'old position'"): ModK_.reorder_clusters(mapping={4: 1}) - with pytest.raises( - ValueError, match="Invalid value for the 'new position'" - ): + with pytest.raises(ValueError, match="Invalid value for the 'new position'"): ModK_.reorder_clusters(mapping={0: 4}) with pytest.raises( ValueError, match="Position in the new order can not be repeated." ): ModK_.reorder_clusters(mapping={0: 1, 2: 1}) - with pytest.raises( - ValueError, match="A position can not be present in both" - ): + with pytest.raises(ValueError, match="A position can not be present in both"): ModK_.reorder_clusters(mapping={0: 1, 1: 2}) with pytest.raises(TypeError, match="'order' must be an instance of "): @@ -404,16 +347,11 @@ def test_reorder(caplog): ValueError, match="Argument 'order' should contain 'n_clusters'" ): ModK_.reorder_clusters(order=[0, 3, 1, 2, 0]) - with pytest.raises( - ValueError, match="Argument 'order' should be a 1D iterable" - ): + with pytest.raises(ValueError, match="Argument 'order' should be a 1D iterable"): ModK_.reorder_clusters(order=np.array([[0, 1, 2, 3], [0, 1, 2, 3]])) ModK_.reorder_clusters() - assert ( - "Either 'mapping', 'order' or 'template' should not be 'None'" - in caplog.text - ) + assert "Either 'mapping', 'order' or 'template' should not be 'None'" in caplog.text with pytest.raises( ValueError, @@ -502,9 +440,7 @@ def test_properties(caplog): def test_invalid_arguments(): """Test invalid arguments for init and for fit.""" # n_clusters - with pytest.raises( - TypeError, match="'n_clusters' must be an instance of " - ): + with pytest.raises(TypeError, match="'n_clusters' must be an instance of "): ModK_ = ModKMeans(n_clusters="4") with pytest.raises(ValueError, match="The number of clusters must be a"): ModK_ = ModKMeans(n_clusters=0) @@ -514,25 +450,17 @@ def test_invalid_arguments(): # n_init with pytest.raises(TypeError, match="'n_init' must be an instance of "): ModK_ = ModKMeans(n_clusters=4, n_init="100") - with pytest.raises( - ValueError, match="The number of initialization must be a" - ): + with pytest.raises(ValueError, match="The number of initialization must be a"): ModK_ = ModKMeans(n_clusters=4, n_init=0) - with pytest.raises( - ValueError, match="The number of initialization must be a" - ): + with pytest.raises(ValueError, match="The number of initialization must be a"): ModK_ = ModKMeans(n_clusters=4, n_init=-101) # max_iter with pytest.raises(TypeError, match="'max_iter' must be an instance of "): ModK_ = ModKMeans(n_clusters=4, max_iter="100") - with pytest.raises( - ValueError, match="The number of max iteration must be a" - ): + with pytest.raises(ValueError, match="The number of max iteration must be a"): ModK_ = ModKMeans(n_clusters=4, max_iter=0) - with pytest.raises( - ValueError, match="The number of max iteration must be a" - ): + with pytest.raises(ValueError, match="The number of max iteration must be a"): ModK_ = ModKMeans(n_clusters=4, max_iter=-101) # tol @@ -699,46 +627,31 @@ def test_fit_data_shapes(): ModK_reject_omit.fit(raw_, n_jobs=1, reject_by_annotation="omit") # Compare 'omit' and True - assert_allclose( - ModK_reject_omit._fitted_data, ModK_reject_True._fitted_data - ) + assert_allclose(ModK_reject_omit._fitted_data, ModK_reject_True._fitted_data) assert np.isclose(ModK_reject_omit.GEV_, ModK_reject_True.GEV_) - assert_allclose( - ModK_reject_omit._labels_, ModK_reject_True._labels_ - ) + assert_allclose(ModK_reject_omit._labels_, ModK_reject_True._labels_) assert_allclose( ModK_reject_omit._cluster_centers_, ModK_reject_True._cluster_centers_ ) # Make sure there is a shape diff between True and False - assert ( - ModK_reject_True._fitted_data.shape - != ModK_no_reject._fitted_data.shape - ) + assert ModK_reject_True._fitted_data.shape != ModK_no_reject._fitted_data.shape # Check fitted data _check_fitted_data_raw( ModK_reject_True._fitted_data, raw_, "eeg", None, None, "omit" ) - _check_fitted_data_raw( - ModK_no_reject._fitted_data, raw_, "eeg", None, None, None - ) + _check_fitted_data_raw(ModK_no_reject._fitted_data, raw_, "eeg", None, None, None) # Check with reject with tmin/tmax ModK_rej_0_5 = ModK_.copy() ModK_rej_0_5.fit(raw_, n_jobs=1, tmin=0, tmax=5, reject_by_annotation=True) ModK_rej_5_end = ModK_.copy() - ModK_rej_5_end.fit( - raw_, n_jobs=1, tmin=5, tmax=None, reject_by_annotation=True - ) + ModK_rej_5_end.fit(raw_, n_jobs=1, tmin=5, tmax=None, reject_by_annotation=True) _check_fitted(ModK_rej_0_5) _check_fitted(ModK_rej_5_end) - _check_fitted_data_raw( - ModK_rej_0_5._fitted_data, raw_, "eeg", None, 5, "omit" - ) - _check_fitted_data_raw( - ModK_rej_5_end._fitted_data, raw_, "eeg", 5, None, "omit" - ) + _check_fitted_data_raw(ModK_rej_0_5._fitted_data, raw_, "eeg", None, 5, "omit") + _check_fitted_data_raw(ModK_rej_5_end._fitted_data, raw_, "eeg", 5, None, "omit") assert ModK_rej_0_5._fitted_data.shape != fitted_data_0_5.shape assert_allclose(fitted_data_5_end, ModK_rej_5_end._fitted_data) @@ -810,9 +723,7 @@ def test_predict_default(caplog): raw_eeg, factor=0, reject_edges=False, min_segment_length=5 ) assert isinstance(segmentation, RawSegmentation) - segment_lengths = [ - len(list(group)) for _, group in groupby(segmentation._labels) - ] + segment_lengths = [len(list(group)) for _, group in groupby(segmentation._labels)] assert all(5 <= size for size in segment_lengths[1:-1]) assert "Rejecting segments shorter than" in caplog.text caplog.clear() @@ -847,9 +758,7 @@ def test_predict_default(caplog): ) assert isinstance(segmentation, EpochsSegmentation) for epoch_labels in segmentation._labels: - segment_lengths = [ - len(list(group)) for _, group in groupby(epoch_labels) - ] + segment_lengths = [len(list(group)) for _, group in groupby(epoch_labels)] assert all(5 <= size for size in segment_lengths[1:-1]) assert "Rejecting segments shorter than" in caplog.text caplog.clear() @@ -876,12 +785,8 @@ def test_predict_default(caplog): rtol=1e-7, atol=0, ) - assert_allclose( - segmentation_no_annot._labels, segmentation_rej_False._labels - ) - assert_allclose( - segmentation_rej_None._labels, segmentation_rej_False._labels - ) + assert_allclose(segmentation_no_annot._labels, segmentation_rej_False._labels) + assert_allclose(segmentation_rej_None._labels, segmentation_rej_False._labels) # test different half_window_size segmentation1 = ModK.predict( @@ -930,9 +835,7 @@ def test_picks_fit_predict(caplog): ModK_.fitted = False # create mock raw for fitting - info_ = create_info( - ["Fp1", "Fp2", "CP1", "CP2"], sfreq=1024, ch_types="eeg" - ) + info_ = create_info(["Fp1", "Fp2", "CP1", "CP2"], sfreq=1024, ch_types="eeg") info_.set_montage("standard_1020") data = np.random.randn(4, 1024 * 10) @@ -979,9 +882,7 @@ def test_picks_fit_predict(caplog): ModK_.predict(raw_predict, picks=["CP2", "CP1"]) # Try with one additional channel in the instance used for prediction. - info_ = create_info( - ["Fp1", "Fp2", "Fpz", "CP2", "CP1"], sfreq=1024, ch_types="eeg" - ) + info_ = create_info(["Fp1", "Fp2", "Fpz", "CP2", "CP1"], sfreq=1024, ch_types="eeg") info_.set_montage("standard_1020") data = np.random.randn(5, 1024 * 10) raw_predict = RawArray(data, info_) @@ -1009,9 +910,7 @@ def test_picks_fit_predict(caplog): # try with a missing channel from the prediction instance # fails, because Fp1 is used in ModK.info raw_predict.drop_channels(["Fp1"]) - with pytest.raises( - ValueError, match="Fp1 was used during fitting but is missing" - ): + with pytest.raises(ValueError, match="Fp1 was used during fitting but is missing"): ModK_.predict(raw_predict, picks="eeg") # set a bad channel during fitting @@ -1082,19 +981,13 @@ def test_predict_invalid_arguments(): ModK.predict(epochs_eeg.average()) with pytest.raises(TypeError, match="'factor' must be an instance of "): ModK.predict(raw_eeg, factor="0") - with pytest.raises( - TypeError, match="'reject_edges' must be an instance of " - ): + with pytest.raises(TypeError, match="'reject_edges' must be an instance of "): ModK.predict(raw_eeg, reject_edges=1) - with pytest.raises( - TypeError, match="'half_window_size' must be an instance of " - ): + with pytest.raises(TypeError, match="'half_window_size' must be an instance of "): ModK.predict(raw_eeg, half_window_size="1") with pytest.raises(TypeError, match="'tol' must be an instance of "): ModK.predict(raw_eeg, tol="0") - with pytest.raises( - TypeError, match="'min_segment_length' must be an instance of " - ): + with pytest.raises(TypeError, match="'min_segment_length' must be an instance of "): ModK.predict(raw_eeg, min_segment_length="0") with pytest.raises( TypeError, match="'reject_by_annotation' must be an instance of " diff --git a/pycrostates/cluster/utils/tests/test_utils.py b/pycrostates/cluster/utils/tests/test_utils.py index 0a10b999..9b0ad0d1 100644 --- a/pycrostates/cluster/utils/tests/test_utils.py +++ b/pycrostates/cluster/utils/tests/test_utils.py @@ -58,9 +58,7 @@ def test__optimize_order(): # Shuffle + sign + ignore_polarity current = random_pol_template ignore_polarity = True - order_ = _optimize_order( - current, template, ignore_polarity=ignore_polarity - ) + order_ = _optimize_order(current, template, ignore_polarity=ignore_polarity) assert np.all(order == order_) # Shuffle + sign diff --git a/pycrostates/cluster/utils/utils.py b/pycrostates/cluster/utils/utils.py index adc04d9f..aad9af75 100644 --- a/pycrostates/cluster/utils/utils.py +++ b/pycrostates/cluster/utils/utils.py @@ -52,8 +52,7 @@ def optimize_order(inst: Cluster, template_inst: Cluster): if inst.n_clusters != template_inst.n_clusters: raise ValueError( - "Instance and the template must have the same " - "number of cluster centers." + "Instance and the template must have the same number of cluster centers." ) if inst._ignore_polarity != template_inst._ignore_polarity: raise ValueError( diff --git a/pycrostates/io/ch_data.py b/pycrostates/io/ch_data.py index 319528ff..317d4d20 100644 --- a/pycrostates/io/ch_data.py +++ b/pycrostates/io/ch_data.py @@ -43,8 +43,7 @@ def __init__(self, data: NDArray[float], info: Union[Info, CHInfo]): ) if not len(info["ch_names"]) == data.shape[0]: raise ValueError( - "Argument 'data' and 'info' do not have the same " - "number of channels." + "Argument 'data' and 'info' do not have the same number of channels." ) self._data = data self._info = info if isinstance(info, ChInfo) else ChInfo(info) @@ -60,12 +59,8 @@ def _repr_html_(self, caption=None): from ..html_templates import repr_templates_env template = repr_templates_env.get_template("ChData.html.jinja") - info_repr = ( - self._info._repr_html_() - ) # pylint: disable=protected-access - return template.render( - n_samples=self._data.shape[-1], info_repr=info_repr - ) + info_repr = self._info._repr_html_() # pylint: disable=protected-access + return template.render(n_samples=self._data.shape[-1], info_repr=info_repr) def __eq__(self, other: Any) -> bool: """Equality == method.""" @@ -157,8 +152,7 @@ def _get_channel_positions(self, picks=None): n_zero = np.sum(np.sum(np.abs(pos), axis=1) == 0) if n_zero > 1: # XXX some systems have origin (0, 0, 0) raise ValueError( - "Could not extract channel positions for " - f"{n_zero} channels." + "Could not extract channel positions for " f"{n_zero} channels." ) return pos diff --git a/pycrostates/io/fiff.py b/pycrostates/io/fiff.py index 5e76e96f..77068300 100644 --- a/pycrostates/io/fiff.py +++ b/pycrostates/io/fiff.py @@ -114,8 +114,7 @@ def _write_cluster( _check_type(cluster_names, (list,), "cluster_names") if len(cluster_names) != cluster_centers_.shape[0]: raise ValueError( - "Argument 'cluster_names' and 'cluster_centers_' shapes do not " - "match." + "Argument 'cluster_names' and 'cluster_centers_' shapes do not match." ) _check_type(fitted_data, (np.ndarray,), "fitted_data") if fitted_data.ndim != 2: @@ -162,9 +161,7 @@ def _write_cluster( fid, FIFF.FIFF_MNE_ICA_INTERFACE_PARAMS, _serialize(fit_parameters) ) # write fit_variables - write_string( - fid, FIFF.FIFF_MNE_ICA_MISC_PARAMS, _serialize(fit_variables) - ) + write_string(fid, FIFF.FIFF_MNE_ICA_MISC_PARAMS, _serialize(fit_variables)) # ------------------------------------------------------------ # close writing block @@ -216,13 +213,13 @@ def _prepare_kwargs(algorithm: str, kwargs: dict): fit_parameters["tol"] = ModKMeans._check_tol(value) elif algorithm == "AAHCluster": if key == "ignore_polarity": - fit_parameters[ - "ignore_polarity" - ] = AAHCluster._check_ignore_polarity(value) + fit_parameters["ignore_polarity"] = AAHCluster._check_ignore_polarity( + value + ) elif key == "normalize_input": - fit_parameters[ - "normalize_input" - ] = AAHCluster._check_normalize_input(value) + fit_parameters["normalize_input"] = AAHCluster._check_normalize_input( + value + ) # pylint: enable=protected-access if key == "GEV_": _check_type(value, ("numeric",), "GEV_") @@ -318,9 +315,7 @@ def _read_cluster(fname: Union[str, Path]): fit_variables, ) if any(elt is None for elt in data): - raise RuntimeError( - "One of the required tag was not found in .fif file." - ) + raise RuntimeError("One of the required tag was not found in .fif file.") algorithm, version = _check_fit_parameters_and_variables( fit_parameters, fit_variables ) @@ -590,9 +585,7 @@ def _read_meas_info(fid, tree): info["custom_ref_applied"] = custom_ref_applied # add coordinate transformation - info["dev_head_t"] = ( - Transform("meg", "head") if dev_head_t is None else dev_head_t - ) + info["dev_head_t"] = Transform("meg", "head") if dev_head_t is None else dev_head_t info["ctf_head_t"] = ctf_head_t info["dev_ctf_t"] = dev_ctf_t if dev_head_t is not None and ctf_head_t is not None and dev_ctf_t is None: diff --git a/pycrostates/io/meas_info.py b/pycrostates/io/meas_info.py index 3d23f36c..6cc105d9 100644 --- a/pycrostates/io/meas_info.py +++ b/pycrostates/io/meas_info.py @@ -206,15 +206,11 @@ def __init__( raise RuntimeError( "Either 'info' or 'ch_names' and 'ch_types' must not be None." ) - if info is None and all( - arg is not None for arg in (ch_names, ch_types) - ): + if info is None and all(arg is not None for arg in (ch_names, ch_types)): _check_type(ch_names, (None, "int", list, tuple), "ch_names") _check_type(ch_types, (None, str, list, tuple), "ch_types") self._init_from_channels(ch_names, ch_types) - elif info is not None and all( - arg is None for arg in (ch_names, ch_types) - ): + elif info is not None and all(arg is None for arg in (ch_names, ch_types)): _check_type(info, (None, Info), "info") self._init_from_info(info) else: @@ -287,8 +283,7 @@ def _init_from_channels( _check_type(ch_type, (str,)) if ch_type not in ch_types_dict: raise KeyError( - f"kind must be one of {list(ch_types_dict)}, not " - f"{ch_type}." + f"kind must be one of {list(ch_types_dict)}, not " f"{ch_type}." ) this_ch_dict = ch_types_dict[ch_type] kind = this_ch_dict["kind"] @@ -333,9 +328,7 @@ def __getattribute__(self, name): # invalid attributes _inv_attributes = () # invalid methods/properties - _inv_methods = ( - "pick_channels" # TODO: Can be removed when req. for MNE = 1.1.0 - ) + _inv_methods = "pick_channels" # TODO: Can be removed when req. for MNE = 1.1.0 if name in _inv_attributes or name in _inv_methods: raise AttributeError( f"'{self.__class__.__name__}' has not attribute '{name}'" diff --git a/pycrostates/io/reader.py b/pycrostates/io/reader.py index 90f95436..034aa45f 100644 --- a/pycrostates/io/reader.py +++ b/pycrostates/io/reader.py @@ -33,9 +33,7 @@ def read_cluster(fname: Union[str, Path]): ext = "".join(fname.suffixes) if ext in readers: cluster, version = readers[ext](fname) - logger.info( - "Cluster solution loaded was saved with pycrostates '%s'.", version - ) + logger.info("Cluster solution loaded was saved with pycrostates '%s'.", version) return cluster else: raise ValueError("File format is not supported.") diff --git a/pycrostates/io/tests/test_ch_data.py b/pycrostates/io/tests/test_ch_data.py index 5bc03217..6d485240 100644 --- a/pycrostates/io/tests/test_ch_data.py +++ b/pycrostates/io/tests/test_ch_data.py @@ -9,15 +9,9 @@ times = np.linspace(0, 5, 2000) signals = np.array([np.sin(2 * np.pi * k * times) for k in (7, 22, 37)]) coeffs = np.random.rand(6, 3) -data = np.dot(coeffs, signals) + np.random.normal( - 0, 0.1, (coeffs.shape[0], times.size) -) -info = create_info( - ["Fpz", "Cz", "CPz", "Oz", "M1", "M2"], sfreq=400, ch_types="eeg" -) -ch_info = ChInfo( - ch_names=["Fpz", "Cz", "CPz", "Oz", "M1", "M2"], ch_types="eeg" -) +data = np.dot(coeffs, signals) + np.random.normal(0, 0.1, (coeffs.shape[0], times.size)) +info = create_info(["Fpz", "Cz", "CPz", "Oz", "M1", "M2"], sfreq=400, ch_types="eeg") +ch_info = ChInfo(ch_names=["Fpz", "Cz", "CPz", "Oz", "M1", "M2"], ch_types="eeg") ch_info_types = ChInfo( ch_names=["Fpz", "Cz", "MEG01", "STIM01", "GRAD01", "EOG"], ch_types=["eeg", "eeg", "mag", "stim", "grad", "eog"], @@ -103,13 +97,9 @@ def test_ChData_picks(picks, exclude, ch_names): def test_ChData_invalid_arguments(): """Test error raised when invalid arguments are provided to ChData.""" - with pytest.raises( - TypeError, match="'data' must be an instance of ndarray" - ): + with pytest.raises(TypeError, match="'data' must be an instance of ndarray"): ChData(list(data[0, :]), create_info(1, 400, "eeg")) - with pytest.raises( - TypeError, match="'info' must be an instance of Info or ChInfo" - ): + with pytest.raises(TypeError, match="'info' must be an instance of Info or ChInfo"): ChData(data, 101) with pytest.raises(ValueError, match="'data' should be a 2D array"): ChData(data.reshape(6, 5, 400), ch_info) diff --git a/pycrostates/io/tests/test_fiff.py b/pycrostates/io/tests/test_fiff.py index 8dcad705..c443a8d7 100644 --- a/pycrostates/io/tests/test_fiff.py +++ b/pycrostates/io/tests/test_fiff.py @@ -23,9 +23,7 @@ raw = read_raw_fif(fname, preload=True) raw2 = raw.copy().crop(10, None).apply_proj() raw = raw.crop(0, 10).apply_proj() -ModK = ModKMeans( - n_clusters=4, n_init=10, max_iter=100, tol=1e-4, random_state=1 -) +ModK = ModKMeans(n_clusters=4, n_init=10, max_iter=100, tol=1e-4, random_state=1) ModK.fit(raw, picks="eeg", n_jobs=1) @@ -112,9 +110,7 @@ def test_invalid_write(tmp_path): GEV_=ModK._GEV_, ) - with pytest.raises( - TypeError, match="'cluster_centers_' must be an instance of" - ): + with pytest.raises(TypeError, match="'cluster_centers_' must be an instance of"): _write_cluster( tmp_path / "cluster.fif", list(ModK._cluster_centers_), @@ -129,9 +125,7 @@ def test_invalid_write(tmp_path): GEV_=ModK._GEV_, ) - with pytest.raises( - ValueError, match="Argument 'cluster_centers_' should be a 2D" - ): + with pytest.raises(ValueError, match="Argument 'cluster_centers_' should be a 2D"): _write_cluster( tmp_path / "cluster.fif", ModK._cluster_centers_.flatten(), @@ -176,9 +170,7 @@ def test_invalid_write(tmp_path): GEV_=ModK._GEV_, ) - with pytest.raises( - ValueError, match="Invalid value for the 'algorithm' parameter" - ): + with pytest.raises(ValueError, match="Invalid value for the 'algorithm' parameter"): _write_cluster( tmp_path / "cluster.fif", ModK._cluster_centers_, @@ -193,9 +185,7 @@ def test_invalid_write(tmp_path): GEV_=ModK._GEV_, ) - with pytest.raises( - TypeError, match="'cluster_names' must be an instance of" - ): + with pytest.raises(TypeError, match="'cluster_names' must be an instance of"): _write_cluster( tmp_path / "cluster.fif", ModK._cluster_centers_, @@ -227,9 +217,7 @@ def test_invalid_write(tmp_path): GEV_=ModK._GEV_, ) - with pytest.raises( - TypeError, match="'fitted_data' must be an instance of" - ): + with pytest.raises(TypeError, match="'fitted_data' must be an instance of"): _write_cluster( tmp_path / "cluster.fif", ModK._cluster_centers_, @@ -244,9 +232,7 @@ def test_invalid_write(tmp_path): GEV_=ModK._GEV_, ) - with pytest.raises( - ValueError, match="Argument 'fitted_data' should be a 2D" - ): + with pytest.raises(ValueError, match="Argument 'fitted_data' should be a 2D"): _write_cluster( tmp_path / "cluster.fif", ModK._cluster_centers_, @@ -276,9 +262,7 @@ def test_invalid_write(tmp_path): GEV_=ModK._GEV_, ) - with pytest.raises( - ValueError, match="Argument 'labels_' should be a 1D array." - ): + with pytest.raises(ValueError, match="Argument 'labels_' should be a 1D array."): _write_cluster( tmp_path / "cluster.fif", ModK._cluster_centers_, @@ -313,31 +297,21 @@ def test_prepare_kwargs(): def test_prepare_kwargs_ModKMeans(): """Test invalid key/values for ModKMeans.""" - kwargs = dict( - n_init=-101, max_iter=ModK._max_iter, tol=ModK._tol, GEV_=ModK._GEV_ - ) - with pytest.raises( - ValueError, match="initialization must be a positive integer" - ): + kwargs = dict(n_init=-101, max_iter=ModK._max_iter, tol=ModK._tol, GEV_=ModK._GEV_) + with pytest.raises(ValueError, match="initialization must be a positive integer"): _prepare_kwargs("ModKMeans", kwargs) - kwargs = dict( - n_init=ModK._n_init, max_iter=-101, tol=ModK._tol, GEV_=ModK._GEV_ - ) + kwargs = dict(n_init=ModK._n_init, max_iter=-101, tol=ModK._tol, GEV_=ModK._GEV_) with pytest.raises(ValueError, match="max iteration must be a positive"): _prepare_kwargs("ModKMeans", kwargs) kwargs = dict( n_init=ModK._n_init, max_iter=ModK.max_iter, tol=-101, GEV_=ModK._GEV_ ) - with pytest.raises( - ValueError, match="tolerance must be a positive number" - ): + with pytest.raises(ValueError, match="tolerance must be a positive number"): _prepare_kwargs("ModKMeans", kwargs) - kwargs = dict( - n_init=ModK.n_init, max_iter=ModK.max_iter, tol=ModK.tol, GEV_=101 - ) + kwargs = dict(n_init=ModK.n_init, max_iter=ModK.max_iter, tol=ModK.tol, GEV_=101) with pytest.raises( ValueError, match="'GEV_' should be a percentage between 0 and 1" ): @@ -350,9 +324,7 @@ def test_invalid_read(tmp_path): _read_cluster(101) fname = directory / "sample_audvis_trunc_raw.fif" - with pytest.raises( - RuntimeError, match="Could not find clustering solution data." - ): + with pytest.raises(RuntimeError, match="Could not find clustering solution data."): _read_cluster(fname) # save an ICA @@ -360,7 +332,5 @@ def test_invalid_read(tmp_path): ica.fit(raw, picks="eeg") ica.save(tmp_path / "decomposition-ica.fif") # try loading the ICA - with pytest.raises( - RuntimeError, match="Could not find clustering solution data." - ): + with pytest.raises(RuntimeError, match="Could not find clustering solution data."): _read_cluster(fname) diff --git a/pycrostates/io/tests/test_meas_info.py b/pycrostates/io/tests/test_meas_info.py index ba7422c9..23723b5e 100644 --- a/pycrostates/io/tests/test_meas_info.py +++ b/pycrostates/io/tests/test_meas_info.py @@ -121,7 +121,7 @@ def test_create_from_info_invalid_arguments(): with pytest.raises(RuntimeError, match="If 'info' is provided"): ChInfo(info, ch_types=ch_types) with pytest.raises( - TypeError, match="'info' must be an instance of None " "or Info" + TypeError, match="'info' must be an instance of None or Info" ): ChInfo(info=ch_names) @@ -221,8 +221,7 @@ def test_create_without_arguments(): """Test error raised if both arguments are None.""" with pytest.raises( RuntimeError, - match="Either 'info' or 'ch_names' and " - "'ch_types' must not be None.", + match="Either 'info' or 'ch_names' and 'ch_types' must not be None.", ): ChInfo() @@ -262,13 +261,9 @@ def test_montage(): if montage.get_positions()[key] is None: assert montage2.get_positions()[key] is None elif isinstance(montage.get_positions()[key], str): - assert ( - montage.get_positions()[key] == montage2.get_positions()[key] - ) + assert montage.get_positions()[key] == montage2.get_positions()[key] elif isinstance(montage.get_positions()[key], np.ndarray): - assert_allclose( - montage.get_positions()[key], montage2.get_positions()[key] - ) + assert_allclose(montage.get_positions()[key], montage2.get_positions()[key]) elif isinstance(montage.get_positions()[key], OrderedDict): for k, v in montage.get_positions()[key].items(): assert_allclose( diff --git a/pycrostates/io/tests/test_reader.py b/pycrostates/io/tests/test_reader.py index 7453fc38..5b896aa9 100644 --- a/pycrostates/io/tests/test_reader.py +++ b/pycrostates/io/tests/test_reader.py @@ -18,9 +18,7 @@ raw = read_raw_fif(fname, preload=False) raw.crop(0, 10).pick("eeg") raw.load_data().apply_proj() -ModK = ModKMeans( - n_clusters=4, n_init=10, max_iter=100, tol=1e-4, random_state=1 -) +ModK = ModKMeans(n_clusters=4, n_init=10, max_iter=100, tol=1e-4, random_state=1) ModK.fit(raw, picks="eeg") diff --git a/pycrostates/metrics/calinski_harabasz.py b/pycrostates/metrics/calinski_harabasz.py index 24e90996..1e88a84b 100644 --- a/pycrostates/metrics/calinski_harabasz.py +++ b/pycrostates/metrics/calinski_harabasz.py @@ -1,9 +1,7 @@ """Calinski Harabasz score.""" import numpy as np -from sklearn.metrics import ( - calinski_harabasz_score as sk_calinski_harabasz_score, -) +from sklearn.metrics import calinski_harabasz_score as sk_calinski_harabasz_score from ..cluster._base import _BaseCluster from ..utils._checks import _check_type diff --git a/pycrostates/metrics/dunn.py b/pycrostates/metrics/dunn.py index 9223a1e6..25b156f5 100644 --- a/pycrostates/metrics/dunn.py +++ b/pycrostates/metrics/dunn.py @@ -67,9 +67,7 @@ def _dunn_score(X, labels): # higher the better for j, ks_j in enumerate(ks): if i == j: continue # skip diagonal - deltas[i, j] = _delta_fast( - (labels == ks_i), (labels == ks_j), distances - ) + deltas[i, j] = _delta_fast((labels == ks_i), (labels == ks_j), distances) big_deltas[i] = _big_delta_fast((labels == ks_i), distances) di = np.min(deltas) / np.max(big_deltas) diff --git a/pycrostates/preprocessing/extract_gfp_peaks.py b/pycrostates/preprocessing/extract_gfp_peaks.py index 7dc2787e..cdcc1ded 100644 --- a/pycrostates/preprocessing/extract_gfp_peaks.py +++ b/pycrostates/preprocessing/extract_gfp_peaks.py @@ -87,15 +87,11 @@ def extract_gfp_peaks( ) tmin, tmax = _check_tmin_tmax(inst, tmin, tmax) if isinstance(inst, BaseRaw): - reject_by_annotation = _check_reject_by_annotation( - reject_by_annotation - ) + reject_by_annotation = _check_reject_by_annotation(reject_by_annotation) # retrieve picks picks = _picks_to_idx(inst.info, picks, none="all", exclude="bads") - picks_all = _picks_to_idx( - inst.info, inst.ch_names, none="all", exclude="bads" - ) + picks_all = _picks_to_idx(inst.info, inst.ch_names, none="all", exclude="bads") _check_picks_uniqueness(inst.info, picks) # set kwargs for .get_data() @@ -130,8 +126,7 @@ def extract_gfp_peaks( if isinstance(inst, BaseEpochs): n_samples *= len(inst) logger.info( - "%s GFP peaks extracted out of %s samples (%.2f%% of the original " - "data).", + "%s GFP peaks extracted out of %s samples (%.2f%% of the original data).", peaks.shape[1], n_samples, peaks.shape[1] / n_samples * 100, diff --git a/pycrostates/preprocessing/resample.py b/pycrostates/preprocessing/resample.py index 091fb33a..6eb08d3a 100644 --- a/pycrostates/preprocessing/resample.py +++ b/pycrostates/preprocessing/resample.py @@ -78,9 +78,7 @@ def resample( if isinstance(inst, (BaseRaw, BaseEpochs)): tmin, tmax = _check_tmin_tmax(inst, tmin, tmax) if isinstance(inst, BaseRaw): - reject_by_annotation = _check_reject_by_annotation( - reject_by_annotation - ) + reject_by_annotation = _check_reject_by_annotation(reject_by_annotation) _check_type(n_resamples, (None, "int"), "n_resamples") _check_type(n_samples, (None, "int"), "n_samples") _check_type(coverage, (None, "numeric"), "coverage") diff --git a/pycrostates/preprocessing/spatial_filter.py b/pycrostates/preprocessing/spatial_filter.py index ef74cdbc..c810769a 100644 --- a/pycrostates/preprocessing/spatial_filter.py +++ b/pycrostates/preprocessing/spatial_filter.py @@ -42,17 +42,13 @@ def _check_adjacency(adjacency, info, ch_type): "Adjacency must have exactly 2 dimensions but got " f"{adjacency.ndim} dimensions instead." ) - if (adjacency.shape[0] != n_channels) or ( - adjacency.shape[1] != n_channels - ): + if (adjacency.shape[0] != n_channels) or (adjacency.shape[1] != n_channels): raise ValueError( "Adjacency must be of shape (n_channels, n_channels) " f"but got {adjacency.shape} instead." ) if not np.array_equal(adjacency, adjacency.astype(bool)): - raise ValueError( - "Values contained in adjacency can only be 0 or 1." - ) + raise ValueError("Values contained in adjacency can only be 0 or 1.") return (adjacency, ch_names) @@ -207,9 +203,7 @@ def _channel_spatial_filter(index, data, adjacency_vector, interpolate_matrix): print(index) return data[index] # neighbor_matrix shape (n_neighbor, n_samples) - neighbor_matrix = np.array( - [neighbor_indices.flatten().tolist()] * data.shape[-1] - ).T + neighbor_matrix = np.array([neighbor_indices.flatten().tolist()] * data.shape[-1]).T # Create a mask max_mask = neighbors_data == np.amax(neighbors_data, keepdims=True, axis=0) @@ -217,10 +211,7 @@ def _channel_spatial_filter(index, data, adjacency_vector, interpolate_matrix): keep_mask = ~(max_mask | min_mask) keep_indices = np.array( - [ - neighbor_matrix[:, i][keep_mask[:, i]] - for i in range(keep_mask.shape[-1]) - ] + [neighbor_matrix[:, i][keep_mask[:, i]] for i in range(keep_mask.shape[-1])] ) channel_data = data[index] for i, keep_ind in enumerate(keep_indices): diff --git a/pycrostates/preprocessing/tests/test_extract_gfp_peaks.py b/pycrostates/preprocessing/tests/test_extract_gfp_peaks.py index 67a0942a..8979a36c 100644 --- a/pycrostates/preprocessing/tests/test_extract_gfp_peaks.py +++ b/pycrostates/preprocessing/tests/test_extract_gfp_peaks.py @@ -56,11 +56,7 @@ def test_extract_gfp_invalid_arguments(inst): """Test errors raised when invalid arguments are provided.""" with pytest.raises(TypeError, match="'inst' must be an instance of "): extract_gfp_peaks(101) - with pytest.raises( - TypeError, match="'min_peak_distance' must be an instance" - ): + with pytest.raises(TypeError, match="'min_peak_distance' must be an instance"): extract_gfp_peaks(inst, min_peak_distance=True) - with pytest.raises( - ValueError, match="Argument 'min_peak_distance' must be" - ): + with pytest.raises(ValueError, match="Argument 'min_peak_distance' must be"): extract_gfp_peaks(inst, min_peak_distance=-2) diff --git a/pycrostates/preprocessing/tests/test_resample.py b/pycrostates/preprocessing/tests/test_resample.py index 2592a6da..2fb49798 100644 --- a/pycrostates/preprocessing/tests/test_resample.py +++ b/pycrostates/preprocessing/tests/test_resample.py @@ -59,9 +59,7 @@ def test_resample_n_resamples_n_samples(inst, replace, n_resamples, n_samples): ) def test_resample_n_resamples_coverage(inst, replace, n_resamples, cov): """Test resampling with n_resamples and coverage provided.""" - resamples = resample( - inst, n_resamples=n_resamples, coverage=cov, replace=replace - ) + resamples = resample(inst, n_resamples=n_resamples, coverage=cov, replace=replace) n_ch = _picks_to_idx(inst.info, None, exclude="bads").size n_data = inst.times.size if isinstance(inst, BaseEpochs): @@ -88,9 +86,7 @@ def test_resample_n_resamples_coverage(inst, replace, n_resamples, cov): ) def test_resample_n_samples_coverage(inst, replace, n_samples, cov): """Test resampling with n_samples and coverage provided.""" - resamples = resample( - inst, n_samples=n_samples, coverage=cov, replace=replace - ) + resamples = resample(inst, n_samples=n_samples, coverage=cov, replace=replace) n_ch = _picks_to_idx(inst.info, None, exclude="bads").size assert isinstance(resamples, list) n_data = inst.times.size @@ -114,21 +110,15 @@ def test_n_resamples_n_samples_coverage_errors(): """Test error raised by wrong combination of n_resamples, n_samples and coverage.""" # n_resamples is not None - with pytest.raises( - ValueError, match="'n_resamples' must be a strictly positive" - ): + with pytest.raises(ValueError, match="'n_resamples' must be a strictly positive"): resample(raw, n_resamples=-1, n_samples=50, replace=False) with pytest.raises(ValueError, match="'n_resamples', at least one of"): resample(raw, n_resamples=10, replace=False) with pytest.raises(ValueError, match="'n_resamples', only one of"): - resample( - raw, n_resamples=10, n_samples=50, coverage=0.2, replace=False - ) + resample(raw, n_resamples=10, n_samples=50, coverage=0.2, replace=False) with pytest.raises(ValueError, match="'coverage' must respect"): resample(raw, n_resamples=10, coverage=1.2, replace=False) - with pytest.raises( - ValueError, match="'n_samples' must be a strictly positive" - ): + with pytest.raises(ValueError, match="'n_samples' must be a strictly positive"): resample(raw, n_resamples=10, n_samples=-10, replace=False) # n_resamples is None @@ -136,9 +126,7 @@ def test_n_resamples_n_samples_coverage_errors(): resample(raw, n_samples=50, replace=False) with pytest.raises(ValueError, match="'n_resamples' is None, both"): resample(raw, coverage=0.2, replace=False) - with pytest.raises( - ValueError, match="'n_samples' must be a strictly positive" - ): + with pytest.raises(ValueError, match="'n_samples' must be a strictly positive"): resample(raw, n_samples=-10, coverage=0.2, replace=False) with pytest.raises(ValueError, match="'coverage' must respect"): resample(raw, n_samples=10, coverage=1.2, replace=False) @@ -146,10 +134,6 @@ def test_n_resamples_n_samples_coverage_errors(): def test_resample_random_state(): """Test resampling with n_samples and coverage provided.""" - resamples_0 = resample(raw, n_resamples=1, n_samples=500, random_state=42)[ - 0 - ] - resamples_1 = resample(raw, n_resamples=1, n_samples=500, random_state=42)[ - 0 - ] + resamples_0 = resample(raw, n_resamples=1, n_samples=500, random_state=42)[0] + resamples_1 = resample(raw, n_resamples=1, n_samples=500, random_state=42)[0] assert_allclose(resamples_0._data, resamples_1._data) diff --git a/pycrostates/preprocessing/tests/test_spatial_filter.py b/pycrostates/preprocessing/tests/test_spatial_filter.py index ea1fde69..200b67cd 100644 --- a/pycrostates/preprocessing/tests/test_spatial_filter.py +++ b/pycrostates/preprocessing/tests/test_spatial_filter.py @@ -115,24 +115,16 @@ def test_spatial_filter_eeg_and_meg(): picks_non_eeg = [ idx for idx in np.arange(len(raw_all.ch_names)) if idx not in picks_eeg ] - assert not np.all( - new_inst._data[picks_eeg, :] == raw_all._data[picks_eeg, :] - ) - assert np.all( - new_inst._data[picks_non_eeg, :] == raw_all._data[picks_non_eeg, :] - ) + assert not np.all(new_inst._data[picks_eeg, :] == raw_all._data[picks_eeg, :]) + assert np.all(new_inst._data[picks_non_eeg, :] == raw_all._data[picks_non_eeg, :]) def test_spatial_filter_custom_adjacency(): """Test apply_spatial_filter with custom adjacency.""" adjacency_matrix, ch_names = find_ch_adjacency(raw_all.info, "eeg") apply_spatial_filter(raw_all.copy(), "eeg", adjacency=adjacency_matrix) - with pytest.raises( - ValueError, match="Adjacency must have exactly 2 dimensions" - ): - apply_spatial_filter( - raw_all.copy(), "eeg", adjacency=np.ones((len(ch_names))) - ) + with pytest.raises(ValueError, match="Adjacency must have exactly 2 dimensions"): + apply_spatial_filter(raw_all.copy(), "eeg", adjacency=np.ones((len(ch_names)))) with pytest.raises(ValueError, match="Adjacency must be of shape"): apply_spatial_filter( raw_all.copy(), "eeg", adjacency=adjacency_matrix[:-2, :-2] diff --git a/pycrostates/segmentation/_base.py b/pycrostates/segmentation/_base.py index 092ff9ae..ea2d1c2d 100644 --- a/pycrostates/segmentation/_base.py +++ b/pycrostates/segmentation/_base.py @@ -15,10 +15,7 @@ from ..utils._docs import fill_doc from ..utils._logs import logger from ..viz import plot_cluster_centers -from .transitions import ( - _compute_expected_transition_matrix, - _compute_transition_matrix, -) +from .transitions import _compute_expected_transition_matrix, _compute_transition_matrix @fill_doc @@ -72,13 +69,9 @@ def __repr__(self) -> str: return s def _repr_html_(self, caption=None): - from ..html_templates import ( # pylint: disable=C0415 - repr_templates_env, - ) + from ..html_templates import repr_templates_env # pylint: disable=C0415 - template = repr_templates_env.get_template( - "BaseSegmentation.html.jinja" - ) + template = repr_templates_env.get_template("BaseSegmentation.html.jinja") return template.render( name=self.__class__.__name__, n_clusters=len(self._cluster_centers_), @@ -86,9 +79,7 @@ def _repr_html_(self, caption=None): inst_repr=self._inst._repr_html_(), ) - def compute_parameters( - self, norm_gfp: bool = True, return_dist: bool = False - ): + def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False): """Compute microstate parameters. Parameters @@ -184,12 +175,8 @@ def compute_parameters( dist_gev = (labeled_gfp * dist_corr) ** 2 / np.sum(gfp**2) params[f"{state_name}_gev"] = np.sum(dist_gev) - s_segments = np.array( - [len(group) for s_, group in segments if s_ == s] - ) - occurrences = ( - len(s_segments) / len(np.where(labels != -1)[0]) * sfreq - ) + s_segments = np.array([len(group) for s_, group in segments if s_ == s]) + occurrences = len(s_segments) / len(np.where(labels != -1)[0]) * sfreq params[f"{state_name}_occurrences"] = occurrences timecov = np.sum(s_segments) / len(np.where(labels != -1)[0]) @@ -211,15 +198,9 @@ def compute_parameters( params[f"{state_name}_occurrences"] = 0.0 if return_dist: - params[f"{state_name}_dist_corr"] = np.array( - [], dtype=float - ) - params[f"{state_name}_dist_gev"] = np.array( - [], dtype=float - ) - params[f"{state_name}_dist_durs"] = np.array( - [], dtype=float - ) + params[f"{state_name}_dist_corr"] = np.array([], dtype=float) + params[f"{state_name}_dist_gev"] = np.array([], dtype=float) + params[f"{state_name}_dist_durs"] = np.array([], dtype=float) params["unlabeled"] = len(np.argwhere(labels == -1)) / len(gfp) return params @@ -256,9 +237,7 @@ def compute_transition_matrix(self, stat="probability", ignore_self=True): ) @fill_doc - def compute_expected_transition_matrix( - self, stat="probability", ignore_self=True - ): + def compute_expected_transition_matrix(self, stat="probability", ignore_self=True): """Compute the expected transition matrix. Compute the theoretical transition matrix as if time course was diff --git a/pycrostates/segmentation/tests/test_segmentation.py b/pycrostates/segmentation/tests/test_segmentation.py index 95c9a37a..d5703cf9 100644 --- a/pycrostates/segmentation/tests/test_segmentation.py +++ b/pycrostates/segmentation/tests/test_segmentation.py @@ -23,12 +23,8 @@ events = make_fixed_length_events(raw, 1) epochs = Epochs(raw, events, preload=True) -ModK_raw = ModKMeans( - n_clusters=4, n_init=10, max_iter=100, tol=1e-4, random_state=1 -) -ModK_epochs = ModKMeans( - n_clusters=4, n_init=10, max_iter=100, tol=1e-4, random_state=1 -) +ModK_raw = ModKMeans(n_clusters=4, n_init=10, max_iter=100, tol=1e-4, random_state=1) +ModK_epochs = ModKMeans(n_clusters=4, n_init=10, max_iter=100, tol=1e-4, random_state=1) ModK_raw.fit(raw, n_jobs=1) ModK_epochs.fit(epochs, n_jobs=1) @@ -217,9 +213,7 @@ def test_plot_segmentation(ModK, inst): def test_invalid_segmentation(Segmentation, inst, bad_inst, caplog): """Test that we can not create an invalid segmentation.""" labels = np.zeros((inst.times.size)) - cluster_centers = np.zeros( - (4, len(inst.ch_names) - len(inst.info["bads"])) - ) + cluster_centers = np.zeros((4, len(inst.ch_names) - len(inst.info["bads"]))) cluster_names = ["a", "b", "c", "d"] # types @@ -237,25 +231,17 @@ def test_invalid_segmentation(Segmentation, inst, bad_inst, caplog): Segmentation(labels, inst, cluster_centers, cluster_names, []) # values - with pytest.raises( - ValueError, match="number of cluster centers and cluster names" - ): + with pytest.raises(ValueError, match="number of cluster centers and cluster names"): Segmentation(labels, inst, cluster_centers, cluster_names[:2], None) with pytest.raises(ValueError, match="should be a 2D array"): - Segmentation( - labels, inst, cluster_centers.flatten(), cluster_names, None - ) + Segmentation(labels, inst, cluster_centers.flatten(), cluster_names, None) # raw/epochs specific with pytest.raises(ValueError, match="and labels do not have"): if isinstance(inst, BaseRaw): - Segmentation( - labels[:10], inst, cluster_centers, cluster_names, None - ) + Segmentation(labels[:10], inst, cluster_centers, cluster_names, None) if isinstance(inst, BaseEpochs): - Segmentation( - np.zeros((2, 10)), inst, cluster_centers, cluster_names, None - ) + Segmentation(np.zeros((2, 10)), inst, cluster_centers, cluster_names, None) with pytest.raises(ValueError, match="'labels' should be"): if isinstance(inst, BaseRaw): @@ -272,9 +258,7 @@ def test_invalid_segmentation(Segmentation, inst, bad_inst, caplog): # unsupported predict_parameters caplog.clear() if isinstance(inst, BaseRaw): - Segmentation( - labels, inst, cluster_centers, cluster_names, dict(test=101) - ) + Segmentation(labels, inst, cluster_centers, cluster_names, dict(test=101)) if isinstance(inst, BaseEpochs): Segmentation( np.zeros((len(inst), inst.times.size)), @@ -298,14 +282,10 @@ def test_compute_transition_matrix_Epochs(): segmentation.compute_transition_matrix(ignore_self=False) -@pytest.mark.parametrize( - "ModK, inst", [(ModK_raw, raw), (ModK_epochs, epochs)] -) +@pytest.mark.parametrize("ModK, inst", [(ModK_raw, raw), (ModK_epochs, epochs)]) def test_compute_transition_matrix_stat(ModK, inst): segmentation = ModK.predict(inst) - with pytest.raises( - ValueError, match="Invalid value for the 'stat' parameter" - ): + with pytest.raises(ValueError, match="Invalid value for the 'stat' parameter"): segmentation.compute_transition_matrix(stat="wrong") T = segmentation.compute_transition_matrix(stat="count") T = segmentation.compute_transition_matrix(stat="probability") @@ -328,18 +308,12 @@ def test_compute_expected_transition_matrix_Epochs(): segmentation.compute_expected_transition_matrix(ignore_self=False) -@pytest.mark.parametrize( - "ModK, inst", [(ModK_raw, raw), (ModK_epochs, epochs)] -) +@pytest.mark.parametrize("ModK, inst", [(ModK_raw, raw), (ModK_epochs, epochs)]) def test_compute_expected_transition_matrix_stat(ModK, inst): segmentation = ModK.predict(inst) - with pytest.raises( - ValueError, match="Invalid value for the 'stat' parameter" - ): + with pytest.raises(ValueError, match="Invalid value for the 'stat' parameter"): segmentation.compute_expected_transition_matrix(stat="wrong") - with pytest.raises( - ValueError, match="Invalid value for the 'stat' parameter" - ): + with pytest.raises(ValueError, match="Invalid value for the 'stat' parameter"): segmentation.compute_expected_transition_matrix(stat="count") T = segmentation.compute_expected_transition_matrix(stat="probability") assert_allclose(np.sum(T, axis=1), 1) diff --git a/pycrostates/segmentation/tests/test_transitions.py b/pycrostates/segmentation/tests/test_transitions.py index db71e1c2..5180f6e8 100644 --- a/pycrostates/segmentation/tests/test_transitions.py +++ b/pycrostates/segmentation/tests/test_transitions.py @@ -74,9 +74,7 @@ ) def test_compute_transition_matrix(labels, ignore_self, T): n_clusters = ( - np.unique(labels).size - 1 - if np.any(labels == -1) - else np.unique(labels).size + np.unique(labels).size - 1 if np.any(labels == -1) else np.unique(labels).size ) t = _compute_transition_matrix( labels, n_clusters=n_clusters, ignore_self=ignore_self @@ -137,15 +135,11 @@ def test_check_labels_n_clusters(): # invalids with pytest.raises(ValueError, match="'-101' is invalid."): _check_labels_n_clusters(np.random.randint(-1, 5, size=100), -101) - with pytest.raises( - ValueError, match="Negative integers except -1 are invalid." - ): + with pytest.raises(ValueError, match="Negative integers except -1 are invalid."): _check_labels_n_clusters(np.random.randint(-2, 5, size=100), 5) with pytest.raises(ValueError, match=re.escape("'[4]' is invalid.")): _check_labels_n_clusters(np.random.randint(1, 5, size=100), 4) with pytest.raises(ValueError, match="'float64' is invalid."): - _check_labels_n_clusters( - np.random.randint(0, 5, size=100).astype(float), 5 - ) + _check_labels_n_clusters(np.random.randint(0, 5, size=100).astype(float), 5) with pytest.raises(ValueError, match=re.escape("'[6 7]' are invalid")): _check_labels_n_clusters(np.random.randint(0, 8, size=100), 6) diff --git a/pycrostates/segmentation/transitions.py b/pycrostates/segmentation/transitions.py index 91ac5921..2ef8de85 100644 --- a/pycrostates/segmentation/transitions.py +++ b/pycrostates/segmentation/transitions.py @@ -48,9 +48,7 @@ def _compute_transition_matrix( ) -> NDArray[float]: """Compute observed transition.""" # common error checking - _check_value( - stat, ("count", "probability", "proportion", "percent"), "stat" - ) + _check_value(stat, ("count", "probability", "proportion", "percent"), "stat") _check_type(ignore_self, (bool,), "ignore_self") # reshape if epochs (returns a view) diff --git a/pycrostates/utils/__init__.py b/pycrostates/utils/__init__.py index dd181c66..9ab02015 100644 --- a/pycrostates/utils/__init__.py +++ b/pycrostates/utils/__init__.py @@ -1,10 +1,6 @@ """Utils module for utilities.""" from ._config import get_config -from .utils import ( # noqa: F401 - _compare_infos, - _corr_vectors, - _distance_matrix, -) +from .utils import _compare_infos, _corr_vectors, _distance_matrix # noqa: F401 __all__ = ("get_config",) diff --git a/pycrostates/utils/_checks.py b/pycrostates/utils/_checks.py index b42d1f2c..35c0fca7 100644 --- a/pycrostates/utils/_checks.py +++ b/pycrostates/utils/_checks.py @@ -43,9 +43,7 @@ def _ensure_int(item, item_name=None): item = int(operator.index(item)) except TypeError: item_name = "Item" if item_name is None else "'%s'" % item_name - raise TypeError( - "%s must be an int, got %s instead." % (item_name, type(item)) - ) + raise TypeError("%s must be an int, got %s instead." % (item_name, type(item))) return item @@ -174,9 +172,7 @@ def _check_value(item, allowed_values, item_name=None, extra=None): options += ", ".join([f"{repr(v)}" for v in allowed_values[:-1]]) options += f", and {repr(allowed_values[-1])}" raise ValueError( - msg.format( - item_name=item_name, extra=extra, options=options, item=item - ) + msg.format(item_name=item_name, extra=extra, options=options, item=item) ) return item diff --git a/pycrostates/utils/_docs.py b/pycrostates/utils/_docs.py index 59d5cfd3..62b7b9ad 100644 --- a/pycrostates/utils/_docs.py +++ b/pycrostates/utils/_docs.py @@ -25,9 +25,7 @@ for key in keys: entry = docdict_mne[key] if ".. versionchanged::" in entry: - entry = entry.replace( - ".. versionchanged::", ".. versionchanged:: MNE " - ) + entry = entry.replace(".. versionchanged::", ".. versionchanged:: MNE ") if ".. versionadded::" in entry: entry = entry.replace(".. versionadded::", ".. versionadded:: MNE ") docdict[key] = entry diff --git a/pycrostates/utils/_imports.py b/pycrostates/utils/_imports.py index 13ff072f..2a7a7e50 100644 --- a/pycrostates/utils/_imports.py +++ b/pycrostates/utils/_imports.py @@ -12,9 +12,7 @@ INSTALL_MAPPING = {} -def import_optional_dependency( - name: str, extra: str = "", raise_error: bool = True -): +def import_optional_dependency(name: str, extra: str = "", raise_error: bool = True): """ Import an optional dependency. diff --git a/pycrostates/utils/_logs.py b/pycrostates/utils/_logs.py index d0b3682b..6b74f670 100644 --- a/pycrostates/utils/_logs.py +++ b/pycrostates/utils/_logs.py @@ -11,9 +11,7 @@ @fill_doc -def _init_logger( - *, verbose: Optional[Union[bool, str, int]] = None -) -> logging.Logger: +def _init_logger(*, verbose: Optional[Union[bool, str, int]] = None) -> logging.Logger: """Initialize a logger. Assigns sys.stdout as the first handler of the logger. diff --git a/pycrostates/utils/tests/test_checks.py b/pycrostates/utils/tests/test_checks.py index b6c66877..c557e30c 100644 --- a/pycrostates/utils/tests/test_checks.py +++ b/pycrostates/utils/tests/test_checks.py @@ -72,9 +72,7 @@ def test_check_value(): # invalids with pytest.raises(ValueError, match="Invalid value for the parameter."): _check_value(5, [1, 2, 3, 4]) - with pytest.raises( - ValueError, match="Invalid value for the 'number' parameter." - ): + with pytest.raises(ValueError, match="Invalid value for the 'number' parameter."): _check_value(5, [1, 2, 3, 4], "number") @@ -110,9 +108,7 @@ def test_check_random_state(): rs = _check_random_state(rng) assert isinstance(rs, Generator) - with pytest.raises( - ValueError, match=re.escape("[101] cannot be used to seed") - ): + with pytest.raises(ValueError, match=re.escape("[101] cannot be used to seed")): _check_random_state([101]) @@ -154,9 +150,7 @@ def test_check_reject_by_annotation(): TypeError, match="'reject_by_annotation' must be an instance of" ): _check_reject_by_annotation(1) - with pytest.raises( - ValueError, match="'reject_by_annotation' only allows for" - ): + with pytest.raises(ValueError, match="'reject_by_annotation' only allows for"): _check_reject_by_annotation("101") @@ -182,13 +176,9 @@ def test_check_tmin_tmax(): _check_tmin_tmax(epochs, 0, 0.5) # test invalid tmin/tmax - with pytest.raises( - ValueError, match="Argument 'tmax' must be shorter than" - ): + with pytest.raises(ValueError, match="Argument 'tmax' must be shorter than"): _check_tmin_tmax(raw, 1, 6) - with pytest.raises( - ValueError, match="Argument 'tmax' must be shorter than" - ): + with pytest.raises(ValueError, match="Argument 'tmax' must be shorter than"): _check_tmin_tmax(epochs, 1, 6) with pytest.raises(ValueError, match="Argument 'tmin' must be positive"): _check_tmin_tmax(raw, -1, 4) @@ -202,13 +192,9 @@ def test_check_tmin_tmax(): ValueError, match="Argument 'tmax' must be strictly larger than 'tmin'" ): _check_tmin_tmax(epochs, 0.3, 0.1) - with pytest.raises( - ValueError, match="Argument 'tmin' must be shorter than" - ): + with pytest.raises(ValueError, match="Argument 'tmin' must be shorter than"): _check_tmin_tmax(raw, 6, None) - with pytest.raises( - ValueError, match="Argument 'tmin' must be shorter than" - ): + with pytest.raises(ValueError, match="Argument 'tmin' must be shorter than"): _check_tmin_tmax(epochs, 2, None) diff --git a/pycrostates/utils/tests/test_logs.py b/pycrostates/utils/tests/test_logs.py index ce72df92..4a0d061f 100644 --- a/pycrostates/utils/tests/test_logs.py +++ b/pycrostates/utils/tests/test_logs.py @@ -5,12 +5,7 @@ import pytest -from pycrostates.utils._logs import ( - add_file_handler, - logger, - set_log_level, - verbose, -) +from pycrostates.utils._logs import add_file_handler, logger, set_log_level, verbose logger.propagate = True @@ -40,9 +35,7 @@ def test_default_log_level(caplog): assert "101" in caplog.text -@pytest.mark.parametrize( - "level", ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") -) +@pytest.mark.parametrize("level", ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL")) def test_logger(level, caplog): """Test basic logger functionalities.""" level_functions = { diff --git a/pycrostates/utils/tests/test_mixin.py b/pycrostates/utils/tests/test_mixin.py index b9685c38..fb1e5022 100644 --- a/pycrostates/utils/tests/test_mixin.py +++ b/pycrostates/utils/tests/test_mixin.py @@ -38,17 +38,11 @@ def test_contains_mixin(): # test with info equal to None foo = Foo(None) - with pytest.raises( - ValueError, match="Instance 'Foo' attribute 'info' is None." - ): + with pytest.raises(ValueError, match="Instance 'Foo' attribute 'info' is None."): "eeg" in foo - with pytest.raises( - ValueError, match="Instance 'Foo' attribute 'info' is None." - ): + with pytest.raises(ValueError, match="Instance 'Foo' attribute 'info' is None."): foo.get_channel_types() - with pytest.raises( - ValueError, match="Instance 'Foo' attribute 'info' is None." - ): + with pytest.raises(ValueError, match="Instance 'Foo' attribute 'info' is None."): foo.compensation_grade # test without attribute info @@ -82,13 +76,9 @@ def test_montage_mixin(): # test with info equal to None foo = Foo(None) - with pytest.raises( - ValueError, match="Instance 'Foo' attribute 'info' is None." - ): + with pytest.raises(ValueError, match="Instance 'Foo' attribute 'info' is None."): foo.set_montage("standard_1020") - with pytest.raises( - ValueError, match="Instance 'Foo' attribute 'info' is None." - ): + with pytest.raises(ValueError, match="Instance 'Foo' attribute 'info' is None."): foo.get_montage() # test without attribute info diff --git a/pycrostates/utils/utils.py b/pycrostates/utils/utils.py index 2cb4f34f..6225801c 100644 --- a/pycrostates/utils/utils.py +++ b/pycrostates/utils/utils.py @@ -107,12 +107,8 @@ def _compare_infos(cluster_info, inst_info): inst_units.append((ch["ch_name"], ch["unit"])) inst_coord_frames.append((ch["ch_name"], ch["coord_frame"])) - cluster_kinds = [ - elt[1] for elt in sorted(cluster_kinds, key=lambda x: x[0]) - ] - cluster_units = [ - elt[1] for elt in sorted(cluster_units, key=lambda x: x[0]) - ] + cluster_kinds = [elt[1] for elt in sorted(cluster_kinds, key=lambda x: x[0])] + cluster_units = [elt[1] for elt in sorted(cluster_units, key=lambda x: x[0])] cluster_coord_frame = [ elt[1] for elt in sorted(cluster_coord_frame, key=lambda x: x[0]) ] @@ -122,23 +118,17 @@ def _compare_infos(cluster_info, inst_info): elt[1] for elt in sorted(inst_coord_frames, key=lambda x: x[0]) ] - if not all( - kind1 == kind2 for kind1, kind2 in zip(cluster_kinds, inst_kinds) - ): + if not all(kind1 == kind2 for kind1, kind2 in zip(cluster_kinds, inst_kinds)): logger.warning( "Instance to segment into microstates sequence does not have " "the same channels kinds as the instance used for fitting. " ) - if not all( - unit1 == unit2 for unit1, unit2 in zip(cluster_units, inst_units) - ): + if not all(unit1 == unit2 for unit1, unit2 in zip(cluster_units, inst_units)): logger.warning( "Instance to segment into microstates sequence does not have " "the same channels units as the instance used for fitting. " ) - if not all( - f1 == f2 for f1, f2 in zip(cluster_coord_frame, inst_coord_frames) - ): + if not all(f1 == f2 for f1, f2 in zip(cluster_coord_frame, inst_coord_frames)): logger.warning( "Instance to segment into microstates sequence does not have " "the same coordinate frames as the instance used for fitting. " diff --git a/pycrostates/viz/cluster_centers.py b/pycrostates/viz/cluster_centers.py index 3db4fbbd..59128379 100644 --- a/pycrostates/viz/cluster_centers.py +++ b/pycrostates/viz/cluster_centers.py @@ -15,7 +15,6 @@ from ..utils._docs import fill_doc from ..utils._logs import logger, verbose - _GRADIENT_KWARGS_DEFAULTS: Dict[str, str] = { "color": "black", "linestyle": "-", @@ -86,9 +85,7 @@ def plot_cluster_centers( # check cluster_names if cluster_names is None: - cluster_names = [ - str(k) for k in range(1, cluster_centers.shape[0] + 1) - ] + cluster_names = [str(k) for k in range(1, cluster_centers.shape[0] + 1)] if len(cluster_names) != cluster_centers.shape[0]: raise ValueError( "Argument 'cluster_centers' and 'cluster_names' should have the " diff --git a/pycrostates/viz/segmentation.py b/pycrostates/viz/segmentation.py index 6197bb91..80c8b733 100644 --- a/pycrostates/viz/segmentation.py +++ b/pycrostates/viz/segmentation.py @@ -78,8 +78,7 @@ def plot_raw_segmentation( # make sure shapes are correct if data.shape[1] != labels.size: raise ValueError( - "Argument 'labels' and 'raw' do not have the same number of " - "samples." + "Argument 'labels' and 'raw' do not have the same number of samples." ) fig, axes, show = _plot_segmentation( @@ -155,8 +154,7 @@ def plot_epoch_segmentation( # make sure shapes are correct if data.shape[1] != labels.size: raise ValueError( - "Argument 'labels' and 'epochs' do not have the same number of " - "samples." + "Argument 'labels' and 'epochs' do not have the same number of samples." ) fig, axes, show = _plot_segmentation( diff --git a/pycrostates/viz/tests/test_cluster_centers.py b/pycrostates/viz/tests/test_cluster_centers.py index 979f0874..b44293e4 100644 --- a/pycrostates/viz/tests/test_cluster_centers.py +++ b/pycrostates/viz/tests/test_cluster_centers.py @@ -74,9 +74,7 @@ def test_plot_cluster_centers(caplog): ): plot_cluster_centers(cluster_centers, info=chinfo, cluster_names=["A"]) f, ax = plt.subplots(1, 1) - with pytest.raises( - ValueError, match="Argument 'cluster_centers' and 'axes' must " - ): + with pytest.raises(ValueError, match="Argument 'cluster_centers' and 'axes' must "): plot_cluster_centers(cluster_centers, info=chinfo, axes=ax) plt.close("all") diff --git a/pycrostates/viz/tests/test_segmentation.py b/pycrostates/viz/tests/test_segmentation.py index 87fb5341..afd0810e 100644 --- a/pycrostates/viz/tests/test_segmentation.py +++ b/pycrostates/viz/tests/test_segmentation.py @@ -33,9 +33,7 @@ def test_plot_raw_segmentation(): # provide ax and cbar_ax f, axes = plt.subplots(1, 2) - plot_raw_segmentation( - labels, raw, n_clusters, axes=axes[0], cbar_axes=axes[1] - ) + plot_raw_segmentation(labels, raw, n_clusters, axes=axes[0], cbar_axes=axes[1]) plt.close("all") # provide cmap @@ -46,9 +44,7 @@ def test_plot_raw_segmentation(): def test_plot_epoch_segmentation(): """Test segmentation plots for epochs.""" n_clusters = 4 - labels = np.random.choice( - [-1, 0, 1, 2, 3], (len(epochs), epochs.times.size) - ) + labels = np.random.choice([-1, 0, 1, 2, 3], (len(epochs), epochs.times.size)) plot_epoch_segmentation(labels, epochs, n_clusters) plt.close("all") @@ -65,9 +61,7 @@ def test_plot_epoch_segmentation(): # provide ax and cbar_ax f, axes = plt.subplots(1, 2) - plot_epoch_segmentation( - labels, epochs, n_clusters, axes=axes[0], cbar_axes=axes[1] - ) + plot_epoch_segmentation(labels, epochs, n_clusters, axes=axes[0], cbar_axes=axes[1]) plt.close("all") # provide cmap From df5bf5b50a3c84b0a5c8958c287b1a5679bfd6c1 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 11:10:21 +0200 Subject: [PATCH 07/16] bump MNE req. to 1.1.0 and above --- pycrostates/io/meas_info.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pycrostates/io/meas_info.py b/pycrostates/io/meas_info.py index 6cc105d9..89bb0d44 100644 --- a/pycrostates/io/meas_info.py +++ b/pycrostates/io/meas_info.py @@ -328,7 +328,7 @@ def __getattribute__(self, name): # invalid attributes _inv_attributes = () # invalid methods/properties - _inv_methods = "pick_channels" # TODO: Can be removed when req. for MNE = 1.1.0 + _inv_methods = () if name in _inv_attributes or name in _inv_methods: raise AttributeError( f"'{self.__class__.__name__}' has not attribute '{name}'" diff --git a/pyproject.toml b/pyproject.toml index 6bea1255..533127c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ classifiers = [ dependencies = [ 'numpy>=1.21', 'scipy', - 'mne>=1.0.0', + 'mne>=1.1.0', 'joblib', 'matplotlib', 'scikit-learn', From 90cd979f6c54c4650e6d6f9ff81aab801dee844b Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 11:12:23 +0200 Subject: [PATCH 08/16] fix missed black --- pycrostates/io/tests/test_meas_info.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pycrostates/io/tests/test_meas_info.py b/pycrostates/io/tests/test_meas_info.py index 23723b5e..a8cbe2fe 100644 --- a/pycrostates/io/tests/test_meas_info.py +++ b/pycrostates/io/tests/test_meas_info.py @@ -120,9 +120,7 @@ def test_create_from_info_invalid_arguments(): ChInfo(info, ch_names=ch_names) with pytest.raises(RuntimeError, match="If 'info' is provided"): ChInfo(info, ch_types=ch_types) - with pytest.raises( - TypeError, match="'info' must be an instance of None or Info" - ): + with pytest.raises(TypeError, match="'info' must be an instance of None or Info"): ChInfo(info=ch_names) From ddca9f40f948a8c236e6373ba44114b2d7e34146 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 11:33:51 +0200 Subject: [PATCH 09/16] remove test on pick_channels --- pycrostates/io/tests/test_meas_info.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/pycrostates/io/tests/test_meas_info.py b/pycrostates/io/tests/test_meas_info.py index a8cbe2fe..f055271e 100644 --- a/pycrostates/io/tests/test_meas_info.py +++ b/pycrostates/io/tests/test_meas_info.py @@ -333,17 +333,6 @@ def test_setting_invalid_keys(): chinfo._check_consistency() -def test_invalid_attributes(): - """Test that attribute error is raised when calling invalid attributes or - methods.""" - info = create_info(ch_names=3, sfreq=1, ch_types="eeg") - chinfo = ChInfo(info=info) - with pytest.raises( - AttributeError, match="'ChInfo' has not attribute 'pick_channels'" - ): - chinfo.pick_channels(["1"]) - - def test_comparison(caplog): """Test == and != methods.""" # simple info without montage From bb64976638e640ee8b8556fa13195d31bfa36f0b Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 13:33:12 +0200 Subject: [PATCH 10/16] clean-up following black --- pycrostates/cluster/kmeans.py | 2 +- pycrostates/io/ch_data.py | 2 +- pycrostates/io/meas_info.py | 2 +- pycrostates/utils/_checks.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index e6d0ac00..22451131 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -427,6 +427,6 @@ def _check_tol(tol: Union[int, float]) -> Union[int, float]: _check_type(tol, ("numeric",), item_name="tol") if tol <= 0: raise ValueError( - "The tolerance must be a positive number. " f"Provided: '{tol}'." + "The tolerance must be a positive number. Provided: '{tol}'." ) return tol diff --git a/pycrostates/io/ch_data.py b/pycrostates/io/ch_data.py index 317d4d20..3a3fa715 100644 --- a/pycrostates/io/ch_data.py +++ b/pycrostates/io/ch_data.py @@ -152,7 +152,7 @@ def _get_channel_positions(self, picks=None): n_zero = np.sum(np.sum(np.abs(pos), axis=1) == 0) if n_zero > 1: # XXX some systems have origin (0, 0, 0) raise ValueError( - "Could not extract channel positions for " f"{n_zero} channels." + f"Could not extract channel positions for {n_zero} channels." ) return pos diff --git a/pycrostates/io/meas_info.py b/pycrostates/io/meas_info.py index 89bb0d44..639fdf18 100644 --- a/pycrostates/io/meas_info.py +++ b/pycrostates/io/meas_info.py @@ -283,7 +283,7 @@ def _init_from_channels( _check_type(ch_type, (str,)) if ch_type not in ch_types_dict: raise KeyError( - f"kind must be one of {list(ch_types_dict)}, not " f"{ch_type}." + f"kind must be one of {list(ch_types_dict)}, not {ch_type}." ) this_ch_dict = ch_types_dict[ch_type] kind = this_ch_dict["kind"] diff --git a/pycrostates/utils/_checks.py b/pycrostates/utils/_checks.py index 35c0fca7..80a77322 100644 --- a/pycrostates/utils/_checks.py +++ b/pycrostates/utils/_checks.py @@ -259,7 +259,7 @@ def _check_tmin_tmax(inst, tmin, tmax): continue if arg < 0: raise ValueError( - f"Argument '{name}' must be positive. " f"Provided '{arg}'." + f"Argument '{name}' must be positive. Provided '{arg}'." ) # check tmax is shorter than instance if tmax is not None and inst.times[-1] < tmax: From 12e4f1c58b984da1db0efe96c02410e6a5bd8fa5 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 3 Jul 2023 14:50:07 +0200 Subject: [PATCH 11/16] run black again.. --- pycrostates/utils/_checks.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pycrostates/utils/_checks.py b/pycrostates/utils/_checks.py index 80a77322..06705532 100644 --- a/pycrostates/utils/_checks.py +++ b/pycrostates/utils/_checks.py @@ -258,9 +258,7 @@ def _check_tmin_tmax(inst, tmin, tmax): if arg is None: continue if arg < 0: - raise ValueError( - f"Argument '{name}' must be positive. Provided '{arg}'." - ) + raise ValueError(f"Argument '{name}' must be positive. Provided '{arg}'.") # check tmax is shorter than instance if tmax is not None and inst.times[-1] < tmax: raise ValueError( From d18d07b94c505de79b9a7d9cb51c3c4f51fb3415 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Tue, 25 Jul 2023 10:00:31 +0200 Subject: [PATCH 12/16] trigger cis From 084f15ea967af9a0f037760a114bdc853e51bbd1 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 16 Aug 2023 10:19:38 +0200 Subject: [PATCH 13/16] trigger cis with MNE 1.5 From aa4d672f7adcf30ecbccc103406822261a5527b8 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 16 Aug 2023 13:03:53 +0200 Subject: [PATCH 14/16] shorten strings through 88 character length --- pycrostates/cluster/_base.py | 222 ++++++++---------- pycrostates/cluster/aahc.py | 6 +- pycrostates/cluster/kmeans.py | 32 ++- pycrostates/cluster/utils/utils.py | 12 +- pycrostates/datasets/lemon/lemon.py | 32 ++- pycrostates/io/ch_data.py | 17 +- pycrostates/io/meas_info.py | 137 +++++------ pycrostates/metrics/calinski_harabasz.py | 7 +- pycrostates/metrics/davies_bouldin.py | 13 +- pycrostates/metrics/silhouette.py | 6 +- .../preprocessing/extract_gfp_peaks.py | 45 ++-- pycrostates/preprocessing/resample.py | 16 +- pycrostates/preprocessing/spatial_filter.py | 65 +++-- pycrostates/segmentation/_base.py | 58 +++-- pycrostates/segmentation/segmentation.py | 13 +- pycrostates/segmentation/transitions.py | 16 +- pycrostates/utils/_checks.py | 13 +- pycrostates/utils/_docs.py | 53 ++--- pycrostates/utils/_fixes.py | 4 +- pycrostates/utils/_imports.py | 15 +- pycrostates/utils/utils.py | 11 +- pycrostates/viz/cluster_centers.py | 17 +- pycrostates/viz/segmentation.py | 10 +- 23 files changed, 366 insertions(+), 454 deletions(-) diff --git a/pycrostates/cluster/_base.py b/pycrostates/cluster/_base.py index bfd44f90..1e1cdf7a 100644 --- a/pycrostates/cluster/_base.py +++ b/pycrostates/cluster/_base.py @@ -88,8 +88,7 @@ def __eq__(self, other: Any) -> bool: # check fit if self._fitted + other._fitted == 0: # Both False raise RuntimeError( - "Clustering algorithms must be fitted before using '==' " - "comparison." + "Clustering algorithms must be fitted before using '==' comparison." ) if self._fitted + other._fitted == 1: # One False return False @@ -130,8 +129,7 @@ def __eq__(self, other: Any) -> bool: if self._cluster_names != other._cluster_names: logger.warning( "Cluster names differ between both clustering solution. " - "Consider using '.rename_clusters' to change the cluster " - "names." + "Consider using '.rename_clusters' to change the cluster names." ) return True @@ -205,13 +203,12 @@ def fit( MNE `~mne.io.Raw`, `~mne.Epochs` or `~pycrostates.io.ChData` object from which to extract :term:`cluster centers`. picks : str | list | slice | None - Channels to include. Note that all channels selected must have the - same type. Slices and lists of integers will be interpreted as - channel indices. In lists, channel name strings (e.g. - ``['Fp1', 'Fp2']``) will pick the given channels. Can also be the - string values “all” to pick all channels, or “data” to pick data - channels. ``"eeg"`` (default) will pick all eeg channels. - Note that channels in ``info['bads']`` will be included if their + Channels to include. Note that all channels selected must have the same + type. Slices and lists of integers will be interpreted as channel indices. + In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the given + channels. Can also be the string values ``“all”`` to pick all channels, or + ``“data”`` to pick data channels. ``"eeg"`` (default) will pick all eeg + channels. Note that channels in ``info['bads']`` will be included if their names or indices are explicitly provided. %(tmin_raw)s %(tmax_raw)s @@ -235,15 +232,14 @@ def fit( if len(ch_not_used) != 0: if len(ch_not_used) == 1: msg = ( - "The channel %s is set as bad and ignored. To include " - "it, either remove it from inst.info['bads'] or " - "provide it explicitly in the 'picks' argument." + "The channel %s is set as bad and ignored. To include it, either " + "remove it from inst.info['bads'] or provide it explicitly in the " + "'picks' argument." ) else: msg = ( - "The channels %s are set as bads and ignored. To " - "include them, either remove them from " - "inst.info['bads'] or provide them " + "The channels %s are set as bads and ignored. To include them, " + "either remove them from inst.info['bads'] or provide them " "explicitly in the 'picks' argument." ) logger.warning( @@ -289,11 +285,11 @@ def rename_clusters( Parameters ---------- mapping : dict - Mapping from the old names to the new names. The keys are the old - names and the values are the new names. + Mapping from the old names to the new names. The keys are the old names and + the values are the new names. new_names : list | tuple - 1D iterable containing the new cluster names. The length of the - iterable should match the number of clusters. + 1D iterable containing the new cluster names. The length of the iterable + should match the number of clusters. Notes ----- @@ -316,8 +312,7 @@ def rename_clusters( if len(new_names) != self._n_clusters: raise ValueError( "Argument 'new_names' should contain 'n_clusters': " - f"{self._n_clusters} elements. " - f"Provided '{len(new_names)}'." + f"{self._n_clusters} elements. Provided '{len(new_names)}'." ) # sanity-check @@ -330,8 +325,8 @@ def rename_clusters( else: logger.warning( - "Either 'mapping' or 'new_names' should not be 'None' " - "for method 'rename_clusters' to operate." + "Either 'mapping' or 'new_names' should not be 'None' for method " + "'rename_clusters' to operate." ) return @@ -356,11 +351,10 @@ def reorder_clusters( Specify one of the following arguments to change the current order: - * ``mapping``: a dictionary that maps old cluster positions - to new positions, + * ``mapping``: a dictionary that maps old cluster positions to new positions, * ``order``: a 1D iterable containing the new order, - * ``template``: a fitted clustering algorithm used as a reference - to match the order. + * ``template``: a fitted clustering algorithm used as a reference to match the + order. Only one argument can be set at a time. @@ -370,14 +364,11 @@ def reorder_clusters( Mapping from the old order to the new order. key: old position, value: new position. order : list of int | tuple of int | array of int - 1D iterable containing the new order. - Positions are 0-indexed. + 1D iterable containing the new order. Positions are 0-indexed. template : :ref:`cluster` - Fitted clustering algorithm use as template for - ordering optimization. For more details about the - current implementation, check the - :func:`pycrostates.cluster.utils.optimize_order` - documentation. + Fitted clustering algorithm use as template for ordering optimization. For + more details about the current implementation, check the + :func:`pycrostates.cluster.utils.optimize_order` documentation. Notes ----- @@ -408,9 +399,8 @@ def reorder_clusters( for key in mapping: if key in mapping.values(): raise ValueError( - "A position can not be present in both the old and " - f"new order. Position '{key}' is mapped to " - f"'{mapping[key]}' and position " + "A position can not be present in both the old and new order. " + f"Position '{key}' is mapped to '{mapping[key]}' and position " f"'{inverse_mapping[key]}' is mapped to '{key}'." ) @@ -446,8 +436,8 @@ def reorder_clusters( else: logger.warning( - "Either 'mapping', 'order' or 'template' should not be 'None' " - "for method 'reorder_clusters' to operate." + "Either 'mapping', 'order' or 'template' should not be 'None' for " + "method 'reorder_clusters' to operate." ) return @@ -475,16 +465,16 @@ def invert_polarity( invert : bool | list of bool | array of bool List of bool of length ``n_clusters``. True will invert map polarity, while False will have no effect. - If a bool is provided, it is applied to all maps. + If a `bool` is provided, it is applied to all maps. Notes ----- Operates in-place. - Inverting polarities has no effect on the other steps - of the analysis as polarity is ignored in the current methodology. - This function is only used for tuning visualization - (i.e. for visual inspection and/or to generate figure for an article). + Inverting polarities has no effect on the other steps of the analysis as + polarity is ignored in the current methodology. This function is only used for + tuning visualization (i.e. for visual inspection and/or to generate figure for + an article). """ self._check_fit() @@ -507,8 +497,8 @@ def invert_polarity( _check_type(inv, (bool, np.bool_), item_name="invert") if len(invert) != self._n_clusters: raise ValueError( - "Argument 'invert' should be either a bool or a list of bools " - f"of length 'n_clusters' ({self._n_clusters}). The provided " + "Argument 'invert' should be either a bool or a list of bools of " + f"length 'n_clusters' ({self._n_clusters}). The provided " f"'invert' length is '{len(invert)}'." ) @@ -540,17 +530,15 @@ def plot( ---------- %(axes_topo)s show_gradient : bool - If True, plot a line between channel locations - with highest and lowest values. + If True, plot a line between channel locations with highest and lowest + values. gradient_kwargs : dict - Additional keyword arguments passed to - :meth:`matplotlib.axes.Axes.plot` to plot - gradient line. + Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.plot` to + plot gradient line. %(block)s %(verbose)s **kwargs - Additional keyword arguments are passed to - :func:`mne.viz.plot_topomap`. + Additional keyword arguments are passed to :func:`mne.viz.plot_topomap`. Returns ------- @@ -574,8 +562,7 @@ def plot( @abstractmethod def save(self, fname: Union[str, Path]): - """ - Save clustering solution to disk. + """Save clustering solution to disk. Parameters ---------- @@ -602,38 +589,35 @@ def predict( ): r"""Segment `~mne.io.Raw` or `~mne.Epochs` into microstate sequence. - Segment instance into microstate sequence using the segmentation - smoothing algorithm\ :footcite:p:`Marqui1995`. + Segment instance into microstate sequence using the segmentation smoothing + algorithm\ :footcite:p:`Marqui1995`. Parameters ---------- inst : Raw | Epochs - MNE `~mne.io.Raw` or `~mne.Epochs` object containing the data to - use for prediction. + MNE `~mne.io.Raw` or `~mne.Epochs` object containing the data to use for + prediction. picks : str | list | slice | None - Channels to include. Note that all channels selected must have the - same type. Slices and lists of integers will be interpreted as - channel indices. In lists, channel name strings (e.g. - ``['Fp1', 'Fp2']``) will pick the given channels. - Can also be the string values “all” to pick all channels, or “data” - to pick data channels. ``None`` (default) will pick all channels - used during fitting (e.g., ``self.info['ch_names']``). - Note that channels in ``info['bads']`` will be included if their - names or indices are explicitly provided. + Channels to include. Note that all channels selected must have the same + type. Slices and lists of integers will be interpreted as channel indices. + In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the given + channels. Can also be the string values ``“all”`` to pick all channels, or + ``“data”`` to pick data channels. ``None`` (default) will pick all channels + used during fitting (e.g., ``self.info['ch_names']``). Note that channels in + ``info['bads']`` will be included if their names or indices are explicitly + provided. factor : int - Factor used for label smoothing. ``0`` means no smoothing. - Default to 0. + Factor used for label smoothing. ``0`` means no smoothing. Default to 0. half_window_size : int - Number of samples used for the half window size while smoothing - labels. The half window size is defined as - ``window_size = 2 * half_window_size + 1``. It has no effect if - ``factor=0`` (default). Default to 1. + Number of samples used for the half window size while smoothing labels. The + half window size is defined as ``window_size = 2 * half_window_size + 1``. + It has no effect if ``factor=0`` (default). Default to 1. tol : float Convergence tolerance. min_segment_length : int - Minimum segment length (in samples). If a segment is shorter than - this value, it will be recursively reasigned to neighbouring - segments based on absolute spatial correlation. + Minimum segment length (in samples). If a segment is shorter than this + value, it will be recursively reasigned to neighbouring segments based on + absolute spatial correlation. reject_edges : bool If ``True``, set first and last segments to unlabeled. %(reject_by_annotation_raw)s @@ -642,10 +626,9 @@ def predict( Returns ------- segmentation : RawSegmentation | EpochsSegmentation - Microstate sequence derivated from instance data. Timepoints are - labeled according to cluster centers number: ``0`` for the first - center, ``1`` for the second, etc.. - ``-1`` is used for unlabeled time points. + Microstate sequence derivated from instance data. Timepoints are labeled + according to cluster centers number: ``0`` for the first center, ``1`` for + the second, etc.. ``-1`` is used for unlabeled time points. References ---------- @@ -670,9 +653,8 @@ def predict( reject_by_annotation = True else: raise ValueError( - "Argument 'reject_by_annotation' can be set to 'True', " - f"'False' or 'omit' (True). '{reject_by_annotation}' is " - "not supported." + "Argument 'reject_by_annotation' can be set to 'True', 'False' or " + f"'omit' (True). '{reject_by_annotation}' is not supported." ) elif reject_by_annotation is None: reject_by_annotation = False @@ -681,13 +663,13 @@ def predict( if self._info["bads"] != []: if len(self._info["bads"]) == 1: msg = ( - "The current fit contains bad channel %s" - + " which will be used for prediction." + "The current fit contains bad channel %s which will be used for " + "prediction." ) else: msg = ( - "The current fit contains bad channels %s" - + " which will be used for prediction." + "The current fit contains bad channels %s which will be used for " + "prediction." ) logger.warning(msg, ", ".join(ch_name for ch_name in self._info["bads"])) del msg @@ -718,8 +700,7 @@ def predict( else: msg = ( f"The channels {missing_non_existing_channel} were used " - "during fitting but are missing from the provided " - "instance." + "during fitting but are missing from the provided instance." ) raise ValueError(msg) @@ -728,20 +709,18 @@ def predict( missing_existing_channel = list(missing_existing_channel) if len(missing_existing_channel) == 1: msg = ( - f"The channel {missing_existing_channel[0]} is required " - "to predict because it was included during fitting. The " - "provided 'picks' argument does not select " - f"{missing_existing_channel[0]}. To include it, either " - "remove it from inst.info['bads'] or provide its name or " - "indice explicitly in the 'picks' argument." + f"The channel {missing_existing_channel[0]} is required to predict " + "because it was included during fitting. The provided 'picks' " + f"argument does not select {missing_existing_channel[0]}. To " + "include it, either remove it from inst.info['bads'] or provide " + "its name or indice explicitly in the 'picks' argument." ) else: msg = ( - f"The channels {missing_existing_channel} are required " - " to predict because they were included during fitting. " - "The provided 'picks' argument does not select " - f"{missing_existing_channel}. To include then, either " - "remove them from inst.info['bads'] or provide their " + f"The channels {missing_existing_channel} are required to predict " + "because they were included during fitting. The provided 'picks' " + f"argument does not select {missing_existing_channel}. To include " + "then, either remove them from inst.info['bads'] or provide their " "names or indices explicitly in the 'picks' argument." ) raise ValueError(msg) @@ -751,15 +730,15 @@ def predict( if len(unused_ch) != 0: if len(unused_ch) == 1: msg = ( - "The provided instance and the 'picks' argument results " - "in the selection of %s which was not used during " - "fitting. Thus, it will not be used for prediction." + "The provided instance and the 'picks' argument results in the " + "selection of %s which was not used during fitting. Thus, it will " + "not be used for prediction." ) else: msg = ( - "The provided instance and the 'picks' argument results " - "in the selection of %s which were not used during " - "fitting. Thus, they will not be used for prediction." + "The provided instance and the 'picks' argument results in the " + "selection of %s which were not used during fitting. Thus, they " + "will not be used for prediction." ) logger.warning(msg, ", ".join(ch_name for ch_name in unused_ch)) del msg @@ -768,13 +747,13 @@ def predict( if len(info["bads"]) != 0: if len(info["bads"]) == 1: msg = ( - "The channel %s is set as bad in the instance but was " - "selected. It will be used for prediction." + "The channel %s is set as bad in the instance but was selected. It " + "will be used for prediction." ) else: msg = ( - "The channels %s are set as bad in the instance but were " - "selected. They will be used for prediction." + "The channels %s are set as bad in the instance but were selected. " + "They will be used for prediction." ) logger.warning(msg, ", ".join(ch_name for ch_name in info["bads"])) del msg @@ -791,8 +770,8 @@ def predict( logger.info("Segmenting data without smoothing.") else: logger.info( - "Segmenting data with factor %s and effective smoothing " - "window size: %.4f (s).", + "Segmenting data with factor %s and effective smoothing window size: " + "%.4f (s).", factor, (2 * half_window_size + 1) / inst.info["sfreq"], ) @@ -1029,8 +1008,8 @@ def _reject_short_segments( ) -> NDArray[int]: """Reject segments that are too short. - Reject segments that are too short by replacing the labels with the - adjacent labels based on data correlation. + Reject segments that are too short by replacing the labels with the adjacent + labels based on data correlation. """ while True: # list all segments @@ -1143,8 +1122,7 @@ def fitted(self, fitted): if fitted and not self._fitted: logger.warning( "The property 'fitted' can not be set to 'True' directly. " - "Please use the .fit() method to fit the clustering " - "algorithm." + "Please use the .fit() method to fit the clustering algorithm." ) elif fitted and self._fitted: logger.warning( @@ -1164,7 +1142,7 @@ def cluster_centers_(self) -> NDArray[float]: Returns None if cluster algorithm has not been fitted. - :type: `~numpy.array` (n_clusters, n_channels) | None + :type: `~numpy.array` of shape (n_clusters, n_channels) | None """ if self._cluster_centers_ is None: assert not self._fitted # sanity-check @@ -1176,7 +1154,7 @@ def cluster_centers_(self) -> NDArray[float]: def fitted_data(self) -> NDArray[float]: """Data array used to fit the clustering algorithm. - :type: `~numpy.array` shape (n_channels, n_samples) | None + :type: `~numpy.array` of shape (n_channels, n_samples) | None """ if self._fitted_data is None: assert not self._fitted # sanity-check @@ -1188,7 +1166,7 @@ def fitted_data(self) -> NDArray[float]: def labels_(self) -> NDArray[int]: """Microstate label attributed to each sample of the fitted data. - :type: `~numpy.array` shape (n_samples, ) | None + :type: `~numpy.array` of shape (n_samples, ) | None """ if self._labels_ is None: assert not self._fitted # sanity-check diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index b9c3558a..2c0642b1 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -156,11 +156,10 @@ def fit( @copy_doc(_BaseCluster.save) def save(self, fname: Union[str, Path]): super().save(fname) - # TODO: to be replaced by a general writer than infers the writer from - # the file extension. + # TODO: to be replaced by a general writer than infers the writer from the file + # extension. # pylint: disable=import-outside-toplevel from ..io.fiff import _write_cluster - # pylint: enable=import-outside-toplevel _write_cluster( @@ -250,7 +249,6 @@ def _compute_maps( # pylint: enable=too-many-locals # -------------------------------------------------------------------- - @property def normalize_input(self) -> bool: """If set, the input data is normalized along the channel dimension. diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index 22451131..c7de600c 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -28,14 +28,13 @@ class ModKMeans(_BaseCluster): ---------- %(n_clusters)s n_init : int - Number of time the k-means algorithm is run with different centroid - seeds. The final result will be the run with the highest Global - Explained Variance (GEV). + Number of time the k-means algorithm is run with different centroid seeds. The + final result will be the run with the highest Global Explained Variance (GEV). max_iter : int Maximum number of iterations of the K-means algorithm for a single run. tol : float - Relative tolerance with regards estimate residual noise in the cluster - centers of two consecutive iterations to declare convergence. + Relative tolerance with regards estimate residual noise in the cluster centers + of two consecutive iterations to declare convergence. %(random_state)s References @@ -151,16 +150,15 @@ def fit( Parameters ---------- inst : Raw | Epochs | ChData - MNE `~mne.io.Raw`, `~mne.Epochs` or `~pycrostates.io.ChData` object - from which to extract :term:`cluster centers`. + MNE `~mne.io.Raw`, `~mne.Epochs` or `~pycrostates.io.ChData` object from + which to extract :term:`cluster centers`. picks : str | list | slice | None - Channels to include. Note that all channels selected must have the - same type. Slices and lists of integers will be interpreted as - channel indices. In lists, channel name strings (e.g. - ``['Fp1', 'Fp2']``) will pick the given channels. Can also be the - string values “all” to pick all channels, or “data” to pick data - channels. ``"eeg"`` (default) will pick all eeg channels. - Note that channels in ``info['bads']`` will be included if their + Channels to include. Note that all channels selected must have the same + type. Slices and lists of integers will be interpreted as channel indices. + In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the given + channels. Can also be the string values ``“all”`` to pick all channels, or + ``“data”`` to pick data channels. ``"eeg"`` (default) will pick all eeg + channels. Note that channels in ``info['bads']`` will be included if their names or indices are explicitly provided. %(tmin_raw)s %(tmax_raw)s @@ -224,8 +222,8 @@ def fit( ) else: logger.error( - "All the K-means run failed to converge. Please adapt the " - "tolerance and the maximum number of iteration." + "All the K-means run failed to converge. Please adapt the tolerance " + "and the maximum number of iteration." ) self.fitted = False # reset variables related to fit return # break early @@ -427,6 +425,6 @@ def _check_tol(tol: Union[int, float]) -> Union[int, float]: _check_type(tol, ("numeric",), item_name="tol") if tol <= 0: raise ValueError( - "The tolerance must be a positive number. Provided: '{tol}'." + f"The tolerance must be a positive number. Provided: '{tol}'." ) return tol diff --git a/pycrostates/cluster/utils/utils.py b/pycrostates/cluster/utils/utils.py index aad9af75..e50094ed 100644 --- a/pycrostates/cluster/utils/utils.py +++ b/pycrostates/cluster/utils/utils.py @@ -24,11 +24,10 @@ def _optimize_order( def optimize_order(inst: Cluster, template_inst: Cluster): """Optimize the order of cluster centers between two cluster instances. - Optimize the order of cluster centers in an instance of a clustering - algorithm to maximize auto-correlation, based on a template instance - as determined by the Hungarian algorithm. - The two cluster instances must have the same number of cluster centers - and the same polarity setting. + Optimize the order of cluster centers in an instance of a clustering algorithm to + maximize auto-correlation, based on a template instance as determined by the + Hungarian algorithm. The two cluster instances must have the same number of cluster + centers and the same polarity setting. Parameters ---------- @@ -40,8 +39,7 @@ def optimize_order(inst: Cluster, template_inst: Cluster): Returns ------- order : list of int - The new order to apply to inst to maximize auto-correlation - of cluster centers. + The new order to apply to inst to maximize auto-correlation of cluster centers. """ from .._base import _BaseCluster diff --git a/pycrostates/datasets/lemon/lemon.py b/pycrostates/datasets/lemon/lemon.py index 629cd62a..08cca02d 100644 --- a/pycrostates/datasets/lemon/lemon.py +++ b/pycrostates/datasets/lemon/lemon.py @@ -19,22 +19,20 @@ def data_path(subject_id: str, condition: str) -> Path: r"""Get path to a local copy of preprocessed EEG recording from the LEMON dataset. - Get path to a local copy of preprocessed EEG recording from the - mind-brain-body dataset of MRI, EEG, cognition, emotion, and peripheral - physiology in young and old adults\ :footcite:p:`babayan_mind-brain-body_2019`. - If there is no local copy of the recording, this function will fetch it - from the online repository and store it. The default location is - ``~/pycrostates_data``. + Get path to a local copy of preprocessed EEG recording from the mind-brain-body + dataset of MRI, EEG, cognition, emotion, and peripheral physiology in young and old + adults\ :footcite:p:`babayan_mind-brain-body_2019`. If there is no local copy of the + recording, this function will fetch it from the online repository and store it. The + default location is ``~/pycrostates_data``. Parameters ---------- subject_id : str The subject id to use. For example ``'010276'``. - The list of available subjects can be found - on this `FTP server `_. + The list of available subjects can be found on this + `FTP server `_. condition : str - Can be ``'EO'`` for eyes open condition or ``'EC'`` for eyes closed - condition. + Can be ``'EO'`` for eyes open condition or ``'EC'`` for eyes closed condition. Returns ------- @@ -51,8 +49,8 @@ def data_path(subject_id: str, condition: str) -> Path: - ``pip install pymatreader`` - ``conda install -c conda-forge pymatreader`` - Note that an environment created via the MNE installers includes - ``pymatreader`` by default. + Note that an environment created via the MNE installers includes ``pymatreader`` by + default. References ---------- @@ -65,7 +63,7 @@ def data_path(subject_id: str, condition: str) -> Path: config = get_config() fetcher = pooch.create( path=config["PREPROCESSED_LEMON_DATASET_PATH"], - base_url="https://ftp.gwdg.de/pub/misc/MPI-Leipzig_Mind-Brain-Body-LEMON/EEG_MPILMBB_LEMON/EEG_Preprocessed_BIDS_ID/EEG_Preprocessed/", # noqa, + base_url="https://ftp.gwdg.de/pub/misc/MPI-Leipzig_Mind-Brain-Body-LEMON/EEG_MPILMBB_LEMON/EEG_Preprocessed_BIDS_ID/EEG_Preprocessed/", # noqa: E501 version=None, registry=None, ) @@ -86,8 +84,8 @@ def data_path(subject_id: str, condition: str) -> Path: def standardize(raw: BaseRaw): """Standardize :class:`~mne.io.Raw` from the lemon dataset. - This function will interpolate missing channels from the standard setup, - then reorder channels and finally reference to a common average. + This function will interpolate missing channels from the standard setup, then + reorder channels and finally reference to a common average. Parameters ---------- @@ -102,8 +100,8 @@ def standardize(raw: BaseRaw): Notes ----- If you don't want to interpolate missing channels, you can use - :func:`mne.channels.equalize_channels` instead to have the same electrodes - across different recordings. + :func:`mne.channels.equalize_channels` instead to have the same electrodes across + different recordings. """ raw = raw.copy() # fmt: off diff --git a/pycrostates/io/ch_data.py b/pycrostates/io/ch_data.py index 3a3fa715..11009cb2 100644 --- a/pycrostates/io/ch_data.py +++ b/pycrostates/io/ch_data.py @@ -16,8 +16,8 @@ class ChData(CHData, ChannelsMixin, ContainsMixin, MontageMixin): """ChData stores atemporal data with its spatial information. - ChData is similar to a raw instance where temporality has been removed. - Only the spatial information, stored as a `~pycrostates.io.ChInfo` is + `~pycrostates.io.ChData` is similar to a raw instance where temporality has been + removed. Only the spatial information, stored as a `~pycrostates.io.ChInfo` is retained. Parameters @@ -25,8 +25,8 @@ class ChData(CHData, ChannelsMixin, ContainsMixin, MontageMixin): data : array Data array of shape ``(n_channels, n_samples)``. info : mne.Info | ChInfo - Atemporal measurement information. If a `mne.Info` is provided, it is - converted to a `~pycrostates.io.ChInfo`. + Atemporal measurement information. If a `mne.Info` is provided, it is converted + to a `~pycrostates.io.ChInfo`. """ def __init__(self, data: NDArray[float], info: Union[Info, CHInfo]): @@ -36,9 +36,8 @@ def __init__(self, data: NDArray[float], info: Union[Info, CHInfo]): _check_type(info, (Info, ChInfo), "info") if data.ndim != 2: raise ValueError( - "Argument 'data' should be a 2D array " - "(n_channels, n_samples). The provided array " - f"shape is {data.shape} which has {data.ndim} " + "Argument 'data' should be a 2D array (n_channels, n_samples). The " + f"provided array shape is {data.shape} which has {data.ndim} " "dimensions." ) if not len(info["ch_names"]) == data.shape[0]: @@ -118,8 +117,8 @@ def pick(self, picks, exclude="bads"): ---------- %(picks_all)s exclude : list | str - Set of channels to exclude, only used when picking based on - types (e.g., ``exclude="bads"`` when ``picks="meg"``). + Set of channels to exclude, only used when picking based on types (e.g., + ``exclude="bads"`` when ``picks="meg"``). Returns ------- diff --git a/pycrostates/io/meas_info.py b/pycrostates/io/meas_info.py index 639fdf18..e36b55fb 100644 --- a/pycrostates/io/meas_info.py +++ b/pycrostates/io/meas_info.py @@ -28,66 +28,62 @@ class ChInfo(CHInfo, Info): """Atemporal measurement information. Similar to a :class:`mne.Info` class, but without any temporal information. - Only the channel-related information are present. A - :class:`~pycrostates.io.ChInfo` can be created either: + Only the channel-related information are present. A :class:`~pycrostates.io.ChInfo` + can be created either: - - by providing a :class:`~mne.Info` class from which information are - retrieved. - - by providing the ``ch_names`` and the ``ch_types`` to create a new - instance. + - by providing a :class:`~mne.Info` class from which information are retrieved. + - by providing the ``ch_names`` and the ``ch_types`` to create a new instance. Only one of those 2 methods should be used at once. - .. warning:: The only entry that should be manually changed by the user - is ``info['bads']``. All other entries should be - considered read-only, though they can be modified by various - functions or methods (which have safeguards to ensure all - fields remain in sync). + .. warning:: The only entry that should be manually changed by the user is + ``info['bads']``. All other entries should be considered read-only, + though they can be modified by various functions or methods (which have + safeguards to ensure all fields remain in sync). Parameters ---------- info : Info | None - MNE measurement information instance from which channel-related - variables are retrieved. + MNE measurement information instance from which channel-related variables are + retrieved. ch_names : list of str | int | None - Channel names. If an int, a list of channel names will be created - from ``range(ch_names)``. + Channel names. If an int, a list of channel names will be created from + ``range(ch_names)``. ch_types : list of str | str | None Channel types. If str, all channels are assumed to be of the same type. Attributes ---------- bads : list of str - List of bad (noisy/broken) channels, by name. These channels will by - default be ignored by many processing steps. + List of bad (noisy/broken) channels, by name. These channels will by default be + ignored by many processing steps. ch_names : tuple of str The names of the channels. chs : tuple of dict - A list of channel information dictionaries, one per channel. - See Notes for more information. + A list of channel information dictionaries, one per channel. See Notes for more + information. comps : list of dict - CTF software gradient compensation data. - See Notes for more information. + CTF software gradient compensation data. See Notes for more information. ctf_head_t : dict | None - The transformation from 4D/CTF head coordinates to Neuromag head - coordinates. This is only present in 4D/CTF data. + The transformation from 4D/CTF head coordinates to Neuromag head coordinates. + This is only present in 4D/CTF data. custom_ref_applied : int - Whether a custom (=other than average) reference has been applied to - the EEG data. This flag is checked by some algorithms that require an - average reference to be set. + Whether a custom (=other than average) reference has been applied to the EEG + data. This flag is checked by some algorithms that require an average reference + to be set. dev_ctf_t : dict | None - The transformation from device coordinates to 4D/CTF head coordinates. - This is only present in 4D/CTF data. + The transformation from device coordinates to 4D/CTF head coordinates. This is + only present in 4D/CTF data. dev_head_t : dict | None The device to head transformation. dig : tuple of dict | None - The Polhemus digitization data in head coordinates. - See Notes for more information. + The Polhemus digitization data in head coordinates. See Notes for more + information. nchan : int Number of channels. projs : list of Projection - List of SSP operators that operate on the data. - See :class:`mne.Projection` for details. + List of SSP operators that operate on the data. See :class:`mne.Projection` for + details. Notes ----- @@ -96,9 +92,8 @@ class ChInfo(CHInfo, Info): * ``chs`` list of dict: cal : float - The calibration factor to bring the channels to physical - units. Used in product with ``range`` to scale the data read - from disk. + The calibration factor to bring the channels to physical units. Used in + product with ``range`` to scale the data read from disk. ch_name : str The channel name. coil_type : int @@ -108,18 +103,16 @@ class ChInfo(CHInfo, Info): kind : int The kind of channel, e.g. ``FIFFV_EEG_CH``. loc : array, shape (12,) - Channel location. For MEG this is the position plus the - normal given by a 3x3 rotation matrix. For EEG this is the - position followed by reference position (with 6 unused). - The values are specified in device coordinates for MEG and in - head coordinates for EEG channels, respectively. + Channel location. For MEG this is the position plus the normal given by a + 3x3 rotation matrix. For EEG this is the position followed by reference + position (with 6 unused). The values are specified in device coordinates for + MEG and in head coordinates for EEG channels, respectively. logno : int - Logical channel number, conventions in the usage of this - number vary. + Logical channel number, conventions in the usage of this number vary. range : float - The hardware-oriented part of the calibration factor. - This should be only applied to the continuous raw data. - Used in product with ``cal`` to scale data read from disk. + The hardware-oriented part of the calibration factor. This should be only + applied to the continuous raw data. Used in product with ``cal`` to scale + data read from disk. scanno : int Scanning order number, starting from 1. unit : int @@ -144,14 +137,12 @@ class ChInfo(CHInfo, Info): * ``dig`` list of dict: kind : int - The kind of channel, - e.g. ``FIFFV_POINT_EEG``, ``FIFFV_POINT_CARDINAL``. + The kind of channel, e.g. ``FIFFV_POINT_EEG``, ``FIFFV_POINT_CARDINAL``. r : array, shape (3,) 3D position in m. and coord_frame. ident : int - Number specifying the identity of the point. - e.g. ``FIFFV_POINT_NASION`` if kind is ``FIFFV_POINT_CARDINAL``, or - 42 if kind is ``FIFFV_POINT_EEG``. + Number specifying the identity of the point. e.g. ``FIFFV_POINT_NASION`` if + kind is ``FIFFV_POINT_CARDINAL``, or 42 if kind is ``FIFFV_POINT_EEG``. coord_frame : int The coordinate frame used, e.g. ``FIFFV_COORD_HEAD``. """ @@ -160,33 +151,27 @@ class ChInfo(CHInfo, Info): # fmt: off _attributes = { "bads": _check_bads, - "ch_names": "ch_names cannot be set directly. " - "Please use methods inst.add_channels(), " - "inst.drop_channels(), inst.pick_channels(), " - "inst.rename_channels(), inst.reorder_channels() " - "and inst.set_channel_types() instead.", - "chs": "chs cannot be set directly. " - "Please use methods inst.add_channels(), " - "inst.drop_channels(), inst.pick_channels(), " - "inst.rename_channels(), inst.reorder_channels() " - "and inst.set_channel_types() instead.", + "ch_names": "ch_names cannot be set directly. Please use methods " + "inst.add_channels(), inst.drop_channels(), inst.pick_channels(), " + "inst.rename_channels(), inst.reorder_channels() and " + "inst.set_channel_types() instead.", + "chs": "chs cannot be set directly. Please use methods inst.add_channels(), " + "inst.drop_channels(), inst.pick_channels(), inst.rename_channels(), " + "inst.reorder_channels() and inst.set_channel_types() instead.", "comps": "comps cannot be set directly. " - "Please use method Raw.apply_gradient_compensation() " - "instead.", + "Please use method Raw.apply_gradient_compensation() instead.", "ctf_head_t": "ctf_head_t cannot be set directly.", "custom_ref_applied": "custom_ref_applied cannot be set directly. " - "Please use method inst.set_eeg_reference() " - "instead.", + "Please use method inst.set_eeg_reference() instead.", "dev_ctf_t": "dev_ctf_t cannot be set directly.", "dev_head_t": _check_dev_head_t, - "dig": "dig cannot be set directly. " - "Please use method inst.set_montage() instead.", - "nchan": "nchan cannot be set directly. " - "Please use methods inst.add_channels(), " - "inst.drop_channels(), and inst.pick_channels() instead.", - "projs": "projs cannot be set directly. " - "Please use methods inst.add_proj() and inst.del_proj() " + "dig": "dig cannot be set directly. Please use method inst.set_montage() " + "instead.", + "nchan": "nchan cannot be set directly. Please use methods " + "inst.add_channels(), inst.drop_channels(), and inst.pick_channels() " "instead.", + "projs": "projs cannot be set directly. Please use methods inst.add_proj() and " + "inst.del_proj() instead.", } # fmt: on @@ -216,8 +201,7 @@ def __init__( else: raise RuntimeError( "If 'info' is provided, 'ch_names' and 'ch_types' must be " - "None. If 'ch_names' and 'ch_types' are provided, 'info' " - "must be None." + "None. If 'ch_names' and 'ch_types' are provided, 'info' must be None." ) def _init_from_info(self, info: Info): @@ -404,9 +388,8 @@ def __setitem__(self, key, val): val = self._attributes[key](val) # attribute checker function else: raise RuntimeError( - f"Info does not support setting the key {repr(key)}. " - "Supported keys are " - f"{', '.join(repr(k) for k in self._attributes)}" + f"Info does not support setting the key {repr(key)}. Supported keys " + f"are {', '.join(repr(k) for k in self._attributes)}" ) super().__setitem__(key, val) # calls the dict __setitem__ @@ -452,7 +435,7 @@ def _check_consistency(self, prepend_error: str = ""): ): raise RuntimeError( f"{prepend_error}info channel name inconsistency detected, " - "please notify developers." + "please notify the developers." ) for pi, proj in enumerate(self.get("projs", [])): diff --git a/pycrostates/metrics/calinski_harabasz.py b/pycrostates/metrics/calinski_harabasz.py index 1e88a84b..0524ac58 100644 --- a/pycrostates/metrics/calinski_harabasz.py +++ b/pycrostates/metrics/calinski_harabasz.py @@ -12,10 +12,9 @@ def calinski_harabasz_score(cluster): # higher the better r"""Compute the Calinski-Harabasz score. - This function computes the Calinski-Harabasz - score\ :footcite:p:`Calinski-Harabasz` with - :func:`sklearn.metrics.calinski_harabasz_score` from a fitted - :ref:`Clustering` instance. + This function computes the Calinski-Harabasz score\ :footcite:p:`Calinski-Harabasz` + with :func:`sklearn.metrics.calinski_harabasz_score` from a fitted :ref:`Clustering` + instance. Parameters ---------- diff --git a/pycrostates/metrics/davies_bouldin.py b/pycrostates/metrics/davies_bouldin.py index 012ba181..801a0382 100644 --- a/pycrostates/metrics/davies_bouldin.py +++ b/pycrostates/metrics/davies_bouldin.py @@ -14,10 +14,9 @@ def davies_bouldin_score(cluster): # lower the better r"""Compute the Davies-Bouldin score. - This function computes the Davies-Bouldin - score\ :footcite:p:`Davies-Bouldin` with - :func:`sklearn.metrics.davies_bouldin_score` from a fitted - :ref:`Clustering` instance. + This function computes the Davies-Bouldin score\ :footcite:p:`Davies-Bouldin` with + :func:`sklearn.metrics.davies_bouldin_score` from a fitted :ref:`Clustering` + instance. Parameters ---------- @@ -31,9 +30,9 @@ def davies_bouldin_score(cluster): # lower the better Notes ----- For more details regarding the implementation, please refer to - :func:`sklearn.metrics.davies_bouldin_score`. - This function was modified in order to use the absolute spatial correlation - for distance computations instead of the euclidean distance. + :func:`sklearn.metrics.davies_bouldin_score`. This function was modified in order to + use the absolute spatial correlation for distance computations instead of the + euclidean distance. References ---------- diff --git a/pycrostates/metrics/silhouette.py b/pycrostates/metrics/silhouette.py index fd2a91ec..655edf34 100644 --- a/pycrostates/metrics/silhouette.py +++ b/pycrostates/metrics/silhouette.py @@ -13,10 +13,8 @@ def silhouette_score(cluster): # higher the better r"""Compute the mean Silhouette Coefficient. - This function computes the Silhouette - Coefficient\ :footcite:p:`Silhouettes` with - :func:`sklearn.metrics.silhouette_score` from a fitted - :ref:`Clustering` instance. + This function computes the Silhouette Coefficient\ :footcite:p:`Silhouettes` with + :func:`sklearn.metrics.silhouette_score` from a fitted :ref:`Clustering` instance. Parameters ---------- diff --git a/pycrostates/preprocessing/extract_gfp_peaks.py b/pycrostates/preprocessing/extract_gfp_peaks.py index cdcc1ded..f7b17a38 100644 --- a/pycrostates/preprocessing/extract_gfp_peaks.py +++ b/pycrostates/preprocessing/extract_gfp_peaks.py @@ -33,32 +33,29 @@ def extract_gfp_peaks( ) -> CHData: """:term:`Global Field Power` (:term:`GFP`) peaks extraction. - Extract :term:`Global Field Power` (:term:`GFP`) peaks from - :class:`~mne.Epochs` or :class:`~mne.io.Raw`. + Extract :term:`Global Field Power` (:term:`GFP`) peaks from :class:`~mne.Epochs` or + :class:`~mne.io.Raw`. Parameters ---------- inst : Raw | Epochs Instance from which to extract :term:`global field power` (GFP) peaks. picks : str | list | slice | None - Channels to use for GFP computation. - Note that all channels selected must have the - same type. Slices and lists of integers will be interpreted as - channel indices. In lists, channel name strings (e.g. - ``['Fp1', 'Fp2']``) will pick the given channels. Can also be the - string values “all” to pick all channels, or “data” to pick data - channels. ``"eeg"`` (default) will pick all eeg channels. - Note that channels in ``info['bads']`` will be included if their + Channels to use for GFP computation. Note that all channels selected must have + the same type. Slices and lists of integers will be interpreted as channel + indices. In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the + given channels. Can also be the string values ``“all”`` to pick all channels, or + ``“data”`` to pick data channels. ``"eeg"`` (default) will pick all eeg + channels. Note that channels in ``info['bads']`` will be included if their names or indices are explicitly provided. return_all : bool - If True, the returned `~pycrostates.io.ChData` instance will - include all channels. - If False (default), the returned `~pycrostates.io.ChData` instance will - only include channels used for GFP computation (i.e ``picks``). + If True, the returned `~pycrostates.io.ChData` instance will include all + channels. If False (default), the returned `~pycrostates.io.ChData` instance + will only include channels used for GFP computation (i.e ``picks``). min_peak_distance : int - Required minimal horizontal distance (``≥ 1`) in samples between - neighboring peaks. Smaller peaks are removed first until the condition - is fulfilled for all remaining peaks. Default to ``1``. + Required minimal horizontal distance (``≥ 1`) in samples between neighboring + peaks. Smaller peaks are removed first until the condition is fulfilled for all + remaining peaks. Default to ``1``. %(tmin_raw)s %(tmax_raw)s %(reject_by_annotation_raw)s @@ -72,9 +69,9 @@ def extract_gfp_peaks( Notes ----- The :term:`Global Field Power` (:term:`GFP`) peaks are extracted with - :func:`scipy.signal.find_peaks`. Only the ``distance`` argument is - filled with the value provided in ``min_peak_distance``. The other - arguments are set to their default values. + :func:`scipy.signal.find_peaks`. Only the ``distance`` argument is filled with the + value provided in ``min_peak_distance``. The other arguments are set to their + default values. """ from ..io import ChData @@ -146,13 +143,13 @@ def _extract_gfp_peaks( data : array of shape (n_channels, n_samples) The data to extract GFP peaks from. min_peak_distance : int - Required minimal horizontal distance (>= 1) in samples between - neighboring peaks. Smaller peaks are removed first until the condition - is fulfilled for all remaining peaks. Default to 2. + Required minimal horizontal distance (>= 1) in samples between neighboring + peaks. Smaller peaks are removed first until the condition is fulfilled for all + remaining peaks. Default to 2. Returns ------- - peaks : array of shape (n_picks) + peaks : array of shape (n_picks,) The indices when peaks occur. """ gfp = np.std(data, axis=0) diff --git a/pycrostates/preprocessing/resample.py b/pycrostates/preprocessing/resample.py index 6eb08d3a..f0a92c25 100644 --- a/pycrostates/preprocessing/resample.py +++ b/pycrostates/preprocessing/resample.py @@ -48,14 +48,13 @@ def resample( %(tmax_raw)s %(reject_by_annotation_raw)s n_resamples : int - Number of resamples to draw. Each epoch can be used to fit a separate - clustering solution. See notes for additional information. + Number of resamples to draw. Each epoch can be used to fit a separate clustering + solution. See notes for additional information. n_samples : int - Length of each epoch (in samples). See notes for additional - information. + Length of each epoch (in samples). See notes for additional information. coverage : float - Strictly positive ratio between resampling data size and size of the - original recording. See notes for additional information. + Strictly positive ratio between resampling data size and size of the original + recording. See notes for additional information. replace : bool Whether or not to allow resampling with replacement. %(random_state)s @@ -68,9 +67,8 @@ def resample( Notes ----- - Only two of ``n_resamples``, ``n_samples`` and ``coverage`` - parameters must be defined, the non-defined one will be - determine at runtime by the 2 other parameters. + Only two of ``n_resamples``, ``n_samples`` and ``coverage`` parameters must be + defined, the non-defined one will be determine at runtime by the 2 other parameters. """ from ..io import ChData diff --git a/pycrostates/preprocessing/spatial_filter.py b/pycrostates/preprocessing/spatial_filter.py index c810769a..e9b2fa07 100644 --- a/pycrostates/preprocessing/spatial_filter.py +++ b/pycrostates/preprocessing/spatial_filter.py @@ -65,50 +65,42 @@ def apply_spatial_filter( ): r"""Apply a spatial filter. - Adapted from \ :footcite:t:`michel2019eeg`. - Apply an instantaneous filter which interpolates channels - with local neighbors while removing outliers. + Adapted from \ :footcite:t:`michel2019eeg`. Apply an instantaneous filter which + interpolates channels with local neighbors while removing outliers. The current implementation proceeds as follows: * An interpolation matrix is computed using - mne.channels.interpolation._make_interpolation_matrix. - * An ajdacency matrix is computed using - `mne.channels.find_ch_adjacency`. - * If ``exclude_bads`` is set to ``True``, - bad channels are removed from the ajdacency matrix. - * For each timepoint and each channel, - a reduced adjacency matrix is built by removing neighbors - with lowest and highest value. - * For each timepoint and each channel, - a reduced interpolation matrix is built by extracting neighbor - weights based on the reduced adjacency matrix. + ``mne.channels.interpolation._make_interpolation_matrix``. + * An ajdacency matrix is computed using `mne.channels.find_ch_adjacency`. + * If ``exclude_bads`` is set to ``True``, bad channels are removed from the + ajdacency matrix. + * For each timepoint and each channel, a reduced adjacency matrix is built by + removing neighbors with lowest and highest value. + * For each timepoint and each channel, a reduced interpolation matrix is built by + extracting neighbor weights based on the reduced adjacency matrix. * The reduced interpolation matrices are normalized. - * The channel's timepoints are interpolated - using their reduced interpolation matrix. + * The channel's timepoints are interpolated using their reduced interpolation + matrix. Parameters ---------- inst : Raw | Epochs | ChData Instance to filter spatially. ch_type : str - The channel type on which to apply the spatial filter. - Currently only supports ``'eeg'``. + The channel type on which to apply the spatial filter. Currently only supports + ``'eeg'``. exclude_bads : bool - If set to ``True``, bad channels will be removed - from the adjacency matrix and therefore not used - to interpolate neighbors. In addition, bad channels - will not be filtered. - If set to ``False``, proceed as if all channels were good. + If set to ``True``, bad channels will be removed from the adjacency matrix and + therefore not used to interpolate neighbors. In addition, bad channels will not + be filtered. If set to ``False``, proceed as if all channels were good. origin : array of shape (3,) | str - Origin of the sphere in the head coordinate frame and in meters. - Can be ``'auto'`` (default), which means a head-digitization-based - origin fit. + Origin of the sphere in the head coordinate frame and in meters. Can be + ``'auto'`` (default), which means a head-digitization-based origin fit. adjacency : array or csr_matrix of shape (n_channels, n_channels) | str - An adjacency matrix. Can be created using - `mne.channels.find_ch_adjacency` and edited with - `mne.viz.plot_ch_adjacency`. - If ``'auto'`` (default), the matrix will be automatically created - using `mne.channels.find_ch_adjacency` and other parameters. + An adjacency matrix. Can be created using `mne.channels.find_ch_adjacency` and + edited with `mne.viz.plot_ch_adjacency`. If ``'auto'`` (default), the matrix + will be automatically created using `mne.channels.find_ch_adjacency` and other + parameters. %(n_jobs)s %(verbose)s @@ -133,10 +125,9 @@ def apply_spatial_filter( _check_preload(inst, "Apply spatial filter") if inst.get_montage() is None: raise ValueError( - "No montage was set on your data, but spatial filter" - "can only work if digitization points for the EEG " - "channels are available. Consider calling inst.set_montage() " - "to apply a montage." + "No montage was set on your data, but spatial filter can only work if " + "digitization points for the EEG channels are available. Consider calling " + "inst.set_montage() to apply a montage." ) # retrieve picks picks = dict(_picks_by_type(inst.info, exclude=[]))[ch_type] @@ -156,8 +147,8 @@ def apply_spatial_filter( distance = np.mean(distance / np.mean(distance)) if np.abs(1.0 - distance) > 0.1: logger.warn( - "Your spherical fit is poor, interpolation results are " - "likely to be inaccurate." + "Your spherical fit is poor, interpolation results are likely to be " + "inaccurate." ) pos = pos - origin interpolate_matrix = _make_interpolation_matrix(pos, pos) diff --git a/pycrostates/segmentation/_base.py b/pycrostates/segmentation/_base.py index ea2d1c2d..9de7d200 100644 --- a/pycrostates/segmentation/_base.py +++ b/pycrostates/segmentation/_base.py @@ -24,7 +24,7 @@ class _BaseSegmentation(ABC): Parameters ---------- - labels : array (n_samples, ) or (n_epochs, n_samples) + labels : array of shape (n_samples, ) or (n_epochs, n_samples) Microstates labels attributed to each sample, i.e. the segmentation. inst : Raw | Epochs MNE instance used to predict the segmentation. @@ -97,34 +97,32 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False): Available parameters are listed below: - * ``mean_corr``: Mean correlation value for each time point - assigned to a given state. + * ``mean_corr``: Mean correlation value for each time point assigned to a + given state. * ``gev``: Global explained variance expressed by a given state. - It is the sum of global explained variance values of each - time point assigned to a given state. + It is the sum of global explained variance values of each time point + assigned to a given state. * ``timecov``: Time coverage, the proportion of time during which a given state is active. This metric is expressed as a ratio. * ``meandurs``: Mean durations of segments assigned to a given state. This metric is expressed in seconds (s). * ``occurrences``: Occurrences per second, the mean number of - segment assigned to a given state per second. This metrics is - expressed in segment per second. + segment assigned to a given state per second. This metrics is expressed + in segment per second. * ``dist_corr`` (req. ``return_dist=True``): Distribution of correlations values of each time point assigned to a given state. * ``dist_gev`` (req. ``return_dist=True``): Distribution of global - explained variances values of each time point assigned to a given - state. + explained variances values of each time point assigned to a given state. * ``dist_durs`` (req. ``return_dist=True``): Distribution of - durations of each segments assigned to a given state. Each value - is expressed in seconds (s). + durations of each segments assigned to a given state. Each value is + expressed in seconds (s). Warnings -------- - When working with `~mne.Epochs`, this method will put together - segments of all epochs. This could lead to wrong interpretation - especially on state durations. To avoid this behaviour, - make sure to set the ``reject_edges`` parameter to ``True`` - when creating the segmentation. + When working with `~mne.Epochs`, this method will put together segments of all + epochs. This could lead to wrong interpretation especially on state durations. + To avoid this behaviour, make sure to set the ``reject_edges`` parameter to + ``True`` when creating the segmentation. """ _check_type(norm_gfp, (bool,), "norm_gfp") _check_type(return_dist, (bool,), "return_dist") @@ -133,8 +131,7 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False): sfreq = self._inst.info["sfreq"] # don't copy the data/labels array, get_data, swapaxes, reshape are - # returning a new view of the array, which is fine since we do not - # modify it. + # returning a new view of the array, which is fine since we do not modify it. labels = self._labels # same pointer, no memory overhead. if isinstance(self._inst, BaseRaw): data = self._inst.get_data() @@ -208,9 +205,9 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False): def compute_transition_matrix(self, stat="probability", ignore_self=True): """Compute the observed transition matrix. - Count the number of transitions from one state to another - and aggregate the result as statistic. - Transition "from" and "to" unlabeled segments ``-1`` are ignored. + Count the number of transitions from one state to another and aggregate the + result as statistic. Transition "from" and "to" unlabeled segments ``-1`` are + ignored. Parameters ---------- @@ -223,11 +220,10 @@ def compute_transition_matrix(self, stat="probability", ignore_self=True): Warnings -------- - When working with `~mne.Epochs`, this method will take into account - transitions that occur between epochs. This could lead to wrong - interpretation when working with discontinuous data. - To avoid this behaviour, make sure to set the ``reject_edges`` - parameter to ``True`` when predicting the segmentation. + When working with `~mne.Epochs`, this method will take into account transitions + that occur between epochs. This could lead to wrong interpretation when working + with discontinuous data. To avoid this behaviour, make sure to set the + ``reject_edges`` parameter to ``True`` when predicting the segmentation. """ return _compute_transition_matrix( self._labels, @@ -240,11 +236,11 @@ def compute_transition_matrix(self, stat="probability", ignore_self=True): def compute_expected_transition_matrix(self, stat="probability", ignore_self=True): """Compute the expected transition matrix. - Compute the theoretical transition matrix as if time course was - ignored, but microstate proportions kept (i.e. shuffled segmentation). - This matrix can be used to quantify/correct the effect of microstate - time coverage on the observed transition matrix obtained with - the method ``compute_expected_transition_matrix``. + Compute the theoretical transition matrix as if time course was ignored, but + microstate proportions kept (i.e. shuffled segmentation). This matrix can be + used to quantify/correct the effect of microstate time coverage on the observed + transition matrix obtained with the method + ``compute_expected_transition_matrix``. Transition "from" and "to" unlabeled segments ``-1`` are ignored. Parameters diff --git a/pycrostates/segmentation/segmentation.py b/pycrostates/segmentation/segmentation.py index 7488035b..951bdfe7 100644 --- a/pycrostates/segmentation/segmentation.py +++ b/pycrostates/segmentation/segmentation.py @@ -32,9 +32,8 @@ def __init__(self, *args, **kwargs): _check_type(self._inst, (BaseRaw,), item_name="raw") if self._labels.ndim != 1: raise ValueError( - "Argument 'labels' should be a 1D array. The provided array " - f"shape is {self._labels.shape} which has {self._labels.ndim} " - "dimensions." + "Argument 'labels' should be a 1D array. The provided array shape " + f"is {self._labels.shape} which has {self._labels.ndim} dimensions." ) if self._inst.times.size != self._labels.shape[-1]: @@ -115,9 +114,8 @@ def __init__(self, *args, **kwargs): if self._labels.ndim != 2: raise ValueError( - "Argument 'labels' should be a 2D array. The provided array " - f"shape is {self._labels.shape} which has {self._labels.ndim} " - "dimensions." + "Argument 'labels' should be a 2D array. The provided array shape " + f"is {self._labels.shape} which has {self._labels.ndim} dimensions." ) if len(self._inst) != self._labels.shape[0]: raise ValueError( @@ -129,8 +127,7 @@ def __init__(self, *args, **kwargs): raise ValueError( "Provided MNE epochs and labels do not have the same number " f"of samples. The 'epochs' have {self._inst.times.size} " - f"samples, while the 'labels' has {self._labels.shape[-1]} " - "samples." + f"samples, while the 'labels' has {self._labels.shape[-1]} samples." ) @fill_doc diff --git a/pycrostates/segmentation/transitions.py b/pycrostates/segmentation/transitions.py index 2ef8de85..b3c36252 100644 --- a/pycrostates/segmentation/transitions.py +++ b/pycrostates/segmentation/transitions.py @@ -16,9 +16,8 @@ def compute_transition_matrix( ) -> NDArray[float]: """Compute the observed transition matrix. - Count the number of transitions from one state to another and aggregate the - result as statistic. Transitions "from" and "to" unlabeled segments ``-1`` - are ignored. + Count the number of transitions from one state to another and aggregate the result + as statistic. Transitions "from" and "to" unlabeled segments ``-1`` are ignored. Parameters ---------- @@ -84,10 +83,10 @@ def compute_expected_transition_matrix( ) -> NDArray[float]: """Compute the expected transition matrix. - Compute the theoretical transition matrix as if time course was - ignored, but microstate proportions was kept (i.e. shuffled segmentation). - This matrix can be used to quantify/correct the effect of microstate - time coverage on the observed transition matrix obtained with the + Compute the theoretical transition matrix as if time course was ignored, but + microstate proportions was kept (i.e. shuffled segmentation). This matrix can be + used to quantify/correct the effect of microstate time coverage on the observed + transition matrix obtained with the :func:`pycrostates.segmentation.compute_transition_matrix`. Transition "from" and "to" unlabeled segments ``-1`` are ignored. @@ -174,8 +173,7 @@ def _check_labels_n_clusters( raise ValueError( "The argument 'labels' must contain the labels of each timepoint " "encoded as consecutive positive integers (0-indexed). Make sure " - f"you are providing an integer array. '{labels.dtype}' is " - "invalid." + f"you are providing an integer array. '{labels.dtype}' is invalid." ) # check for negative integers except -1 if np.any(labels < -1): diff --git a/pycrostates/utils/_checks.py b/pycrostates/utils/_checks.py index 06705532..1394e0c4 100644 --- a/pycrostates/utils/_checks.py +++ b/pycrostates/utils/_checks.py @@ -144,8 +144,7 @@ def _check_value(item, allowed_values, item_name=None, extra=None): item_name : str | None Name of the item to show inside the error message. extra : str | None - Extra string to append to the invalid value sentence, e.g. - "when using ico mode". + Extra string to append to the invalid value sentence, e.g. "with ico mode". Raises ------ @@ -179,11 +178,10 @@ def _check_value(item, allowed_values, item_name=None, extra=None): def _check_n_jobs(n_jobs): - """ - Check n_jobs parameter. + """Check n_jobs parameter. - Check that n_jobs is a positive integer or a negative integer for all - cores. CUDA is not supported. + Check that n_jobs is a positive integer or a negative integer for all cores. CUDA is + not supported. """ _check_type(n_jobs, ("int",), "n_jobs") if n_jobs <= 0: @@ -292,8 +290,7 @@ def _check_picks_uniqueness(info, picks): "%s '%s' channel(s)" % t for t in zip(counts, ch_types) ) raise ValueError( - "Only one datatype can be selected, but 'picks' " - f"results in {channels_msg}." + f"Only one datatype can be selected, but 'picks' results in {channels_msg}." ) diff --git a/pycrostates/utils/_docs.py b/pycrostates/utils/_docs.py index 62b7b9ad..7bf23960 100644 --- a/pycrostates/utils/_docs.py +++ b/pycrostates/utils/_docs.py @@ -34,11 +34,10 @@ "verbose" ] = """ verbose : int | str | bool | None - Sets the verbosity level. The verbosity increases gradually between - ``"CRITICAL"``, ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. - If None is provided, the verbosity is set to ``"WARNING"``. - If a bool is provided, the verbosity is set to ``"WARNING"`` for False and - to ``"INFO"`` for True.""" + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True.""" # ---- Clusters ---- docdict[ @@ -63,9 +62,9 @@ "cluster" ] = """ cluster : :ref:`cluster` - Fitted clustering algorithm from which to compute score. - For more details about current clustering implementations, - check the :ref:`Clustering` section of the documentation. + Fitted clustering algorithm from which to compute score. For more details about + current clustering implementations, check the :ref:`Clustering` section of the + documentation. """ # ------ I/O ------- @@ -109,33 +108,32 @@ Aggregate statistic to compute transitions. Can be: * ``count``: show the number of observations of each transition. - * ``probability`` or ``proportion``: normalize count such as the - probabilities along the first axis is always equal to ``1``. - * ``percent``: normalize count such as the - probabilities along the first axis is always equal to ``100``.""" + * ``probability`` or ``proportion``: normalize count such as the probabilities along + the first axis is always equal to ``1``. + * ``percent``: normalize count such as the probabilities along the first axis is + always equal to ``100``.""" docdict[ "stat_expected_transitions" ] = """ stat : str Aggregate statistic to compute transitions. Can be: - * ``probability`` or ``proportion``: normalize count such as the - probabilities along the first axis is always equal to ``1``. - * ``percent``: normalize count such as the - probabilities along the first axis is always equal to ``100``.""" + * ``probability`` or ``proportion``: normalize count such as the probabilities along + the first axis is always equal to ``1``. + * ``percent``: normalize count such as the probabilities along the first axis is + always equal to ``100``.""" docdict[ "ignore_self" ] = """ ignore_self : bool - If True, ignores the transition from one state to itself. - This is equivalent to setting the duration of all states to 1 sample.""" + If True, ignores the transition from one state to itself. This is equivalent to + setting the duration of all states to 1 sample.""" docdict[ "transition_matrix" ] = """ T : array of shape ``(n_cluster, n_cluster)`` Array of transition probability values from one label to another. - First axis indicates state ``"from"``. Second axis indicates state - ``"to"``.""" + First axis indicates state ``"from"``. Second axis indicates state ``"to"``.""" # ------ Viz ------- docdict[ @@ -152,22 +150,21 @@ "axes_topo" ] = """ axes : Axes | None - Either ``None`` to create a new figure or axes (or an array of - axes) on which the topographic map should be plotted. If the number of - microstates maps to plot is ``≥ 1``, an array of axes of size - ``n_clusters`` should be provided.""" + Either ``None`` to create a new figure or axes (or an array of axes) on which the + topographic map should be plotted. If the number of microstates maps to plot is + ``≥ 1``, an array of axes of size ``n_clusters`` should be provided.""" docdict[ "axes_seg" ] = """ axes : Axes | None - Either ``None`` to create a new figure or axes on which the - segmentation is plotted.""" + Either ``None`` to create a new figure or axes on which the segmentation is + plotted.""" docdict[ "axes_cbar" ] = """ cbar_axes : Axes | None - Axes on which to draw the colorbar, otherwise the colormap takes - space from the main axes.""" + Axes on which to draw the colorbar, otherwise the colormap takes space from the main + axes.""" # ------------------------- Documentation functions -------------------------- docdict_indented: Dict[int, Dict[str, str]] = {} diff --git a/pycrostates/utils/_fixes.py b/pycrostates/utils/_fixes.py index f81e1404..579aebce 100644 --- a/pycrostates/utils/_fixes.py +++ b/pycrostates/utils/_fixes.py @@ -7,8 +7,8 @@ class _WrapStdOut(object): """Dynamically wrap to sys.stdout. - This makes packages that monkey-patch sys.stdout (e.g.doctest, - sphinx-gallery) work properly. + This makes packages that monkey-patch sys.stdout (e.g.doctest, sphinx-gallery) work + properly. """ def __getattr__(self, name): # noqa: D105 diff --git a/pycrostates/utils/_imports.py b/pycrostates/utils/_imports.py index 2a7a7e50..a60ce880 100644 --- a/pycrostates/utils/_imports.py +++ b/pycrostates/utils/_imports.py @@ -16,8 +16,8 @@ def import_optional_dependency(name: str, extra: str = "", raise_error: bool = T """ Import an optional dependency. - By default, if a dependency is missing an ImportError with a nice message - will be raised. + By default, if a dependency is missing an ImportError with a nice message will be + raised. Parameters ---------- @@ -27,17 +27,16 @@ def import_optional_dependency(name: str, extra: str = "", raise_error: bool = T Additional text to include in the ImportError message. raise_error : bool What to do when a dependency is not found. - * True : If the module is not installed, raise an ImportError, + * True : If the module is not installed, raise an ImportError, otherwise, return + the module. + * False: If the module is not installed, issue a warning and return None, otherwise, return the module. - * False: If the module is not installed, issue a warning and return - None, otherwise, return the module. Returns ------- maybe_module : Optional[ModuleType] The imported module when found. - None is returned when the package is not found and raise_error is - False. + None is returned when the package is not found and raise_error is False. """ package_name = INSTALL_MAPPING.get(name) install_name = package_name if package_name is not None else name @@ -47,7 +46,7 @@ def import_optional_dependency(name: str, extra: str = "", raise_error: bool = T except ImportError: msg = ( f"Missing optional dependency '{install_name}'. {extra} " - + f"Use pip or conda to install '{install_name}'." + f"Use pip or conda to install '{install_name}'." ) if raise_error: raise ImportError(msg) diff --git a/pycrostates/utils/utils.py b/pycrostates/utils/utils.py index 6225801c..1ffac1a3 100644 --- a/pycrostates/utils/utils.py +++ b/pycrostates/utils/utils.py @@ -15,12 +15,11 @@ def _corr_vectors(A, B, axis=0): # written by Marijn van Vliet """Compute pairwise correlation of multiple pairs of vectors. - Fast way to compute correlation of multiple pairs of vectors without - computing all pairs as would with corr(A,B). Borrowed from Oli at - StackOverflow. Note the resulting coefficients vary slightly from the ones - obtained from corr due to differences in the order of the calculations. - (Differences are of a magnitude of 1e-9 to 1e-17 depending on the tested - data). + Fast way to compute correlation of multiple pairs of vectors without computing all + pairs as would with corr(A,B). Borrowed from Oli at StackOverflow. Note the + resulting coefficients vary slightly from the ones obtained from corr due to + differences in the order of the calculations. (Differences are of a magnitude of + 1e-9 to 1e-17 depending on the tested data). Parameters ---------- diff --git a/pycrostates/viz/cluster_centers.py b/pycrostates/viz/cluster_centers.py index 59128379..6fb323bd 100644 --- a/pycrostates/viz/cluster_centers.py +++ b/pycrostates/viz/cluster_centers.py @@ -46,17 +46,14 @@ def plot_cluster_centers( %(cluster_names)s %(axes_topo)s show_gradient : bool - If True, plot a line between channel locations - with highest and lowest values. + If True, plot a line between channel locations with highest and lowest values. gradient_kwargs : dict - Additional keyword arguments passed to - :meth:`matplotlib.axes.Axes.plot` to plot + Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.plot` to plot gradient line. %(block)s %(verbose)s **kwargs - Additional keyword arguments are passed to - :func:`mne.viz.plot_topomap`. + Additional keyword arguments are passed to :func:`mne.viz.plot_topomap`. Returns ------- @@ -78,8 +75,8 @@ def plot_cluster_centers( ) if gradient_kwargs != _GRADIENT_KWARGS_DEFAULTS and not show_gradient: logger.warning( - "The argument 'gradient_kwargs' has not effect when " - "the argument 'show_gradient' is set to False." + "The argument 'gradient_kwargs' has not effect when the argument " + "'show_gradient' is set to False." ) _check_type(block, (bool,), "block") @@ -88,8 +85,8 @@ def plot_cluster_centers( cluster_names = [str(k) for k in range(1, cluster_centers.shape[0] + 1)] if len(cluster_names) != cluster_centers.shape[0]: raise ValueError( - "Argument 'cluster_centers' and 'cluster_names' should have the " - "same number of elements." + "Argument 'cluster_centers' and 'cluster_names' should have the same " + "number of elements." ) # create axes if needed, and retrieve figure diff --git a/pycrostates/viz/segmentation.py b/pycrostates/viz/segmentation.py index 80c8b733..4ff65fd6 100644 --- a/pycrostates/viz/segmentation.py +++ b/pycrostates/viz/segmentation.py @@ -214,8 +214,8 @@ def _plot_segmentation( _check_type(n_clusters, ("int",), "n_clusters") if n_clusters <= 0: raise ValueError( - f"Provided number of clusters {n_clusters} is invalid. The number " - "of clusters must be strictly positive." + f"Provided number of clusters {n_clusters} is invalid. The number of " + "clusters must be strictly positive." ) _check_type(cluster_names, (None, list, tuple), "cluster_names") _check_type(cmap, (None, str, colors.Colormap), "cmap") @@ -270,10 +270,8 @@ def _plot_segmentation( times, gfp, color=color, where=x, step=None, interpolate=False ) logger.info( - "For visualization purposes, " - "the last segment appears truncated by 1 sample. " - "In the case where the last segment is 1 sample long, " - "it does not appear." + "For visualization purposes, the last segment appears truncated by 1 sample. " + "In the case where the last segment is 1 sample long, it does not appear." ) # commonm formatting From 889fcd32ab3da0d737591def524d4569860f201c Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 16 Aug 2023 13:07:18 +0200 Subject: [PATCH 15/16] fix style --- pycrostates/cluster/aahc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 2c0642b1..87099b16 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -160,6 +160,7 @@ def save(self, fname: Union[str, Path]): # extension. # pylint: disable=import-outside-toplevel from ..io.fiff import _write_cluster + # pylint: enable=import-outside-toplevel _write_cluster( From 40c1225a3a9a13da059019c1c275b97c7e56c88b Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 16 Aug 2023 13:38:07 +0200 Subject: [PATCH 16/16] fix documentation build --- docs/source/conf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 66ae92a6..1ef7e93d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -156,6 +156,7 @@ "bool": ":class:`python:bool`", "Path": "pathlib.Path", # MNE + "ConductorModel": "mne.bem.ConductorModel", "DigMontage": "mne.channels.DigMontage", "Epochs": "mne.Epochs", "Evoked": "mne.Evoked", @@ -169,6 +170,7 @@ "EpochsSegmentation": "pycrostates.segmentation.EpochsSegmentation", # Matplotlib "Axes": "matplotlib.axes.Axes", + "Axes3D": "mpl_toolkits.mplot3d.axes3d.Axes3D", "colormap": ":doc:`colormap `", "Figure": "matplotlib.figure.Figure", # Scipy @@ -178,9 +180,11 @@ "instance", "of", "shape", + "n_ch_groups", "n_channels", "n_clusters", "n_epochs", + "n_picks", "n_samples", }