From ad657edfb52e9957b9a93b3a16fc8a87852f3f09 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Tue, 18 Jun 2024 16:47:55 +0530 Subject: [PATCH] speedup(~7x) of the clipping array inside scaling function (#3100) Co-authored-by: Severin Dicks <37635888+Intron7@users.noreply.github.com> Co-authored-by: Intron7 --- docs/release-notes/1.10.2.md | 1 + src/scanpy/preprocessing/_scale.py | 30 +++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/docs/release-notes/1.10.2.md b/docs/release-notes/1.10.2.md index af5ada6bab..d1d9212e8a 100644 --- a/docs/release-notes/1.10.2.md +++ b/docs/release-notes/1.10.2.md @@ -30,3 +30,4 @@ * `sparse_mean_variance_axis` now uses all cores for the calculations {pr}`3015` {smaller}`S Dicks` * `pp.highly_variable_genes` with `flavor=seurat_v3` now uses a numba kernel {pr}`3017` {smaller}`S Dicks` * Speed up {func}`~scanpy.pp.scrublet` {pr}`3044` {smaller}`S Dicks` and {pr}`3056` {smaller}`P Angerer` +* Speed up clipping of array in {func}`~scanpy.pp.scale` {pr}`3100` {smaller}`P Ashish & S Dicks` diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index edd9843c59..f6f4b4e586 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -43,6 +43,24 @@ def _scale_sparse_numba(indptr, indices, data, *, std, mask_obs, clip): data[j] /= std[indices[j]] +@numba.njit(parallel=True, cache=True) +def clip_array(X: np.ndarray, max_value: float | None = 10, zero_center: bool = True): + a_min, a_max = -max_value, max_value + if X.ndim > 1: + for r, c in numba.pndindex(X.shape): + if X[r, c] > a_max: + X[r, c] = a_max + elif X[r, c] < a_min and zero_center: + X[r, c] = a_min + else: + for i in numba.prange(X.size): + if X[i] > a_max: + X[i] = a_max + elif X[i] < a_min and zero_center: + X[i] = a_min + return X + + @renamed_arg("X", "data", pos_0=True) @old_positionals("zero_center", "max_value", "copy", "layer", "obsm") @singledispatch @@ -197,14 +215,12 @@ def clip_set(x): X = da.map_blocks(clip_set, X) else: - if zero_center: - a_min, a_max = -max_value, max_value - X = np.clip(X, a_min, a_max) # dask does not accept these as kwargs + if isinstance(X, DaskArray): + X = X.map_blocks(clip_array, max_value, zero_center) + elif issparse(X): + X.data = clip_array(X.data, max_value=max_value, zero_center=False) else: - if issparse(X): - X.data[X.data > max_value] = max_value - else: - X[X > max_value] = max_value + X = clip_array(X, max_value=max_value, zero_center=zero_center) if return_mean_std: return X, mean, std else: