Skip to content

Commit

Permalink
Fix format of stubs file (#148)
Browse files Browse the repository at this point in the history
* fix format of stubs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* try config to run stubs generation on PR

* remove push on main to avoid spamming the main branch on every change

* fix condition

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mscheltienne and pre-commit-ci[bot] authored Jan 17, 2024
1 parent 5e3b810 commit b1d7996
Show file tree
Hide file tree
Showing 33 changed files with 476 additions and 192 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/stubs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }}
cancel-in-progress: true
on: # yamllint disable-line rule:truthy
pull_request:
schedule:
- cron: '0 3 * * *'
workflow_dispatch:
Expand Down Expand Up @@ -31,6 +32,7 @@ jobs:
- name: Generate stub files
run: python tools/stubgen.py
- name: Push stub files
if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
run: |
git config --global user.name 'github-actions[bot]'
git config --global user.email 'github-actions[bot]@users.noreply.github.com'
Expand Down
3 changes: 2 additions & 1 deletion pycrostates/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ class Cluster(ABC):

class Segmentation(ABC):
"""Typing for a clustering class."""

RANDomState = Optional[Union[int, RandomState, Generator]]
Picks = Optional[Union[str, NDArray[int]]]
Picks = Optional[Union[str, NDArray[int]]]
116 changes: 97 additions & 19 deletions pycrostates/cluster/_base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ from .utils import optimize_order as optimize_order

class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
"""Base Class for Microstates Clustering algorithms."""

_n_clusters: Incomplete
_cluster_names: Incomplete
_cluster_centers_: Incomplete
Expand All @@ -38,13 +39,11 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
_fitted: bool

@abstractmethod
def __init__(self):
...

def __init__(self): ...
def __repr__(self) -> str:
"""String representation."""

def _repr_html_(self, caption: Incomplete | None=None):
def _repr_html_(self, caption: Incomplete | None = None):
"""HTML representation."""

def __eq__(self, other: Any) -> bool:
Expand All @@ -53,7 +52,7 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
def __ne__(self, other: Any) -> bool:
"""Different != method."""

