From 9b72a5f6847cf2f9acf6d4e280e5db72b3edd622 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 21 Mar 2024 14:40:22 +0100 Subject: [PATCH] update init --- pycrostates/cluster/array.py | 55 ++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/pycrostates/cluster/array.py b/pycrostates/cluster/array.py index de6ade98..17e967c7 100644 --- a/pycrostates/cluster/array.py +++ b/pycrostates/cluster/array.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import TYPE_CHECKING +from warnings import warn import numpy as np from mne import Info @@ -54,7 +55,7 @@ def __init__( self._n_clusters = data.shape[0] self._info = ChInfo(info=info) # no-op if a ChInfo is priovided self._cluster_centers_ = data - # validate optional inputs + # validate cluster names if cluster_names is not None: _check_type(cluster_names, (list, tuple), "cluster_names") if len(cluster_names) != self._n_clusters: @@ -65,9 +66,57 @@ def __init__( self._cluster_names = cluster_names else: self._cluster_names = [str(k) for k in range(self._n_clusters)] - # validate fitted_data, which is either from a Raw () + # validate fitted_data, which is either from a Raw (n_channels, n_times) or + # Epochs (n_epochs, n_channels, n_times). + _check_type(fitted_data, (np.ndarray, None), "fitted_data") + if fitted_data is not None and fitted_data.ndim == 2: + if fitted_data.shape[0] != len(info["ch_names"]): + raise ValueError( + f"The number of channels in 'fitted_data' ({fitted_data.shape[0]}) " + "must match the number of channels in 'info' " + f"({len(info['ch_names'])})." + ) + elif fitted_data is not None and fitted_data.ndim == 3: + # either with the (n_channels,) as first or second dimension + if fitted_data.shape[1] != len(info["ch_names"]): + raise ValueError( + "The number of channels in 'fitted_data' " + f"({fitted_data.shape[1]}) must match the number of channels " + f"in 'info' ({len(info['ch_names'])}). Please provide " + "'fitted_data' as (n_epochs, n_channels, n_times) for Epochs." + ) + fitted_data = np.swapaxes(fitted_data, 0, 1) + fitted_data = fitted_data.reshape(fitted_data.shape[0], -1) + else: + raise ValueError( + "'fitted_data' must be a 2D (raw, ChData) or 3D (epochs) array. The " + f"provided {fitted_data.ndim}D array is invalid." + ) + self._fitted_data = fitted_data + # validate labels + _check_type(labels, (np.ndarray, None), "labels") + if labels is not None: + if labels.ndim != 1: + raise ValueError( + f"'labels' must be a 1D array. The provided {labels.ndim}D array " + "is invalid." + ) + if self._fitted_data is None: + warn( + RuntimeWarning, + "'labels' were provided without 'fitted_data' to which " + "they apply.", + stacklevel=2, + ) + else: + if labels.size != self._fitted_data.shape[1]: + raise ValueError( + f"The number of samples in 'labels' ({labels.size}) must match " + "the number of samples in 'fitted_data' " + f"({self._fitted_data.shape[1]})." + ) + self._labels_ = labels - # left for empty @copy_doc(_BaseCluster.save) def save(self, fname: Union[str, Path]): super().save(fname)