Skip to content

Commit

Permalink
Allow the use of datasets in the apply_on_groups method and also use …
Browse files Browse the repository at this point in the history
…kwargs as parameter instead of **kwargs
  • Loading branch information
josephnowak committed Aug 19, 2024
1 parent 315d415 commit 13d178d
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 69 deletions.
28 changes: 14 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
from setuptools import setup, find_packages

with open('requirements.txt') as f:
with open("requirements.txt") as f:
required = f.read().splitlines()

setup(
name='TensorDB',
version='0.30.4',
description='Database based in a file system storage combined with Xarray and Zarr',
author='Joseph Nowak',
author_email='[email protected]',
name="TensorDB",
version="0.31.0",
description="Database based in a file system storage combined with Xarray and Zarr",
author="Joseph Nowak",
author_email="[email protected]",
classifiers=[
'Development Status :: 1 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'Intended Audience :: General',
'Natural Language :: English',
'Programming Language :: Python :: 3.9',
"Development Status :: 1 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Intended Audience :: General",
"Natural Language :: English",
"Programming Language :: Python :: 3.9",
],
keywords='Database Files Xarray Handler Zarr Store Read Write Append Update Upsert Backup Delete S3',
keywords="Database Files Xarray Handler Zarr Store Read Write Append Update Upsert Backup Delete S3",
packages=find_packages(),
install_requires=required
install_requires=required,
)
102 changes: 59 additions & 43 deletions tensordb/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import bottleneck as bn
import dask
import dask.array as da
import numbagg as nba
import numpy as np
import xarray as xr
import numbagg as nba
from dask.distributed import Client
from scipy.stats import rankdata

Expand Down Expand Up @@ -143,8 +143,9 @@ def map_blocks_along_axis(
dim: str,
dtype,
drop_dim: bool = False,
**kwargs,
kwargs: Dict = None,
) -> xr.DataArray:
kwargs = kwargs or {}
template = new_data.chunk({dim: -1})
data = template.data

Expand Down Expand Up @@ -211,12 +212,14 @@ def rank(
new_data,
func=NumpyAlgorithms.rank,
dtype=np.float64,
axis=new_data.dims.index(dim),
dim=dim,
method=method,
ascending=ascending,
nan_policy=nan_policy,
use_bottleneck=use_bottleneck,
kwargs=dict(
axis=new_data.dims.index(dim),
method=method,
ascending=ascending,
nan_policy=nan_policy,
use_bottleneck=use_bottleneck,
),
)

@classmethod
Expand Down Expand Up @@ -397,7 +400,8 @@ def apply_on_groups(
func: Union[str, Callable],
keep_shape: bool = False,
unique_groups: np.ndarray = None,
**kwargs,
kwargs: Dict[str, Any] = None,
template: Union[xr.DataArray, xr.Dataset, str] = None,
):
"""
This method was created as a replacement of the groupby of Xarray when the group is only
Expand Down Expand Up @@ -432,19 +436,17 @@ def apply_on_groups(
Useful when the group array has the same shape as the data and more than one dim, for this case
is necessary extract the unique elements, so you can provide them here (optional).
**kwargs
template: Union[xr.DataArray, xr.Dataset]: str = None
If the template is not set then is going to be generated internally based on the keep_shape
parameter and the data vars inside the template (if a Dataset).
If a string is set then the template is going to be generated internally but based
on the var name specified
kwargs: Dict[str, Any] = None,
Any extra parameter to send to the function
"""
if isinstance(new_data, xr.Dataset):
return new_data.map(
cls.apply_on_groups,
groups=groups,
dim=dim,
func=func,
keep_shape=keep_shape,
unique_groups=unique_groups,
)
kwargs = kwargs or dict()

