Skip to content

Commit

Permalink
added example and test it
Browse files Browse the repository at this point in the history
  • Loading branch information
tgoelles committed Oct 12, 2023
1 parent 074229d commit de8e7f7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
22 changes: 22 additions & 0 deletions kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ def __init__(
attribute is None, save version will be autogenerated.
fs_args: Extra arguments to pass into underlying filesystem class constructor
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``).
Example usage for the
`Python API <https://kedro.readthedocs.io/en/stable/data/\
advanced_data_catalog_usage.html>`_:
.. code-block:: pycon
>>> from kedro_datasets.xarray import GeoTiffDataset
>>> import xarray as xr
>>> import numpy as np
>>>
>>> data = xr.DataArray(
np.random.randn(2, 3, 2),
dims=("band", "y", "x"),
coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]},
)
>>> data = data.rio.write_crs("epsg:4326")
>>> data = data.rio.set_spatial_dims("x", "y")
>>> dataset = GeoTiffDataset(filepath="test.tif")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> xr.testing.assert_allclose(data, reloaded, rtol=1e-5)
"""
protocol, path = get_protocol_and_path(filepath, version)
self._protocol = protocol
Expand Down
17 changes: 16 additions & 1 deletion kedro-datasets/tests/xarray/test_geotiff_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pathlib import Path

import numpy as np
import pytest
import rioxarray
import xarray as xr
from kedro.io.core import Version

from kedro_datasets._io import DatasetError
from kedro_datasets.xarray import GeoTiffDataset

Expand Down Expand Up @@ -58,3 +58,18 @@ def test_save_and_load(geotiff_dataset, cog_xarray):
assert reloaded.shape == cog_xarray.shape
assert reloaded.dims == cog_xarray.dims
assert reloaded.equals(cog_xarray)


def test_example(tmp_path):
data = xr.DataArray(
np.random.randn(2, 3, 2),
dims=("band", "y", "x"),
coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]},
)
# Add spatial coordinates and CRS information
data = data.rio.write_crs("epsg:4326")
data = data.rio.set_spatial_dims("x", "y")
dataset = GeoTiffDataset(filepath=tmp_path.joinpath("test.tif").as_posix())
dataset.save(data)
reloaded = dataset.load()
xr.testing.assert_allclose(data, reloaded, rtol=1e-5)

0 comments on commit de8e7f7

Please sign in to comment.