Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batches to zarr #40

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pytest
coverage
torch
coverage
pytest-cov
adlfs
zarr
-r requirements.txt
96 changes: 91 additions & 5 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Classes for iterating through xarray datarrays / datasets in batches."""

import itertools
import json
from collections import OrderedDict
from typing import Any, Dict, Hashable, Iterator

Expand Down Expand Up @@ -65,7 +66,21 @@ def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'):
return ds_stack.transpose(*dim_order)


class BatchGenerator:
class BatchGeneratorBase:
def __init__(
self,
input_dims: Dict[Hashable, int],
input_overlap: Dict[Hashable, int] = {},
batch_dims: Dict[Hashable, int] = {},
concat_input_dims: bool = False,
):
self.input_dims = OrderedDict(input_dims)
self.input_overlap = input_overlap
self.batch_dims = OrderedDict(batch_dims)
self.concat_input_dims = concat_input_dims


class BatchGenerator(BatchGeneratorBase):
"""Create generator for iterating through xarray datarrays / datasets in
batches.

Expand Down Expand Up @@ -107,13 +122,15 @@ def __init__(
concat_input_dims: bool = False,
preload_batch: bool = True,
):
super().__init__(
input_dims=input_dims,
input_overlap=input_overlap,
batch_dims=batch_dims,
concat_input_dims=concat_input_dims,
)

self.ds = _as_xarray_dataset(ds)
# should be a dict
self.input_dims = OrderedDict(input_dims)
self.input_overlap = input_overlap
self.batch_dims = OrderedDict(batch_dims)
self.concat_input_dims = concat_input_dims
self.preload_batch = preload_batch

self._batches: Dict[
Expand Down Expand Up @@ -178,3 +195,72 @@ def _iterate_batch_dims(self, ds):

def _iterate_input_dims(self, ds):
return _iterate_through_dataset(ds, self.input_dims, self.input_overlap)

def to_zarr(self, path, chunks={'batch': '1Gb'}):
"""
Store batches into a zarr datastore in `path`. To speed up loading of
batches it is recommended that the chunking across batches is set close
to the available RAM on the computere where you are doing ML model
training
"""
batch_datasets = list(self)
# can't call the batch dimension `batch` because Dataset.batch is used
# for the batch acccessor. Instead we'll call it `batch_number`
ds_all = xr.concat(batch_datasets, dim='batch_number').reset_index(
'sample'
)
if 'batch' in chunks:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test when 'batch' not in chunks

chunks['batch_number'] = chunks.pop('batch')

if len(chunks) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test when len(chunks) == 0

ds_all = ds_all.chunk(chunks)

for v in StoredBatchesGenerator.INIT_ARGS_TO_SERIALIZE:
ds_all.attrs[v] = json.dumps(getattr(self, v))
ds_all.to_zarr(path)

@staticmethod
def from_zarr(path):
"""
Load a batch generator from the zarr datastore at a given `path`
"""
return StoredBatchesGenerator(path=path)


class StoredBatchesGenerator(BatchGeneratorBase):
"""
Create a generator which mimicks the behaviour of BatchGenerator but loads
the batches from a zarr store that was previously created with
`BatchGenerator.to_zarr`. Arguments which the original BatchGenerator was
created with are serialized using json and saved as attributes in the
zarr-store
"""

INIT_ARGS_TO_SERIALIZE = [
'input_dims',
'input_overlap',
'batch_dims',
'concat_input_dims',
]

def __init__(self, path):
self.ds_batches = xr.open_zarr(path)
self.path = path

init_kws = {
v: json.loads(self.ds_batches.attrs[v])
for v in self.INIT_ARGS_TO_SERIALIZE
}
super().__init__(**init_kws)

def __iter__(self):
for batch_id in self.ds_batches.batch_number.values:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exactly why but codecov think something in this for loop is not being covered by the existing tests. Perhaps its the empty iterable (.values) or it could be the if` statement in line 194. Any thoughts?

ds_batch = self.ds_batches.sel(batch_number=batch_id)
# create a MultiIndex like we had before storing the batches
stacked_coords = [
d
for d in ds_batch.coords
if d not in ['sample', 'batch_number']
]
ds_batch = ds_batch.set_index(sample=stacked_coords)
yield ds_batch
35 changes: 35 additions & 0 deletions xbatcher/tests/test_to_zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import tempfile

import numpy as np
import xarray as xr

import xbatcher


def test_to_zarr():
da = xr.DataArray(
np.random.rand(1000, 100, 100), name='foo', dims=['time', 'y', 'x']
).chunk({'time': 1})

bgen = xbatcher.BatchGenerator(da, {'time': 10}, preload_batch=False)

for ds_batch in bgen:
ds_first_batch = ds_batch
break

tempdir = tempfile.TemporaryDirectory().name
bgen.to_zarr(tempdir, chunks={})

bgen_loaded = xbatcher.BatchGenerator.from_zarr(tempdir)

for loaded_batch in bgen_loaded:
loaded_first_batch = loaded_batch
break

# DataArray.equals doesn't work while the DataArray's are still stacked
da_first_batch = ds_first_batch.unstack()
da_loaded_first_batch = loaded_first_batch.unstack()
# For some reason DataArray.equals doesn't work here, but DataArray.broadcast_equals did
assert da_loaded_first_batch.broadcast_equals(da_first_batch)
# I think this should mean that DataArray.equals should work
assert (da_loaded_first_batch - da_first_batch).max() == 0.0