Skip to content

Commit

Permalink
test reloaded
Browse files Browse the repository at this point in the history
  • Loading branch information
tgoelles committed Oct 5, 2023
1 parent 262f30f commit 4e54fe9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
3 changes: 1 addition & 2 deletions kedro-datasets/kedro_datasets/xarray/geotiff_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from kedro_datasets._io import AbstractVersionedDataset, DatasetError



class GeoTiffDataset(AbstractVersionedDataset[xarray.DataArray, xarray.DataArray]):
"""``GeoTiffDataset`` loads and saves geotiff files and reads them as xarray
DataArrays.
Expand Down Expand Up @@ -88,7 +87,7 @@ def _load(self) -> xarray.DataArray:
def _save(self, data: xarray.DataArray) -> None:
save_path = self._get_save_path()
if self._filepath.suffix in [".tif", ".tiff"]:
data.to_raster(save_path.as_posix(), **self._save_args)
data.rio.to_raster(save_path.as_posix(), **self._save_args)
else:
raise ValueError("expecting .tif or .tiff file suffix")

Expand Down
29 changes: 26 additions & 3 deletions kedro-datasets/tests/xarray/test_geotiff_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import xarray as xr
import rioxarray
from kedro.io.core import Version

from kedro_datasets.xarray import GeoTiffDataset
Expand All @@ -12,11 +13,22 @@ def cog_file_path() -> str:
cog_file_path = Path(__file__).parent / "cog.tif"
return cog_file_path.as_posix()

@pytest.fixture
def cog_xarray(cog_file_path) -> xr.DataArray:
return rioxarray.open_rasterio(cog_file_path)

@pytest.fixture
def geotiff_dataset(cog_file_path, save_args, fs_args) -> GeoTiffDataset:
def cog_geotiff_dataset(cog_file_path, save_args, fs_args) -> GeoTiffDataset:
return GeoTiffDataset(filepath=cog_file_path, save_args=save_args)

@pytest.fixture
def filepath_geotiff(tmp_path):
return (tmp_path / "test.tiff").as_posix()

@pytest.fixture
def geotiff_dataset(filepath_geotiff, save_args):
return GeoTiffDataset(filepath=filepath_geotiff, save_args=save_args)


@pytest.fixture
def versioned_geotiff_dataset(
Expand All @@ -27,7 +39,18 @@ def versioned_geotiff_dataset(
)


def test_save_and_load(geotiff_dataset):
def test_load(cog_geotiff_dataset):
"""Test saving and reloading the data set."""
loaded_tiff = geotiff_dataset.load()
loaded_tiff = cog_geotiff_dataset.load()
assert isinstance(loaded_tiff, xr.DataArray)
assert loaded_tiff.shape == (1,500,500)
assert loaded_tiff.dims == ('band', 'y', 'x')


def test_save_and_load(geotiff_dataset, cog_xarray):
"""Test saving and reloading the data set."""
geotiff_dataset.save(cog_xarray)
reloaded = geotiff_dataset.load()
assert reloaded.shape == cog_xarray.shape
assert reloaded.dims == cog_xarray.dimsts
assert reloaded.equals(cog_xarray)

0 comments on commit 4e54fe9

Please sign in to comment.