From b13b777ab96af5dd1ef67077f93600964dcdbd64 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 3 Jan 2024 12:38:08 +0000 Subject: [PATCH] deploy stub files [ci skip] --- pycrostates/__init__.pyi | 11 + pycrostates/_typing.pyi | 20 + pycrostates/cluster/__init__.pyi | 5 + pycrostates/cluster/_base.pyi | 455 ++++++++++++++++++ pycrostates/cluster/aahc.pyi | 150 ++++++ pycrostates/cluster/kmeans.pyi | 218 +++++++++ pycrostates/cluster/utils/__init__.pyi | 3 + pycrostates/cluster/utils/utils.pyi | 31 ++ pycrostates/datasets/__init__.pyi | 3 + pycrostates/datasets/lemon/__init__.pyi | 6 + pycrostates/datasets/lemon/lemon.pyi | 71 +++ pycrostates/io/__init__.pyi | 5 + pycrostates/io/ch_data.pyi | 128 +++++ pycrostates/io/fiff.pyi | 125 +++++ pycrostates/io/meas_info.pyi | 163 +++++++ pycrostates/io/reader.pyi | 20 + pycrostates/metrics/__init__.pyi | 6 + pycrostates/metrics/calinski_harabasz.pyi | 32 ++ pycrostates/metrics/davies_bouldin.pyi | 52 ++ pycrostates/metrics/dunn.pyi | 49 ++ pycrostates/metrics/silhouette.pyi | 34 ++ pycrostates/preprocessing/__init__.pyi | 5 + .../preprocessing/extract_gfp_peaks.pyi | 99 ++++ pycrostates/preprocessing/resample.pyi | 91 ++++ pycrostates/preprocessing/spatial_filter.pyi | 93 ++++ pycrostates/segmentation/__init__.pyi | 8 + pycrostates/segmentation/_base.pyi | 217 +++++++++ pycrostates/segmentation/segmentation.pyi | 133 +++++ pycrostates/segmentation/transitions.pyi | 106 ++++ pycrostates/utils/__init__.pyi | 6 + pycrostates/utils/_checks.pyi | 118 +++++ pycrostates/utils/_config.pyi | 38 ++ pycrostates/utils/_docs.pyi | 70 +++ pycrostates/utils/_fixes.pyi | 8 + pycrostates/utils/_imports.pyi | 30 ++ pycrostates/utils/_logs.pyi | 119 +++++ pycrostates/utils/mixin.pyi | 16 + pycrostates/utils/sys_info.pyi | 22 + pycrostates/utils/utils.pyi | 33 ++ pycrostates/viz/__init__.pyi | 5 + pycrostates/viz/cluster_centers.pyi | 60 +++ pycrostates/viz/segmentation.pyi | 138 ++++++ 42 files changed, 3002 insertions(+) create mode 100644 pycrostates/__init__.pyi create mode 100644 pycrostates/_typing.pyi create mode 100644 pycrostates/cluster/__init__.pyi create mode 100644 pycrostates/cluster/_base.pyi create mode 100644 pycrostates/cluster/aahc.pyi create mode 100644 pycrostates/cluster/kmeans.pyi create mode 100644 pycrostates/cluster/utils/__init__.pyi create mode 100644 pycrostates/cluster/utils/utils.pyi create mode 100644 pycrostates/datasets/__init__.pyi create mode 100644 pycrostates/datasets/lemon/__init__.pyi create mode 100644 pycrostates/datasets/lemon/lemon.pyi create mode 100644 pycrostates/io/__init__.pyi create mode 100644 pycrostates/io/ch_data.pyi create mode 100644 pycrostates/io/fiff.pyi create mode 100644 pycrostates/io/meas_info.pyi create mode 100644 pycrostates/io/reader.pyi create mode 100644 pycrostates/metrics/__init__.pyi create mode 100644 pycrostates/metrics/calinski_harabasz.pyi create mode 100644 pycrostates/metrics/davies_bouldin.pyi create mode 100644 pycrostates/metrics/dunn.pyi create mode 100644 pycrostates/metrics/silhouette.pyi create mode 100644 pycrostates/preprocessing/__init__.pyi create mode 100644 pycrostates/preprocessing/extract_gfp_peaks.pyi create mode 100644 pycrostates/preprocessing/resample.pyi create mode 100644 pycrostates/preprocessing/spatial_filter.pyi create mode 100644 pycrostates/segmentation/__init__.pyi create mode 100644 pycrostates/segmentation/_base.pyi create mode 100644 pycrostates/segmentation/segmentation.pyi create mode 100644 pycrostates/segmentation/transitions.pyi create mode 100644 pycrostates/utils/__init__.pyi create mode 100644 pycrostates/utils/_checks.pyi create mode 100644 pycrostates/utils/_config.pyi create mode 100644 pycrostates/utils/_docs.pyi create mode 100644 pycrostates/utils/_fixes.pyi create mode 100644 pycrostates/utils/_imports.pyi create mode 100644 pycrostates/utils/_logs.pyi create mode 100644 pycrostates/utils/mixin.pyi create mode 100644 pycrostates/utils/sys_info.pyi create mode 100644 pycrostates/utils/utils.pyi create mode 100644 pycrostates/viz/__init__.pyi create mode 100644 pycrostates/viz/cluster_centers.pyi create mode 100644 pycrostates/viz/segmentation.pyi diff --git a/pycrostates/__init__.pyi b/pycrostates/__init__.pyi new file mode 100644 index 00000000..4f127561 --- /dev/null +++ b/pycrostates/__init__.pyi @@ -0,0 +1,11 @@ +from . import cluster as cluster +from . import datasets as datasets +from . import metrics as metrics +from . import preprocessing as preprocessing +from . import utils as utils +from . import viz as viz +from ._version import __version__ as __version__ +from .utils._logs import set_log_level as set_log_level +from .utils.sys_info import sys_info as sys_info + +__all__: tuple[str, ...] diff --git a/pycrostates/_typing.pyi b/pycrostates/_typing.pyi new file mode 100644 index 00000000..7417a49c --- /dev/null +++ b/pycrostates/_typing.pyi @@ -0,0 +1,20 @@ +from abc import ABC +from typing import Optional, Union + +from numpy.random import Generator, RandomState +from numpy.typing import NDArray + +class CHData(ABC): + """Typing for CHData.""" + +class CHInfo(ABC): + """Typing for CHInfo.""" + +class Cluster(ABC): + """Typing for a clustering class.""" + +class Segmentation(ABC): + """Typing for a clustering class.""" + +RANDomState = Optional[Union[int, RandomState, Generator]] +Picks = Optional[Union[str, NDArray[int]]] diff --git a/pycrostates/cluster/__init__.pyi b/pycrostates/cluster/__init__.pyi new file mode 100644 index 00000000..fb13d124 --- /dev/null +++ b/pycrostates/cluster/__init__.pyi @@ -0,0 +1,5 @@ +from . import utils as utils +from .aahc import AAHCluster as AAHCluster +from .kmeans import ModKMeans as ModKMeans + +__all__: tuple[str, ...] diff --git a/pycrostates/cluster/_base.pyi b/pycrostates/cluster/_base.pyi new file mode 100644 index 00000000..ee4bf7f0 --- /dev/null +++ b/pycrostates/cluster/_base.pyi @@ -0,0 +1,455 @@ +from abc import abstractmethod +from pathlib import Path as Path +from typing import Any, Optional, Union + +from _typeshed import Incomplete +from matplotlib.axes import Axes as Axes +from mne import BaseEpochs +from mne.io import BaseRaw +from numpy.typing import NDArray + +from .._typing import CHData as CHData +from .._typing import Cluster as Cluster +from .._typing import Picks as Picks +from ..segmentation import EpochsSegmentation as EpochsSegmentation +from ..segmentation import RawSegmentation as RawSegmentation +from ..utils import _corr_vectors as _corr_vectors +from ..utils._checks import _check_picks_uniqueness as _check_picks_uniqueness +from ..utils._checks import _check_reject_by_annotation as _check_reject_by_annotation +from ..utils._checks import _check_tmin_tmax as _check_tmin_tmax +from ..utils._checks import _check_type as _check_type +from ..utils._checks import _check_value as _check_value +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger +from ..utils.mixin import ChannelsMixin as ChannelsMixin +from ..utils.mixin import ContainsMixin as ContainsMixin +from ..utils.mixin import MontageMixin as MontageMixin +from .utils import optimize_order as optimize_order + +class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin): + """Base Class for Microstates Clustering algorithms.""" + + _n_clusters: Incomplete + _cluster_names: Incomplete + _cluster_centers_: Incomplete + _ignore_polarity: Incomplete + _info: Incomplete + _fitted_data: Incomplete + _labels_: Incomplete + _fitted: bool + + @abstractmethod + def __init__(self): ... + def __repr__(self) -> str: + """String representation.""" + + def _repr_html_(self, caption: Incomplete | None = None): + """HTML representation.""" + + def __eq__(self, other: Any) -> bool: + """Equality == method.""" + + def __ne__(self, other: Any) -> bool: + """Different != method.""" + + def copy(self, deep: bool = True): + """Return a copy of the instance. + + Parameters + ---------- + deep : bool + If True, `~copy.deepcopy` is used instead of `~copy.copy`. + """ + + def _check_fit(self) -> None: + """Check if the cluster is fitted.""" + + def _check_unfitted(self) -> None: + """Check if the cluster is unfitted.""" + + @abstractmethod + def fit( + self, + inst: Union[BaseRaw, BaseEpochs, CHData], + picks: Picks = "eeg", + tmin: Optional[Union[int, float]] = None, + tmax: Optional[Union[int, float]] = None, + reject_by_annotation: bool = True, + *, + verbose: Optional[str] = None, + ) -> NDArray[float]: + """Compute cluster centers. + + Parameters + ---------- + inst : Raw | Epochs | ChData + MNE `~mne.io.Raw`, `~mne.Epochs` or `~pycrostates.io.ChData` object + from which to extract :term:`cluster centers`. + picks : str | list | slice | None + Channels to include. Note that all channels selected must have the same + type. Slices and lists of integers will be interpreted as channel indices. + In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the given + channels. Can also be the string values ``“all”`` to pick all channels, or + ``“data”`` to pick data channels. ``"eeg"`` (default) will pick all eeg + channels. Note that channels in ``info['bads']`` will be included if their + names or indices are explicitly provided. + tmin : float + Start time of the raw data to use in seconds (must be >= 0). + tmax : float + End time of the raw data to use in seconds (cannot exceed data duration). + reject_by_annotation : bool + Whether to omit bad segments from the data before fitting. If ``True`` + (default), annotated segments whose description begins with ``'bad'`` are + omitted. If ``False``, no rejection based on annotations is performed. + + Has no effect if ``inst`` is not a :class:`mne.io.Raw` object. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + """ + + def rename_clusters( + self, + mapping: Optional[dict[str, str]] = None, + new_names: Optional[Union[list[str], tuple[str, ...]]] = None, + ) -> None: + """Rename the clusters. + + Parameters + ---------- + mapping : dict + Mapping from the old names to the new names. The keys are the old names and + the values are the new names. + new_names : list | tuple + 1D iterable containing the new cluster names. The length of the iterable + should match the number of clusters. + + Notes + ----- + Operates in-place. + """ + + def reorder_clusters( + self, + mapping: Optional[dict[int, int]] = None, + order: Optional[Union[list[int], tuple[int, ...], NDArray[int]]] = None, + template: Optional[Cluster] = None, + ) -> None: + """ + Reorder the clusters of the fitted model. + + Specify one of the following arguments to change the current order: + + * ``mapping``: a dictionary that maps old cluster positions to new positions, + * ``order``: a 1D iterable containing the new order, + * ``template``: a fitted clustering algorithm used as a reference to match the + order. + + Only one argument can be set at a time. + + Parameters + ---------- + mapping : dict + Mapping from the old order to the new order. + key: old position, value: new position. + order : list of int | tuple of int | array of int + 1D iterable containing the new order. Positions are 0-indexed. + template : :ref:`cluster` + Fitted clustering algorithm use as template for ordering optimization. For + more details about the current implementation, check the + :func:`pycrostates.cluster.utils.optimize_order` documentation. + + Notes + ----- + Operates in-place. + """ + + def invert_polarity( + self, invert: Union[bool, list[bool], tuple[bool, ...], NDArray[bool]] + ) -> None: + """Invert map polarities. + + Parameters + ---------- + invert : bool | list of bool | array of bool + List of bool of length ``n_clusters``. + True will invert map polarity, while False will have no effect. + If a `bool` is provided, it is applied to all maps. + + Notes + ----- + Operates in-place. + + Inverting polarities has no effect on the other steps of the analysis as + polarity is ignored in the current methodology. This function is only used for + tuning visualization (i.e. for visual inspection and/or to generate figure for + an article). + """ + + def plot( + self, + axes: Optional[Union[Axes, NDArray[Axes]]] = None, + show_gradient: Optional[bool] = False, + gradient_kwargs: dict[str, Any] = { + "color": "black", + "linestyle": "-", + "marker": "P", + }, + *, + block: bool = False, + verbose: Optional[str] = None, + **kwargs, + ): + """ + Plot cluster centers as topographic maps. + + Parameters + ---------- + axes : Axes | None + Either ``None`` to create a new figure or axes (or an array of axes) on which the + topographic map should be plotted. If the number of microstates maps to plot is + ``≥ 1``, an array of axes of size ``n_clusters`` should be provided. + show_gradient : bool + If True, plot a line between channel locations with highest and lowest + values. + gradient_kwargs : dict + Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.plot` to + plot gradient line. + block : bool + Whether to halt program execution until the figure is closed. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + **kwargs + Additional keyword arguments are passed to :func:`mne.viz.plot_topomap`. + + Returns + ------- + f : Figure + Matplotlib figure containing the topographic plots. + """ + + @abstractmethod + def save(self, fname: Union[str, Path]): + """Save clustering solution to disk. + + Parameters + ---------- + fname : path-like + Path to the ``.fif`` file where the clustering solution is saved. + """ + + def predict( + self, + inst: Union[BaseRaw, BaseEpochs], + picks: Picks = None, + factor: int = 0, + half_window_size: int = 1, + tol: Union[int, float] = 1e-05, + min_segment_length: int = 0, + reject_edges: bool = True, + reject_by_annotation: bool = True, + *, + verbose: Optional[str] = None, + ): + """Segment `~mne.io.Raw` or `~mne.Epochs` into microstate sequence. + + Segment instance into microstate sequence using the segmentation smoothing + algorithm\\ :footcite:p:`Marqui1995`. + + Parameters + ---------- + inst : Raw | Epochs + MNE `~mne.io.Raw` or `~mne.Epochs` object containing the data to use for + prediction. + picks : str | list | slice | None + Channels to include. Note that all channels selected must have the same + type. Slices and lists of integers will be interpreted as channel indices. + In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the given + channels. Can also be the string values ``“all”`` to pick all channels, or + ``“data”`` to pick data channels. ``None`` (default) will pick all channels + used during fitting (e.g., ``self.info['ch_names']``). Note that channels in + ``info['bads']`` will be included if their names or indices are explicitly + provided. + factor : int + Factor used for label smoothing. ``0`` means no smoothing. Default to 0. + half_window_size : int + Number of samples used for the half window size while smoothing labels. The + half window size is defined as ``window_size = 2 * half_window_size + 1``. + It has no effect if ``factor=0`` (default). Default to 1. + tol : float + Convergence tolerance. + min_segment_length : int + Minimum segment length (in samples). If a segment is shorter than this + value, it will be recursively reasigned to neighbouring segments based on + absolute spatial correlation. + reject_edges : bool + If ``True``, set first and last segments to unlabeled. + reject_by_annotation : bool + Whether to omit bad segments from the data before fitting. If ``True`` + (default), annotated segments whose description begins with ``'bad'`` are + omitted. If ``False``, no rejection based on annotations is performed. + + Has no effect if ``inst`` is not a :class:`mne.io.Raw` object. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + + Returns + ------- + segmentation : RawSegmentation | EpochsSegmentation + Microstate sequence derivated from instance data. Timepoints are labeled + according to cluster centers number: ``0`` for the first center, ``1`` for + the second, etc.. ``-1`` is used for unlabeled time points. + + References + ---------- + .. footbibliography:: + """ + + def _predict_raw( + self, + raw: BaseRaw, + picks_data: NDArray[int], + factor: int, + tol: Union[int, float], + half_window_size: int, + min_segment_length: int, + reject_edges: bool, + reject_by_annotation: bool, + ) -> RawSegmentation: + """Create segmentation for raw.""" + + def _predict_epochs( + self, + epochs: BaseEpochs, + picks_data: NDArray[int], + factor: int, + tol: Union[int, float], + half_window_size: int, + min_segment_length: int, + reject_edges: bool, + ) -> EpochsSegmentation: + """Create segmentation for epochs.""" + + @staticmethod + def _segment( + data: NDArray[float], + states: NDArray[float], + factor: int, + tol: Union[int, float], + half_window_size: int, + ) -> NDArray[int]: + """Create segmentation. Must operate on a copy of states.""" + + @staticmethod + def _smooth_segmentation( + data: NDArray[float], + states: NDArray[float], + labels: NDArray[int], + factor: int, + tol: Union[int, float], + half_window_size: int, + ) -> NDArray[int]: + """Apply smoothing. + + Adapted from [1]. + + References + ---------- + .. [1] R. D. Pascual-Marqui, C. M. Michel and D. Lehmann. + Segmentation of brain electrical activity into microstates: + model estimation and validation. + IEEE Transactions on Biomedical Engineering, + vol. 42, no. 7, pp. 658-665, July 1995, + https://doi.org/10.1109/10.391164. + """ + + @staticmethod + def _reject_short_segments( + segmentation: NDArray[int], data: NDArray[float], min_segment_length: int + ) -> NDArray[int]: + """Reject segments that are too short. + + Reject segments that are too short by replacing the labels with the adjacent + labels based on data correlation. + """ + + @staticmethod + def _reject_edge_segments(segmentation: NDArray[int]) -> NDArray[int]: + """Set the first and last segment as unlabeled (0).""" + + @property + def n_clusters(self) -> int: + """Number of clusters (number of microstates). + + :type: `int` + """ + + @property + def info(self): + """Info instance with the channel information used to fit the instance. + + :type: `~pycrostates.io.ChInfo` + """ + + @property + def fitted(self) -> bool: + """Fitted state. + + :type: `bool` + """ + + @fitted.setter + def fitted(self, fitted) -> None: + """Fitted state. + + :type: `bool` + """ + + @property + def cluster_centers_(self) -> NDArray[float]: + """Fitted clusters (the microstates maps). + + Returns None if cluster algorithm has not been fitted. + + :type: `~numpy.array` of shape (n_clusters, n_channels) | None + """ + + @property + def fitted_data(self) -> NDArray[float]: + """Data array used to fit the clustering algorithm. + + :type: `~numpy.array` of shape (n_channels, n_samples) | None + """ + + @property + def labels_(self) -> NDArray[int]: + """Microstate label attributed to each sample of the fitted data. + + :type: `~numpy.array` of shape (n_samples, ) | None + """ + + @property + def cluster_names(self) -> list[str]: + """Name of the clusters. + + :type: `list` + """ + + @cluster_names.setter + def cluster_names(self, other: Any): + """Name of the clusters. + + :type: `list` + """ + + @staticmethod + def _check_n_clusters(n_clusters: int) -> int: + """Check that the number of clusters is a positive integer.""" diff --git a/pycrostates/cluster/aahc.pyi b/pycrostates/cluster/aahc.pyi new file mode 100644 index 00000000..72f416f0 --- /dev/null +++ b/pycrostates/cluster/aahc.pyi @@ -0,0 +1,150 @@ +from pathlib import Path as Path +from typing import Any, Optional, Union + +from _typeshed import Incomplete +from mne import BaseEpochs as BaseEpochs +from mne.io import BaseRaw as BaseRaw +from numpy.typing import NDArray + +from .._typing import Picks as Picks +from ..utils import _corr_vectors as _corr_vectors +from ..utils._checks import _check_type as _check_type +from ..utils._docs import copy_doc as copy_doc +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger +from ._base import _BaseCluster as _BaseCluster + +class AAHCluster(_BaseCluster): + """Atomize and Agglomerate Hierarchical Clustering (AAHC) algorithm. + + See :footcite:t:`Murray2008` for additional information. + + Parameters + ---------- + n_clusters : int + The number of clusters, i.e. the number of microstates. + normalize_input : bool + If set, the input data is normalized along the channel dimension. + + References + ---------- + .. footbibliography:: + """ + + _n_clusters: Incomplete + _cluster_names: Incomplete + _ignore_polarity: bool + _normalize_input: Incomplete + _GEV_: Incomplete + + def __init__(self, n_clusters: int, normalize_input: bool = False) -> None: ... + def _repr_html_(self, caption: Incomplete | None = None): ... + def __eq__(self, other: Any) -> bool: + """Equality == method.""" + + def __ne__(self, other: Any) -> bool: + """Different != method.""" + + def _check_fit(self) -> None: + """Check if the cluster is fitted.""" + _cluster_centers_: Incomplete + _labels_: Incomplete + _fitted: bool + + def fit( + self, + inst: Union[BaseRaw, BaseEpochs], + picks: Picks = "eeg", + tmin: Optional[Union[int, float]] = None, + tmax: Optional[Union[int, float]] = None, + reject_by_annotation: bool = True, + *, + verbose: Optional[str] = None, + ) -> None: + """Compute cluster centers. + + Parameters + ---------- + inst : Raw | Epochs | ChData + MNE `~mne.io.Raw`, `~mne.Epochs` or `~pycrostates.io.ChData` object + from which to extract :term:`cluster centers`. + picks : str | list | slice | None + Channels to include. Note that all channels selected must have the same + type. Slices and lists of integers will be interpreted as channel indices. + In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the given + channels. Can also be the string values ``“all”`` to pick all channels, or + ``“data”`` to pick data channels. ``"eeg"`` (default) will pick all eeg + channels. Note that channels in ``info['bads']`` will be included if their + names or indices are explicitly provided. + tmin : float + Start time of the raw data to use in seconds (must be >= 0). + tmax : float + End time of the raw data to use in seconds (cannot exceed data duration). + reject_by_annotation : bool + Whether to omit bad segments from the data before fitting. If ``True`` + (default), annotated segments whose description begins with ``'bad'`` are + omitted. If ``False``, no rejection based on annotations is performed. + + Has no effect if ``inst`` is not a :class:`mne.io.Raw` object. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + """ + + def save(self, fname: Union[str, Path]): + """Save clustering solution to disk. + + Parameters + ---------- + fname : path-like + Path to the ``.fif`` file where the clustering solution is saved. + """ + + @staticmethod + def _aahc( + data: NDArray[float], + n_clusters: int, + ignore_polarity: bool, + normalize_input: bool, + ) -> tuple[float, NDArray[float], NDArray[int]]: + """Run the AAHC algorithm.""" + + @staticmethod + def _compute_maps( + data: NDArray[float], + n_clusters: int, + ignore_polarity: bool, + normalize_input: bool, + ) -> tuple[NDArray[float], NDArray[int]]: + """Compute microstates maps.""" + + @property + def normalize_input(self) -> bool: + """If set, the input data is normalized along the channel dimension. + + :type: `bool` + """ + + @property + def GEV_(self) -> float: + """Global Explained Variance. + + :type: `float` + """ + + @_BaseCluster.fitted.setter + def fitted(self, fitted) -> None: + """Fitted state. + + :type: `bool` + """ + + @staticmethod + def _check_ignore_polarity(ignore_polarity: bool) -> bool: + """Check that ignore_polarity is a boolean.""" + + @staticmethod + def _check_normalize_input(normalize_input: bool) -> bool: + """Check that normalize_input is a boolean.""" diff --git a/pycrostates/cluster/kmeans.pyi b/pycrostates/cluster/kmeans.pyi new file mode 100644 index 00000000..8d173412 --- /dev/null +++ b/pycrostates/cluster/kmeans.pyi @@ -0,0 +1,218 @@ +from pathlib import Path as Path +from typing import Any, Optional, Union + +from _typeshed import Incomplete +from mne import BaseEpochs as BaseEpochs +from mne.io import BaseRaw as BaseRaw +from numpy.random import Generator as Generator +from numpy.random import RandomState as RandomState +from numpy.typing import NDArray + +from .._typing import CHData as CHData +from .._typing import Picks as Picks +from .._typing import RANDomState as RANDomState +from ..utils import _corr_vectors as _corr_vectors +from ..utils._checks import _check_n_jobs as _check_n_jobs +from ..utils._checks import _check_random_state as _check_random_state +from ..utils._checks import _check_type as _check_type +from ..utils._docs import copy_doc as copy_doc +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger +from ._base import _BaseCluster as _BaseCluster + +class ModKMeans(_BaseCluster): + """Modified K-Means clustering algorithm. + + See :footcite:t:`Marqui1995` for additional information. + + Parameters + ---------- + n_clusters : int + The number of clusters, i.e. the number of microstates. + n_init : int + Number of time the k-means algorithm is run with different centroid seeds. The + final result will be the run with the highest Global Explained Variance (GEV). + max_iter : int + Maximum number of iterations of the K-means algorithm for a single run. + tol : float + Relative tolerance with regards estimate residual noise in the cluster centers + of two consecutive iterations to declare convergence. + random_state : None | int | instance of ~numpy.random.RandomState + A seed for the NumPy random number generator (RNG). If ``None`` (default), + the seed will be obtained from the operating system + (see :class:`~numpy.random.RandomState` for details), meaning it will most + likely produce different output every time this function or method is run. + To achieve reproducible results, pass a value here to explicitly initialize + the RNG with a defined state. + + References + ---------- + .. footbibliography:: + """ + + _n_clusters: Incomplete + _cluster_names: Incomplete + _n_init: Incomplete + _max_iter: Incomplete + _tol: Incomplete + _random_state: Incomplete + _GEV_: Incomplete + + def __init__( + self, + n_clusters: int, + n_init: int = 100, + max_iter: int = 300, + tol: Union[int, float] = 1e-06, + random_state: RANDomState = None, + ) -> None: ... + def _repr_html_(self, caption: Incomplete | None = None): ... + def __eq__(self, other: Any) -> bool: + """Equality == method.""" + + def __ne__(self, other: Any) -> bool: + """Different != method.""" + + def _check_fit(self) -> None: + """Check if the cluster is fitted.""" + _cluster_centers_: Incomplete + _labels_: Incomplete + _fitted: bool + _ignore_polarity: bool + + def fit( + self, + inst: Union[BaseRaw, BaseEpochs, CHData], + picks: Picks = "eeg", + tmin: Optional[Union[int, float]] = None, + tmax: Optional[Union[int, float]] = None, + reject_by_annotation: bool = True, + n_jobs: int = 1, + *, + verbose: Optional[str] = None, + ) -> None: + """Compute cluster centers. + + Parameters + ---------- + inst : Raw | Epochs | ChData + MNE `~mne.io.Raw`, `~mne.Epochs` or `~pycrostates.io.ChData` object from + which to extract :term:`cluster centers`. + picks : str | list | slice | None + Channels to include. Note that all channels selected must have the same + type. Slices and lists of integers will be interpreted as channel indices. + In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the given + channels. Can also be the string values ``“all”`` to pick all channels, or + ``“data”`` to pick data channels. ``"eeg"`` (default) will pick all eeg + channels. Note that channels in ``info['bads']`` will be included if their + names or indices are explicitly provided. + tmin : float + Start time of the raw data to use in seconds (must be >= 0). + tmax : float + End time of the raw data to use in seconds (cannot exceed data duration). + reject_by_annotation : bool + Whether to omit bad segments from the data before fitting. If ``True`` + (default), annotated segments whose description begins with ``'bad'`` are + omitted. If ``False``, no rejection based on annotations is performed. + + Has no effect if ``inst`` is not a :class:`mne.io.Raw` object. + n_jobs : int | None + The number of jobs to run in parallel. If ``-1``, it is set + to the number of CPU cores. Requires the :mod:`joblib` package. + ``None`` (default) is a marker for 'unset' that will be interpreted + as ``n_jobs=1`` (sequential execution) unless the call is performed under + a :class:`joblib:joblib.parallel_config` context manager that sets another + value for ``n_jobs``. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + """ + + def save(self, fname: Union[str, Path]): + """Save clustering solution to disk. + + Parameters + ---------- + fname : path-like + Path to the ``.fif`` file where the clustering solution is saved. + """ + + @staticmethod + def _kmeans( + data: NDArray[float], + n_clusters: int, + max_iter: int, + random_state: Union[RandomState, Generator], + tol: Union[int, float], + ) -> tuple[float, NDArray[float], NDArray[int], bool]: + """Run the k-means algorithm.""" + + @staticmethod + def _compute_maps( + data: NDArray[float], + n_clusters: int, + max_iter: int, + random_state: Union[RandomState, Generator], + tol: Union[int, float], + ) -> tuple[NDArray[float], bool]: + """Compute microstates maps. + + Based on mne_microstates by Marijn van Vliet + https://github.com/wmvanvliet/mne_microstates/blob/master/microstates.py + """ + + @property + def n_init(self) -> int: + """Number of k-means algorithms run with different centroid seeds. + + :type: `int` + """ + + @property + def max_iter(self) -> int: + """Maximum number of iterations of the k-means algorithm for a run. + + :type: `int` + """ + + @property + def tol(self) -> Union[int, float]: + """Relative tolerance to reach convergence. + + :type: `float` + """ + + @property + def random_state(self) -> Union[RandomState, Generator]: + """Random state to fix seed generation. + + :type: `~numpy.random.RandomState` | `~numpy.random.Generator` + """ + + @property + def GEV_(self) -> float: + """Global Explained Variance. + + :type: `float` + """ + + @_BaseCluster.fitted.setter + def fitted(self, fitted) -> None: + """Fitted state. + + :type: `bool` + """ + + @staticmethod + def _check_n_init(n_init: int) -> int: + """Check that n_init is a positive integer.""" + + @staticmethod + def _check_max_iter(max_iter: int) -> int: + """Check that max_iter is a positive integer.""" + + @staticmethod + def _check_tol(tol: Union[int, float]) -> Union[int, float]: + """Check that tol is a positive number.""" diff --git a/pycrostates/cluster/utils/__init__.pyi b/pycrostates/cluster/utils/__init__.pyi new file mode 100644 index 00000000..57630b0c --- /dev/null +++ b/pycrostates/cluster/utils/__init__.pyi @@ -0,0 +1,3 @@ +from .utils import optimize_order as optimize_order + +__all__: tuple[str, ...] diff --git a/pycrostates/cluster/utils/utils.pyi b/pycrostates/cluster/utils/utils.pyi new file mode 100644 index 00000000..44a00fd8 --- /dev/null +++ b/pycrostates/cluster/utils/utils.pyi @@ -0,0 +1,31 @@ +from numpy.typing import NDArray + +from ..._typing import Cluster as Cluster +from ...utils._checks import _check_type as _check_type +from ...utils._docs import fill_doc as fill_doc + +def _optimize_order( + centers: NDArray[float], + template_centers: NDArray[float], + ignore_polarity: bool = True, +): ... +def optimize_order(inst: Cluster, template_inst: Cluster): + """Optimize the order of cluster centers between two cluster instances. + + Optimize the order of cluster centers in an instance of a clustering algorithm to + maximize auto-correlation, based on a template instance as determined by the + Hungarian algorithm. The two cluster instances must have the same number of cluster + centers and the same polarity setting. + + Parameters + ---------- + inst : :ref:`cluster` + Fitted clustering algorithm to reorder. + template_inst : :ref:`cluster` + Fitted clustering algorithm to use as template for reordering. + + Returns + ------- + order : list of int + The new order to apply to inst to maximize auto-correlation of cluster centers. + """ diff --git a/pycrostates/datasets/__init__.pyi b/pycrostates/datasets/__init__.pyi new file mode 100644 index 00000000..934fcf9a --- /dev/null +++ b/pycrostates/datasets/__init__.pyi @@ -0,0 +1,3 @@ +from . import lemon as lemon + +__all__: tuple[str, ...] diff --git a/pycrostates/datasets/lemon/__init__.pyi b/pycrostates/datasets/lemon/__init__.pyi new file mode 100644 index 00000000..f9efd995 --- /dev/null +++ b/pycrostates/datasets/lemon/__init__.pyi @@ -0,0 +1,6 @@ +from _typeshed import Incomplete + +from .lemon import data_path as data_path +from .lemon import standardize as standardize + +__all__: Incomplete diff --git a/pycrostates/datasets/lemon/lemon.pyi b/pycrostates/datasets/lemon/lemon.pyi new file mode 100644 index 00000000..c9be379d --- /dev/null +++ b/pycrostates/datasets/lemon/lemon.pyi @@ -0,0 +1,71 @@ +from pathlib import Path + +from mne.io import BaseRaw + +from ...utils._checks import _check_type as _check_type +from ...utils._checks import _check_value as _check_value +from ...utils._config import get_config as get_config + +def data_path(subject_id: str, condition: str) -> Path: + """Get path to a local copy of preprocessed EEG recording from the LEMON dataset. + + Get path to a local copy of preprocessed EEG recording from the mind-brain-body + dataset of MRI, EEG, cognition, emotion, and peripheral physiology in young and old + adults\\ :footcite:p:`babayan_mind-brain-body_2019`. If there is no local copy of the + recording, this function will fetch it from the online repository and store it. The + default location is ``~/pycrostates_data``. + + Parameters + ---------- + subject_id : str + The subject id to use. For example ``'010276'``. + The list of available subjects can be found on this + `FTP server `_. + condition : str + Can be ``'EO'`` for eyes open condition or ``'EC'`` for eyes closed condition. + + Returns + ------- + path : Path + Path to a local copy of the requested recording. + + Notes + ----- + The lemon datasets is composed of EEGLAB files. To use the MNE reader + :func:`mne.io.read_raw_eeglab`, the ``pymatreader`` optional dependency + is required. Use the following installation method appropriate for your + environment: + + - ``pip install pymatreader`` + - ``conda install -c conda-forge pymatreader`` + + Note that an environment created via the MNE installers includes ``pymatreader`` by + default. + + References + ---------- + .. footbibliography:: + """ + +def standardize(raw: BaseRaw): + """Standardize :class:`~mne.io.Raw` from the lemon dataset. + + This function will interpolate missing channels from the standard setup, then + reorder channels and finally reference to a common average. + + Parameters + ---------- + raw : Raw + Raw data from the lemon dataset. + + Returns + ------- + raw : Raw + Standardize raw. + + Notes + ----- + If you don't want to interpolate missing channels, you can use + :func:`mne.channels.equalize_channels` instead to have the same electrodes across + different recordings. + """ diff --git a/pycrostates/io/__init__.pyi b/pycrostates/io/__init__.pyi new file mode 100644 index 00000000..f6aaf151 --- /dev/null +++ b/pycrostates/io/__init__.pyi @@ -0,0 +1,5 @@ +from .ch_data import ChData as ChData +from .meas_info import ChInfo as ChInfo +from .reader import read_cluster as read_cluster + +__all__: tuple[str, ...] diff --git a/pycrostates/io/ch_data.pyi b/pycrostates/io/ch_data.pyi new file mode 100644 index 00000000..d1dd8332 --- /dev/null +++ b/pycrostates/io/ch_data.pyi @@ -0,0 +1,128 @@ +from typing import Any, Union + +from _typeshed import Incomplete +from mne import Info +from numpy.typing import NDArray + +from .._typing import CHData as CHData +from .._typing import CHInfo as CHInfo +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc +from ..utils.mixin import ChannelsMixin as ChannelsMixin +from ..utils.mixin import ContainsMixin as ContainsMixin +from ..utils.mixin import MontageMixin as MontageMixin + +class ChData(CHData, ChannelsMixin, ContainsMixin, MontageMixin): + """ChData stores atemporal data with its spatial information. + + `~pycrostates.io.ChData` is similar to a raw instance where temporality has been + removed. Only the spatial information, stored as a `~pycrostates.io.ChInfo` is + retained. + + Parameters + ---------- + data : array + Data array of shape ``(n_channels, n_samples)``. + info : mne.Info | ChInfo + Atemporal measurement information. If a `mne.Info` is provided, it is converted + to a `~pycrostates.io.ChInfo`. + """ + + _data: Incomplete + _info: Incomplete + + def __init__(self, data: NDArray[float], info: Union[Info, CHInfo]) -> None: ... + def __repr__(self) -> str: + """String representation.""" + + def _repr_html_(self, caption: Incomplete | None = None): + """HTML representation.""" + + def __eq__(self, other: Any) -> bool: + """Equality == method.""" + + def __ne__(self, other: Any) -> bool: + """Different != method.""" + + def copy(self, deep: bool = True): + """Return a copy of the instance. + + Parameters + ---------- + deep : bool + If True, `~copy.deepcopy` is used instead of `~copy.copy`. + """ + + def get_data(self, picks: Incomplete | None = None) -> NDArray[float]: + """Retrieve the data array. + + Parameters + ---------- + picks : str | array-like | slice | None + Channels to include. Slices and lists of integers will be interpreted as + channel indices. In lists, channel *type* strings (e.g., ``['meg', + 'eeg']``) will pick channels of those types, channel *name* strings (e.g., + ``['MEG0111', 'MEG2623']`` will pick the given channels. Can also be the + string values "all" to pick all channels, or "data" to pick :term:`data + channels`. None (default) will pick all channels. Note that channels in + ``info['bads']`` *will be included* if their names or indices are + explicitly provided. + + Returns + ------- + data : array + Data array of shape ``(n_channels, n_samples)``. + """ + + def pick(self, picks, exclude: str = "bads"): + """Pick a subset of channels. + + Parameters + ---------- + picks : str | array-like | slice | None + Channels to include. Slices and lists of integers will be interpreted as + channel indices. In lists, channel *type* strings (e.g., ``['meg', + 'eeg']``) will pick channels of those types, channel *name* strings (e.g., + ``['MEG0111', 'MEG2623']`` will pick the given channels. Can also be the + string values "all" to pick all channels, or "data" to pick :term:`data + channels`. None (default) will pick all channels. Note that channels in + ``info['bads']`` *will be included* if their names or indices are + explicitly provided. + exclude : list | str + Set of channels to exclude, only used when picking based on types (e.g., + ``exclude="bads"`` when ``picks="meg"``). + + Returns + ------- + inst : ChData + The instance modified in-place. + """ + + def _get_channel_positions(self, picks: Incomplete | None = None): + """Get channel locations from info. + + Parameters + ---------- + picks : str | list | slice | None + None selects the good data channels. + + Returns + ------- + pos : array of shape (n_channels, 3) + Channel X/Y/Z locations. + """ + + @property + def info(self) -> CHInfo: + """Atemporal measurement information. + + :type: ChInfo + """ + + @property + def ch_names(self): + """Channel names.""" + + @property + def preload(self): + """Preload required by some MNE functions.""" diff --git a/pycrostates/io/fiff.pyi b/pycrostates/io/fiff.pyi new file mode 100644 index 00000000..77dac0fe --- /dev/null +++ b/pycrostates/io/fiff.pyi @@ -0,0 +1,125 @@ +from pathlib import Path as Path +from typing import Union + +from mne import Info +from numpy.typing import NDArray + +from .. import __version__ as __version__ +from .._typing import CHInfo as CHInfo +from ..cluster import AAHCluster as AAHCluster +from ..cluster import ModKMeans as ModKMeans +from ..utils._checks import _check_type as _check_type +from ..utils._checks import _check_value as _check_value +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger + +def _write_cluster( + fname: Union[str, Path], + cluster_centers_: NDArray[float], + chinfo: Union[CHInfo, Info], + algorithm: str, + cluster_names: list[str], + fitted_data: NDArray[float], + labels_: NDArray[int], + **kwargs, +): + """Save clustering solution to disk. + + Parameters + ---------- + fname : str | Path + Path to the ``.fif`` file where the clustering solution is saved. + cluster_centers_ : array + Cluster centers as a numpy array of shape (n_clusters, n_channels). + chinfo : ChInfo + Channel information (name, type, montage, ..) + algorithm : str + Clustering algorithm used. Valids are: + 'ModKMeans' + cluster_names : list + List of names for each of the clusters. + fitted_data : array + Data array used for fitting of shape (n_channels, n_samples) + labels_ : array + Array of labels for each sample of shape (n_samples, ) + """ + +def _prepare_kwargs(algorithm: str, kwargs: dict): + """Prepare params to save from kwargs.""" + +def _read_cluster(fname: Union[str, Path]): + """Read clustering solution from disk. + + Parameters + ---------- + fname : str | Path + Path to the ``.fif`` file where the clustering solution is saved. + + Returns + ------- + cluster : _BaseCluster + Loaded cluster solution. + version : str + pycrostates version used to save the cluster solution. + """ + +def _check_fit_parameters_and_variables(fit_parameters: dict, fit_variables: dict): + """Check that we have all the keys we are looking for and return algo.""" + +def _create_ModKMeans( + cluster_centers_: NDArray[float], + info: CHInfo, + cluster_names: list[str], + fitted_data: NDArray[float], + labels_: NDArray[int], + n_init: int, + max_iter: int, + tol: Union[int, float], + GEV_: float, +): + """Create a ModKMeans cluster.""" + +def _create_AAHCluster( + cluster_centers_: NDArray[float], + info: CHInfo, + cluster_names: list[str], + fitted_data: NDArray[float], + labels_: NDArray[int], + ignore_polarity: bool, + normalize_input: bool, + GEV_: float, +): + """Create a AAHCluster object.""" + +def _write_meas_info(fid, info: CHInfo): + """Write measurement info into a file id (from a fif file). + + Parameters + ---------- + fid : file + Open file descriptor. + info : ChInfo + Channel information. + """ + +def _read_meas_info(fid, tree): + """Read the measurement info. + + Parameters + ---------- + fid : file + Open file descriptor. + tree : tree + FIF tree structure. + + Returns + ------- + info : ChInfo + Channel information instance. + """ + +def _serialize(dict_: dict, outer_sep: str = ";", inner_sep: str = ":"): + """Aux function.""" + +def _deserialize(str_: str, outer_sep: str = ";", inner_sep: str = ":"): + """Aux Function.""" diff --git a/pycrostates/io/meas_info.pyi b/pycrostates/io/meas_info.pyi new file mode 100644 index 00000000..73b9834b --- /dev/null +++ b/pycrostates/io/meas_info.pyi @@ -0,0 +1,163 @@ +from typing import Optional, Union + +from mne import Info + +from .._typing import CHInfo as CHInfo +from ..utils._checks import _check_type as _check_type +from ..utils._checks import _IntLike as _IntLike +from ..utils._logs import logger as logger + +class ChInfo(CHInfo, Info): + """Atemporal measurement information. + + Similar to a :class:`mne.Info` class, but without any temporal information. + Only the channel-related information are present. A :class:`~pycrostates.io.ChInfo` + can be created either: + + - by providing a :class:`~mne.Info` class from which information are retrieved. + - by providing the ``ch_names`` and the ``ch_types`` to create a new instance. + + Only one of those 2 methods should be used at once. + + .. warning:: The only entry that should be manually changed by the user is + ``info['bads']``. All other entries should be considered read-only, + though they can be modified by various functions or methods (which have + safeguards to ensure all fields remain in sync). + + Parameters + ---------- + info : Info | None + MNE measurement information instance from which channel-related variables are + retrieved. + ch_names : list of str | int | None + Channel names. If an int, a list of channel names will be created from + ``range(ch_names)``. + ch_types : list of str | str | None + Channel types. If str, all channels are assumed to be of the same type. + + Attributes + ---------- + bads : list of str + List of bad (noisy/broken) channels, by name. These channels will by default be + ignored by many processing steps. + ch_names : tuple of str + The names of the channels. + chs : tuple of dict + A list of channel information dictionaries, one per channel. See Notes for more + information. + comps : list of dict + CTF software gradient compensation data. See Notes for more information. + ctf_head_t : dict | None + The transformation from 4D/CTF head coordinates to Neuromag head coordinates. + This is only present in 4D/CTF data. + custom_ref_applied : int + Whether a custom (=other than average) reference has been applied to the EEG + data. This flag is checked by some algorithms that require an average reference + to be set. + dev_ctf_t : dict | None + The transformation from device coordinates to 4D/CTF head coordinates. This is + only present in 4D/CTF data. + dev_head_t : dict | None + The device to head transformation. + dig : tuple of dict | None + The Polhemus digitization data in head coordinates. See Notes for more + information. + nchan : int + Number of channels. + projs : list of Projection + List of SSP operators that operate on the data. See :class:`mne.Projection` for + details. + + Notes + ----- + The following parameters have a nested structure. + + * ``chs`` list of dict: + + cal : float + The calibration factor to bring the channels to physical units. Used in + product with ``range`` to scale the data read from disk. + ch_name : str + The channel name. + coil_type : int + Coil type, e.g. ``FIFFV_COIL_MEG``. + coord_frame : int + The coordinate frame used, e.g. ``FIFFV_COORD_HEAD``. + kind : int + The kind of channel, e.g. ``FIFFV_EEG_CH``. + loc : array, shape (12,) + Channel location. For MEG this is the position plus the normal given by a + 3x3 rotation matrix. For EEG this is the position followed by reference + position (with 6 unused). The values are specified in device coordinates for + MEG and in head coordinates for EEG channels, respectively. + logno : int + Logical channel number, conventions in the usage of this number vary. + range : float + The hardware-oriented part of the calibration factor. This should be only + applied to the continuous raw data. Used in product with ``cal`` to scale + data read from disk. + scanno : int + Scanning order number, starting from 1. + unit : int + The unit to use, e.g. ``FIFF_UNIT_T_M``. + unit_mul : int + Unit multipliers, most commonly ``FIFF_UNITM_NONE``. + + * ``comps`` list of dict: + + ctfkind : int + CTF compensation grade. + colcals : ndarray + Column calibrations. + mat : dict + A named matrix dictionary (with entries "data", "col_names", etc.) + containing the compensation matrix. + rowcals : ndarray + Row calibrations. + save_calibrated : bool + Were the compensation data saved in calibrated form. + + * ``dig`` list of dict: + + kind : int + The kind of channel, e.g. ``FIFFV_POINT_EEG``, ``FIFFV_POINT_CARDINAL``. + r : array, shape (3,) + 3D position in m. and coord_frame. + ident : int + Number specifying the identity of the point. e.g. ``FIFFV_POINT_NASION`` if + kind is ``FIFFV_POINT_CARDINAL``, or 42 if kind is ``FIFFV_POINT_EEG``. + coord_frame : int + The coordinate frame used, e.g. ``FIFFV_COORD_HEAD``. + """ + + def __init__( + self, + info: Optional[Info] = None, + ch_names: Optional[Union[int, list[str], tuple[str, ...]]] = None, + ch_types: Optional[Union[str, list[str], tuple[str, ...]]] = None, + ) -> None: ... + def _init_from_info(self, info: Info): + """Init instance from mne Info.""" + _unlocked: bool + + def _init_from_channels( + self, + ch_names: Union[int, list[str], tuple[str, ...]], + ch_types: Union[str, list[str], tuple[str, ...]], + ): + """Init instance from channel names and types.""" + + def __getattribute__(self, name): + """Attribute getter.""" + + def __eq__(self, other): + """Equality == method.""" + + def __ne__(self, other): + """Different != method.""" + + def __deepcopy__(self, memodict): + """Make a deepcopy.""" + + def _check_consistency(self, prepend_error: str = ""): + """Do some self-consistency checks and datatype tweaks.""" diff --git a/pycrostates/io/reader.pyi b/pycrostates/io/reader.pyi new file mode 100644 index 00000000..b39bab4b --- /dev/null +++ b/pycrostates/io/reader.pyi @@ -0,0 +1,20 @@ +from pathlib import Path +from typing import Union + +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger + +def read_cluster(fname: Union[str, Path]): + """Read clustering solution from disk. + + Parameters + ---------- + fname : str | Path + Path to the ``.fif`` file where the clustering solution is saved. + + Returns + ------- + cluster : :ref:`Clustering` + Fitted clustering instance. + """ diff --git a/pycrostates/metrics/__init__.pyi b/pycrostates/metrics/__init__.pyi new file mode 100644 index 00000000..f2c61947 --- /dev/null +++ b/pycrostates/metrics/__init__.pyi @@ -0,0 +1,6 @@ +from .calinski_harabasz import calinski_harabasz_score as calinski_harabasz_score +from .davies_bouldin import davies_bouldin_score as davies_bouldin_score +from .dunn import dunn_score as dunn_score +from .silhouette import silhouette_score as silhouette_score + +__all__: tuple[str, ...] diff --git a/pycrostates/metrics/calinski_harabasz.pyi b/pycrostates/metrics/calinski_harabasz.pyi new file mode 100644 index 00000000..1e3c5cd4 --- /dev/null +++ b/pycrostates/metrics/calinski_harabasz.pyi @@ -0,0 +1,32 @@ +from ..cluster._base import _BaseCluster as _BaseCluster +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc + +def calinski_harabasz_score(cluster): + """Compute the Calinski-Harabasz score. + + This function computes the Calinski-Harabasz score\\ :footcite:p:`Calinski-Harabasz` + with :func:`sklearn.metrics.calinski_harabasz_score` from a fitted :ref:`Clustering` + instance. + + Parameters + ---------- + cluster : :ref:`cluster` + Fitted clustering algorithm from which to compute score. For more details about + current clustering implementations, check the :ref:`Clustering` section of the + documentation. + + Returns + ------- + score : float + The resulting Calinski-Harabasz score. + + Notes + ----- + For more details regarding the implementation, please refer to + :func:`sklearn.metrics.calinski_harabasz_score`. + + References + ---------- + .. footbibliography:: + """ diff --git a/pycrostates/metrics/davies_bouldin.pyi b/pycrostates/metrics/davies_bouldin.pyi new file mode 100644 index 00000000..ee1cb124 --- /dev/null +++ b/pycrostates/metrics/davies_bouldin.pyi @@ -0,0 +1,52 @@ +from ..cluster._base import _BaseCluster as _BaseCluster +from ..utils import _distance_matrix as _distance_matrix +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc + +def davies_bouldin_score(cluster): + """Compute the Davies-Bouldin score. + + This function computes the Davies-Bouldin score\\ :footcite:p:`Davies-Bouldin` with + :func:`sklearn.metrics.davies_bouldin_score` from a fitted :ref:`Clustering` + instance. + + Parameters + ---------- + cluster : :ref:`cluster` + Fitted clustering algorithm from which to compute score. For more details about + current clustering implementations, check the :ref:`Clustering` section of the + documentation. + + Returns + ------- + score : float + The resulting Davies-Bouldin score. + + Notes + ----- + For more details regarding the implementation, please refer to + :func:`sklearn.metrics.davies_bouldin_score`. This function was modified in order to + use the absolute spatial correlation for distance computations instead of the + euclidean distance. + + References + ---------- + .. footbibliography:: + """ + +def _davies_bouldin_score(X, labels): + """Compute the Davies-Bouldin score. + + Parameters + ---------- + X : array of shape (n_samples, n_features) + A list of ``n_features``-dimensional data points. Each row corresponds + to a single data point. + labels : array of shape (n_samples,) + Predicted labels for each sample. + + Returns + ------- + score: float + The resulting Davies-Bouldin score. + """ diff --git a/pycrostates/metrics/dunn.pyi b/pycrostates/metrics/dunn.pyi new file mode 100644 index 00000000..a55a740f --- /dev/null +++ b/pycrostates/metrics/dunn.pyi @@ -0,0 +1,49 @@ +from ..cluster._base import _BaseCluster as _BaseCluster +from ..utils import _distance_matrix as _distance_matrix +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc + +def dunn_score(cluster): + """Compute the Dunn index score. + + This function computes the Dunn index score\\ :footcite:p:`Dunn` from a + fitted :ref:`Clustering` instance. + + Parameters + ---------- + cluster : :ref:`cluster` + Fitted clustering algorithm from which to compute score. For more details about + current clustering implementations, check the :ref:`Clustering` section of the + documentation. + + Returns + ------- + score : float + The resulting Dunn score. + + Notes + ----- + This function uses the absolute spatial correlation for distance. + + References + ---------- + .. footbibliography:: + """ + +def _dunn_score(X, labels): + """Compute the Dunn index. + + Parameters + ---------- + X : np.array + np.array([N, p]) of all points + labels: np.array + np.array([N]) labels of all points + + Notes + ----- + Based on https://github.com/jqmviegas/jqm_cvi + """ + +def _delta_fast(ck, cl, distances): ... +def _big_delta_fast(ci, distances): ... diff --git a/pycrostates/metrics/silhouette.pyi b/pycrostates/metrics/silhouette.pyi new file mode 100644 index 00000000..092ab805 --- /dev/null +++ b/pycrostates/metrics/silhouette.pyi @@ -0,0 +1,34 @@ +from ..cluster._base import _BaseCluster as _BaseCluster +from ..utils import _distance_matrix as _distance_matrix +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc + +def silhouette_score(cluster): + """Compute the mean Silhouette Coefficient. + + This function computes the Silhouette Coefficient\\ :footcite:p:`Silhouettes` with + :func:`sklearn.metrics.silhouette_score` from a fitted :ref:`Clustering` instance. + + Parameters + ---------- + cluster : :ref:`cluster` + Fitted clustering algorithm from which to compute score. For more details about + current clustering implementations, check the :ref:`Clustering` section of the + documentation. + + Returns + ------- + silhouette : float + The mean Silhouette Coefficient. + + Notes + ----- + For more details regarding the implementation, please refer to + :func:`sklearn.metrics.silhouette_score`. + This proxy function uses ``metric="precomputed"`` with the absolute spatial + correlation for distance computations. + + References + ---------- + .. footbibliography:: + """ diff --git a/pycrostates/preprocessing/__init__.pyi b/pycrostates/preprocessing/__init__.pyi new file mode 100644 index 00000000..af15a317 --- /dev/null +++ b/pycrostates/preprocessing/__init__.pyi @@ -0,0 +1,5 @@ +from .extract_gfp_peaks import extract_gfp_peaks as extract_gfp_peaks +from .resample import resample as resample +from .spatial_filter import apply_spatial_filter as apply_spatial_filter + +__all__: tuple[str, ...] diff --git a/pycrostates/preprocessing/extract_gfp_peaks.pyi b/pycrostates/preprocessing/extract_gfp_peaks.pyi new file mode 100644 index 00000000..b7ac9075 --- /dev/null +++ b/pycrostates/preprocessing/extract_gfp_peaks.pyi @@ -0,0 +1,99 @@ +from typing import Optional, Union + +from _typeshed import Incomplete +from mne import BaseEpochs +from mne.io import BaseRaw +from numpy.typing import NDArray + +from .._typing import CHData as CHData +from .._typing import Picks as Picks +from ..utils._checks import _check_picks_uniqueness as _check_picks_uniqueness +from ..utils._checks import _check_reject_by_annotation as _check_reject_by_annotation +from ..utils._checks import _check_tmin_tmax as _check_tmin_tmax +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger + +def extract_gfp_peaks( + inst: Union[BaseRaw, BaseEpochs], + picks: Picks = "eeg", + return_all: bool = False, + min_peak_distance: int = 1, + tmin: Optional[float] = None, + tmax: Optional[float] = None, + reject_by_annotation: bool = True, + verbose: Incomplete | None = None, +) -> CHData: + """:term:`Global Field Power` (:term:`GFP`) peaks extraction. + + Extract :term:`Global Field Power` (:term:`GFP`) peaks from :class:`~mne.Epochs` or + :class:`~mne.io.Raw`. + + Parameters + ---------- + inst : Raw | Epochs + Instance from which to extract :term:`global field power` (GFP) peaks. + picks : str | list | slice | None + Channels to use for GFP computation. Note that all channels selected must have + the same type. Slices and lists of integers will be interpreted as channel + indices. In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the + given channels. Can also be the string values ``“all”`` to pick all channels, or + ``“data”`` to pick data channels. ``"eeg"`` (default) will pick all eeg + channels. Note that channels in ``info['bads']`` will be included if their + names or indices are explicitly provided. + return_all : bool + If True, the returned `~pycrostates.io.ChData` instance will include all + channels. If False (default), the returned `~pycrostates.io.ChData` instance + will only include channels used for GFP computation (i.e ``picks``). + min_peak_distance : int + Required minimal horizontal distance (``≥ 1`) in samples between neighboring + peaks. Smaller peaks are removed first until the condition is fulfilled for all + remaining peaks. Default to ``1``. + tmin : float + Start time of the raw data to use in seconds (must be >= 0). + tmax : float + End time of the raw data to use in seconds (cannot exceed data duration). + reject_by_annotation : bool + Whether to omit bad segments from the data before fitting. If ``True`` + (default), annotated segments whose description begins with ``'bad'`` are + omitted. If ``False``, no rejection based on annotations is performed. + + Has no effect if ``inst`` is not a :class:`mne.io.Raw` object. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + + Returns + ------- + ch_data : ChData + Samples at global field power peaks. + + Notes + ----- + The :term:`Global Field Power` (:term:`GFP`) peaks are extracted with + :func:`scipy.signal.find_peaks`. Only the ``distance`` argument is filled with the + value provided in ``min_peak_distance``. The other arguments are set to their + default values. + """ + +def _extract_gfp_peaks( + data: NDArray[float], min_peak_distance: int = 2 +) -> NDArray[float]: + """Extract GFP peaks from input data. + + Parameters + ---------- + data : array of shape (n_channels, n_samples) + The data to extract GFP peaks from. + min_peak_distance : int + Required minimal horizontal distance (>= 1) in samples between neighboring + peaks. Smaller peaks are removed first until the condition is fulfilled for all + remaining peaks. Default to 2. + + Returns + ------- + peaks : array of shape (n_picks,) + The indices when peaks occur. + """ diff --git a/pycrostates/preprocessing/resample.pyi b/pycrostates/preprocessing/resample.pyi new file mode 100644 index 00000000..d68f3146 --- /dev/null +++ b/pycrostates/preprocessing/resample.pyi @@ -0,0 +1,91 @@ +from typing import Optional, Union + +from _typeshed import Incomplete +from mne import BaseEpochs +from mne.io import BaseRaw + +from .._typing import CHData as CHData +from .._typing import Picks as Picks +from .._typing import RANDomState as RANDomState +from ..utils._checks import _check_random_state as _check_random_state +from ..utils._checks import _check_reject_by_annotation as _check_reject_by_annotation +from ..utils._checks import _check_tmin_tmax as _check_tmin_tmax +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger + +def resample( + inst: Union[BaseRaw, BaseEpochs, CHData], + picks: Picks = None, + tmin: Optional[float] = None, + tmax: Optional[float] = None, + reject_by_annotation: bool = True, + n_resamples: int = None, + n_samples: int = None, + coverage: float = None, + replace: bool = True, + random_state: RANDomState = None, + verbose: Incomplete | None = None, +) -> list[CHData]: + """Resample a recording into epochs of random samples. + + Resample :class:`~mne.io.Raw`. :class:`~mne.Epochs` or + `~pycrostates.io.ChData` into ``n_resamples`` each containing ``n_samples`` + random samples of the original recording. + + Parameters + ---------- + inst : Raw | Epochs | ChData + Instance to resample. + picks : str | array-like | slice | None + Channels to include. Slices and lists of integers will be interpreted as + channel indices. In lists, channel *type* strings (e.g., ``['meg', + 'eeg']``) will pick channels of those types, channel *name* strings (e.g., + ``['MEG0111', 'MEG2623']`` will pick the given channels. Can also be the + string values "all" to pick all channels, or "data" to pick :term:`data + channels`. None (default) will pick all channels. Note that channels in + ``info['bads']`` *will be included* if their names or indices are + explicitly provided. + tmin : float + Start time of the raw data to use in seconds (must be >= 0). + tmax : float + End time of the raw data to use in seconds (cannot exceed data duration). + reject_by_annotation : bool + Whether to omit bad segments from the data before fitting. If ``True`` + (default), annotated segments whose description begins with ``'bad'`` are + omitted. If ``False``, no rejection based on annotations is performed. + + Has no effect if ``inst`` is not a :class:`mne.io.Raw` object. + n_resamples : int + Number of resamples to draw. Each epoch can be used to fit a separate clustering + solution. See notes for additional information. + n_samples : int + Length of each epoch (in samples). See notes for additional information. + coverage : float + Strictly positive ratio between resampling data size and size of the original + recording. See notes for additional information. + replace : bool + Whether or not to allow resampling with replacement. + random_state : None | int | instance of ~numpy.random.RandomState + A seed for the NumPy random number generator (RNG). If ``None`` (default), + the seed will be obtained from the operating system + (see :class:`~numpy.random.RandomState` for details), meaning it will most + likely produce different output every time this function or method is run. + To achieve reproducible results, pass a value here to explicitly initialize + the RNG with a defined state. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + + Returns + ------- + resamples : list of :class:`~pycrostates.io.ChData` + List of resamples. + + Notes + ----- + Only two of ``n_resamples``, ``n_samples`` and ``coverage`` parameters must be + defined, the non-defined one will be determine at runtime by the 2 other parameters. + """ diff --git a/pycrostates/preprocessing/spatial_filter.pyi b/pycrostates/preprocessing/spatial_filter.pyi new file mode 100644 index 00000000..e49fe3a0 --- /dev/null +++ b/pycrostates/preprocessing/spatial_filter.pyi @@ -0,0 +1,93 @@ +from typing import Union + +from _typeshed import Incomplete +from mne import BaseEpochs +from mne.io import BaseRaw +from numpy.typing import NDArray as NDArray +from scipy.sparse import csr_matrix + +from .._typing import CHData as CHData +from ..utils._checks import _check_n_jobs as _check_n_jobs +from ..utils._checks import _check_type as _check_type +from ..utils._checks import _check_value as _check_value +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger + +def _check_adjacency(adjacency, info, ch_type): + """Check adjacency matrix.""" + +def apply_spatial_filter( + inst: Union[BaseRaw, BaseEpochs, CHData], + ch_type: str = "eeg", + exclude_bads: bool = True, + origin: Union[str, NDArray[float]] = "auto", + adjacency: Union[csr_matrix, str] = "auto", + n_jobs: int = 1, + verbose: Incomplete | None = None, +): + """Apply a spatial filter. + + Adapted from \\ :footcite:t:`michel2019eeg`. Apply an instantaneous filter which + interpolates channels with local neighbors while removing outliers. + The current implementation proceeds as follows: + + * An interpolation matrix is computed using + ``mne.channels.interpolation._make_interpolation_matrix``. + * An ajdacency matrix is computed using `mne.channels.find_ch_adjacency`. + * If ``exclude_bads`` is set to ``True``, bad channels are removed from the + ajdacency matrix. + * For each timepoint and each channel, a reduced adjacency matrix is built by + removing neighbors with lowest and highest value. + * For each timepoint and each channel, a reduced interpolation matrix is built by + extracting neighbor weights based on the reduced adjacency matrix. + * The reduced interpolation matrices are normalized. + * The channel's timepoints are interpolated using their reduced interpolation + matrix. + + Parameters + ---------- + inst : Raw | Epochs | ChData + Instance to filter spatially. + ch_type : str + The channel type on which to apply the spatial filter. Currently only supports + ``'eeg'``. + exclude_bads : bool + If set to ``True``, bad channels will be removed from the adjacency matrix and + therefore not used to interpolate neighbors. In addition, bad channels will not + be filtered. If set to ``False``, proceed as if all channels were good. + origin : array of shape (3,) | str + Origin of the sphere in the head coordinate frame and in meters. Can be + ``'auto'`` (default), which means a head-digitization-based origin fit. + adjacency : array or csr_matrix of shape (n_channels, n_channels) | str + An adjacency matrix. Can be created using `mne.channels.find_ch_adjacency` and + edited with `mne.viz.plot_ch_adjacency`. If ``'auto'`` (default), the matrix + will be automatically created using `mne.channels.find_ch_adjacency` and other + parameters. + n_jobs : int | None + The number of jobs to run in parallel. If ``-1``, it is set + to the number of CPU cores. Requires the :mod:`joblib` package. + ``None`` (default) is a marker for 'unset' that will be interpreted + as ``n_jobs=1`` (sequential execution) unless the call is performed under + a :class:`joblib:joblib.parallel_config` context manager that sets another + value for ``n_jobs``. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + + Returns + ------- + inst : Raw | Epochs| ChData + The instance modified in place. + + Notes + ----- + This function requires a full copy of the data in memory. + + References + ---------- + .. footbibliography:: + """ + +def _channel_spatial_filter(index, data, adjacency_vector, interpolate_matrix): ... diff --git a/pycrostates/segmentation/__init__.pyi b/pycrostates/segmentation/__init__.pyi new file mode 100644 index 00000000..f126d94d --- /dev/null +++ b/pycrostates/segmentation/__init__.pyi @@ -0,0 +1,8 @@ +from .segmentation import EpochsSegmentation as EpochsSegmentation +from .segmentation import RawSegmentation as RawSegmentation +from .transitions import ( + compute_expected_transition_matrix as compute_expected_transition_matrix, +) +from .transitions import compute_transition_matrix as compute_transition_matrix + +__all__: tuple[str, ...] diff --git a/pycrostates/segmentation/_base.pyi b/pycrostates/segmentation/_base.pyi new file mode 100644 index 00000000..b4387401 --- /dev/null +++ b/pycrostates/segmentation/_base.pyi @@ -0,0 +1,217 @@ +from abc import abstractmethod +from typing import Optional, Union + +from _typeshed import Incomplete +from matplotlib.axes import Axes as Axes +from mne import BaseEpochs +from mne.io import BaseRaw +from numpy.typing import NDArray + +from .._typing import Segmentation as Segmentation +from ..utils import _corr_vectors as _corr_vectors +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger +from .transitions import ( + _compute_expected_transition_matrix as _compute_expected_transition_matrix, +) +from .transitions import _compute_transition_matrix as _compute_transition_matrix + +class _BaseSegmentation(Segmentation): + """Base class for a Microstates segmentation. + + Parameters + ---------- + labels : array of shape (n_samples, ) or (n_epochs, n_samples) + Microstates labels attributed to each sample, i.e. the segmentation. + inst : Raw | Epochs + MNE instance used to predict the segmentation. + cluster_centers : array (n_clusters, n_channels) + Clusters, i.e. the microstates maps used to compute the segmentation. + cluster_names : list | None + Name of the clusters. + predict_parameters : dict | None + The prediction parameters. + """ + + _labels: Incomplete + _inst: Incomplete + _cluster_centers_: Incomplete + _cluster_names: Incomplete + _predict_parameters: Incomplete + + @abstractmethod + def __init__( + self, + labels: NDArray[int], + inst: Union[BaseRaw, BaseEpochs], + cluster_centers_: NDArray[float], + cluster_names: Optional[list[str]] = None, + predict_parameters: Optional[dict] = None, + ): ... + def __repr__(self) -> str: ... + def _repr_html_(self, caption: Incomplete | None = None): ... + def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False): + """Compute microstate parameters. + + .. warning:: + + When working with `~mne.Epochs`, this method will put together segments of + all epochs. This could lead to wrong interpretation especially on state + durations. To avoid this behaviour, make sure to set the ``reject_edges`` + parameter to ``True`` when creating the segmentation. + + Parameters + ---------- + norm_gfp : bool + If True, the :term:`global field power` (GFP) is normalized. + return_dist : bool + If True, return the parameters distributions. + + Returns + ------- + dict : dict + Dictionaries containing microstate parameters as key/value pairs. + Keys are named as follow: ``'{microstate name}_{parameter name}'``. + + Available parameters are listed below: + + * ``mean_corr``: Mean correlation value for each time point assigned to a + given state. + * ``gev``: Global explained variance expressed by a given state. + It is the sum of global explained variance values of each time point + assigned to a given state. + * ``timecov``: Time coverage, the proportion of time during which + a given state is active. This metric is expressed as a ratio. + * ``meandurs``: Mean durations of segments assigned to a given + state. This metric is expressed in seconds (s). + * ``occurrences``: Occurrences per second, the mean number of + segment assigned to a given state per second. This metrics is expressed + in segment per second. + * ``dist_corr`` (req. ``return_dist=True``): Distribution of + correlations values of each time point assigned to a given state. + * ``dist_gev`` (req. ``return_dist=True``): Distribution of global + explained variances values of each time point assigned to a given state. + * ``dist_durs`` (req. ``return_dist=True``): Distribution of + durations of each segments assigned to a given state. Each value is + expressed in seconds (s). + """ + + def compute_transition_matrix( + self, stat: str = "probability", ignore_repetitions: bool = True + ): + """Compute the observed transition matrix. + + Count the number of transitions from one state to another and aggregate the + result as statistic. Transition "from" and "to" unlabeled segments ``-1`` are + ignored. + + Parameters + ---------- + %(stat_transition)s + %(ignore_repetitions)s + + Returns + ------- + %(transition_matrix)s + + Warnings + -------- + When working with `~mne.Epochs`, this method will take into account transitions + that occur between epochs. This could lead to wrong interpretation when working + with discontinuous data. To avoid this behaviour, make sure to set the + ``reject_edges`` parameter to ``True`` when predicting the segmentation. + """ + + def compute_expected_transition_matrix( + self, stat: str = "probability", ignore_repetitions: bool = True + ): + """Compute the expected transition matrix. + + Compute the theoretical transition matrix as if time course was ignored, but + microstate proportions kept (i.e. shuffled segmentation). This matrix can be + used to quantify/correct the effect of microstate time coverage on the observed + transition matrix obtained with the method + ``compute_expected_transition_matrix``. + Transition "from" and "to" unlabeled segments ``-1`` are ignored. + + Parameters + ---------- + stat : str + Aggregate statistic to compute transitions. Can be: + + * ``probability`` or ``proportion``: normalize count such as the probabilities along + the first axis is always equal to ``1``. + * ``percent``: normalize count such as the probabilities along the first axis is + always equal to ``100``. + ignore_repetitions : bool + If ``True``, ignores state repetitions. + For example, the input sequence ``AAABBCCD`` + will be transformed into ``ABCD`` before any calculation. + This is equivalent to setting the duration of all states to 1 sample. + + Returns + ------- + T : array of shape ``(n_cluster, n_cluster)`` + Array of transition probability values from one label to another. + First axis indicates state ``"from"``. Second axis indicates state ``"to"``. + """ + + def plot_cluster_centers( + self, axes: Optional[Union[Axes, NDArray[Axes]]] = None, block: bool = False + ): + """Plot cluster centers as topographic maps. + + Parameters + ---------- + axes : Axes | None + Either ``None`` to create a new figure or axes (or an array of axes) on which the + topographic map should be plotted. If the number of microstates maps to plot is + ``≥ 1``, an array of axes of size ``n_clusters`` should be provided. + block : bool + Whether to halt program execution until the figure is closed. + + Returns + ------- + fig : Figure + Matplotlib figure containing the topographic plots. + """ + + @staticmethod + def _check_cluster_names( + cluster_names: list[str], cluster_centers_: NDArray[float] + ): + """Check that the argument 'cluster_names' is valid.""" + + @staticmethod + def _check_predict_parameters(predict_parameters: dict): + """Check that the argument 'predict_parameters' is valid.""" + + @property + def predict_parameters(self) -> dict: + """Parameters used to predict the current segmentation. + + :type: `dict` + """ + + @property + def labels(self) -> NDArray[int]: + """Microstate label attributed to each sample (the segmentation). + + :type: `~numpy.array` + """ + + @property + def cluster_centers_(self) -> NDArray[float]: + """Cluster centers (i.e topographies) + used to compute the segmentation. + + :type: `~numpy.array` + """ + + @property + def cluster_names(self) -> list[str]: + """Name of the cluster centers. + + :type: `list` + """ diff --git a/pycrostates/segmentation/segmentation.pyi b/pycrostates/segmentation/segmentation.pyi new file mode 100644 index 00000000..d20f8147 --- /dev/null +++ b/pycrostates/segmentation/segmentation.pyi @@ -0,0 +1,133 @@ +from typing import Optional, Union + +from _typeshed import Incomplete +from matplotlib.axes import Axes as Axes +from mne import BaseEpochs +from mne.io import BaseRaw + +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc +from ..viz import plot_epoch_segmentation as plot_epoch_segmentation +from ..viz import plot_raw_segmentation as plot_raw_segmentation +from ._base import _BaseSegmentation as _BaseSegmentation + +class RawSegmentation(_BaseSegmentation): + """ + Contains the segmentation of a `~mne.io.Raw` instance. + + Parameters + ---------- + labels : array of shape ``(n_samples,)`` + Microstates labels attributed to each sample, i.e. the segmentation. + raw : Raw + `~mne.io.Raw` instance used for prediction. + cluster_centers : array (n_clusters, n_channels) + Clusters, i.e. the microstates maps used to compute the segmentation. + cluster_names : list | None + Name of the clusters. + predict_parameters : dict | None + The prediction parameters. + """ + + def __init__(self, *args, **kwargs) -> None: ... + def plot( + self, + tmin: Optional[Union[int, float]] = None, + tmax: Optional[Union[int, float]] = None, + cmap: Optional[str] = None, + axes: Optional[Axes] = None, + cbar_axes: Optional[Axes] = None, + *, + block: bool = False, + verbose: Optional[str] = None, + ): + """Plot the segmentation. + + Parameters + ---------- + tmin : float + Start time of the raw data to use in seconds (must be >= 0). + tmax : float + End time of the raw data to use in seconds (cannot exceed data duration). + cmap : str | colormap | None + The colormap to use. If None, ``viridis`` is used. + axes : Axes | None + Either ``None`` to create a new figure or axes on which the segmentation is + plotted. + cbar_axes : Axes | None + Axes on which to draw the colorbar, otherwise the colormap takes space from the main + axes. + block : bool + Whether to halt program execution until the figure is closed. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + + Returns + ------- + fig : Figure + Matplotlib figure containing the segmentation. + """ + + @property + def raw(self) -> BaseRaw: + """`~mne.io.Raw` instance from which the segmentation was computed.""" + +class EpochsSegmentation(_BaseSegmentation): + """Contains the segmentation of an `~mne.Epochs` instance. + + Parameters + ---------- + labels : array of shape ``(n_epochs, n_samples)`` + Microstates labels attributed to each sample, i.e. the segmentation. + epochs : Epochs + `~mne.Epochs` instance used for prediction. + cluster_centers : array (n_clusters, n_channels) + Clusters, i.e. the microstates maps used to compute the segmentation. + cluster_names : list | None + Name of the clusters. + predict_parameters : dict | None + The prediction parameters. + """ + + def __init__(self, *args, **kwargs) -> None: ... + def plot( + self, + cmap: Optional[str] = None, + axes: Optional[Axes] = None, + cbar_axes: Optional[Axes] = None, + *, + block: bool = False, + verbose: Incomplete | None = None, + ): + """Plot segmentation. + + Parameters + ---------- + cmap : str | colormap | None + The colormap to use. If None, ``viridis`` is used. + axes : Axes | None + Either ``None`` to create a new figure or axes on which the segmentation is + plotted. + cbar_axes : Axes | None + Axes on which to draw the colorbar, otherwise the colormap takes space from the main + axes. + block : bool + Whether to halt program execution until the figure is closed. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + + Returns + ------- + fig : Figure + Matplotlib figure containing the segmentation. + """ + + @property + def epochs(self) -> BaseEpochs: + """`~mne.Epochs` instance from which the segmentation was computed.""" diff --git a/pycrostates/segmentation/transitions.pyi b/pycrostates/segmentation/transitions.pyi new file mode 100644 index 00000000..8f8ff1bd --- /dev/null +++ b/pycrostates/segmentation/transitions.pyi @@ -0,0 +1,106 @@ +from numpy.typing import NDArray + +from ..utils._checks import _check_type as _check_type +from ..utils._checks import _check_value as _check_value +from ..utils._docs import fill_doc as fill_doc + +def compute_transition_matrix( + labels: NDArray[int], + n_clusters: int, + stat: str = "probability", + ignore_repetitions: bool = True, +) -> NDArray[float]: + """Compute the observed transition matrix. + + Count the number of transitions from one state to another and aggregate the result + as statistic. Transitions "from" and "to" unlabeled segments ``-1`` are ignored. + + Parameters + ---------- + labels : array of shape ``(n_samples,)`` or ``(n_epochs, n_samples)`` + Microstates labels attributed to each sample, i.e. the segmentation. + n_clusters : int + The number of clusters, i.e. the number of microstates. + stat : str + Aggregate statistic to compute transitions. Can be: + + * ``count``: show the number of observations of each transition. + * ``probability`` or ``proportion``: normalize count such as the probabilities along + the first axis is always equal to ``1``. + * ``percent``: normalize count such as the probabilities along the first axis is + always equal to ``100``. + ignore_repetitions : bool + If ``True``, ignores state repetitions. + For example, the input sequence ``AAABBCCD`` + will be transformed into ``ABCD`` before any calculation. + This is equivalent to setting the duration of all states to 1 sample. + + Returns + ------- + T : array of shape ``(n_cluster, n_cluster)`` + Array of transition probability values from one label to another. + First axis indicates state ``"from"``. Second axis indicates state ``"to"``. + """ + +def _compute_transition_matrix( + labels: NDArray[int], + n_clusters: int, + stat: str = "probability", + ignore_repetitions: bool = True, +) -> NDArray[float]: + """Compute observed transition.""" + +def compute_expected_transition_matrix( + labels: NDArray[int], + n_clusters: int, + stat: str = "probability", + ignore_repetitions: bool = True, +) -> NDArray[float]: + """Compute the expected transition matrix. + + Compute the theoretical transition matrix as if time course was ignored, but + microstate proportions was kept (i.e. shuffled segmentation). This matrix can be + used to quantify/correct the effect of microstate time coverage on the observed + transition matrix obtained with the + :func:`pycrostates.segmentation.compute_transition_matrix`. + Transition "from" and "to" unlabeled segments ``-1`` are ignored. + + Parameters + ---------- + labels : array of shape ``(n_samples,)`` or ``(n_epochs, n_samples)`` + Microstates labels attributed to each sample, i.e. the segmentation. + n_clusters : int + The number of clusters, i.e. the number of microstates. + stat : str + Aggregate statistic to compute transitions. Can be: + + * ``probability`` or ``proportion``: normalize count such as the probabilities along + the first axis is always equal to ``1``. + * ``percent``: normalize count such as the probabilities along the first axis is + always equal to ``100``. + ignore_repetitions : bool + If ``True``, ignores state repetitions. + For example, the input sequence ``AAABBCCD`` + will be transformed into ``ABCD`` before any calculation. + This is equivalent to setting the duration of all states to 1 sample. + + Returns + ------- + T : array of shape ``(n_cluster, n_cluster)`` + Array of transition probability values from one label to another. + First axis indicates state ``"from"``. Second axis indicates state ``"to"``. + """ + +def _compute_expected_transition_matrix( + labels: NDArray[int], + n_clusters: int, + stat: str = "probability", + ignore_repetitions: bool = True, +) -> NDArray[float]: + """Compute theoretical transition matrix. + + The theoretical transition matrix takes into account the time coverage. + """ + +def _check_labels_n_clusters(labels: NDArray[int], n_clusters: int) -> None: + """Checker for labels and n_clusters.""" diff --git a/pycrostates/utils/__init__.pyi b/pycrostates/utils/__init__.pyi new file mode 100644 index 00000000..dfef719b --- /dev/null +++ b/pycrostates/utils/__init__.pyi @@ -0,0 +1,6 @@ +from ._config import get_config as get_config +from .utils import _compare_infos as _compare_infos +from .utils import _corr_vectors as _corr_vectors +from .utils import _distance_matrix as _distance_matrix + +__all__: tuple[str, ...] diff --git a/pycrostates/utils/_checks.pyi b/pycrostates/utils/_checks.pyi new file mode 100644 index 00000000..7afa9228 --- /dev/null +++ b/pycrostates/utils/_checks.pyi @@ -0,0 +1,118 @@ +from typing import Any + +from _typeshed import Incomplete + +from ._docs import fill_doc as fill_doc + +def _ensure_int(item, item_name: Incomplete | None = None): + """ + Ensure a variable is an integer. + + Parameters + ---------- + item : object + Item to check. + item_name : str | None + Name of the item to show inside the error message. + + Raises + ------ + TypeError + When the type of the item is not int. + """ + +class _IntLike: + @classmethod + def __instancecheck__(cls, other): ... + +class _Callable: + @classmethod + def __instancecheck__(cls, other): ... + +_types: Incomplete + +def _check_type(item, types, item_name: Incomplete | None = None): + """ + Check that item is an instance of types. + + Parameters + ---------- + item : object + Item to check. + types : tuple of types | tuple of str + Types to be checked against. + If str, must be one of: + ('int', 'str', 'numeric', 'path-like', 'callable') + item_name : str | None + Name of the item to show inside the error message. + + Raises + ------ + TypeError + When the type of the item is not one of the valid options. + """ + +def _check_value( + item, + allowed_values, + item_name: Incomplete | None = None, + extra: Incomplete | None = None, +): + """ + Check the value of a parameter against a list of valid options. + + Parameters + ---------- + item : object + Item to check. + allowed_values : tuple of objects + Allowed values to be checked against. + item_name : str | None + Name of the item to show inside the error message. + extra : str | None + Extra string to append to the invalid value sentence, e.g. "with ico mode". + + Raises + ------ + ValueError + When the value of the item is not one of the valid options. + """ + +def _check_n_jobs(n_jobs): + """Check n_jobs parameter. + + Check that n_jobs is a positive integer or a negative integer for all cores. CUDA is + not supported. + """ + +def _check_random_state(seed): + """Turn seed into a numpy.random.mtrand.RandomState instance.""" + +def _check_axes(axes): + """Check that ax is an Axes object or an array of Axes.""" + +def _check_reject_by_annotation(reject_by_annotation: bool) -> bool: + """Check the reject_by_annotation argument.""" + +def _check_tmin_tmax(inst, tmin, tmax): + """Check tmin/tmax compared to the provided instance.""" + +def _check_picks_uniqueness(info, picks) -> None: + """Check that the provided picks yield a single channel type.""" + +def _check_verbose(verbose: Any) -> int: + """Check that the value of verbose is valid. + + Parameters + ---------- + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + + Returns + ------- + verbose : int + The verbosity level as an integer. + """ diff --git a/pycrostates/utils/_config.pyi b/pycrostates/utils/_config.pyi new file mode 100644 index 00000000..5e614283 --- /dev/null +++ b/pycrostates/utils/_config.pyi @@ -0,0 +1,38 @@ +from _typeshed import Incomplete + +def _get_user_dir(): + """Get user directory.""" + +def _get_home_dir(): + """Get pycrostates config directory.""" + +def _get_config_path(): + """Get config path.""" + +def _get_data_path(): + """Get pycrostates data directory.""" + +default_config: Incomplete + +def _save_config(config) -> None: + """Save pycrostates config.""" + +def get_config(): + """Read preferences from pycrostates' config file. + + Returns + ------- + config : dict + Dictionary containing all preferences as key/values pairs. + """ + +def set_config(key, value) -> None: + """Set preference key in the pycrostates' config file. + + Parameters + ---------- + key : str + The preference key to set. Must be a valid key. + value : str | None + The value to assign to the preference key. + """ diff --git a/pycrostates/utils/_docs.pyi b/pycrostates/utils/_docs.pyi new file mode 100644 index 00000000..e2bdd31b --- /dev/null +++ b/pycrostates/utils/_docs.pyi @@ -0,0 +1,70 @@ +from typing import Callable + +docdict: dict[str, str] +keys: tuple[str, ...] +entry: str +docdict_indented: dict[int, dict[str, str]] + +def fill_doc(f: Callable) -> Callable: + """Fill a docstring with docdict entries. + + Parameters + ---------- + f : callable + The function to fill the docstring of (modified in place). + + Returns + ------- + f : callable + The function, potentially with an updated __doc__. + """ + +def _indentcount_lines(lines: list[str]) -> int: + """Minimum indent for all lines in line list. + + >>> lines = [" one", " two", " three"] + >>> indentcount_lines(lines) + 1 + >>> lines = [] + >>> indentcount_lines(lines) + 0 + >>> lines = [" one"] + >>> indentcount_lines(lines) + 1 + >>> indentcount_lines([" "]) + 0 + """ + +def copy_doc(source: Callable) -> Callable: + """Copy the docstring from another function (decorator). + + The docstring of the source function is prepepended to the docstring of the + function wrapped by this decorator. + + This is useful when inheriting from a class and overloading a method. This + decorator can be used to copy the docstring of the original method. + + Parameters + ---------- + source : callable + The function to copy the docstring from. + + Returns + ------- + wrapper : callable + The decorated function. + + Examples + -------- + >>> class A: + ... def m1(): + ... '''Docstring for m1''' + ... pass + >>> class B(A): + ... @copy_doc(A.m1) + ... def m1(): + ... '''this gets appended''' + ... pass + >>> print(B.m1.__doc__) + Docstring for m1 this gets appended + """ diff --git a/pycrostates/utils/_fixes.pyi b/pycrostates/utils/_fixes.pyi new file mode 100644 index 00000000..08a8ce6f --- /dev/null +++ b/pycrostates/utils/_fixes.pyi @@ -0,0 +1,8 @@ +class _WrapStdOut: + """Dynamically wrap to sys.stdout. + + This makes packages that monkey-patch sys.stdout (e.g.doctest, sphinx-gallery) work + properly. + """ + + def __getattr__(self, name): ... diff --git a/pycrostates/utils/_imports.pyi b/pycrostates/utils/_imports.pyi new file mode 100644 index 00000000..96572acc --- /dev/null +++ b/pycrostates/utils/_imports.pyi @@ -0,0 +1,30 @@ +from ._logs import logger as logger + +_INSTALL_MAPPING: dict[str, str] + +def import_optional_dependency(name: str, extra: str = "", raise_error: bool = True): + """ + Import an optional dependency. + + By default, if a dependency is missing an ImportError with a nice message will be + raised. + + Parameters + ---------- + name : str + The module name. + extra : str + Additional text to include in the ImportError message. + raise_error : bool + What to do when a dependency is not found. + * True : If the module is not installed, raise an ImportError, otherwise, return + the module. + * False: If the module is not installed, issue a warning and return None, + otherwise, return the module. + + Returns + ------- + maybe_module : Optional[ModuleType] + The imported module when found. + None is returned when the package is not found and raise_error is False. + """ diff --git a/pycrostates/utils/_logs.pyi b/pycrostates/utils/_logs.pyi new file mode 100644 index 00000000..05bff187 --- /dev/null +++ b/pycrostates/utils/_logs.pyi @@ -0,0 +1,119 @@ +import logging +from pathlib import Path as Path +from typing import Callable, Optional, Union + +from _typeshed import Incomplete + +from ._checks import _check_type as _check_type +from ._checks import _check_verbose as _check_verbose +from ._docs import fill_doc as fill_doc +from ._fixes import _WrapStdOut as _WrapStdOut + +def _init_logger(*, verbose: Optional[Union[bool, str, int]] = None) -> logging.Logger: + """Initialize a logger. + + Assigns sys.stdout as the first handler of the logger. + + Parameters + ---------- + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + + Returns + ------- + logger : Logger + The initialized logger. + """ + +def add_file_handler( + fname: Union[str, Path], + mode: str = "a", + encoding: Optional[str] = None, + *, + verbose: Optional[Union[bool, str, int]] = None, +) -> None: + """Add a file handler to the logger. + + Parameters + ---------- + fname : str | Path + Path to the file where the logging output is saved. + mode : str + Mode in which the file is opened. + encoding : str | None + If not None, encoding used to open the file. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + """ + +def set_log_level( + verbose: Union[bool, str, int, None], apply_to_mne: bool = True +) -> None: + """Set the log level for the logger and the first handler ``sys.stdout``. + + Parameters + ---------- + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + apply_to_mne : bool + If True, also changes the log level of MNE. + """ + +class _LoggerFormatter(logging.Formatter): + """Format string Syntax for pycrostates.""" + + _formatters: Incomplete + + def __init__(self) -> None: ... + def format(self, record): + """ + Format the received log record. + + Parameters + ---------- + record : logging.LogRecord + """ + +def verbose(f: Callable) -> Callable: + """Set the verbose for the function call from the kwargs. + + Parameters + ---------- + f : callable + The function with a verbose argument. + + Returns + ------- + f : callable + The function. + """ + +class _use_log_level: + """Context manager to change the logging level temporary. + + Parameters + ---------- + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + """ + + _old_level: Incomplete + _level: Incomplete + + def __init__(self, verbose: Optional[Union[bool, str, int]] = None) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, *args) -> None: ... + +logger: Incomplete diff --git a/pycrostates/utils/mixin.pyi b/pycrostates/utils/mixin.pyi new file mode 100644 index 00000000..b0198197 --- /dev/null +++ b/pycrostates/utils/mixin.pyi @@ -0,0 +1,16 @@ +from mne.io.meas_info import ContainsMixin as MNEContainsMixin +from mne.io.meas_info import MontageMixin as MNEMontageMixin + +from ._docs import copy_doc as copy_doc + +class ChannelsMixin: + """Channels Mixin for futur implementation.""" + +class ContainsMixin(MNEContainsMixin): + def __contains__(self, ch_type) -> bool: ... + def __getattribute__(self, name): + """Attribute getter.""" + +class MontageMixin(MNEMontageMixin): + def __getattribute__(self, name): + """Attribute getter.""" diff --git a/pycrostates/utils/sys_info.pyi b/pycrostates/utils/sys_info.pyi new file mode 100644 index 00000000..7b9374a5 --- /dev/null +++ b/pycrostates/utils/sys_info.pyi @@ -0,0 +1,22 @@ +from typing import IO, Callable, Optional + +from packaging.requirements import Requirement + +from ._checks import _check_type as _check_type + +def sys_info(fid: Optional[IO] = None, developer: bool = False): + """Print the system information for debugging. + + Parameters + ---------- + fid : file-like | None + The file to write to, passed to :func:`print`. Can be None to use + :data:`sys.stdout`. + developer : bool + If True, display information about optional dependencies. + """ + +def _list_dependencies_info( + out: Callable, ljust: int, package: str, dependencies: list[Requirement] +): + """List dependencies names and versions.""" diff --git a/pycrostates/utils/utils.pyi b/pycrostates/utils/utils.pyi new file mode 100644 index 00000000..24ca6443 --- /dev/null +++ b/pycrostates/utils/utils.pyi @@ -0,0 +1,33 @@ +from _typeshed import Incomplete + +from ._logs import logger as logger + +def _corr_vectors(A, B, axis: int = 0): + """Compute pairwise correlation of multiple pairs of vectors. + + Fast way to compute correlation of multiple pairs of vectors without computing all + pairs as would with corr(A,B). Borrowed from Oli at StackOverflow. Note the + resulting coefficients vary slightly from the ones obtained from corr due to + differences in the order of the calculations. (Differences are of a magnitude of + 1e-9 to 1e-17 depending on the tested data). + + Parameters + ---------- + A : ndarray, shape (n, m) + The first collection of vectors + B : ndarray, shape (n, m) + The second collection of vectors + axis : int + The axis that contains the elements of each vector. Defaults to 0. + + Returns + ------- + corr : ndarray, shape (m, ) + For each pair of vectors, the correlation between them. + """ + +def _distance_matrix(X, Y: Incomplete | None = None): + """Distance matrix used in metrics.""" + +def _compare_infos(cluster_info, inst_info): + """Check that channels in cluster_info are all present in inst_info.""" diff --git a/pycrostates/viz/__init__.pyi b/pycrostates/viz/__init__.pyi new file mode 100644 index 00000000..2f038a90 --- /dev/null +++ b/pycrostates/viz/__init__.pyi @@ -0,0 +1,5 @@ +from .cluster_centers import plot_cluster_centers as plot_cluster_centers +from .segmentation import plot_epoch_segmentation as plot_epoch_segmentation +from .segmentation import plot_raw_segmentation as plot_raw_segmentation + +__all__: tuple[str, ...] diff --git a/pycrostates/viz/cluster_centers.pyi b/pycrostates/viz/cluster_centers.pyi new file mode 100644 index 00000000..40d64698 --- /dev/null +++ b/pycrostates/viz/cluster_centers.pyi @@ -0,0 +1,60 @@ +from typing import Any, Optional, Union + +from matplotlib.axes import Axes +from mne import Info +from numpy.typing import NDArray + +from .._typing import CHInfo as CHInfo +from ..utils._checks import _check_axes as _check_axes +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger + +_GRADIENT_KWARGS_DEFAULTS: dict[str, str] + +def plot_cluster_centers( + cluster_centers: NDArray[float], + info: Union[Info, CHInfo], + cluster_names: list[str] = None, + axes: Optional[Union[Axes, NDArray[Axes]]] = None, + show_gradient: Optional[bool] = False, + gradient_kwargs: dict[str, Any] = ..., + *, + block: bool = False, + verbose: Optional[str] = None, + **kwargs, +): + """Create topographic maps for cluster centers. + + Parameters + ---------- + cluster_centers : array (n_clusters, n_channels) + Fitted clusters, i.e. the microstates maps. + info : Info | ChInfo + Info instance with a montage used to plot the topographic maps. + cluster_names : list | None + Name of the clusters. + axes : Axes | None + Either ``None`` to create a new figure or axes (or an array of axes) on which the + topographic map should be plotted. If the number of microstates maps to plot is + ``≥ 1``, an array of axes of size ``n_clusters`` should be provided. + show_gradient : bool + If True, plot a line between channel locations with highest and lowest values. + gradient_kwargs : dict + Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.plot` to plot + gradient line. + block : bool + Whether to halt program execution until the figure is closed. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + **kwargs + Additional keyword arguments are passed to :func:`mne.viz.plot_topomap`. + + Returns + ------- + fig : Figure + Matplotlib figure(s) on which topographic maps are plotted. + """ diff --git a/pycrostates/viz/segmentation.pyi b/pycrostates/viz/segmentation.pyi new file mode 100644 index 00000000..7790751d --- /dev/null +++ b/pycrostates/viz/segmentation.pyi @@ -0,0 +1,138 @@ +from typing import Optional, Union + +from matplotlib import colors +from matplotlib.axes import Axes +from mne import BaseEpochs +from mne.io import BaseRaw +from numpy.typing import NDArray + +from ..utils._checks import _check_type as _check_type +from ..utils._docs import fill_doc as fill_doc +from ..utils._logs import logger as logger + +def plot_raw_segmentation( + labels: NDArray[int], + raw: BaseRaw, + n_clusters: int, + cluster_names: list[str] = None, + tmin: Optional[Union[int, float]] = None, + tmax: Optional[Union[int, float]] = None, + cmap: Optional[str] = None, + axes: Optional[Axes] = None, + cbar_axes: Optional[Axes] = None, + *, + block: bool = False, + verbose: Optional[str] = None, + **kwargs, +): + """Plot raw segmentation. + + Parameters + ---------- + labels : array of shape ``(n_samples,)`` + Microstates labels attributed to each sample, i.e. the segmentation. + raw : Raw + MNE `~mne.io.Raw` instance. + n_clusters : int + The number of clusters, i.e. the number of microstates. + cluster_names : list | None + Name of the clusters. + tmin : float + Start time of the raw data to use in seconds (must be >= 0). + tmax : float + End time of the raw data to use in seconds (cannot exceed data duration). + cmap : str | colormap | None + The colormap to use. If None, ``viridis`` is used. + axes : Axes | None + Either ``None`` to create a new figure or axes on which the segmentation is + plotted. + cbar_axes : Axes | None + Axes on which to draw the colorbar, otherwise the colormap takes space from the main + axes. + block : bool + Whether to halt program execution until the figure is closed. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + **kwargs + Kwargs are passed to ``axes.plot``. + + Returns + ------- + fig : Figure + Matplotlib figure(s) on which topographic maps are plotted. + """ + +def plot_epoch_segmentation( + labels: NDArray[int], + epochs: BaseEpochs, + n_clusters: int, + cluster_names: list[str] = None, + cmap: Optional[str] = None, + axes: Optional[Axes] = None, + cbar_axes: Optional[Axes] = None, + *, + block: bool = False, + verbose: Optional[str] = None, + **kwargs, +): + """ + Plot epochs segmentation. + + Parameters + ---------- + labels : array of shape ``(n_epochs, n_samples)`` + Microstates labels attributed to each sample, i.e. the segmentation. + epochs : Epochs + MNE `~mne.Epochs` instance. + n_clusters : int + The number of clusters, i.e. the number of microstates. + cluster_names : list | None + Name of the clusters. + cmap : str | colormap | None + The colormap to use. If None, ``viridis`` is used. + axes : Axes | None + Either ``None`` to create a new figure or axes on which the segmentation is + plotted. + cbar_axes : Axes | None + Axes on which to draw the colorbar, otherwise the colormap takes space from the main + axes. + block : bool + Whether to halt program execution until the figure is closed. + verbose : int | str | bool | None + Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``, + ``"ERROR"``, ``"WARNING"``, ``"INFO"`` and ``"DEBUG"``. If None is provided, the + verbosity is set to ``"WARNING"``. If a bool is provided, the verbosity is set to + ``"WARNING"`` for False and to ``"INFO"`` for True. + **kwargs + Kwargs are passed to ``axes.plot``. + + Returns + ------- + fig : Figure + Matplotlib figure on which topographic maps are plotted. + """ + +def _plot_segmentation( + labels: NDArray[int], + gfp: NDArray[float], + times: NDArray[float], + n_clusters: int, + cluster_names: list[str] = None, + cmap: Optional[Union[str, colors.Colormap]] = None, + axes: Optional[Axes] = None, + cbar_axes: Optional[Axes] = None, + *, + verbose: Optional[str] = None, + **kwargs, +): + """Code snippet to plot segmentation for raw and epochs.""" + +def _compatibility_cmap(cmap: Optional[Union[str, colors.Colormap]], n_colors: int): + """Convert the 'cmap' argument to a colormap. + + Matplotlib 3.6 introduced a deprecation of plt.cm.get_cmap(). + When support for the 3.6 version is dropped, this checker can be removed. + """