Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relabel block #664

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions src/spatialdata/_core/operations/map.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from __future__ import annotations

import math
import operator
from collections.abc import Iterable, Mapping
from functools import reduce
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Callable

import dask.array as da
import numpy as np
from dask.array.overlap import coerce_depth
from datatree import DataTree
from xarray import DataArray

from spatialdata._types import IntArray
from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims
from spatialdata.transformations import get_transformation

Expand All @@ -25,6 +30,7 @@ def map_raster(
c_coords: Iterable[int] | Iterable[str] | None = None,
dims: tuple[str, ...] | None = None,
transformations: dict[str, Any] | None = None,
relabel: bool = True,
**kwargs: Any,
) -> DataArray:
"""
Expand Down Expand Up @@ -69,6 +75,11 @@ def map_raster(
transformations
The transformations of the output data. If not provided, the transformations of the input data are copied to the
output data. It should be specified if the callable changes the data transformations.
relabel
Whether to relabel the blocks of the output data.
This option is ignored when the output data is not a labels layer (i.e., when `dims` does not contain `c`).
It is recommended to enable relabeling if `func` returns labels that are not unique across chunks.
Relabeling will be done by performing a bit shift.
kwargs
Additional keyword arguments to pass to :func:`dask.array.map_overlap` or :func:`dask.array.map_blocks`.
Ignored if `blockwise` is set to `False`.
Expand Down Expand Up @@ -131,6 +142,9 @@ def map_raster(
assert isinstance(d, dict)
transformations = d

if "c" not in dims and relabel:
arr = _relabel(arr)

model_kwargs = {
"chunks": arr.chunksize,
"c_coords": c_coords,
Expand All @@ -139,3 +153,89 @@ def map_raster(
}
model = get_raster_model_from_data_dims(dims)
return model.parse(arr, **model_kwargs)


def _relabel(arr: da.Array) -> da.Array:
if not np.issubdtype(arr.dtype, np.integer):
raise ValueError(f"Relabeling is only supported for arrays of type {np.integer}.")
num_blocks = arr.numblocks

shift = (math.prod(num_blocks) - 1).bit_length()

meta = np.empty((0,) * arr.ndim, dtype=arr.dtype)

def _relabel_block(block: IntArray, block_id: tuple[int, ...], num_blocks: tuple[int, ...], shift: int) -> IntArray:
def _calculate_block_num(block_id: tuple[int, ...], num_blocks: tuple[int, ...]) -> int:
if len(num_blocks) != len(block_id):
raise ValueError("num_blocks and block_id must have the same length")
block_num = 0
for i in range(len(num_blocks)):
multiplier = reduce(operator.mul, num_blocks[i + 1 :], 1)
block_num += block_id[i] * multiplier
return block_num

available_bits = np.iinfo(block.dtype).max.bit_length()
max_bits_block = int(block.max()).bit_length()

if max_bits_block + shift > available_bits:
raise ValueError(
f"Relabel was set to True, but "
f"max bits required to represent the labels in the block ({max_bits_block}) "
f"+ required shift ({shift}) > "
f"available_bits ({available_bits}). "
"To solve this issue, please consider rechunking using a larger chunk size, "
"resulting in a fewer number of blocks and thus a lower value for the required shift; "
f"cast to a data type (current data type is {block.dtype}) with a higher maximum value "
"(resulting in more available bits for the bit shift); "
"or consider a sequential relabeling of the dask array, "
"which could result in a lower maximum value of the labels in the block."
)

block_num = _calculate_block_num(block_id=block_id, num_blocks=num_blocks)

mask = block > 0
block[mask] = (block[mask] << shift) | block_num

return block

return da.map_blocks(
_relabel_block,
arr,
dtype=arr.dtype,
num_blocks=num_blocks,
shift=shift,
meta=meta,
)


def _relabel_sequential(arr: da.Array) -> da.Array:
"""
Relabels integers in a Dask array sequentially.

This function assigns sequential labels to the integers in a Dask array starting from 1.
For example, if the unique values in the input array are [0, 5, 9],
they will be relabeled to [0, 1, 2] respectively.

Parameters
----------
arr
input array.

