From de8e7f7ec29302a3d7aa19b9540f0fe76dedbe7f Mon Sep 17 00:00:00 2001 From: tgoelles Date: Thu, 12 Oct 2023 14:31:04 +0200 Subject: [PATCH] added example and test it --- .../kedro_datasets/xarray/geotiff_dataset.py | 22 +++++++++++++++++++ .../tests/xarray/test_geotiff_dataset.py | 17 +++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py b/kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py index 976b79a38..f8a56e26d 100644 --- a/kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py +++ b/kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py @@ -59,6 +59,28 @@ def __init__( attribute is None, save version will be autogenerated. fs_args: Extra arguments to pass into underlying filesystem class constructor (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.xarray import GeoTiffDataset + >>> import xarray as xr + >>> import numpy as np + >>> + >>> data = xr.DataArray( + np.random.randn(2, 3, 2), + dims=("band", "y", "x"), + coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]}, + ) + >>> data = data.rio.write_crs("epsg:4326") + >>> data = data.rio.set_spatial_dims("x", "y") + >>> dataset = GeoTiffDataset(filepath="test.tif") + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> xr.testing.assert_allclose(data, reloaded, rtol=1e-5) """ protocol, path = get_protocol_and_path(filepath, version) self._protocol = protocol diff --git a/kedro-datasets/tests/xarray/test_geotiff_dataset.py b/kedro-datasets/tests/xarray/test_geotiff_dataset.py index 5cb7f6d11..b9f9fec61 100644 --- a/kedro-datasets/tests/xarray/test_geotiff_dataset.py +++ b/kedro-datasets/tests/xarray/test_geotiff_dataset.py @@ -1,10 +1,10 @@ from pathlib import Path +import numpy as np import pytest import rioxarray import xarray as xr from kedro.io.core import Version - from kedro_datasets._io import DatasetError from kedro_datasets.xarray import GeoTiffDataset @@ -58,3 +58,18 @@ def test_save_and_load(geotiff_dataset, cog_xarray): assert reloaded.shape == cog_xarray.shape assert reloaded.dims == cog_xarray.dims assert reloaded.equals(cog_xarray) + + +def test_example(tmp_path): + data = xr.DataArray( + np.random.randn(2, 3, 2), + dims=("band", "y", "x"), + coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]}, + ) + # Add spatial coordinates and CRS information + data = data.rio.write_crs("epsg:4326") + data = data.rio.set_spatial_dims("x", "y") + dataset = GeoTiffDataset(filepath=tmp_path.joinpath("test.tif").as_posix()) + dataset.save(data) + reloaded = dataset.load() + xr.testing.assert_allclose(data, reloaded, rtol=1e-5)