Skip to content

Commit

Permalink
Allow the use of datasets in the apply_on_groups method and also use …
Browse files Browse the repository at this point in the history
…kwargs as parameter instead of **kwargs
  • Loading branch information
josephnowak committed Aug 19, 2024
1 parent 13d178d commit 1192f87
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 154 deletions.
36 changes: 13 additions & 23 deletions tensordb/storages/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class BaseStorage:
"""

def __init__(
self,
base_map: Union[Mapping, MutableMapping],
tmp_map: Union[Mapping, MutableMapping],
data_names: Union[str, List[str]] = "data",
**kwargs
self,
base_map: Union[Mapping, MutableMapping],
tmp_map: Union[Mapping, MutableMapping],
data_names: Union[str, List[str]] = "data",
**kwargs
):
if not isinstance(base_map, Mapping):
base_map = Mapping(base_map)
Expand All @@ -47,7 +47,9 @@ def __init__(
self.group = None

def get_data_names_list(self) -> List[str]:
return self.data_names if isinstance(self.data_names, list) else [self.data_names]
return (
self.data_names if isinstance(self.data_names, list) else [self.data_names]
)

def delete_tensor(self):
"""
Expand All @@ -64,9 +66,7 @@ def delete_tensor(self):

@abstractmethod
def append(
self,
new_data: Union[xr.DataArray, xr.Dataset],
**kwargs
self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs
) -> List[xr.backends.common.AbstractWritableDataStore]:
"""
This abstractmethod must be overwritten to append new_data to an existing file, the way that it append the data
Expand All @@ -90,9 +90,7 @@ def append(

@abstractmethod
def update(
self,
new_data: Union[xr.DataArray, xr.Dataset],
**kwargs
self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs
) -> xr.backends.common.AbstractWritableDataStore:
"""
This abstractmethod must be overwritten to update new_data to an existing file, so it must not insert any new
Expand All @@ -116,9 +114,7 @@ def update(

@abstractmethod
def store(
self,
new_data: Union[xr.DataArray, xr.Dataset],
**kwargs
self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs
) -> xr.backends.common.AbstractWritableDataStore:
"""
This abstractmethod must be overwritten to store new_data to an existing file, so it must create
Expand All @@ -140,9 +136,7 @@ def store(

@abstractmethod
def upsert(
self,
new_data: Union[xr.DataArray, xr.Dataset],
**kwargs
self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs
) -> List[xr.backends.common.AbstractWritableDataStore]:
"""
This abstractmethod must be overwritten to update and append new_data to an existing file,
Expand All @@ -163,11 +157,7 @@ def upsert(
pass

@abstractmethod
def drop(
self,
coords,
**kwargs
) -> xr.backends.common.AbstractWritableDataStore:
def drop(self, coords, **kwargs) -> xr.backends.common.AbstractWritableDataStore:
"""
Drop coords of the tensor, this can rewrite the hole file depending on the storage
Expand Down
77 changes: 39 additions & 38 deletions tensordb/storages/cached_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ class CachedStorage:
"""

def __init__(
self,
storage: BaseStorage,
max_cached_in_dim: int,
dim: str,
sort_dims: List[str] = None,
merge_cache: bool = False,
update_logic: Literal["keep_last", "combine_first"] = "combine_first"
self,
storage: BaseStorage,
max_cached_in_dim: int,
dim: str,
sort_dims: List[str] = None,
merge_cache: bool = False,
update_logic: Literal["keep_last", "combine_first"] = "combine_first",
):
self.storage = storage
self.max_cached_in_dim = max_cached_in_dim
Expand All @@ -57,69 +57,70 @@ def __init__(

def _clean_cache(self):
self._cache = {
'store': {'new_data': []},
'append': {'new_data': []},
'update': {'new_data': []}
"store": {"new_data": []},
"append": {"new_data": []},
"update": {"new_data": []},
}
self._cached_count = 0

def add_operation(self, type_operation: str, new_data: xr.DataArray, parameters: Dict[str, Any]):
def add_operation(
self, type_operation: str, new_data: xr.DataArray, parameters: Dict[str, Any]
):
self._cached_count += new_data.sizes[self.dim]
if type_operation == 'append' and self._cache['store']['new_data']:
type_operation = 'store'
if type_operation == "append" and self._cache["store"]["new_data"]:
type_operation = "store"

self._cache[type_operation].update(parameters)
if type_operation == "update" and len(self._cache["update"]['new_data']):
if type_operation == "update" and len(self._cache["update"]["new_data"]):
self.merge_updates(new_data)
else:
self._cache[type_operation]['new_data'].append(new_data)
self._cache[type_operation]["new_data"].append(new_data)

if self._cached_count > self.max_cached_in_dim:
self.execute_operations()

def merge_updates(self, new_data):
data = self._cache["update"]['new_data'][-1]
data = self._cache["update"]["new_data"][-1]
if self.update_logic == "keep_last":
data = data.sel({self.dim: ~data.coords[self.dim].isin(new_data.coords[self.dim])})
data = data.sel(
{self.dim: ~data.coords[self.dim].isin(new_data.coords[self.dim])}
)
data = new_data.combine_first(data)
self._cache["update"]['new_data'][-1] = data
self._cache["update"]["new_data"][-1] = data

def merge_update_on_append(self):
append_data = self._cache["append"]["new_data"]
update_data = self._cache["update"]["new_data"]
if not isinstance(update_data, list) and not isinstance(append_data, list):
common_coord = append_data.indexes[self.dim].intersection(update_data.indexes[self.dim])
common_coord = append_data.indexes[self.dim].intersection(
update_data.indexes[self.dim]
)
if len(common_coord):
self._cache["append"]["new_data"] = update_data.sel(**{
self.dim: common_coord
}).combine_first(
append_data
self._cache["append"]["new_data"] = update_data.sel(
**{self.dim: common_coord}
).combine_first(append_data)
update_data = update_data.sel(
**{self.dim: ~update_data.coords[self.dim].isin(common_coord)}
)
update_data = update_data.sel(**{
self.dim: ~update_data.coords[self.dim].isin(common_coord)
})
self._cache["update"]["new_data"] = []
if update_data.sizes[self.dim]:
self._cache["update"]["new_data"] = update_data

def execute_operations(self):
for type_operation in ['store', 'append', 'update']:
for type_operation in ["store", "append", "update"]:
operation = self._cache[type_operation]
if not operation['new_data']:
if not operation["new_data"]:
continue
operation['new_data'] = xr.concat(
operation['new_data'],
dim=self.dim
)
operation["new_data"] = xr.concat(operation["new_data"], dim=self.dim)
if self.sort_dims:
operation['new_data'] = operation['new_data'].sortby(self.sort_dims)
operation["new_data"] = operation["new_data"].sortby(self.sort_dims)

if self.merge_cache:
self.merge_update_on_append()

for type_operation in ['store', 'append', 'update']:
for type_operation in ["store", "append", "update"]:
operation = self._cache[type_operation]
if isinstance(operation['new_data'], list):
if isinstance(operation["new_data"], list):
continue
try:
getattr(self.storage, type_operation)(**operation)
Expand All @@ -133,14 +134,14 @@ def read(self, **kwargs) -> xr.DataArray:
return self.storage.read(**kwargs)

def append(self, new_data: xr.DataArray, **kwargs):
self.add_operation('append', new_data, kwargs)
self.add_operation("append", new_data, kwargs)

def update(self, new_data: xr.DataArray, **kwargs):
self.add_operation('update', new_data, kwargs)
self.add_operation("update", new_data, kwargs)

def store(self, new_data: xr.DataArray, **kwargs):
self._clean_cache()
self.add_operation('store', new_data, kwargs)
self.add_operation("store", new_data, kwargs)

def close(self):
self.execute_operations()
Loading

0 comments on commit 1192f87

Please sign in to comment.