diff --git a/pycrostates/cluster/_base.py b/pycrostates/cluster/_base.py index a010a285..4dbad41b 100644 --- a/pycrostates/cluster/_base.py +++ b/pycrostates/cluster/_base.py @@ -953,22 +953,15 @@ def _segment( half_window_size: int, ) -> NDArray[int]: """Create segmentation. Must operate on a copy of states.""" - data -= np.mean(data, axis=0) - std = np.std(data, axis=0) - std[std == 0] = 1 # std == 0 -> null map - data /= std - - states -= np.mean(states, axis=1)[:, np.newaxis] - states /= np.std(states, axis=1)[:, np.newaxis] - - _correlation(data, states, ignore_polarity=ignore_polarity) - labels = np.argmax(np.abs(np.dot(states, data)), axis=0) + corr = _correlation(states, data, ignore_polarity=ignore_polarity)[:len(states), :] + if ignore_polarity: + corr = np.abs(corr) + labels = np.argmax(corr, axis=0) if factor != 0: labels = _BaseCluster._smooth_segmentation( data, states, labels, ignore_polarity, factor, tol, half_window_size ) - return labels @staticmethod @@ -1064,12 +1057,12 @@ def _reject_short_segments( _correlation( data[:, left - 1].reshape(-1, 1), data[:, left].reshape(-1, 1), - ignore_polarity=True, + ignore_polarity=ignore_polarity, )[0, 0] ) right_corr = np.abs( _correlation( - data[:, right], data[:, right + 1], ignore_polarity=True + data[:, right], data[:, right + 1], ignore_polarity=ignore_polarity )[0, 0] ) diff --git a/pycrostates/segmentation/_base.py b/pycrostates/segmentation/_base.py index f1825613..95fd91d2 100644 --- a/pycrostates/segmentation/_base.py +++ b/pycrostates/segmentation/_base.py @@ -155,7 +155,7 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False): # create a 1D view of the labels array labels = labels.reshape(-1) - gfp = np.std(data, axis=0) + gfp = np.std(data, axis=0) #TODO: change gfp if norm_gfp: labeled = np.argwhere(labels != -1) # ignore unlabeled segments gfp /= np.linalg.norm(gfp[labeled]) # normalize @@ -171,9 +171,9 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False): labeled_gfp = gfp[arg_where][:, 0] state_array = np.array([state] * len(arg_where)).transpose() - dist_corr = _correlation(state_array, labeled_tp, ignore_polarity=True) + dist_corr = _correlation(state_array, labeled_tp, ignore_polarity=True) #TODO: ignore_polarity params[f"{state_name}_mean_corr"] = np.mean(np.abs(dist_corr)) - dist_gev = (labeled_gfp * dist_corr) ** 2 / np.sum(gfp**2) + dist_gev = (labeled_gfp * dist_corr) ** 2 / np.sum(gfp**2) #TODO: gev params[f"{state_name}_gev"] = np.sum(dist_gev) s_segments = np.array([len(group) for s_, group in segments if s_ == s])