Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Merel Theisen <[email protected]>
  • Loading branch information
merelcht committed Feb 29, 2024
1 parent 691f413 commit 44801b0
Showing 1 changed file with 61 additions and 39 deletions.
100 changes: 61 additions & 39 deletions kedro-datasets/tests/xarray/test_geotiff_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import rioxarray
import xarray as xr
from kedro.io import DatasetError

from kedro_datasets.xarray import GeoTiffDataset

Expand All @@ -30,45 +31,66 @@ def filepath_geotiff(tmp_path):


@pytest.fixture
def geotiff_dataset(filepath_geotiff, save_args):
return GeoTiffDataset(filepath=filepath_geotiff, save_args=save_args)


def test_load(cog_geotiff_dataset):
"""Test saving and reloading the data set."""
loaded_tiff = cog_geotiff_dataset.load()
assert isinstance(loaded_tiff, xr.DataArray)
assert loaded_tiff.shape == (1, 500, 500)
assert loaded_tiff.dims == ("band", "y", "x")


def test_exists(geotiff_dataset, cog_xarray):
"""Test `exists` method invocation for both existing and
nonexistent data set."""
assert not geotiff_dataset.exists()
geotiff_dataset.save(cog_xarray)
assert geotiff_dataset.exists()

def geotiff_dataset(filepath_geotiff, load_args, save_args):
return GeoTiffDataset(filepath=filepath_geotiff, load_args=load_args, save_args=save_args)


class TestGeoTiffDataset:
def test_load(self, cog_geotiff_dataset):
"""Test saving and reloading the data set."""
loaded_tiff = cog_geotiff_dataset.load()
assert isinstance(loaded_tiff, xr.DataArray)
assert loaded_tiff.shape == (1, 500, 500)
assert loaded_tiff.dims == ("band", "y", "x")

def test_exists(self, geotiff_dataset, cog_xarray):
"""Test `exists` method invocation for both existing and
nonexistent data set."""
assert not geotiff_dataset.exists()
geotiff_dataset.save(cog_xarray)
assert geotiff_dataset.exists()

def test_save_and_load(self, geotiff_dataset, cog_xarray):
"""Test saving and reloading the data set."""
geotiff_dataset.save(cog_xarray)
reloaded = geotiff_dataset.load()
assert reloaded.shape == cog_xarray.shape
assert reloaded.dims == cog_xarray.dims
assert reloaded.equals(cog_xarray)

def test_example(self, tmp_path):
data = xr.DataArray(
np.random.randn(2, 3, 2),
dims=("band", "y", "x"),
coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]},
)
# Add spatial coordinates and CRS information
data = data.rio.write_crs("epsg:4326")
data = data.rio.set_spatial_dims("x", "y")
dataset = GeoTiffDataset(filepath=tmp_path.joinpath("test.tif").as_posix())
dataset.save(data)
reloaded = dataset.load()
xr.testing.assert_allclose(data, reloaded, rtol=1e-5)

@pytest.mark.parametrize(
"load_args", [{"k1": "v1", "index": "value"}], indirect=True
)
def test_load_extra_params(self, geotiff_dataset, load_args):
"""Test overriding the default load arguments."""
for key, value in load_args.items():
assert geotiff_dataset._load_args[key] == value

def test_save_and_load(geotiff_dataset, cog_xarray):
"""Test saving and reloading the data set."""
geotiff_dataset.save(cog_xarray)
reloaded = geotiff_dataset.load()
assert reloaded.shape == cog_xarray.shape
assert reloaded.dims == cog_xarray.dims
assert reloaded.equals(cog_xarray)
@pytest.mark.parametrize(
"save_args", [{"k1": "v1", "index": "value"}], indirect=True
)
def test_save_extra_params(self, geotiff_dataset, save_args):
"""Test overriding the default save arguments."""
for key, value in save_args.items():
assert geotiff_dataset._save_args[key] == value

def test_load_missing_file(self, geotiff_dataset):
"""Check the error when trying to load missing file."""
pattern = r"Failed while loading data from data set GeoTiffDataset\(.*\)"
with pytest.raises(DatasetError, match=pattern):
geotiff_dataset.load()

def test_example(tmp_path):
data = xr.DataArray(
np.random.randn(2, 3, 2),
dims=("band", "y", "x"),
coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]},
)
# Add spatial coordinates and CRS information
data = data.rio.write_crs("epsg:4326")
data = data.rio.set_spatial_dims("x", "y")
dataset = GeoTiffDataset(filepath=tmp_path.joinpath("test.tif").as_posix())
dataset.save(data)
reloaded = dataset.load()
xr.testing.assert_allclose(data, reloaded, rtol=1e-5)

0 comments on commit 44801b0

Please sign in to comment.