def copy(self, deep: bool=True):
def copy(self, deep: bool = True):
"""Return a copy of the instance.
Parameters
Expand All @@ -69,7 +68,16 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
"""Check if the cluster is unfitted."""

@abstractmethod
def fit(self, inst: Union[BaseRaw, BaseEpochs, CHData], picks: Picks='eeg', tmin: Optional[Union[int, float]]=None, tmax: Optional[Union[int, float]]=None, reject_by_annotation: bool=True, *, verbose: Optional[str]=None) -> NDArray[float]:
def fit(
self,
inst: Union[BaseRaw, BaseEpochs, CHData],
picks: Picks = "eeg",
tmin: Optional[Union[int, float]] = None,
tmax: Optional[Union[int, float]] = None,
reject_by_annotation: bool = True,
*,
verbose: Optional[str] = None,
) -> NDArray[float]:
"""Compute cluster centers.
Parameters
Expand All @@ -93,7 +101,7 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
Whether to omit bad segments from the data before fitting. If ``True``
(default), annotated segments whose description begins with ``'bad'`` are
omitted. If ``False``, no rejection based on annotations is performed.
Has no effect if ``inst`` is not a :class:`mne.io.Raw` object.
verbose : int | str | bool | None
Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``,
Expand All @@ -102,7 +110,11 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
``"WARNING"`` for False and to ``"INFO"`` for True.
"""

def rename_clusters(self, mapping: Optional[dict[str, str]]=None, new_names: Optional[Union[list[str], tuple[str, ...]]]=None) -> None:
def rename_clusters(
self,
mapping: Optional[dict[str, str]] = None,
new_names: Optional[Union[list[str], tuple[str, ...]]] = None,
) -> None:
"""Rename the clusters.
Parameters
Expand All @@ -119,7 +131,12 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
Operates in-place.
"""

def reorder_clusters(self, mapping: Optional[dict[int, int]]=None, order: Optional[Union[list[int], tuple[int, ...], NDArray[int]]]=None, template: Optional[Cluster]=None) -> None:
def reorder_clusters(
self,
mapping: Optional[dict[int, int]] = None,
order: Optional[Union[list[int], tuple[int, ...], NDArray[int]]] = None,
template: Optional[Cluster] = None,
) -> None:
"""
Reorder the clusters of the fitted model.
Expand Down Expand Up @@ -149,7 +166,9 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
Operates in-place.
"""

def invert_polarity(self, invert: Union[bool, list[bool], tuple[bool, ...], NDArray[bool]]) -> None:
def invert_polarity(
self, invert: Union[bool, list[bool], tuple[bool, ...], NDArray[bool]]
) -> None:
"""Invert map polarities.
Parameters
Expand All @@ -169,7 +188,20 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
an article).
"""

def plot(self, axes: Optional[Union[Axes, NDArray[Axes]]]=None, show_gradient: Optional[bool]=False, gradient_kwargs: dict[str, Any]={'color': 'black', 'linestyle': '-', 'marker': 'P'}, *, block: bool=False, verbose: Optional[str]=None, **kwargs):
def plot(
self,
axes: Optional[Union[Axes, NDArray[Axes]]] = None,
show_gradient: Optional[bool] = False,
gradient_kwargs: dict[str, Any] = {
"color": "black",
"linestyle": "-",
"marker": "P",
},
*,
block: bool = False,
verbose: Optional[str] = None,
**kwargs,
):
"""
Plot cluster centers as topographic maps.
Expand Down Expand Up @@ -211,7 +243,19 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
Path to the ``.fif`` file where the clustering solution is saved.
"""

def predict(self, inst: Union[BaseRaw, BaseEpochs], picks: Picks=None, factor: int=0, half_window_size: int=1, tol: Union[int, float]=1e-05, min_segment_length: int=0, reject_edges: bool=True, reject_by_annotation: bool=True, *, verbose: Optional[str]=None):
def predict(
self,
inst: Union[BaseRaw, BaseEpochs],
picks: Picks = None,
factor: int = 0,
half_window_size: int = 1,
tol: Union[int, float] = 1e-05,
min_segment_length: int = 0,
reject_edges: bool = True,
reject_by_annotation: bool = True,
*,
verbose: Optional[str] = None,
):
"""Segment `~mne.io.Raw` or `~mne.Epochs` into microstate sequence.
Segment instance into microstate sequence using the segmentation smoothing
Expand Down Expand Up @@ -249,7 +293,7 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
Whether to omit bad segments from the data before fitting. If ``True``
(default), annotated segments whose description begins with ``'bad'`` are
omitted. If ``False``, no rejection based on annotations is performed.
Has no effect if ``inst`` is not a :class:`mne.io.Raw` object.
verbose : int | str | bool | None
Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``,
Expand All @@ -269,18 +313,50 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
.. footbibliography::
"""

def _predict_raw(self, raw: BaseRaw, picks_data: NDArray[int], factor: int, tol: Union[int, float], half_window_size: int, min_segment_length: int, reject_edges: bool, reject_by_annotation: bool) -> RawSegmentation:
def _predict_raw(
self,
raw: BaseRaw,
picks_data: NDArray[int],
factor: int,
tol: Union[int, float],
half_window_size: int,
min_segment_length: int,
reject_edges: bool,
reject_by_annotation: bool,
) -> RawSegmentation:
"""Create segmentation for raw."""

def _predict_epochs(self, epochs: BaseEpochs, picks_data: NDArray[int], factor: int, tol: Union[int, float], half_window_size: int, min_segment_length: int, reject_edges: bool) -> EpochsSegmentation:
def _predict_epochs(
self,
epochs: BaseEpochs,
picks_data: NDArray[int],
factor: int,
tol: Union[int, float],
half_window_size: int,
min_segment_length: int,
reject_edges: bool,
) -> EpochsSegmentation:
"""Create segmentation for epochs."""

@staticmethod
def _segment(data: NDArray[float], states: NDArray[float], factor: int, tol: Union[int, float], half_window_size: int) -> NDArray[int]:
def _segment(
data: NDArray[float],
states: NDArray[float],
factor: int,
tol: Union[int, float],
half_window_size: int,
) -> NDArray[int]:
"""Create segmentation. Must operate on a copy of states."""

@staticmethod
def _smooth_segmentation(data: NDArray[float], states: NDArray[float], labels: NDArray[int], factor: int, tol: Union[int, float], half_window_size: int) -> NDArray[int]:
def _smooth_segmentation(
data: NDArray[float],
states: NDArray[float],
labels: NDArray[int],
factor: int,
tol: Union[int, float],
half_window_size: int,
) -> NDArray[int]:
"""Apply smoothing.
Adapted from [1].
Expand All @@ -296,7 +372,9 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
"""

@staticmethod
def _reject_short_segments(segmentation: NDArray[int], data: NDArray[float], min_segment_length: int) -> NDArray[int]:
def _reject_short_segments(
segmentation: NDArray[int], data: NDArray[float], min_segment_length: int
) -> NDArray[int]:
"""Reject segments that are too short.
Reject segments that are too short by replacing the labels with the adjacent
Expand Down Expand Up @@ -374,4 +452,4 @@ class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):

