From 1192f87ebeee7c1ea28cf68142d3a0e6b8192d58 Mon Sep 17 00:00:00 2001 From: Joseph Gonzalez Date: Mon, 19 Aug 2024 11:19:48 -0400 Subject: [PATCH] Allow the use of datasets in the apply_on_groups method and also use kwargs as parameter instead of **kwargs --- tensordb/storages/base_storage.py | 36 +++++------- tensordb/storages/cached_storage.py | 77 ++++++++++++------------- tensordb/storages/mapping.py | 82 +++++++++++++-------------- tensordb/storages/zarr_storage.py | 2 +- tensordb/tests/test_cached_storage.py | 46 ++++++++------- tensordb/tests/test_json_storage.py | 23 ++++---- tensordb/utils/dag.py | 39 +++++++------ 7 files changed, 151 insertions(+), 154 deletions(-) diff --git a/tensordb/storages/base_storage.py b/tensordb/storages/base_storage.py index fbacaaa..b1e4f19 100644 --- a/tensordb/storages/base_storage.py +++ b/tensordb/storages/base_storage.py @@ -31,11 +31,11 @@ class BaseStorage: """ def __init__( - self, - base_map: Union[Mapping, MutableMapping], - tmp_map: Union[Mapping, MutableMapping], - data_names: Union[str, List[str]] = "data", - **kwargs + self, + base_map: Union[Mapping, MutableMapping], + tmp_map: Union[Mapping, MutableMapping], + data_names: Union[str, List[str]] = "data", + **kwargs ): if not isinstance(base_map, Mapping): base_map = Mapping(base_map) @@ -47,7 +47,9 @@ def __init__( self.group = None def get_data_names_list(self) -> List[str]: - return self.data_names if isinstance(self.data_names, list) else [self.data_names] + return ( + self.data_names if isinstance(self.data_names, list) else [self.data_names] + ) def delete_tensor(self): """ @@ -64,9 +66,7 @@ def delete_tensor(self): @abstractmethod def append( - self, - new_data: Union[xr.DataArray, xr.Dataset], - **kwargs + self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs ) -> List[xr.backends.common.AbstractWritableDataStore]: """ This abstractmethod must be overwritten to append new_data to an existing file, the way that it append the data @@ -90,9 +90,7 @@ def append( @abstractmethod def update( - self, - new_data: Union[xr.DataArray, xr.Dataset], - **kwargs + self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs ) -> xr.backends.common.AbstractWritableDataStore: """ This abstractmethod must be overwritten to update new_data to an existing file, so it must not insert any new @@ -116,9 +114,7 @@ def update( @abstractmethod def store( - self, - new_data: Union[xr.DataArray, xr.Dataset], - **kwargs + self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs ) -> xr.backends.common.AbstractWritableDataStore: """ This abstractmethod must be overwritten to store new_data to an existing file, so it must create @@ -140,9 +136,7 @@ def store( @abstractmethod def upsert( - self, - new_data: Union[xr.DataArray, xr.Dataset], - **kwargs + self, new_data: Union[xr.DataArray, xr.Dataset], **kwargs ) -> List[xr.backends.common.AbstractWritableDataStore]: """ This abstractmethod must be overwritten to update and append new_data to an existing file, @@ -163,11 +157,7 @@ def upsert( pass @abstractmethod - def drop( - self, - coords, - **kwargs - ) -> xr.backends.common.AbstractWritableDataStore: + def drop(self, coords, **kwargs) -> xr.backends.common.AbstractWritableDataStore: """ Drop coords of the tensor, this can rewrite the hole file depending on the storage diff --git a/tensordb/storages/cached_storage.py b/tensordb/storages/cached_storage.py index f6bf939..61e7a2a 100644 --- a/tensordb/storages/cached_storage.py +++ b/tensordb/storages/cached_storage.py @@ -37,13 +37,13 @@ class CachedStorage: """ def __init__( - self, - storage: BaseStorage, - max_cached_in_dim: int, - dim: str, - sort_dims: List[str] = None, - merge_cache: bool = False, - update_logic: Literal["keep_last", "combine_first"] = "combine_first" + self, + storage: BaseStorage, + max_cached_in_dim: int, + dim: str, + sort_dims: List[str] = None, + merge_cache: bool = False, + update_logic: Literal["keep_last", "combine_first"] = "combine_first", ): self.storage = storage self.max_cached_in_dim = max_cached_in_dim @@ -57,69 +57,70 @@ def __init__( def _clean_cache(self): self._cache = { - 'store': {'new_data': []}, - 'append': {'new_data': []}, - 'update': {'new_data': []} + "store": {"new_data": []}, + "append": {"new_data": []}, + "update": {"new_data": []}, } self._cached_count = 0 - def add_operation(self, type_operation: str, new_data: xr.DataArray, parameters: Dict[str, Any]): + def add_operation( + self, type_operation: str, new_data: xr.DataArray, parameters: Dict[str, Any] + ): self._cached_count += new_data.sizes[self.dim] - if type_operation == 'append' and self._cache['store']['new_data']: - type_operation = 'store' + if type_operation == "append" and self._cache["store"]["new_data"]: + type_operation = "store" self._cache[type_operation].update(parameters) - if type_operation == "update" and len(self._cache["update"]['new_data']): + if type_operation == "update" and len(self._cache["update"]["new_data"]): self.merge_updates(new_data) else: - self._cache[type_operation]['new_data'].append(new_data) + self._cache[type_operation]["new_data"].append(new_data) if self._cached_count > self.max_cached_in_dim: self.execute_operations() def merge_updates(self, new_data): - data = self._cache["update"]['new_data'][-1] + data = self._cache["update"]["new_data"][-1] if self.update_logic == "keep_last": - data = data.sel({self.dim: ~data.coords[self.dim].isin(new_data.coords[self.dim])}) + data = data.sel( + {self.dim: ~data.coords[self.dim].isin(new_data.coords[self.dim])} + ) data = new_data.combine_first(data) - self._cache["update"]['new_data'][-1] = data + self._cache["update"]["new_data"][-1] = data def merge_update_on_append(self): append_data = self._cache["append"]["new_data"] update_data = self._cache["update"]["new_data"] if not isinstance(update_data, list) and not isinstance(append_data, list): - common_coord = append_data.indexes[self.dim].intersection(update_data.indexes[self.dim]) + common_coord = append_data.indexes[self.dim].intersection( + update_data.indexes[self.dim] + ) if len(common_coord): - self._cache["append"]["new_data"] = update_data.sel(**{ - self.dim: common_coord - }).combine_first( - append_data + self._cache["append"]["new_data"] = update_data.sel( + **{self.dim: common_coord} + ).combine_first(append_data) + update_data = update_data.sel( + **{self.dim: ~update_data.coords[self.dim].isin(common_coord)} ) - update_data = update_data.sel(**{ - self.dim: ~update_data.coords[self.dim].isin(common_coord) - }) self._cache["update"]["new_data"] = [] if update_data.sizes[self.dim]: self._cache["update"]["new_data"] = update_data def execute_operations(self): - for type_operation in ['store', 'append', 'update']: + for type_operation in ["store", "append", "update"]: operation = self._cache[type_operation] - if not operation['new_data']: + if not operation["new_data"]: continue - operation['new_data'] = xr.concat( - operation['new_data'], - dim=self.dim - ) + operation["new_data"] = xr.concat(operation["new_data"], dim=self.dim) if self.sort_dims: - operation['new_data'] = operation['new_data'].sortby(self.sort_dims) + operation["new_data"] = operation["new_data"].sortby(self.sort_dims) if self.merge_cache: self.merge_update_on_append() - for type_operation in ['store', 'append', 'update']: + for type_operation in ["store", "append", "update"]: operation = self._cache[type_operation] - if isinstance(operation['new_data'], list): + if isinstance(operation["new_data"], list): continue try: getattr(self.storage, type_operation)(**operation) @@ -133,14 +134,14 @@ def read(self, **kwargs) -> xr.DataArray: return self.storage.read(**kwargs) def append(self, new_data: xr.DataArray, **kwargs): - self.add_operation('append', new_data, kwargs) + self.add_operation("append", new_data, kwargs) def update(self, new_data: xr.DataArray, **kwargs): - self.add_operation('update', new_data, kwargs) + self.add_operation("update", new_data, kwargs) def store(self, new_data: xr.DataArray, **kwargs): self._clean_cache() - self.add_operation('store', new_data, kwargs) + self.add_operation("store", new_data, kwargs) def close(self): self.execute_operations() diff --git a/tensordb/storages/mapping.py b/tensordb/storages/mapping.py index 791e86d..c1a29e1 100644 --- a/tensordb/storages/mapping.py +++ b/tensordb/storages/mapping.py @@ -10,31 +10,31 @@ class Mapping(MutableMapping): def __init__( - self, - mapper: MutableMapping, - sub_path: str = None, - read_lock: PrefixLock = None, - write_lock: PrefixLock = None, - root: str = None, - enable_sub_map: bool = True, + self, + mapper: MutableMapping, + sub_path: str = None, + read_lock: PrefixLock = None, + write_lock: PrefixLock = None, + root: str = None, + enable_sub_map: bool = True, ): self.mapper = mapper self.sub_path = sub_path self.read_lock = PrefixLock("") if read_lock is None else read_lock self.write_lock = self.read_lock if write_lock is None else write_lock self._root = root - self.enable_sub_map = enable_sub_map and hasattr(mapper, 'fs') + self.enable_sub_map = enable_sub_map and hasattr(mapper, "fs") @property def root(self): if self._root is not None: return self._root root = None - if hasattr(self.mapper, 'root'): + if hasattr(self.mapper, "root"): root = self.mapper.root - elif hasattr(self.mapper, 'path'): + elif hasattr(self.mapper, "path"): root = self.mapper.path - elif hasattr(self.mapper, 'url'): + elif hasattr(self.mapper, "url"): root = self.mapper.url self._root = root @@ -45,7 +45,7 @@ def sub_map(self, sub_path): root = self.root if self.enable_sub_map: if root is not None: - root = f'{root}/{sub_path}' + root = f"{root}/{sub_path}" mapper = FSStore(root, fs=mapper.fs) sub_path = self.add_sub_path(sub_path) @@ -55,7 +55,7 @@ def sub_map(self, sub_path): read_lock=self.read_lock, write_lock=self.write_lock, root=root, - enable_sub_map=self.enable_sub_map + enable_sub_map=self.enable_sub_map, ) def add_root(self, key): @@ -63,7 +63,7 @@ def add_root(self, key): return self.root if self.root is None: return key - return f'{self.root}/{key}' + return f"{self.root}/{key}" def add_sub_path(self, key): if self.enable_sub_map or self.sub_path is None: @@ -72,7 +72,7 @@ def add_sub_path(self, key): if key is None: return self.sub_path - return f'{self.sub_path}/{key}' + return f"{self.sub_path}/{key}" def full_path(self, key): return self.add_root(self.add_sub_path(key)) @@ -80,7 +80,7 @@ def full_path(self, key): def add_lock_path(self, key): if self.sub_path is None: return key - return f'{self.sub_path}/{key}' + return f"{self.sub_path}/{key}" def __getitem__(self, key): with self.read_lock[self.add_lock_path(key)]: @@ -99,7 +99,7 @@ def __iter__(self): if self.enable_sub_map: yield key elif key.startswith(self.sub_path): - yield key[len(self.sub_path) + 1:] + yield key[len(self.sub_path) + 1 :] def __len__(self): return sum(1 for _ in self) @@ -120,7 +120,7 @@ def delitems(self, keys, **kwargs): self.mapper.delitems(keys, **kwargs) def listdir(self, path=None): - if hasattr(self.mapper, 'listdir'): + if hasattr(self.mapper, "listdir"): return self.mapper.listdir(self.add_sub_path(path)) sub_map = self.mapper if path is None else self.sub_map(path) @@ -141,7 +141,9 @@ def info(self, path): def checksum(self, key): return self.mapper.fs.checksum(self.full_path(key)) - def equal_content(self, other, path, method: Literal["checksum", "content"] = "checksum"): + def equal_content( + self, other, path, method: Literal["checksum", "content"] = "checksum" + ): if method == "checksum": return other.checksum(path) == self.checksum(path) @@ -154,11 +156,11 @@ def equal_content(self, other, path, method: Literal["checksum", "content"] = "c @staticmethod def synchronize( - remote_map: "Mapping", - local_map: "Mapping", - checksum_map: "Mapping", - to_local: bool, - force: bool = False, + remote_map: "Mapping", + local_map: "Mapping", + checksum_map: "Mapping", + to_local: bool, + force: bool = False, ): remote_paths = set(list(remote_map.keys())) local_paths = set(list(local_map.keys())) @@ -197,26 +199,22 @@ def _move_data(path): list(p.map(_move_data, total_paths)) def folders_synchronize( - self, - destination, - folders, - comparing_method: Literal["checksum", "content"], - n_threads + self, + destination, + folders, + comparing_method: Literal["checksum", "content"], + n_threads, ): - source_paths = [ - file - for folder in folders - for file in self.listdir(folder) - ] + source_paths = [file for folder in folders for file in self.listdir(folder)] destination_paths = [ - file - for folder in folders - for file in destination.listdir(folder) + file for folder in folders for file in destination.listdir(folder) ] delete_paths = sorted(set(destination_paths) - set(source_paths)) def copy_file(path): - if path in destination and self.equal_content(destination, path, comparing_method): + if path in destination and self.equal_content( + destination, path, comparing_method + ): return None destination[path] = self[path] return path @@ -226,11 +224,9 @@ def del_file(path): modified_files = list(delete_paths) with ThreadPoolExecutor(n_threads) as p: - modified_files.extend([ - path - for path in p.map(copy_file, source_paths) - if path is not None - ]) + modified_files.extend( + [path for path in p.map(copy_file, source_paths) if path is not None] + ) list(p.map(del_file, delete_paths)) return modified_files diff --git a/tensordb/storages/zarr_storage.py b/tensordb/storages/zarr_storage.py index 910f611..4b83556 100644 --- a/tensordb/storages/zarr_storage.py +++ b/tensordb/storages/zarr_storage.py @@ -405,7 +405,7 @@ def update( synchronizer=self.synchronizer, region=regions, # This option is save based on this https://github.com/pydata/xarray/issues/9072 - safe_chunks=False + safe_chunks=False, ) return delayed_write diff --git a/tensordb/tests/test_cached_storage.py b/tensordb/tests/test_cached_storage.py index 9df5bd7..fa8437d 100644 --- a/tensordb/tests/test_cached_storage.py +++ b/tensordb/tests/test_cached_storage.py @@ -9,33 +9,35 @@ # TODO: Add more tests for the update cases + class TestCachedTensor: @pytest.fixture(autouse=True) def setup_tests(self, tmpdir): sub_path = tmpdir.strpath storage = ZarrStorage( base_map=fsspec.get_mapper(sub_path + "/store"), - tmp_map=fsspec.get_mapper(sub_path + '/tmp'), - path='zarr_cache', - dataset_names='cached_test', - chunks={'index': 3, 'columns': 2}, + tmp_map=fsspec.get_mapper(sub_path + "/tmp"), + path="zarr_cache", + dataset_names="cached_test", + chunks={"index": 3, "columns": 2}, ) self.cached_storage = CachedStorage( - storage=storage, - max_cached_in_dim=3, - dim='index' + storage=storage, max_cached_in_dim=3, dim="index" ) self.arr = xr.DataArray( - data=np.array([ - [1, 2, 7, 4, 5], - [np.nan, 3, 5, 5, 6], - [3, 3, np.nan, 5, 6], - [np.nan, 3, 10, 5, 6], - [np.nan, 7, 8, 5, 6], - ], dtype=float), - dims=['index', 'columns'], - coords={'index': [0, 1, 2, 3, 4], 'columns': [0, 1, 2, 3, 4]}, + data=np.array( + [ + [1, 2, 7, 4, 5], + [np.nan, 3, 5, 5, 6], + [3, 3, np.nan, 5, 6], + [np.nan, 3, 10, 5, 6], + [np.nan, 7, 8, 5, 6], + ], + dtype=float, + ), + dims=["index", "columns"], + coords={"index": [0, 1, 2, 3, 4], "columns": [0, 1, 2, 3, 4]}, ) def test_append(self): @@ -44,16 +46,16 @@ def test_append(self): self.cached_storage.append(self.arr.isel(index=[2])) assert self.cached_storage._cached_count == 3 - assert len(self.cached_storage._cache['append']['new_data']) == 3 + assert len(self.cached_storage._cache["append"]["new_data"]) == 3 self.cached_storage.append(self.arr.isel(index=[3])) assert self.cached_storage._cached_count == 0 - assert len(self.cached_storage._cache['append']['new_data']) == 0 + assert len(self.cached_storage._cache["append"]["new_data"]) == 0 self.cached_storage.append(self.arr.isel(index=[4])) self.cached_storage.close() assert self.cached_storage._cached_count == 0 - assert len(self.cached_storage._cache['append']['new_data']) == 0 + assert len(self.cached_storage._cache["append"]["new_data"]) == 0 assert self.cached_storage.read().equals(self.arr) @@ -62,14 +64,14 @@ def test_store(self): self.cached_storage.append(self.arr.isel(index=[1])) self.cached_storage.append(self.arr.isel(index=[2])) assert self.cached_storage._cached_count == 3 - assert len(self.cached_storage._cache['store']['new_data']) == 3 + assert len(self.cached_storage._cache["store"]["new_data"]) == 3 self.cached_storage.store(self.arr.isel(index=[3, 4])) assert self.cached_storage._cached_count == 2 - assert len(self.cached_storage._cache['store']['new_data']) == 1 + assert len(self.cached_storage._cache["store"]["new_data"]) == 1 self.cached_storage.close() assert self.cached_storage._cached_count == 0 - assert len(self.cached_storage._cache['store']['new_data']) == 0 + assert len(self.cached_storage._cache["store"]["new_data"]) == 0 assert self.cached_storage.read().equals(self.arr.isel(index=[3, 4])) diff --git a/tensordb/tests/test_json_storage.py b/tensordb/tests/test_json_storage.py index f173d68..687a167 100644 --- a/tensordb/tests/test_json_storage.py +++ b/tensordb/tests/test_json_storage.py @@ -9,21 +9,24 @@ class TestJsonStorage: def setup_tests(self, tmpdir): sub_path = tmpdir.strpath self.storage = JsonStorage( - base_map=fsspec.get_mapper(sub_path + '/json'), - tmp_map=fsspec.get_mapper(sub_path + '/tmp'), - path='json_storage' + base_map=fsspec.get_mapper(sub_path + "/json"), + tmp_map=fsspec.get_mapper(sub_path + "/tmp"), + path="json_storage", ) - self.dummy_data = {'a': 0, '1': 2, 'c': {'e': 10}} + self.dummy_data = {"a": 0, "1": 2, "c": {"e": 10}} def test_store_data(self): - self.storage.store(path='first/tensor_metadata', new_data=self.dummy_data) - assert self.dummy_data == self.storage.read('first/tensor_metadata') + self.storage.store(path="first/tensor_metadata", new_data=self.dummy_data) + assert self.dummy_data == self.storage.read("first/tensor_metadata") def test_upsert_data(self): - self.storage.store(path='first/tensor_metadata', new_data=self.dummy_data) - upsert_d = {'b': 5, 'g': [10, 12]} - self.storage.upsert(path='first/tensor_metadata', new_data=upsert_d) - assert self.storage.read('first/tensor_metadata') == {**self.dummy_data, **upsert_d} + self.storage.store(path="first/tensor_metadata", new_data=self.dummy_data) + upsert_d = {"b": 5, "g": [10, 12]} + self.storage.upsert(path="first/tensor_metadata", new_data=upsert_d) + assert self.storage.read("first/tensor_metadata") == { + **self.dummy_data, + **upsert_d, + } if __name__ == "__main__": diff --git a/tensordb/utils/dag.py b/tensordb/utils/dag.py index 657d2d9..2c22fb6 100644 --- a/tensordb/utils/dag.py +++ b/tensordb/utils/dag.py @@ -10,8 +10,7 @@ def get_tensor_dag( - tensors: List[TensorDefinition], - check_dependencies: bool + tensors: List[TensorDefinition], check_dependencies: bool ) -> List[List[TensorDefinition]]: tensor_search = {tensor.path: tensor for tensor in tensors} # Create the dag based on the dependencies, so the node used as Key depends on the Nodes in the values @@ -27,28 +26,30 @@ def get_tensor_dag( if not ordered: break - ordered_tensors.append([ - tensor_search[path] - for path in ordered - if check_dependencies or path in tensor_search - ]) + ordered_tensors.append( + [ + tensor_search[path] + for path in ordered + if check_dependencies or path in tensor_search + ] + ) dag = { - item: dependencies - ordered for item, dependencies in dag.items() + item: dependencies - ordered + for item, dependencies in dag.items() if item not in ordered } if dag: raise ValueError( - f'There is a cyclic dependency between the tensors, ' - f'the key is the node and the values are the dependencies: {dag}' + f"There is a cyclic dependency between the tensors, " + f"the key is the node and the values are the dependencies: {dag}" ) return ordered_tensors def add_dependencies( - tensors: List[TensorDefinition], - total_tensors: List[TensorDefinition] + tensors: List[TensorDefinition], total_tensors: List[TensorDefinition] ) -> List[TensorDefinition]: total_tensors_search = {tensor.path: tensor for tensor in total_tensors} total_paths = set(tensor.path for tensor in tensors) @@ -62,10 +63,12 @@ def add_dependencies( def get_leaf_tasks(tensors, new_dependencies=None): # Add the non blocking tasks to a final task final_tasks = set(tensor.path for tensor in tensors) - final_tasks -= set().union(*[ - set(tensor.dag.depends) | new_dependencies.get(tensor.path, set()) - for tensor in tensors - ]) + final_tasks -= set().union( + *[ + set(tensor.dag.depends) | new_dependencies.get(tensor.path, set()) + for tensor in tensors + ] + ) return final_tasks @@ -88,6 +91,8 @@ def get_limit_dependencies(total_tensors, max_parallelization_per_group): level = list(mit.sliced(level, limit)) prev_dependencies = set(level[0]) for act_tensors in level[1:]: - new_dependencies.update({tensor: prev_dependencies for tensor in act_tensors}) + new_dependencies.update( + {tensor: prev_dependencies for tensor in act_tensors} + ) prev_dependencies = set(act_tensors) return new_dependencies