Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/args in groupby #120

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
from setuptools import setup, find_packages

with open('requirements.txt') as f:
with open("requirements.txt") as f:
required = f.read().splitlines()

setup(
name='TensorDB',
version='0.30.4',
description='Database based in a file system storage combined with Xarray and Zarr',
author='Joseph Nowak',
author_email='[email protected]',
name="TensorDB",
version="0.31.0",
description="Database based in a file system storage combined with Xarray and Zarr",
author="Joseph Nowak",
author_email="[email protected]",
classifiers=[
'Development Status :: 1 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'Intended Audience :: General',
'Natural Language :: English',
'Programming Language :: Python :: 3.9',
"Development Status :: 1 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Intended Audience :: General",
"Natural Language :: English",
"Programming Language :: Python :: 3.9",
],
keywords='Database Files Xarray Handler Zarr Store Read Write Append Update Upsert Backup Delete S3',
keywords="Database Files Xarray Handler Zarr Store Read Write Append Update Upsert Backup Delete S3",
packages=find_packages(),
install_requires=required
install_requires=required,
)
102 changes: 59 additions & 43 deletions tensordb/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import bottleneck as bn
import dask
import dask.array as da
import numbagg as nba
import numpy as np
import xarray as xr
import numbagg as nba
from dask.distributed import Client
from scipy.stats import rankdata

Expand Down Expand Up @@ -143,8 +143,9 @@ def map_blocks_along_axis(
dim: str,
dtype,
drop_dim: bool = False,
**kwargs,
kwargs: Dict = None,
) -> xr.DataArray:
kwargs = kwargs or {}
template = new_data.chunk({dim: -1})
data = template.data

Expand Down Expand Up @@ -211,12 +212,14 @@ def rank(
new_data,
func=NumpyAlgorithms.rank,
dtype=np.float64,
axis=new_data.dims.index(dim),
dim=dim,
method=method,
ascending=ascending,
nan_policy=nan_policy,
use_bottleneck=use_bottleneck,
kwargs=dict(
axis=new_data.dims.index(dim),
method=method,
ascending=ascending,
nan_policy=nan_policy,
use_bottleneck=use_bottleneck,
),
)

@classmethod
Expand Down Expand Up @@ -397,7 +400,8 @@ def apply_on_groups(
func: Union[str, Callable],
keep_shape: bool = False,
unique_groups: np.ndarray = None,
**kwargs,
kwargs: Dict[str, Any] = None,
template: Union[xr.DataArray, xr.Dataset, str] = None,
):
"""
This method was created as a replacement of the groupby of Xarray when the group is only
Expand Down Expand Up @@ -432,19 +436,17 @@ def apply_on_groups(
Useful when the group array has the same shape as the data and more than one dim, for this case
is necessary extract the unique elements, so you can provide them here (optional).

**kwargs
template: Union[xr.DataArray, xr.Dataset]: str = None
If the template is not set then is going to be generated internally based on the keep_shape
parameter and the data vars inside the template (if a Dataset).
If a string is set then the template is going to be generated internally but based
on the var name specified

kwargs: Dict[str, Any] = None,
Any extra parameter to send to the function

"""
if isinstance(new_data, xr.Dataset):
return new_data.map(
cls.apply_on_groups,
groups=groups,
dim=dim,
func=func,
keep_shape=keep_shape,
unique_groups=unique_groups,
)
kwargs = kwargs or dict()

