From 0e7b5389da4e8da5da656e91544bab7850a67f47 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 17 Nov 2021 17:27:58 +0000 Subject: [PATCH 1/6] Functionality for storing batches with zarr Add `BatchGenerator.to_zarr` and `BatchGenerator.from_zarr` to make it possible to save generated batches to zarr and later load them from zarr. By chunking along the batch dimension this enables fast data-loading at training time. --- xbatcher/generators.py | 46 ++++++++++++++++++++++++++++++++++ xbatcher/tests/test_to_zarr.py | 34 +++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 xbatcher/tests/test_to_zarr.py diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 612be61..0e3da67 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -144,3 +144,49 @@ def _iterate_batch_dims(self, ds): def _iterate_input_dims(self, ds): return _iterate_through_dataset(ds, self.input_dims, self.input_overlap) + + def to_zarr(self, path, chunks={"batch": "1Gb"}): + """ + Store batches into a zarr datastore in `path`. To speed up loading of + batches it is recommended that the chunking across batches is set close + to the available RAM on the computere where you are doing ML model + training + """ + batch_datasets = list(self) + # can't call the batch dimension `batch` because Dataset.batch is used + # for the batch acccessor. Instead we'll call it `batch_number` + ds_all = xr.concat(batch_datasets, dim="batch_number").reset_index("sample") + if "batch" in chunks: + chunks["batch_number"] = chunks.pop("batch") + + if len(chunks) > 0: + ds_all = ds_all.chunk(chunks) + ds_all.to_zarr(path) + + @staticmethod + def from_zarr(path): + """ + Load a batch generator from the zarr datastore at a given `path` + """ + return StoredBatchesGenerator(path=path) + + +class StoredBatchesGenerator: + """ + Create a generator which mimicks the behaviour of BatchGenerator but loads + the batches from a zarr store that was previously created with + `BatchGenerator.to_zarr` + """ + def __init__(self, path): + self.ds_batches = xr.open_zarr(path) + self.path = path + + def __iter__(self): + for batch_id in self.ds_batches.batch_number.values: + ds_batch = self.ds_batches.sel(batch_number=batch_id) + # create a MultiIndex like we had before storing the batches + stacked_coords = [ + d for d in ds_batch.coords if d not in ["sample", "batch_number"] + ] + ds_batch = ds_batch.set_index(sample=stacked_coords) + yield ds_batch diff --git a/xbatcher/tests/test_to_zarr.py b/xbatcher/tests/test_to_zarr.py new file mode 100644 index 0000000..46eadc5 --- /dev/null +++ b/xbatcher/tests/test_to_zarr.py @@ -0,0 +1,34 @@ +import xarray as xr +import numpy as np +import tempfile +import xbatcher + + +def test_to_zarr(): + da = xr.DataArray( + np.random.rand(1000, 100, 100), name="foo", dims=["time", "y", "x"] + ).chunk({"time": 1}) + + bgen = xbatcher.BatchGenerator(da, {"time": 10}, preload_batch=False) + + for ds_batch in bgen: + ds_first_batch = ds_batch + break + + tempdir = tempfile.TemporaryDirectory().name + # with tempfile.TemporaryDirectory() as tempdir: + bgen.to_zarr(tempdir) + + bgen_loaded = xbatcher.BatchGenerator.from_zarr(tempdir) + + for loaded_batch in bgen_loaded: + loaded_first_batch = loaded_batch + break + + # DataArray.equals doesn't work while the DataArray's are still stacked + da_first_batch = ds_first_batch.unstack() + da_loaded_first_batch = loaded_first_batch.unstack() + # For some reason DataArray.equals doesn't work here, but DataArray.broadcast_equals did + assert da_loaded_first_batch.broadcast_equals(da_first_batch) + # I think this should mean that DataArray.equals should work + assert (da_loaded_first_batch - da_first_batch).max() == 0.0 From 45e20ec12f8e2e04a3e1de8a43662bcfce19e64e Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 17 Nov 2021 17:31:55 +0000 Subject: [PATCH 2/6] cleanup test --- xbatcher/tests/test_to_zarr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xbatcher/tests/test_to_zarr.py b/xbatcher/tests/test_to_zarr.py index 46eadc5..87b5481 100644 --- a/xbatcher/tests/test_to_zarr.py +++ b/xbatcher/tests/test_to_zarr.py @@ -16,7 +16,6 @@ def test_to_zarr(): break tempdir = tempfile.TemporaryDirectory().name - # with tempfile.TemporaryDirectory() as tempdir: bgen.to_zarr(tempdir) bgen_loaded = xbatcher.BatchGenerator.from_zarr(tempdir) From 90838818fd2a13a6e934b58054e6284c133eade2 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 17 Nov 2021 17:51:40 +0000 Subject: [PATCH 3/6] Apply linting etc with pre-commit --- xbatcher/generators.py | 15 ++++++++++----- xbatcher/tests/test_to_zarr.py | 12 +++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 0e3da67..3ad3d9a 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -145,7 +145,7 @@ def _iterate_batch_dims(self, ds): def _iterate_input_dims(self, ds): return _iterate_through_dataset(ds, self.input_dims, self.input_overlap) - def to_zarr(self, path, chunks={"batch": "1Gb"}): + def to_zarr(self, path, chunks={'batch': '1Gb'}): """ Store batches into a zarr datastore in `path`. To speed up loading of batches it is recommended that the chunking across batches is set close @@ -155,9 +155,11 @@ def to_zarr(self, path, chunks={"batch": "1Gb"}): batch_datasets = list(self) # can't call the batch dimension `batch` because Dataset.batch is used # for the batch acccessor. Instead we'll call it `batch_number` - ds_all = xr.concat(batch_datasets, dim="batch_number").reset_index("sample") - if "batch" in chunks: - chunks["batch_number"] = chunks.pop("batch") + ds_all = xr.concat(batch_datasets, dim='batch_number').reset_index( + 'sample' + ) + if 'batch' in chunks: + chunks['batch_number'] = chunks.pop('batch') if len(chunks) > 0: ds_all = ds_all.chunk(chunks) @@ -177,6 +179,7 @@ class StoredBatchesGenerator: the batches from a zarr store that was previously created with `BatchGenerator.to_zarr` """ + def __init__(self, path): self.ds_batches = xr.open_zarr(path) self.path = path @@ -186,7 +189,9 @@ def __iter__(self): ds_batch = self.ds_batches.sel(batch_number=batch_id) # create a MultiIndex like we had before storing the batches stacked_coords = [ - d for d in ds_batch.coords if d not in ["sample", "batch_number"] + d + for d in ds_batch.coords + if d not in ['sample', 'batch_number'] ] ds_batch = ds_batch.set_index(sample=stacked_coords) yield ds_batch diff --git a/xbatcher/tests/test_to_zarr.py b/xbatcher/tests/test_to_zarr.py index 87b5481..f1b019f 100644 --- a/xbatcher/tests/test_to_zarr.py +++ b/xbatcher/tests/test_to_zarr.py @@ -1,15 +1,17 @@ -import xarray as xr -import numpy as np import tempfile + +import numpy as np +import xarray as xr + import xbatcher def test_to_zarr(): da = xr.DataArray( - np.random.rand(1000, 100, 100), name="foo", dims=["time", "y", "x"] - ).chunk({"time": 1}) + np.random.rand(1000, 100, 100), name='foo', dims=['time', 'y', 'x'] + ).chunk({'time': 1}) - bgen = xbatcher.BatchGenerator(da, {"time": 10}, preload_batch=False) + bgen = xbatcher.BatchGenerator(da, {'time': 10}, preload_batch=False) for ds_batch in bgen: ds_first_batch = ds_batch From 7b2341bfc7166f54130ae20efd47791d2cb10d27 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 17 Nov 2021 17:54:15 +0000 Subject: [PATCH 4/6] add zarr to dev-requirements --- dev-requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/dev-requirements.txt b/dev-requirements.txt index 642f9c5..43ff679 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,4 @@ pytest coverage +zarr -r requirements.txt From 1ce312fae4480d6eb8935c71f603472f741d388e Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 10 May 2022 16:08:15 +0200 Subject: [PATCH 5/6] store init attrs and create BatchGeneratorBase --- xbatcher/generators.py | 82 ++++++++++++++++++++++------------ xbatcher/tests/test_to_zarr.py | 2 +- 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 87d0260..672e2dd 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -3,6 +3,7 @@ import itertools from collections import OrderedDict from typing import Any, Dict, Hashable, Iterator +import json import xarray as xr @@ -41,7 +42,7 @@ def _iterate_through_dataset(ds, dims, overlap={}): yield ds.isel(**selector) -def _drop_input_dims(ds, input_dims, suffix='_input'): +def _drop_input_dims(ds, input_dims, suffix="_input"): # remove input_dims coordinates from datasets, rename the dimensions # then put intput_dims back in as coordinates out = ds.copy() @@ -55,7 +56,7 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): return out -def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'): +def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name="sample"): batch_dims = [d for d in ds.dims if d not in input_dims] if len(batch_dims) < 2: return ds @@ -65,7 +66,21 @@ def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'): return ds_stack.transpose(*dim_order) -class BatchGenerator: +class BatchGeneratorBase: + def __init__( + self, + input_dims: Dict[Hashable, int], + input_overlap: Dict[Hashable, int] = {}, + batch_dims: Dict[Hashable, int] = {}, + concat_input_dims: bool = False, + ): + self.input_dims = OrderedDict(input_dims) + self.input_overlap = input_overlap + self.batch_dims = OrderedDict(batch_dims) + self.concat_input_dims = concat_input_dims + + +class BatchGenerator(BatchGeneratorBase): """Create generator for iterating through xarray datarrays / datasets in batches. @@ -107,18 +122,18 @@ def __init__( concat_input_dims: bool = False, preload_batch: bool = True, ): + super().__init__( + input_dims=input_dims, + input_overlap=input_overlap, + batch_dims=batch_dims, + concat_input_dims=concat_input_dims, + ) self.ds = _as_xarray_dataset(ds) # should be a dict - self.input_dims = OrderedDict(input_dims) - self.input_overlap = input_overlap - self.batch_dims = OrderedDict(batch_dims) - self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch - self._batches: Dict[ - int, Any - ] = self._gen_batches() # dict cache for batches + self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches # in the future, we can make this a lru cache or similar thing (cachey?) def __iter__(self) -> Iterator[xr.Dataset]: @@ -132,7 +147,7 @@ def __getitem__(self, idx: int) -> xr.Dataset: if not isinstance(idx, int): raise NotImplementedError( - f'{type(self).__name__}.__getitem__ currently requires a single integer key' + f"{type(self).__name__}.__getitem__ currently requires a single integer key" ) if idx < 0: @@ -141,7 +156,7 @@ def __getitem__(self, idx: int) -> xr.Dataset: if idx in self._batches: return self._batches[idx] else: - raise IndexError('list index out of range') + raise IndexError("list index out of range") def _gen_batches(self) -> dict: # in the future, we will want to do the batch generation lazily @@ -153,17 +168,15 @@ def _gen_batches(self) -> dict: ds_batch.load() input_generator = self._iterate_input_dims(ds_batch) if self.concat_input_dims: - new_dim_suffix = '_input' + new_dim_suffix = "_input" all_dsets = [ _drop_input_dims( ds_input, list(self.input_dims), suffix=new_dim_suffix ) for ds_input in input_generator ] - dsc = xr.concat(all_dsets, dim='input_batch') - new_input_dims = [ - str(dim) + new_dim_suffix for dim in self.input_dims - ] + dsc = xr.concat(all_dsets, dim="input_batch") + new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] batches.append(_maybe_stack_batch_dims(dsc, new_input_dims)) else: for ds_input in input_generator: @@ -179,7 +192,7 @@ def _iterate_batch_dims(self, ds): def _iterate_input_dims(self, ds): return _iterate_through_dataset(ds, self.input_dims, self.input_overlap) - def to_zarr(self, path, chunks={'batch': '1Gb'}): + def to_zarr(self, path, chunks={"batch": "1Gb"}): """ Store batches into a zarr datastore in `path`. To speed up loading of batches it is recommended that the chunking across batches is set close @@ -189,14 +202,15 @@ def to_zarr(self, path, chunks={'batch': '1Gb'}): batch_datasets = list(self) # can't call the batch dimension `batch` because Dataset.batch is used # for the batch acccessor. Instead we'll call it `batch_number` - ds_all = xr.concat(batch_datasets, dim='batch_number').reset_index( - 'sample' - ) - if 'batch' in chunks: - chunks['batch_number'] = chunks.pop('batch') + ds_all = xr.concat(batch_datasets, dim="batch_number").reset_index("sample") + if "batch" in chunks: + chunks["batch_number"] = chunks.pop("batch") if len(chunks) > 0: ds_all = ds_all.chunk(chunks) + + for v in StoredBatchesGenerator.INIT_ARGS_TO_SERIALIZE: + ds_all.attrs[v] = json.dumps(getattr(self, v)) ds_all.to_zarr(path) @staticmethod @@ -207,25 +221,37 @@ def from_zarr(path): return StoredBatchesGenerator(path=path) -class StoredBatchesGenerator: +class StoredBatchesGenerator(BatchGeneratorBase): """ Create a generator which mimicks the behaviour of BatchGenerator but loads the batches from a zarr store that was previously created with - `BatchGenerator.to_zarr` + `BatchGenerator.to_zarr`. Arguments which the original BatchGenerator was + created with are serialized using json and saved as attributes in the + zarr-store """ + INIT_ARGS_TO_SERIALIZE = [ + "input_dims", + "input_overlap", + "batch_dims", + "concat_input_dims", + ] + def __init__(self, path): self.ds_batches = xr.open_zarr(path) self.path = path + init_kws = { + v: json.loads(self.ds_batches.attrs[v]) for v in self.INIT_ARGS_TO_SERIALIZE + } + super().__init__(**init_kws) + def __iter__(self): for batch_id in self.ds_batches.batch_number.values: ds_batch = self.ds_batches.sel(batch_number=batch_id) # create a MultiIndex like we had before storing the batches stacked_coords = [ - d - for d in ds_batch.coords - if d not in ['sample', 'batch_number'] + d for d in ds_batch.coords if d not in ["sample", "batch_number"] ] ds_batch = ds_batch.set_index(sample=stacked_coords) yield ds_batch diff --git a/xbatcher/tests/test_to_zarr.py b/xbatcher/tests/test_to_zarr.py index f1b019f..52ef495 100644 --- a/xbatcher/tests/test_to_zarr.py +++ b/xbatcher/tests/test_to_zarr.py @@ -18,7 +18,7 @@ def test_to_zarr(): break tempdir = tempfile.TemporaryDirectory().name - bgen.to_zarr(tempdir) + bgen.to_zarr(tempdir, chunks={}) bgen_loaded = xbatcher.BatchGenerator.from_zarr(tempdir) From 08a9e94376af44935551f4f65fe4f7d4619e8d1c Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 11 May 2022 07:55:38 +0200 Subject: [PATCH 6/6] linting fixes --- xbatcher/generators.py | 47 +++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 672e2dd..4518e21 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,9 +1,9 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools +import json from collections import OrderedDict from typing import Any, Dict, Hashable, Iterator -import json import xarray as xr @@ -42,7 +42,7 @@ def _iterate_through_dataset(ds, dims, overlap={}): yield ds.isel(**selector) -def _drop_input_dims(ds, input_dims, suffix="_input"): +def _drop_input_dims(ds, input_dims, suffix='_input'): # remove input_dims coordinates from datasets, rename the dimensions # then put intput_dims back in as coordinates out = ds.copy() @@ -56,7 +56,7 @@ def _drop_input_dims(ds, input_dims, suffix="_input"): return out -def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name="sample"): +def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'): batch_dims = [d for d in ds.dims if d not in input_dims] if len(batch_dims) < 2: return ds @@ -133,7 +133,9 @@ def __init__( # should be a dict self.preload_batch = preload_batch - self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches + self._batches: Dict[ + int, Any + ] = self._gen_batches() # dict cache for batches # in the future, we can make this a lru cache or similar thing (cachey?) def __iter__(self) -> Iterator[xr.Dataset]: @@ -147,7 +149,7 @@ def __getitem__(self, idx: int) -> xr.Dataset: if not isinstance(idx, int): raise NotImplementedError( - f"{type(self).__name__}.__getitem__ currently requires a single integer key" + f'{type(self).__name__}.__getitem__ currently requires a single integer key' ) if idx < 0: @@ -156,7 +158,7 @@ def __getitem__(self, idx: int) -> xr.Dataset: if idx in self._batches: return self._batches[idx] else: - raise IndexError("list index out of range") + raise IndexError('list index out of range') def _gen_batches(self) -> dict: # in the future, we will want to do the batch generation lazily @@ -168,15 +170,17 @@ def _gen_batches(self) -> dict: ds_batch.load() input_generator = self._iterate_input_dims(ds_batch) if self.concat_input_dims: - new_dim_suffix = "_input" + new_dim_suffix = '_input' all_dsets = [ _drop_input_dims( ds_input, list(self.input_dims), suffix=new_dim_suffix ) for ds_input in input_generator ] - dsc = xr.concat(all_dsets, dim="input_batch") - new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] + dsc = xr.concat(all_dsets, dim='input_batch') + new_input_dims = [ + str(dim) + new_dim_suffix for dim in self.input_dims + ] batches.append(_maybe_stack_batch_dims(dsc, new_input_dims)) else: for ds_input in input_generator: @@ -192,7 +196,7 @@ def _iterate_batch_dims(self, ds): def _iterate_input_dims(self, ds): return _iterate_through_dataset(ds, self.input_dims, self.input_overlap) - def to_zarr(self, path, chunks={"batch": "1Gb"}): + def to_zarr(self, path, chunks={'batch': '1Gb'}): """ Store batches into a zarr datastore in `path`. To speed up loading of batches it is recommended that the chunking across batches is set close @@ -202,9 +206,11 @@ def to_zarr(self, path, chunks={"batch": "1Gb"}): batch_datasets = list(self) # can't call the batch dimension `batch` because Dataset.batch is used # for the batch acccessor. Instead we'll call it `batch_number` - ds_all = xr.concat(batch_datasets, dim="batch_number").reset_index("sample") - if "batch" in chunks: - chunks["batch_number"] = chunks.pop("batch") + ds_all = xr.concat(batch_datasets, dim='batch_number').reset_index( + 'sample' + ) + if 'batch' in chunks: + chunks['batch_number'] = chunks.pop('batch') if len(chunks) > 0: ds_all = ds_all.chunk(chunks) @@ -231,10 +237,10 @@ class StoredBatchesGenerator(BatchGeneratorBase): """ INIT_ARGS_TO_SERIALIZE = [ - "input_dims", - "input_overlap", - "batch_dims", - "concat_input_dims", + 'input_dims', + 'input_overlap', + 'batch_dims', + 'concat_input_dims', ] def __init__(self, path): @@ -242,7 +248,8 @@ def __init__(self, path): self.path = path init_kws = { - v: json.loads(self.ds_batches.attrs[v]) for v in self.INIT_ARGS_TO_SERIALIZE + v: json.loads(self.ds_batches.attrs[v]) + for v in self.INIT_ARGS_TO_SERIALIZE } super().__init__(**init_kws) @@ -251,7 +258,9 @@ def __iter__(self): ds_batch = self.ds_batches.sel(batch_number=batch_id) # create a MultiIndex like we had before storing the batches stacked_coords = [ - d for d in ds_batch.coords if d not in ["sample", "batch_number"] + d + for d in ds_batch.coords + if d not in ['sample', 'batch_number'] ] ds_batch = ds_batch.set_index(sample=stacked_coords) yield ds_batch