Skip to content

Commit

Permalink
Reducing the execution time of the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
josephnowak committed Sep 13, 2024
1 parent 1e16925 commit 2064bd7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
8 changes: 7 additions & 1 deletion tensordb/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def merge_duplicates_coord(
new_data: Union[xr.DataArray, xr.Dataset],
dim: str,
func: str,
kwargs: Dict[str, Any] = None,
):
"""
Group and merge duplicates coord base on a function, this can be a sum or a max. Read numpy-groupies
Expand All @@ -570,7 +571,12 @@ def merge_duplicates_coord(
new_data.coords[dim] = np.arange(new_data.sizes[dim])

return cls.apply_on_groups(
new_data=new_data, groups=groups, dim=dim, func=func, keep_shape=False
new_data=new_data,
groups=groups,
dim=dim,
func=func,
keep_shape=False,
kwargs=kwargs,
)

@classmethod
Expand Down
16 changes: 13 additions & 3 deletions tensordb/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ def test_apply_on_groups(dim, keep_shape, func):
if axis == 1:
expected = expected.T

kwargs = {"engine": "numpy"}

if func == "custom":
# Avoid the use of the engine on the custom functions
kwargs = {}
arr = xr.Dataset(
{
"x": arr,
Expand Down Expand Up @@ -350,7 +354,13 @@ def custom_func(dataset):
expected = expected.T

result = Algorithms.apply_on_groups(
arr, groups=groups, dim=dim, func=func, keep_shape=keep_shape, template="x"
arr,
groups=groups,
dim=dim,
func=func,
keep_shape=keep_shape,
template="x",
kwargs=kwargs,
)

expected = xr.DataArray(expected.values, coords=result.coords, dims=result.dims)
Expand Down Expand Up @@ -438,8 +448,8 @@ def test_merge_duplicates_coord(dim):
coords={"a": [1, 5, 5, 0, 1], "b": [0, 1, 1, 0, -1]},
).chunk(a=3, b=2)

g = arr.groupby(dim).max(dim)
arr = Algorithms.merge_duplicates_coord(arr, dim, "max")
g = arr.groupby(dim).max(dim, engine="numpy")
arr = Algorithms.merge_duplicates_coord(arr, dim, "max", kwargs={"engine": "numpy"})
assert g.equals(arr)


Expand Down

0 comments on commit 2064bd7

Please sign in to comment.