if isinstance(groups, dict):
groups = xr.DataArray(
Expand All @@ -458,20 +460,44 @@ def apply_on_groups(
f"but got {groups.dims} and {new_data.dims}"
)

axis = new_data.dims.index(dim)
groups.name = "group"

if len(groups.dims) != len(new_data.dims):
groups = groups.compute()

if unique_groups is None:
unique_groups = da.unique(groups.data).compute()

output_coord = new_data.coords[dim].values
output_coord = new_data.coords[dim].to_numpy()
if not keep_shape:
# In case of grouping by an array of more than 1 dimension and the keep_shape is False.
output_coord = unique_groups

chunks = (
new_data.chunks[:axis] + (len(output_coord),) + new_data.chunks[axis + 1 :]
)
def generate_template(x):
axis = x.dims.index(dim)

chunks = x.chunks[:axis] + (len(output_coord),) + x.chunks[axis + 1 :]
new_coords = {
k: output_coord if k == dim else v for k, v in x.coords.items()
}
return xr.DataArray(
da.empty(
dtype=np.float64,
chunks=chunks,
shape=[len(new_coords[v]) for v in x.dims],
),
coords=new_coords,
dims=x.dims,
)

if template is None or isinstance(template, str):
if isinstance(new_data, xr.Dataset):
var_name = template
template = new_data.map(generate_template)
if isinstance(var_name, str):
template = template[var_name]
else:
template = generate_template(new_data)

def _reduce(x, g, func, **kwargs):
if len(g.dims) == 1:
Expand All @@ -497,6 +523,7 @@ def _reduce(x, g, func, **kwargs):
arr.coords[dim] = g.coords[dim]
else:
arr = arr.reindex({dim: unique_groups})
arr = arr.transpose(*x.dims)

return arr

Expand All @@ -514,25 +541,12 @@ def _reduce(x, g, func, **kwargs):
data = new_data.chunk({dim: -1})
if len(groups.dims) == len(data.dims):
groups = groups.chunk(data.chunksizes)
else:
groups = groups.compute()
new_coords = {
k: output_coord if k == dim else v for k, v in new_data.coords.items()
}

data = data.map_blocks(
_reduce,
[groups, func],
kwargs=kwargs,
template=xr.DataArray(
da.empty(
dtype=np.float64,
chunks=chunks,
shape=[len(new_coords[v]) for v in new_data.dims],
),
coords=new_coords,
dims=new_data.dims,
),
template=template,
)
return data

Expand Down Expand Up @@ -636,10 +650,12 @@ def cumulative_on_sort(
dtype=new_data.dtype,
dim=dim,
func=NumpyAlgorithms.cumulative_on_sort,
axis=new_data.dims.index(dim),
cum_func=func,
keep_nan=keep_nan,
ascending=ascending,
kwargs=dict(
axis=new_data.dims.index(dim),
cum_func=func,
keep_nan=keep_nan,
ascending=ascending,
),
)

@classmethod
Expand Down Expand Up @@ -683,7 +699,7 @@ def bitmask_topk(
(f"f{i}", new_data.dtype)
for i in range(new_data.sizes[tie_breaker_dim])
],
axis=new_data.dims.index(tie_breaker_dim),
kwargs=dict(axis=new_data.dims.index(tie_breaker_dim)),
drop_dim=True,
)

Expand Down Expand Up @@ -743,7 +759,7 @@ def rolling_overlap(
window_margin: int,
min_periods: int = None,
apply_ffill: bool = True,
validate_window_size: bool = True
validate_window_size: bool = True,
):
assert window_margin >= window

Expand Down
36 changes: 13 additions & 23 deletions tensordb/storages/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class BaseStorage:
"""

def __init__(
self,
base_map: Union[Mapping, MutableMapping],
tmp_map: Union[Mapping, MutableMapping],
data_names: Union[str, List[str]] = "data",
**kwargs
self,
base_map: Union[Mapping, MutableMapping],
tmp_map: Union[Mapping, MutableMapping],
data_names: Union[str, List[str]] = "data",
**kwargs
):
if not isinstance(base_map, Mapping):
base_map = Mapping(base_map)
Expand All @@ -47,7 +47,9 @@ def __init__(
self.group = None

def get_data_names_list(self) -> List[str]:
return self.data_names if isinstance(self.data_names, list) else [self.data_names]
return (
self.data_names if isinstance(self.data_names, list) else [self.data_names]
)

def delete_tensor(self):
"""
Expand All @@ -64,9 +66,7 @@ def delete_tensor(self):

@abstractmethod
def append(
self,
new_data: Union[xr.DataArray, xr.Dataset],
**kwargs
self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs
) -> List[xr.backends.common.AbstractWritableDataStore]:
"""
This abstractmethod must be overwritten to append new_data to an existing file, the way that it append the data
Expand All @@ -90,9 +90,7 @@ def append(

@abstractmethod
def update(
self,
new_data: Union[xr.DataArray, xr.Dataset],
**kwargs
self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs
) -> xr.backends.common.AbstractWritableDataStore:
"""
This abstractmethod must be overwritten to update new_data to an existing file, so it must not insert any new
Expand All @@ -116,9 +114,7 @@ def update(

@abstractmethod
def store(
self,
new_data: Union[xr.DataArray, xr.Dataset],
**kwargs
self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs
) -> xr.backends.common.AbstractWritableDataStore:
"""
This abstractmethod must be overwritten to store new_data to an existing file, so it must create
Expand All @@ -140,9 +136,7 @@ def store(

@abstractmethod
def upsert(
self,
new_data: Union[xr.DataArray, xr.Dataset],
**kwargs
self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs
) -> List[xr.backends.common.AbstractWritableDataStore]:
"""
This abstractmethod must be overwritten to update and append new_data to an existing file,
Expand All @@ -163,11 +157,7 @@ def upsert(
pass

@abstractmethod
def drop(
self,
coords,
**kwargs
) -> xr.backends.common.AbstractWritableDataStore:
def drop(self, coords, **kwargs) -> xr.backends.common.AbstractWritableDataStore:
"""
Drop coords of the tensor, this can rewrite the hole file depending on the storage
Expand Down
Loading
Loading