@staticmethod
def _check_n_clusters(n_clusters: int) -> int:
"""Check that the number of clusters is a positive integer."""
"""Check that the number of clusters is a positive integer."""
38 changes: 27 additions & 11 deletions pycrostates/cluster/aahc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,15 @@ class AAHCluster(_BaseCluster):
----------
.. footbibliography::
"""

_n_clusters: Incomplete
_cluster_names: Incomplete
_ignore_polarity: bool
_normalize_input: Incomplete
_GEV_: Incomplete

def __init__(self, n_clusters: int, normalize_input: bool=False) -> None:
...

def _repr_html_(self, caption: Incomplete | None=None):
...

def __init__(self, n_clusters: int, normalize_input: bool = False) -> None: ...
def _repr_html_(self, caption: Incomplete | None = None): ...
def __eq__(self, other: Any) -> bool:
"""Equality == method."""

Expand All @@ -54,7 +51,16 @@ class AAHCluster(_BaseCluster):
_labels_: Incomplete
_fitted: bool

def fit(self, inst: Union[BaseRaw, BaseEpochs], picks: Picks='eeg', tmin: Optional[Union[int, float]]=None, tmax: Optional[Union[int, float]]=None, reject_by_annotation: bool=True, *, verbose: Optional[str]=None) -> None:
def fit(
self,
inst: Union[BaseRaw, BaseEpochs],
picks: Picks = "eeg",
tmin: Optional[Union[int, float]] = None,
tmax: Optional[Union[int, float]] = None,
reject_by_annotation: bool = True,
*,
verbose: Optional[str] = None,
) -> None:
"""Compute cluster centers.
Parameters
Expand All @@ -78,7 +84,7 @@ class AAHCluster(_BaseCluster):
Whether to omit bad segments from the data before fitting. If ``True``
(default), annotated segments whose description begins with ``'bad'`` are
omitted. If ``False``, no rejection based on annotations is performed.
Has no effect if ``inst`` is not a :class:`mne.io.Raw` object.
verbose : int | str | bool | None
Sets the verbosity level. The verbosity increases gradually between ``"CRITICAL"``,
Expand All @@ -97,11 +103,21 @@ class AAHCluster(_BaseCluster):
"""

@staticmethod
def _aahc(data: NDArray[float], n_clusters: int, ignore_polarity: bool, normalize_input: bool) -> tuple[float, NDArray[float], NDArray[int]]:
def _aahc(
data: NDArray[float],
n_clusters: int,
ignore_polarity: bool,
normalize_input: bool,
) -> tuple[float, NDArray[float], NDArray[int]]:
"""Run the AAHC algorithm."""

@staticmethod
def _compute_maps(data: NDArray[float], n_clusters: int, ignore_polarity: bool, normalize_input: bool) -> tuple[NDArray[float], NDArray[int]]:
def _compute_maps(
data: NDArray[float],
n_clusters: int,
ignore_polarity: bool,
normalize_input: bool,
) -> tuple[NDArray[float], NDArray[int]]:
"""Compute microstates maps."""

@property
Expand Down Expand Up @@ -131,4 +147,4 @@ class AAHCluster(_BaseCluster):

@staticmethod
def _check_normalize_input(normalize_input: bool) -> bool:
"""Check that normalize_input is a boolean."""
"""Check that normalize_input is a boolean."""
Loading

0 comments on commit b1d7996

Please sign in to comment.