if isinstance(groups, dict):
groups = xr.DataArray(
Expand All @@ -458,20 +460,44 @@ def apply_on_groups(
f"but got {groups.dims} and {new_data.dims}"
)

axis = new_data.dims.index(dim)
groups.name = "group"

if len(groups.dims) != len(new_data.dims):
groups = groups.compute()

if unique_groups is None:
unique_groups = da.unique(groups.data).compute()

output_coord = new_data.coords[dim].values
output_coord = new_data.coords[dim].to_numpy()
if not keep_shape:
# In case of grouping by an array of more than 1 dimension and the keep_shape is False.
output_coord = unique_groups

chunks = (
new_data.chunks[:axis] + (len(output_coord),) + new_data.chunks[axis + 1 :]
)
def generate_template(x):
axis = x.dims.index(dim)

chunks = x.chunks[:axis] + (len(output_coord),) + x.chunks[axis + 1 :]
new_coords = {
k: output_coord if k == dim else v for k, v in x.coords.items()
}
return xr.DataArray(
da.empty(
dtype=np.float64,
chunks=chunks,
shape=[len(new_coords[v]) for v in x.dims],
),
coords=new_coords,
dims=x.dims,
)

if template is None or isinstance(template, str):
if isinstance(new_data, xr.Dataset):
var_name = template
template = new_data.map(generate_template)
if isinstance(var_name, str):
template = template[var_name]
else:
template = generate_template(new_data)

def _reduce(x, g, func, **kwargs):
if len(g.dims) == 1:
Expand All @@ -497,6 +523,7 @@ def _reduce(x, g, func, **kwargs):
arr.coords[dim] = g.coords[dim]
else:
arr = arr.reindex({dim: unique_groups})
arr = arr.transpose(*x.dims)

return arr

Expand All @@ -514,25 +541,12 @@ def _reduce(x, g, func, **kwargs):
data = new_data.chunk({dim: -1})
if len(groups.dims) == len(data.dims):
groups = groups.chunk(data.chunksizes)
else:
groups = groups.compute()
new_coords = {
k: output_coord if k == dim else v for k, v in new_data.coords.items()
}

data = data.map_blocks(
_reduce,
[groups, func],
kwargs=kwargs,
template=xr.DataArray(
da.empty(
dtype=np.float64,
chunks=chunks,
shape=[len(new_coords[v]) for v in new_data.dims],
),
coords=new_coords,
dims=new_data.dims,
),
template=template,
)
return data

Expand Down Expand Up @@ -636,10 +650,12 @@ def cumulative_on_sort(
dtype=new_data.dtype,
dim=dim,
func=NumpyAlgorithms.cumulative_on_sort,
axis=new_data.dims.index(dim),
cum_func=func,
keep_nan=keep_nan,
ascending=ascending,
kwargs=dict(
axis=new_data.dims.index(dim),
cum_func=func,
keep_nan=keep_nan,
ascending=ascending,
),
)

@classmethod
Expand Down Expand Up @@ -683,7 +699,7 @@ def bitmask_topk(
(f"f{i}", new_data.dtype)
for i in range(new_data.sizes[tie_breaker_dim])
],
axis=new_data.dims.index(tie_breaker_dim),
kwargs=dict(axis=new_data.dims.index(tie_breaker_dim)),
drop_dim=True,
)

Expand Down Expand Up @@ -743,7 +759,7 @@ def rolling_overlap(
window_margin: int,
min_periods: int = None,
apply_ffill: bool = True,
validate_window_size: bool = True
validate_window_size: bool = True,
):
assert window_margin >= window

Expand Down
49 changes: 37 additions & 12 deletions tensordb/tests/test_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numbagg as nba
import numpy as np
import pandas as pd
import pytest
import xarray as xr
import numbagg as nba

from tensordb.algorithms import Algorithms

Expand Down Expand Up @@ -291,6 +291,8 @@ def test_vindex():
("b", False, "max"),
("a", True, "max"),
("b", True, "max"),
("a", True, "custom"),
("b", False, "custom"),
],
)
def test_apply_on_groups(dim, keep_shape, func):
Expand All @@ -308,24 +310,49 @@ def test_apply_on_groups(dim, keep_shape, func):
grouper = {"a": [1, 5, 5, 0, 1], "b": [0, 1, 1, 0, -1]}
groups = {k: v for k, v in zip(arr.coords[dim].values, grouper[dim])}

result = Algorithms.apply_on_groups(
arr, groups=groups, dim=dim, func=func, keep_shape=keep_shape
)

expected = arr.to_pandas()
axis = 0 if dim == "a" else 1

if axis == 1:
expected = expected.T

if keep_shape:
expected = expected.groupby(groups).transform(func)
if func == "custom":
arr = xr.Dataset(
{
"x": arr,
"v": arr,
}
)

def custom_func(dataset):
x = dataset["x"]
v = dataset["v"]
a = x - v * 0.1 + 1

a = a.sum(dim=dim)
return a

func = custom_func

expected = (expected - expected * 0.1).groupby(groups)
if keep_shape:
expected = expected.transform(lambda x: (x + 1).sum())
else:
expected = expected.apply(lambda x: (x + 1).sum())

else:
expected = getattr(expected.groupby(groups), func)()
if keep_shape:
expected = expected.groupby(groups).transform(func)
else:
expected = getattr(expected.groupby(groups), func)()

if axis == 1:
expected = expected.T

result = Algorithms.apply_on_groups(
arr, groups=groups, dim=dim, func=func, keep_shape=keep_shape, template="x"
)

expected = xr.DataArray(expected.values, coords=result.coords, dims=result.dims)
assert expected.equals(result)

Expand Down Expand Up @@ -469,14 +496,12 @@ def test_rolling_overlap(window, apply_ffill):
dim="a",
window_margin=window_margin,
min_periods=1,
apply_ffill=apply_ffill
apply_ffill=apply_ffill,
)

expected = df.dropna()
expected = (
expected.groupby(level=0)
.rolling(window=window, min_periods=1)
.mean()
expected.groupby(level=0).rolling(window=window, min_periods=1).mean()
)
expected = expected.droplevel(0).unstack(0)

Expand Down

0 comments on commit 13d178d

Please sign in to comment.