From 91518e09809fbc0686b577794a68a4ec6d123172 Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Wed, 7 Aug 2024 13:00:27 +0200 Subject: [PATCH 1/7] relabel block --- src/spatialdata/_core/operations/map.py | 67 +++++++++++++++++++++++++ tests/core/operations/test_map.py | 59 ++++++++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/src/spatialdata/_core/operations/map.py b/src/spatialdata/_core/operations/map.py index b3810352..398eeab5 100644 --- a/src/spatialdata/_core/operations/map.py +++ b/src/spatialdata/_core/operations/map.py @@ -1,14 +1,19 @@ from __future__ import annotations +import math +import operator from collections.abc import Iterable, Mapping +from functools import reduce from types import MappingProxyType from typing import TYPE_CHECKING, Any, Callable import dask.array as da +import numpy as np from dask.array.overlap import coerce_depth from datatree import DataTree from xarray import DataArray +from spatialdata._types import ArrayLike from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims from spatialdata.transformations import get_transformation @@ -25,6 +30,7 @@ def map_raster( c_coords: Iterable[int] | Iterable[str] | None = None, dims: tuple[str, ...] | None = None, transformations: dict[str, Any] | None = None, + relabel: bool = True, **kwargs: Any, ) -> DataArray: """ @@ -69,6 +75,11 @@ def map_raster( transformations The transformations of the output data. If not provided, the transformations of the input data are copied to the output data. It should be specified if the callable changes the data transformations. + relabel + Whether to relabel the blocks of the output data. + This option is ignored when the output data is not a labels layer (i.e., when `dims` does not contain `c`). + It is recommended to enable relabeling if `func` returns labels that are not unique across chunks. + Relabeling will be done by performing a bit shift. kwargs Additional keyword arguments to pass to :func:`dask.array.map_overlap` or :func:`dask.array.map_blocks`. Ignored if `blockwise` is set to `False`. @@ -131,6 +142,9 @@ def map_raster( assert isinstance(d, dict) transformations = d + if "c" not in dims and relabel: + arr = _relabel(arr) + model_kwargs = { "chunks": arr.chunksize, "c_coords": c_coords, @@ -139,3 +153,56 @@ def map_raster( } model = get_raster_model_from_data_dims(dims) return model.parse(arr, **model_kwargs) + + +def _relabel(arr: da.Array) -> da.Array: + num_blocks = arr.numblocks + + shift = (math.prod(num_blocks) - 1).bit_length() + + meta = np.empty((0,) * arr.ndim, dtype=arr.dtype) + + def _relabel_block( + block: ArrayLike, block_id: tuple[int, ...], num_blocks: tuple[int, ...], shift: int + ) -> ArrayLike: + def _calculate_block_num(block_id: tuple[int, ...], num_blocks: tuple[int, ...]) -> int: + if len(num_blocks) != len(block_id): + raise ValueError("num_blocks and block_id must have the same length") + block_num = 0 + for i in range(len(num_blocks)): + multiplier = reduce(operator.mul, num_blocks[i + 1 :], 1) + block_num += block_id[i] * multiplier + return block_num + + available_bits = np.iinfo(block.dtype).max.bit_length() + max_bits_block = int(block.max()).bit_length() + + if max_bits_block + shift > available_bits: + raise ValueError( + f"Relabel was set to True, but " + f"max bits required to represent the labels in the block ({max_bits_block}) " + f"+ required shift ({shift}) > " + f"available_bits ({available_bits}). " + "To solve this issue, please consider rechunking using a larger chunk size, " + "resulting in a fewer number of blocks and thus a lower value for the required shift; " + f"cast to a data type (current data type is {block.dtype}) with a higher maximum value " + "(resulting in more available bits for the bit shift); " + "or consider a sequential relabeling of the dask array, " + "which could result in a lower maximum value of the labels in the block." + ) + + block_num = _calculate_block_num(block_id=block_id, num_blocks=num_blocks) + + mask = block > 0 + block[mask] = (block[mask] << shift) | block_num + + return block + + return da.map_blocks( + _relabel_block, + arr, + dtype=arr.dtype, + num_blocks=num_blocks, + shift=shift, + meta=meta, + ) diff --git a/tests/core/operations/test_map.py b/tests/core/operations/test_map.py index fea9deea..eb57bc70 100644 --- a/tests/core/operations/test_map.py +++ b/tests/core/operations/test_map.py @@ -1,3 +1,4 @@ +import math import re import numpy as np @@ -27,6 +28,11 @@ def _multiply_to_labels(arr, parameter=10): return arr[0].astype(np.int32) +def _to_constant(arr, constant=5): + arr[arr > 0] = constant + return arr + + @pytest.mark.parametrize( "depth", [ @@ -46,6 +52,7 @@ def test_map_raster(sdata_blobs, depth, element_name): func_kwargs=func_kwargs, c_coords=None, depth=depth, + relabel=False, ) assert isinstance(se, DataArray) @@ -161,6 +168,7 @@ def test_map_to_labels_(sdata_blobs, blockwise, chunks, drop_axis): chunks=chunks, drop_axis=drop_axis, dims=("y", "x"), + relabel=False, ) data = sdata_blobs[img_layer].data.compute() @@ -248,3 +256,54 @@ def test_invalid_map_raster(sdata_blobs): c_coords=["c"], depth=(0, 60, 60), ) + + +def test_map_raster_relabel(sdata_blobs): + constant = 2047 + func_kwargs = {"constant": constant} + + element_name = "blobs_labels" + se = map_raster( + sdata_blobs[element_name].chunk((100, 100)), + func=_to_constant, + func_kwargs=func_kwargs, + c_coords=None, + depth=None, + relabel=True, + ) + + # check if labels in different blocks are all mapped to a different value + assert isinstance(se, DataArray) + se.data.compute() + a = set() + for chunk in se.data.to_delayed().flatten(): + chunk = chunk.compute() + b = set(np.unique(chunk)) + b.remove(0) + assert not b.intersection(a) + a.update(b) + # 9 blocks, each block contains 'constant' left shifted by (9-1).bit_length() + block_num. + shift = (math.prod(se.data.numblocks) - 1).bit_length() + assert a == set(range(constant << shift, (constant << shift) + math.prod(se.data.numblocks))) + + +def test_map_raster_relabel_fail(sdata_blobs): + constant = 2048 + func_kwargs = {"constant": constant} + + element_name = "blobs_labels" + + with pytest.raises( + ValueError, + match=re.escape("Relabel was set to True, but max bits"), + ): + se = map_raster( + sdata_blobs[element_name].chunk((100, 100)), + func=_to_constant, + func_kwargs=func_kwargs, + c_coords=None, + depth=None, + relabel=True, + ) + + se.data.compute() From 206c44573b92f469b86652a80cfd2ffa0e08939a Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Wed, 7 Aug 2024 16:22:11 +0200 Subject: [PATCH 2/7] dtype check array --- src/spatialdata/_core/operations/map.py | 2 ++ tests/core/operations/test_map.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/spatialdata/_core/operations/map.py b/src/spatialdata/_core/operations/map.py index 398eeab5..2b87d97b 100644 --- a/src/spatialdata/_core/operations/map.py +++ b/src/spatialdata/_core/operations/map.py @@ -156,6 +156,8 @@ def map_raster( def _relabel(arr: da.Array) -> da.Array: + if not np.issubdtype(arr.dtype, np.integer): + raise ValueError(f"Relabeling is only supported for arrays of type {np.integer}.") num_blocks = arr.numblocks shift = (math.prod(num_blocks) - 1).bit_length() diff --git a/tests/core/operations/test_map.py b/tests/core/operations/test_map.py index eb57bc70..933b8ea7 100644 --- a/tests/core/operations/test_map.py +++ b/tests/core/operations/test_map.py @@ -307,3 +307,20 @@ def test_map_raster_relabel_fail(sdata_blobs): ) se.data.compute() + + constant = 2047 + func_kwargs = {"constant": constant} + + element_name = "blobs_labels" + with pytest.raises( + ValueError, + match=re.escape(f"Relabeling is only supported for arrays of type {np.integer}."), + ): + se = map_raster( + sdata_blobs[element_name].astype(float).chunk((100, 100)), + func=_to_constant, + func_kwargs=func_kwargs, + c_coords=None, + depth=None, + relabel=True, + ) From ef933d828c015c3ed12cd22d39f4941c545ba8c7 Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Wed, 7 Aug 2024 16:43:29 +0200 Subject: [PATCH 3/7] fix mypy --- src/spatialdata/_core/operations/map.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_core/operations/map.py b/src/spatialdata/_core/operations/map.py index 2b87d97b..76cc969a 100644 --- a/src/spatialdata/_core/operations/map.py +++ b/src/spatialdata/_core/operations/map.py @@ -11,9 +11,9 @@ import numpy as np from dask.array.overlap import coerce_depth from datatree import DataTree +from numpy.typing import NDArray from xarray import DataArray -from spatialdata._types import ArrayLike from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims from spatialdata.transformations import get_transformation @@ -165,8 +165,8 @@ def _relabel(arr: da.Array) -> da.Array: meta = np.empty((0,) * arr.ndim, dtype=arr.dtype) def _relabel_block( - block: ArrayLike, block_id: tuple[int, ...], num_blocks: tuple[int, ...], shift: int - ) -> ArrayLike: + block: NDArray[np.int64], block_id: tuple[int, ...], num_blocks: tuple[int, ...], shift: int + ) -> NDArray[np.int64]: def _calculate_block_num(block_id: tuple[int, ...], num_blocks: tuple[int, ...]) -> int: if len(num_blocks) != len(block_id): raise ValueError("num_blocks and block_id must have the same length") From 4d0ea37cc0c1d11d35d7ea2775770142015269d5 Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Thu, 8 Aug 2024 08:40:16 +0200 Subject: [PATCH 4/7] type IntArray --- src/spatialdata/_core/operations/map.py | 6 ++---- src/spatialdata/_types.py | 5 ++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_core/operations/map.py b/src/spatialdata/_core/operations/map.py index 76cc969a..57f3a454 100644 --- a/src/spatialdata/_core/operations/map.py +++ b/src/spatialdata/_core/operations/map.py @@ -11,9 +11,9 @@ import numpy as np from dask.array.overlap import coerce_depth from datatree import DataTree -from numpy.typing import NDArray from xarray import DataArray +from spatialdata._types import IntArray from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims from spatialdata.transformations import get_transformation @@ -164,9 +164,7 @@ def _relabel(arr: da.Array) -> da.Array: meta = np.empty((0,) * arr.ndim, dtype=arr.dtype) - def _relabel_block( - block: NDArray[np.int64], block_id: tuple[int, ...], num_blocks: tuple[int, ...], shift: int - ) -> NDArray[np.int64]: + def _relabel_block(block: IntArray, block_id: tuple[int, ...], num_blocks: tuple[int, ...], shift: int) -> IntArray: def _calculate_block_num(block_id: tuple[int, ...], num_blocks: tuple[int, ...]) -> int: if len(num_blocks) != len(block_id): raise ValueError("num_blocks and block_id must have the same length") diff --git a/src/spatialdata/_types.py b/src/spatialdata/_types.py index ae6b0a34..db98b66e 100644 --- a/src/spatialdata/_types.py +++ b/src/spatialdata/_types.py @@ -6,14 +6,17 @@ from datatree import DataTree from xarray import DataArray -__all__ = ["ArrayLike", "DTypeLike", "Raster_T"] +__all__ = ["ArrayLike", "IntArray", "DTypeLike", "Raster_T"] try: from numpy.typing import DTypeLike, NDArray ArrayLike = NDArray[np.float64] + IntArray = NDArray[np.int64] # or any np.integer + except (ImportError, TypeError): ArrayLike = np.ndarray # type: ignore[misc] DTypeLike = np.dtype # type: ignore[misc] + IntArray = np.ndarray # type: ignore[misc] Raster_T = Union[DataArray, DataTree] From 633a132ddfb6353925bc099114b893df43ad3f2e Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Thu, 8 Aug 2024 11:08:22 +0200 Subject: [PATCH 5/7] add sequential relabeling helper function --- src/spatialdata/_core/operations/map.py | 33 +++++++++++++++++++ tests/core/operations/test_map.py | 42 ++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/operations/map.py b/src/spatialdata/_core/operations/map.py index 57f3a454..197650f8 100644 --- a/src/spatialdata/_core/operations/map.py +++ b/src/spatialdata/_core/operations/map.py @@ -206,3 +206,36 @@ def _calculate_block_num(block_id: tuple[int, ...], num_blocks: tuple[int, ...]) shift=shift, meta=meta, ) + + +def _relabel_sequential(arr: da.Array) -> da.Array: + """ + Relabels integers in a Dask array sequentially. + + This function assigns sequential labels to the integers in a Dask array starting from 1. + For example, if the unique values in the input array are [0, 5, 9], + they will be relabeled to [0, 1, 2] respectively. + + Parameters + ---------- + arr + input array. + + Returns + ------- + The relabeled array. + """ + if not np.issubdtype(arr.dtype, np.integer): + raise ValueError(f"Sequential relabeling is only supported for arrays of type {np.integer}.") + unique_labels = da.unique(arr).compute() + if 0 not in unique_labels: + # otherwise first non zero label would be relabeled to 0 + unique_labels = np.insert(unique_labels, 0, 0) + + max_label = unique_labels[-1] + + new_labeling = da.full(max_label + 1, -1, dtype=arr.dtype) + + new_labeling[unique_labels] = da.arange(len(unique_labels), dtype=arr.dtype) + + return da.map_blocks(operator.getitem, new_labeling, arr, dtype=arr.dtype, chunks=arr.chunks) diff --git a/tests/core/operations/test_map.py b/tests/core/operations/test_map.py index 933b8ea7..ceae6a97 100644 --- a/tests/core/operations/test_map.py +++ b/tests/core/operations/test_map.py @@ -1,9 +1,10 @@ import math import re +import dask.array as da import numpy as np import pytest -from spatialdata._core.operations.map import map_raster +from spatialdata._core.operations.map import _relabel_sequential, map_raster from spatialdata.transformations import Translation, get_transformation, set_transformation from xarray import DataArray @@ -324,3 +325,42 @@ def test_map_raster_relabel_fail(sdata_blobs): depth=None, relabel=True, ) + + +def test_relabel_sequential(sdata_blobs): + def _is_sequential(arr): + if arr.ndim != 1: + raise ValueError("Input array must be one-dimensional") + sorted_arr = np.sort(arr) + expected_sequence = np.arange(sorted_arr[0], sorted_arr[0] + len(sorted_arr)) + return np.array_equal(sorted_arr, expected_sequence) + + arr = sdata_blobs["blobs_labels"].data.rechunk(100) + + arr_relabeled = _relabel_sequential(arr) + + labels_relabeled = da.unique(arr_relabeled).compute() + labels_original = da.unique(arr).compute() + + assert labels_relabeled.shape == labels_original.shape + assert _is_sequential(labels_relabeled) + + # test some edge cases + arr = da.asarray(np.array([0])) + assert np.array_equal(_relabel_sequential(arr).compute(), np.array([0])) + + arr = da.asarray(np.array([1])) + assert np.array_equal(_relabel_sequential(arr).compute(), np.array([1])) + + arr = da.asarray(np.array([2])) + assert np.array_equal(_relabel_sequential(arr).compute(), np.array([1])) + + arr = da.asarray(np.array([2, 0])) + assert np.array_equal(_relabel_sequential(arr).compute(), np.array([1, 0])) + + +def test_relabel_sequential_fails(sdata_blobs): + with pytest.raises( + ValueError, match=re.escape(f"Sequential relabeling is only supported for arrays of type {np.integer}.") + ): + _relabel_sequential(sdata_blobs["blobs_labels"].data.astype(float)) From a5ac4ad37b01eb50cda60520ec22d6271c64bf66 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 16:53:53 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/core/operations/test_map.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/operations/test_map.py b/tests/core/operations/test_map.py index 45ab550a..a329c8f6 100644 --- a/tests/core/operations/test_map.py +++ b/tests/core/operations/test_map.py @@ -4,6 +4,7 @@ import dask.array as da import numpy as np import pytest + from spatialdata._core.operations.map import _relabel_sequential, map_raster from spatialdata.transformations import Translation, get_transformation, set_transformation From cad5ce3d6e34faf955a7bb51a8c099fb65750d6c Mon Sep 17 00:00:00 2001 From: Giovanni Palla <25887487+giovp@users.noreply.github.com> Date: Mon, 2 Sep 2024 10:00:26 -0700 Subject: [PATCH 7/7] reintroduce DataArray --- tests/core/operations/test_map.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/operations/test_map.py b/tests/core/operations/test_map.py index a329c8f6..43ef1fd3 100644 --- a/tests/core/operations/test_map.py +++ b/tests/core/operations/test_map.py @@ -4,6 +4,7 @@ import dask.array as da import numpy as np import pytest +from xarray import DataArray from spatialdata._core.operations.map import _relabel_sequential, map_raster from spatialdata.transformations import Translation, get_transformation, set_transformation