diff --git a/kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py b/kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py index cf0fde10c..4e4a07eaa 100644 --- a/kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py +++ b/kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py @@ -15,7 +15,6 @@ from kedro_datasets._io import AbstractVersionedDataset, DatasetError - class GeoTiffDataset(AbstractVersionedDataset[xarray.DataArray, xarray.DataArray]): """``GeoTiffDataset`` loads and saves geotiff files and reads them as xarray DataArrays. @@ -88,7 +87,7 @@ def _load(self) -> xarray.DataArray: def _save(self, data: xarray.DataArray) -> None: save_path = self._get_save_path() if self._filepath.suffix in [".tif", ".tiff"]: - data.to_raster(save_path.as_posix(), **self._save_args) + data.rio.to_raster(save_path.as_posix(), **self._save_args) else: raise ValueError("expecting .tif or .tiff file suffix") diff --git a/kedro-datasets/tests/xarray/test_geotiff_dataset.py b/kedro-datasets/tests/xarray/test_geotiff_dataset.py index efdad267b..78745b42c 100644 --- a/kedro-datasets/tests/xarray/test_geotiff_dataset.py +++ b/kedro-datasets/tests/xarray/test_geotiff_dataset.py @@ -2,6 +2,7 @@ import pytest import xarray as xr +import rioxarray from kedro.io.core import Version from kedro_datasets.xarray import GeoTiffDataset @@ -12,11 +13,22 @@ def cog_file_path() -> str: cog_file_path = Path(__file__).parent / "cog.tif" return cog_file_path.as_posix() +@pytest.fixture +def cog_xarray(cog_file_path) -> xr.DataArray: + return rioxarray.open_rasterio(cog_file_path) @pytest.fixture -def geotiff_dataset(cog_file_path, save_args, fs_args) -> GeoTiffDataset: +def cog_geotiff_dataset(cog_file_path, save_args, fs_args) -> GeoTiffDataset: return GeoTiffDataset(filepath=cog_file_path, save_args=save_args) +@pytest.fixture +def filepath_geotiff(tmp_path): + return (tmp_path / "test.tiff").as_posix() + +@pytest.fixture +def geotiff_dataset(filepath_geotiff, save_args): + return GeoTiffDataset(filepath=filepath_geotiff, save_args=save_args) + @pytest.fixture def versioned_geotiff_dataset( @@ -27,7 +39,18 @@ def versioned_geotiff_dataset( ) -def test_save_and_load(geotiff_dataset): +def test_load(cog_geotiff_dataset): """Test saving and reloading the data set.""" - loaded_tiff = geotiff_dataset.load() + loaded_tiff = cog_geotiff_dataset.load() assert isinstance(loaded_tiff, xr.DataArray) + assert loaded_tiff.shape == (1,500,500) + assert loaded_tiff.dims == ('band', 'y', 'x') + + +def test_save_and_load(geotiff_dataset, cog_xarray): + """Test saving and reloading the data set.""" + geotiff_dataset.save(cog_xarray) + reloaded = geotiff_dataset.load() + assert reloaded.shape == cog_xarray.shape + assert reloaded.dims == cog_xarray.dimsts + assert reloaded.equals(cog_xarray) \ No newline at end of file