diff --git a/pycrostates/cluster/_base.py b/pycrostates/cluster/_base.py index b8307113..4ed3613d 100644 --- a/pycrostates/cluster/_base.py +++ b/pycrostates/cluster/_base.py @@ -852,7 +852,7 @@ def _predict_raw( data_ = data[:, onset:end] segment = _BaseCluster._segment( - data_, cluster_centers_, factor, tol, half_window_size + data_, cluster_centers_, self._ignore_polarity, factor, tol, half_window_size ) if reject_edges: segment = _BaseCluster._reject_edge_segments(segment) @@ -860,14 +860,14 @@ def _predict_raw( else: segmentation = _BaseCluster._segment( - data, cluster_centers_, factor, tol, half_window_size + data, cluster_centers_, self._ignore_polarity, factor, tol, half_window_size ) if reject_edges: segmentation = _BaseCluster._reject_edge_segments(segmentation) if 0 < min_segment_length: segmentation = _BaseCluster._reject_short_segments( - segmentation, data, min_segment_length + segmentation, data, min_segment_length, self._ignore_polarity ) # Provide properties to copy the arrays @@ -906,12 +906,12 @@ def _predict_epochs( segments = [] for epoch_data in data: segment = _BaseCluster._segment( - epoch_data, cluster_centers_, factor, tol, half_window_size + epoch_data, cluster_centers_, self._ignore_polarity, factor, tol, half_window_size ) if 0 < min_segment_length: segment = _BaseCluster._reject_short_segments( - segment, epoch_data, min_segment_length + segment, epoch_data, min_segment_length, self._ignore_polarity ) if reject_edges: segment = _BaseCluster._reject_edge_segments(segment) @@ -932,6 +932,7 @@ def _predict_epochs( def _segment( data: NDArray[float], states: NDArray[float], + ignore_polarity: bool, factor: int, tol: Union[int, float], half_window_size: int, @@ -945,11 +946,12 @@ def _segment( states -= np.mean(states, axis=1)[:, np.newaxis] states /= np.std(states, axis=1)[:, np.newaxis] + _correlation(data, states, ignore_polarity=ignore_polarity) labels = np.argmax(np.abs(np.dot(states, data)), axis=0) if factor != 0: labels = _BaseCluster._smooth_segmentation( - data, states, labels, factor, tol, half_window_size + data, states, labels, ignore_polarity, factor, tol, half_window_size ) return labels @@ -958,6 +960,7 @@ def _segment( def _smooth_segmentation( data: NDArray[float], states: NDArray[float], + ignore_polarity: bool, labels: NDArray[int], factor: int, tol: Union[int, float], @@ -976,6 +979,7 @@ def _smooth_segmentation( vol. 42, no. 7, pp. 658-665, July 1995, https://doi.org/10.1109/10.391164. """ + # TODO: ignore_polarity Ne, Nt = data.shape Nu = states.shape[0] Vvar = np.sum(data * data, axis=0) @@ -1013,6 +1017,7 @@ def _reject_short_segments( segmentation: NDArray[int], data: NDArray[float], min_segment_length: int, + ignore_polarity: bool ) -> NDArray[int]: """Reject segments that are too short. @@ -1042,13 +1047,13 @@ def _reject_short_segments( # compute correlation left/right side left_corr = np.abs( _correlation( - data[:, left - 1], data[:, left], ignore_polarity=True - ) + data[:, left - 1].reshape(-1, 1), data[:, left].reshape(-1, 1), ignore_polarity=True + )[0, 0] ) right_corr = np.abs( _correlation( data[:, right], data[:, right + 1], ignore_polarity=True - ) + )[0, 0] ) if np.abs(right_corr - left_corr) <= 1e-8: diff --git a/pycrostates/metrics/davies_bouldin.py b/pycrostates/metrics/davies_bouldin.py index 6859d3cb..c784e718 100644 --- a/pycrostates/metrics/davies_bouldin.py +++ b/pycrostates/metrics/davies_bouldin.py @@ -77,7 +77,7 @@ def _davies_bouldin_score(X, labels, ignore_polarity): # Calculate the centroids of the clusters centroids = np.array([np.mean(X[labels == i], axis=0) for i in range(num_clusters)]) - # Calculate pairwise distances between centroids using custom distance function + # Calculate pairwise distances between centroids centroid_distances = _distance(centroids, ignore_polarity=ignore_polarity) # Initialize array to hold scatter values for each cluster