Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
vferat committed Feb 20, 2024
1 parent 4275359 commit 22a3a7e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
23 changes: 14 additions & 9 deletions pycrostates/cluster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,22 +852,22 @@ def _predict_raw(

data_ = data[:, onset:end]
segment = _BaseCluster._segment(
data_, cluster_centers_, factor, tol, half_window_size
data_, cluster_centers_, self._ignore_polarity, factor, tol, half_window_size
)
if reject_edges:
segment = _BaseCluster._reject_edge_segments(segment)
segmentation[onset:end] = segment

else:
segmentation = _BaseCluster._segment(
data, cluster_centers_, factor, tol, half_window_size
data, cluster_centers_, self._ignore_polarity, factor, tol, half_window_size
)
if reject_edges:
segmentation = _BaseCluster._reject_edge_segments(segmentation)

if 0 < min_segment_length:
segmentation = _BaseCluster._reject_short_segments(
segmentation, data, min_segment_length
segmentation, data, min_segment_length, self._ignore_polarity
)

# Provide properties to copy the arrays
Expand Down Expand Up @@ -906,12 +906,12 @@ def _predict_epochs(
segments = []
for epoch_data in data:
segment = _BaseCluster._segment(
epoch_data, cluster_centers_, factor, tol, half_window_size
epoch_data, cluster_centers_, self._ignore_polarity, factor, tol, half_window_size
)

if 0 < min_segment_length:
segment = _BaseCluster._reject_short_segments(
segment, epoch_data, min_segment_length
segment, epoch_data, min_segment_length, self._ignore_polarity
)
if reject_edges:
segment = _BaseCluster._reject_edge_segments(segment)
Expand All @@ -932,6 +932,7 @@ def _predict_epochs(
def _segment(
data: NDArray[float],
states: NDArray[float],
ignore_polarity: bool,
factor: int,
tol: Union[int, float],
half_window_size: int,
Expand All @@ -945,11 +946,12 @@ def _segment(
states -= np.mean(states, axis=1)[:, np.newaxis]
states /= np.std(states, axis=1)[:, np.newaxis]

_correlation(data, states, ignore_polarity=ignore_polarity)
labels = np.argmax(np.abs(np.dot(states, data)), axis=0)

if factor != 0:
labels = _BaseCluster._smooth_segmentation(
data, states, labels, factor, tol, half_window_size
data, states, labels, ignore_polarity, factor, tol, half_window_size
)

return labels
Expand All @@ -958,6 +960,7 @@ def _segment(
def _smooth_segmentation(
data: NDArray[float],
states: NDArray[float],
ignore_polarity: bool,
labels: NDArray[int],
factor: int,
tol: Union[int, float],
Expand All @@ -976,6 +979,7 @@ def _smooth_segmentation(
vol. 42, no. 7, pp. 658-665, July 1995,
https://doi.org/10.1109/10.391164.
"""
# TODO: ignore_polarity
Ne, Nt = data.shape
Nu = states.shape[0]
Vvar = np.sum(data * data, axis=0)
Expand Down Expand Up @@ -1013,6 +1017,7 @@ def _reject_short_segments(
segmentation: NDArray[int],
data: NDArray[float],
min_segment_length: int,
ignore_polarity: bool
) -> NDArray[int]:
"""Reject segments that are too short.
Expand Down Expand Up @@ -1042,13 +1047,13 @@ def _reject_short_segments(
# compute correlation left/right side
left_corr = np.abs(
_correlation(
data[:, left - 1], data[:, left], ignore_polarity=True
)
data[:, left - 1].reshape(-1, 1), data[:, left].reshape(-1, 1), ignore_polarity=True
)[0, 0]
)
right_corr = np.abs(
_correlation(
data[:, right], data[:, right + 1], ignore_polarity=True
)
)[0, 0]
)

if np.abs(right_corr - left_corr) <= 1e-8:
Expand Down
2 changes: 1 addition & 1 deletion pycrostates/metrics/davies_bouldin.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _davies_bouldin_score(X, labels, ignore_polarity):
# Calculate the centroids of the clusters
centroids = np.array([np.mean(X[labels == i], axis=0) for i in range(num_clusters)])

# Calculate pairwise distances between centroids using custom distance function
# Calculate pairwise distances between centroids
centroid_distances = _distance(centroids, ignore_polarity=ignore_polarity)

# Initialize array to hold scatter values for each cluster
Expand Down

0 comments on commit 22a3a7e

Please sign in to comment.