Skip to content

Commit

Permalink
add chunks to from_array tests
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Mar 20, 2024
1 parent f5156b8 commit b8e55ab
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions tests/v04/test_multiscales.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations
from typing import Dict, Tuple
from typing import Dict, Literal, Tuple
from pydantic import ValidationError
import pytest
import jsonschema as jsc
from zarr.util import guess_chunks
from pydantic_zarr.v2 import ArraySpec
from pydantic_ome_ngff.v04.multiscale import (
MultiscaleMetadata,
Expand Down Expand Up @@ -383,12 +384,14 @@ def test_multiscale_group_datasets_rank(default_multiscale: MultiscaleMetadata)
@pytest.mark.parametrize("path_pattern", ["{0}", "s{0}", "foo/{0}"])
@pytest.mark.parametrize("metadata", [None, {"foo": 10}])
@pytest.mark.parametrize("ndim", [2, 3, 4, 5])
@pytest.mark.parametrize("chunks", ["auto", "tuple", "tuple-of-tuple"])
def test_from_arrays(
name: str | None,
type: str | None,
path_pattern: str,
metadata: Dict[str, int] | None,
ndim: int,
chunks: Literal["auto", "tuple", "tuple-of-tuple"],
) -> None:
arrays = [np.arange(x**ndim).reshape((x,) * ndim) for x in [3, 2, 1]]
paths = [path_pattern.format(idx) for idx in range(len(arrays))]
Expand Down Expand Up @@ -416,6 +419,18 @@ def test_from_arrays(
else:
axes = [*all_axes[4:], *all_axes[:3]]

if chunks == "auto":
chunks_arg = chunks
chunks_expected = (
guess_chunks(arrays[0].shape, arrays[0].dtype.itemsize),
) * len(arrays)
elif chunks == "tuple":
chunks_arg = (2,) * ndim
chunks_expected = (chunks_arg,) * len(arrays)
elif chunks == "tuple-of-tuple":
chunks_arg = tuple((idx,) * ndim for idx in range(1, len(arrays) + 1))
chunks_expected = chunks_arg

group = Group.from_arrays(
paths=paths,
axes=axes,
Expand All @@ -425,6 +440,7 @@ def test_from_arrays(
name=name,
type=type,
metadata=metadata,
chunks=chunks_arg,
)

group_flat = group.to_flat()
Expand All @@ -435,8 +451,10 @@ def test_from_arrays(
assert group.attributes.multiscales[0].coordinateTransformations is None
assert group.attributes.multiscales[0].axes == axes
for idx, array in enumerate(arrays):
assert array.shape == group_flat["/" + paths[idx]].shape
assert array.dtype == group_flat["/" + paths[idx]].dtype
array_model: ArraySpec = group_flat["/" + paths[idx]]
assert array.shape == array_model.shape
assert array.dtype == array_model.dtype
assert chunks_expected[idx] == array_model.chunks
assert group.attributes.multiscales[0].datasets[
idx
].coordinateTransformations == (
Expand Down

0 comments on commit b8e55ab

Please sign in to comment.