From 6cbc09ebf352e330b50dcd7334a5938cda5ae2f0 Mon Sep 17 00:00:00 2001 From: LLehner Date: Thu, 10 Oct 2024 14:12:13 +0200 Subject: [PATCH] Fix test --- src/squidpy/gr/_niche.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 10457449..e7231f9a 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -317,8 +317,9 @@ def _aggregate(adata: AnnData, normalized_adjacency_matrix: sps.spmatrix, aggreg if aggregation == "mean": aggregated_matrix = normalized_adjacency_matrix @ adata.X elif aggregation == "variance": - mean_matrix = normalized_adjacency_matrix @ adata.X - mean_squared_matrix = normalized_adjacency_matrix @ (adata.X * adata.X) + mean_matrix = (normalized_adjacency_matrix @ adata.X).toarray() + X_to_arr = adata.X.toarray() + mean_squared_matrix = normalized_adjacency_matrix @ (X_to_arr * X_to_arr) aggregated_matrix = mean_squared_matrix - mean_matrix * mean_matrix else: raise ValueError(f"Invalid aggregation method '{aggregation}'. Please choose either 'mean' or 'variance'.")