Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
vferat committed Feb 20, 2024
1 parent f473bdd commit 829f87f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 16 deletions.
19 changes: 6 additions & 13 deletions pycrostates/cluster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,22 +953,15 @@ def _segment(
half_window_size: int,
) -> NDArray[int]:
"""Create segmentation. Must operate on a copy of states."""
data -= np.mean(data, axis=0)
std = np.std(data, axis=0)
std[std == 0] = 1 # std == 0 -> null map
data /= std

states -= np.mean(states, axis=1)[:, np.newaxis]
states /= np.std(states, axis=1)[:, np.newaxis]

_correlation(data, states, ignore_polarity=ignore_polarity)
labels = np.argmax(np.abs(np.dot(states, data)), axis=0)
corr = _correlation(states, data, ignore_polarity=ignore_polarity)[:len(states), :]
if ignore_polarity:
corr = np.abs(corr)
labels = np.argmax(corr, axis=0)

if factor != 0:
labels = _BaseCluster._smooth_segmentation(
data, states, labels, ignore_polarity, factor, tol, half_window_size
)

return labels

@staticmethod
Expand Down Expand Up @@ -1064,12 +1057,12 @@ def _reject_short_segments(
_correlation(
data[:, left - 1].reshape(-1, 1),
data[:, left].reshape(-1, 1),
ignore_polarity=True,
ignore_polarity=ignore_polarity,
)[0, 0]
)
right_corr = np.abs(
_correlation(
data[:, right], data[:, right + 1], ignore_polarity=True
data[:, right], data[:, right + 1], ignore_polarity=ignore_polarity
)[0, 0]
)

Expand Down
6 changes: 3 additions & 3 deletions pycrostates/segmentation/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False):
# create a 1D view of the labels array
labels = labels.reshape(-1)

gfp = np.std(data, axis=0)
gfp = np.std(data, axis=0) #TODO: change gfp
if norm_gfp:
labeled = np.argwhere(labels != -1) # ignore unlabeled segments
gfp /= np.linalg.norm(gfp[labeled]) # normalize
Expand All @@ -171,9 +171,9 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False):
labeled_gfp = gfp[arg_where][:, 0]
state_array = np.array([state] * len(arg_where)).transpose()

dist_corr = _correlation(state_array, labeled_tp, ignore_polarity=True)
dist_corr = _correlation(state_array, labeled_tp, ignore_polarity=True) #TODO: ignore_polarity
params[f"{state_name}_mean_corr"] = np.mean(np.abs(dist_corr))
dist_gev = (labeled_gfp * dist_corr) ** 2 / np.sum(gfp**2)
dist_gev = (labeled_gfp * dist_corr) ** 2 / np.sum(gfp**2) #TODO: gev
params[f"{state_name}_gev"] = np.sum(dist_gev)

s_segments = np.array([len(group) for s_, group in segments if s_ == s])
Expand Down

0 comments on commit 829f87f

Please sign in to comment.