Skip to content

Commit

Permalink
update init
Browse files Browse the repository at this point in the history
  • Loading branch information
mscheltienne committed Mar 21, 2024
1 parent 167d318 commit 9b72a5f
Showing 1 changed file with 52 additions and 3 deletions.
55 changes: 52 additions & 3 deletions pycrostates/cluster/array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import TYPE_CHECKING
from warnings import warn

import numpy as np
from mne import Info
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
self._n_clusters = data.shape[0]
self._info = ChInfo(info=info) # no-op if a ChInfo is priovided
self._cluster_centers_ = data
# validate optional inputs
# validate cluster names
if cluster_names is not None:
_check_type(cluster_names, (list, tuple), "cluster_names")
if len(cluster_names) != self._n_clusters:
Expand All @@ -65,9 +66,57 @@ def __init__(
self._cluster_names = cluster_names
else:
self._cluster_names = [str(k) for k in range(self._n_clusters)]
# validate fitted_data, which is either from a Raw ()
# validate fitted_data, which is either from a Raw (n_channels, n_times) or
# Epochs (n_epochs, n_channels, n_times).
_check_type(fitted_data, (np.ndarray, None), "fitted_data")
if fitted_data is not None and fitted_data.ndim == 2:
if fitted_data.shape[0] != len(info["ch_names"]):
raise ValueError(
f"The number of channels in 'fitted_data' ({fitted_data.shape[0]}) "
"must match the number of channels in 'info' "
f"({len(info['ch_names'])})."
)
elif fitted_data is not None and fitted_data.ndim == 3:
# either with the (n_channels,) as first or second dimension
if fitted_data.shape[1] != len(info["ch_names"]):
raise ValueError(
"The number of channels in 'fitted_data' "
f"({fitted_data.shape[1]}) must match the number of channels "
f"in 'info' ({len(info['ch_names'])}). Please provide "
"'fitted_data' as (n_epochs, n_channels, n_times) for Epochs."
)
fitted_data = np.swapaxes(fitted_data, 0, 1)
fitted_data = fitted_data.reshape(fitted_data.shape[0], -1)
else:
raise ValueError(
"'fitted_data' must be a 2D (raw, ChData) or 3D (epochs) array. The "
f"provided {fitted_data.ndim}D array is invalid."
)
self._fitted_data = fitted_data
# validate labels
_check_type(labels, (np.ndarray, None), "labels")
if labels is not None:
if labels.ndim != 1:
raise ValueError(
f"'labels' must be a 1D array. The provided {labels.ndim}D array "
"is invalid."
)
if self._fitted_data is None:
warn(
RuntimeWarning,
"'labels' were provided without 'fitted_data' to which "
"they apply.",
stacklevel=2,
)
else:
if labels.size != self._fitted_data.shape[1]:
raise ValueError(
f"The number of samples in 'labels' ({labels.size}) must match "
"the number of samples in 'fitted_data' "
f"({self._fitted_data.shape[1]})."
)
self._labels_ = labels

# left for empty
@copy_doc(_BaseCluster.save)
def save(self, fname: Union[str, Path]):
super().save(fname)
Expand Down

0 comments on commit 9b72a5f

Please sign in to comment.