Returns
-------
The relabeled array.
"""
if not np.issubdtype(arr.dtype, np.integer):
raise ValueError(f"Sequential relabeling is only supported for arrays of type {np.integer}.")
unique_labels = da.unique(arr).compute()
if 0 not in unique_labels:
# otherwise first non zero label would be relabeled to 0
unique_labels = np.insert(unique_labels, 0, 0)

max_label = unique_labels[-1]

new_labeling = da.full(max_label + 1, -1, dtype=arr.dtype)

new_labeling[unique_labels] = da.arange(len(unique_labels), dtype=arr.dtype)

return da.map_blocks(operator.getitem, new_labeling, arr, dtype=arr.dtype, chunks=arr.chunks)
5 changes: 4 additions & 1 deletion src/spatialdata/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from datatree import DataTree
from xarray import DataArray

__all__ = ["ArrayLike", "DTypeLike", "Raster_T"]
__all__ = ["ArrayLike", "IntArray", "DTypeLike", "Raster_T"]

try:
from numpy.typing import DTypeLike, NDArray

ArrayLike = NDArray[np.float64]
IntArray = NDArray[np.int64] # or any np.integer

except (ImportError, TypeError):
ArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc]
IntArray = np.ndarray # type: ignore[misc]

Raster_T = Union[DataArray, DataTree]
118 changes: 117 additions & 1 deletion tests/core/operations/test_map.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
import re

import dask.array as da
import numpy as np
import pytest
from spatialdata._core.operations.map import map_raster
from spatialdata._core.operations.map import _relabel_sequential, map_raster
from spatialdata.transformations import Translation, get_transformation, set_transformation
from xarray import DataArray

Expand All @@ -27,6 +29,11 @@ def _multiply_to_labels(arr, parameter=10):
return arr[0].astype(np.int32)


def _to_constant(arr, constant=5):
arr[arr > 0] = constant
return arr


@pytest.mark.parametrize(
"depth",
[
Expand All @@ -46,6 +53,7 @@ def test_map_raster(sdata_blobs, depth, element_name):
func_kwargs=func_kwargs,
c_coords=None,
depth=depth,
relabel=False,
)

assert isinstance(se, DataArray)
Expand Down Expand Up @@ -161,6 +169,7 @@ def test_map_to_labels_(sdata_blobs, blockwise, chunks, drop_axis):
chunks=chunks,
drop_axis=drop_axis,
dims=("y", "x"),
relabel=False,
)

data = sdata_blobs[img_layer].data.compute()
Expand Down Expand Up @@ -248,3 +257,110 @@ def test_invalid_map_raster(sdata_blobs):
c_coords=["c"],
depth=(0, 60, 60),
)


def test_map_raster_relabel(sdata_blobs):
constant = 2047
func_kwargs = {"constant": constant}

element_name = "blobs_labels"
se = map_raster(
sdata_blobs[element_name].chunk((100, 100)),
func=_to_constant,
func_kwargs=func_kwargs,
c_coords=None,
depth=None,
relabel=True,
)

# check if labels in different blocks are all mapped to a different value
assert isinstance(se, DataArray)
se.data.compute()
a = set()
for chunk in se.data.to_delayed().flatten():
chunk = chunk.compute()
b = set(np.unique(chunk))
b.remove(0)
assert not b.intersection(a)
a.update(b)
# 9 blocks, each block contains 'constant' left shifted by (9-1).bit_length() + block_num.
shift = (math.prod(se.data.numblocks) - 1).bit_length()
assert a == set(range(constant << shift, (constant << shift) + math.prod(se.data.numblocks)))


def test_map_raster_relabel_fail(sdata_blobs):
constant = 2048
func_kwargs = {"constant": constant}

element_name = "blobs_labels"

with pytest.raises(
ValueError,
match=re.escape("Relabel was set to True, but max bits"),
):
se = map_raster(
sdata_blobs[element_name].chunk((100, 100)),
func=_to_constant,
func_kwargs=func_kwargs,
c_coords=None,
depth=None,
relabel=True,
)

se.data.compute()

constant = 2047
func_kwargs = {"constant": constant}

element_name = "blobs_labels"
with pytest.raises(
ValueError,
match=re.escape(f"Relabeling is only supported for arrays of type {np.integer}."),
):
se = map_raster(
sdata_blobs[element_name].astype(float).chunk((100, 100)),
func=_to_constant,
func_kwargs=func_kwargs,
c_coords=None,
depth=None,
relabel=True,
)


def test_relabel_sequential(sdata_blobs):
def _is_sequential(arr):
if arr.ndim != 1:
raise ValueError("Input array must be one-dimensional")
sorted_arr = np.sort(arr)
expected_sequence = np.arange(sorted_arr[0], sorted_arr[0] + len(sorted_arr))
return np.array_equal(sorted_arr, expected_sequence)

arr = sdata_blobs["blobs_labels"].data.rechunk(100)

arr_relabeled = _relabel_sequential(arr)

labels_relabeled = da.unique(arr_relabeled).compute()
labels_original = da.unique(arr).compute()

assert labels_relabeled.shape == labels_original.shape
assert _is_sequential(labels_relabeled)

# test some edge cases
arr = da.asarray(np.array([0]))
assert np.array_equal(_relabel_sequential(arr).compute(), np.array([0]))

arr = da.asarray(np.array([1]))
assert np.array_equal(_relabel_sequential(arr).compute(), np.array([1]))

arr = da.asarray(np.array([2]))
assert np.array_equal(_relabel_sequential(arr).compute(), np.array([1]))

arr = da.asarray(np.array([2, 0]))
assert np.array_equal(_relabel_sequential(arr).compute(), np.array([1, 0]))


def test_relabel_sequential_fails(sdata_blobs):
with pytest.raises(
ValueError, match=re.escape(f"Sequential relabeling is only supported for arrays of type {np.integer}.")
):
_relabel_sequential(sdata_blobs["blobs_labels"].data.astype(float))
Loading