From 44801b0a2f4989d93d37ffb11f4be9ecbc05a5f0 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 29 Feb 2024 09:49:31 +0000 Subject: [PATCH] Improve test coverage Signed-off-by: Merel Theisen --- .../tests/xarray/test_geotiff_dataset.py | 100 +++++++++++------- 1 file changed, 61 insertions(+), 39 deletions(-) diff --git a/kedro-datasets/tests/xarray/test_geotiff_dataset.py b/kedro-datasets/tests/xarray/test_geotiff_dataset.py index 34d38c151..7f30ab994 100644 --- a/kedro-datasets/tests/xarray/test_geotiff_dataset.py +++ b/kedro-datasets/tests/xarray/test_geotiff_dataset.py @@ -4,6 +4,7 @@ import pytest import rioxarray import xarray as xr +from kedro.io import DatasetError from kedro_datasets.xarray import GeoTiffDataset @@ -30,45 +31,66 @@ def filepath_geotiff(tmp_path): @pytest.fixture -def geotiff_dataset(filepath_geotiff, save_args): - return GeoTiffDataset(filepath=filepath_geotiff, save_args=save_args) - - -def test_load(cog_geotiff_dataset): - """Test saving and reloading the data set.""" - loaded_tiff = cog_geotiff_dataset.load() - assert isinstance(loaded_tiff, xr.DataArray) - assert loaded_tiff.shape == (1, 500, 500) - assert loaded_tiff.dims == ("band", "y", "x") - - -def test_exists(geotiff_dataset, cog_xarray): - """Test `exists` method invocation for both existing and - nonexistent data set.""" - assert not geotiff_dataset.exists() - geotiff_dataset.save(cog_xarray) - assert geotiff_dataset.exists() - +def geotiff_dataset(filepath_geotiff, load_args, save_args): + return GeoTiffDataset(filepath=filepath_geotiff, load_args=load_args, save_args=save_args) + + +class TestGeoTiffDataset: + def test_load(self, cog_geotiff_dataset): + """Test saving and reloading the data set.""" + loaded_tiff = cog_geotiff_dataset.load() + assert isinstance(loaded_tiff, xr.DataArray) + assert loaded_tiff.shape == (1, 500, 500) + assert loaded_tiff.dims == ("band", "y", "x") + + def test_exists(self, geotiff_dataset, cog_xarray): + """Test `exists` method invocation for both existing and + nonexistent data set.""" + assert not geotiff_dataset.exists() + geotiff_dataset.save(cog_xarray) + assert geotiff_dataset.exists() + + def test_save_and_load(self, geotiff_dataset, cog_xarray): + """Test saving and reloading the data set.""" + geotiff_dataset.save(cog_xarray) + reloaded = geotiff_dataset.load() + assert reloaded.shape == cog_xarray.shape + assert reloaded.dims == cog_xarray.dims + assert reloaded.equals(cog_xarray) + + def test_example(self, tmp_path): + data = xr.DataArray( + np.random.randn(2, 3, 2), + dims=("band", "y", "x"), + coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]}, + ) + # Add spatial coordinates and CRS information + data = data.rio.write_crs("epsg:4326") + data = data.rio.set_spatial_dims("x", "y") + dataset = GeoTiffDataset(filepath=tmp_path.joinpath("test.tif").as_posix()) + dataset.save(data) + reloaded = dataset.load() + xr.testing.assert_allclose(data, reloaded, rtol=1e-5) + + @pytest.mark.parametrize( + "load_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_load_extra_params(self, geotiff_dataset, load_args): + """Test overriding the default load arguments.""" + for key, value in load_args.items(): + assert geotiff_dataset._load_args[key] == value -def test_save_and_load(geotiff_dataset, cog_xarray): - """Test saving and reloading the data set.""" - geotiff_dataset.save(cog_xarray) - reloaded = geotiff_dataset.load() - assert reloaded.shape == cog_xarray.shape - assert reloaded.dims == cog_xarray.dims - assert reloaded.equals(cog_xarray) + @pytest.mark.parametrize( + "save_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_save_extra_params(self, geotiff_dataset, save_args): + """Test overriding the default save arguments.""" + for key, value in save_args.items(): + assert geotiff_dataset._save_args[key] == value + def test_load_missing_file(self, geotiff_dataset): + """Check the error when trying to load missing file.""" + pattern = r"Failed while loading data from data set GeoTiffDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + geotiff_dataset.load() -def test_example(tmp_path): - data = xr.DataArray( - np.random.randn(2, 3, 2), - dims=("band", "y", "x"), - coords={"band": [1, 2], "y": [0.5, 1.5, 2.5], "x": [0.5, 1.5]}, - ) - # Add spatial coordinates and CRS information - data = data.rio.write_crs("epsg:4326") - data = data.rio.set_spatial_dims("x", "y") - dataset = GeoTiffDataset(filepath=tmp_path.joinpath("test.tif").as_posix()) - dataset.save(data) - reloaded = dataset.load() - xr.testing.assert_allclose(data, reloaded, rtol=1e-5)