diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index e333d2787..a4043d97a 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -23,6 +23,7 @@ ## Breaking Changes * Exposed `load` and `save` publicly for each dataset. This requires Kedro version 0.19.7 or higher. +* Replaced the `geopandas.GeoJSONDataset` with `geopandas.GenericDataset` to support parquet and feather file formats. ## Community contributions Many thanks to the following Kedroids for contributing PRs to this release: @@ -32,6 +33,7 @@ Many thanks to the following Kedroids for contributing PRs to this release: * [janickspirig](https://github.com/janickspirig) * [Galen Seilis](https://github.com/galenseilis) * [Mariusz Wojakowski](https://github.com/mariusz89016) +* [harm-matthias-harms](https://github.com/harm-matthias-harms) * [Felix Scherz](https://github.com/felixscherz) diff --git a/kedro-datasets/docs/source/api/kedro_datasets.rst b/kedro-datasets/docs/source/api/kedro_datasets.rst index 669378b7b..45b275de5 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets.rst @@ -17,7 +17,7 @@ kedro_datasets dask.ParquetDataset databricks.ManagedTableDataset email.EmailMessageDataset - geopandas.GeoJSONDataset + geopandas.GenericDataset holoviews.HoloviewsWriter huggingface.HFDataset huggingface.HFTransformerPipelineDataset diff --git a/kedro-datasets/kedro_datasets/geopandas/README.md b/kedro-datasets/kedro_datasets/geopandas/README.md deleted file mode 100644 index a7926a706..000000000 --- a/kedro-datasets/kedro_datasets/geopandas/README.md +++ /dev/null @@ -1,31 +0,0 @@ -# GeoJSON - -``GeoJSONDataset`` loads and saves data to a local yaml file using ``geopandas``. -See [geopandas.GeoDataFrame](http://geopandas.org/reference/geopandas.GeoDataFrame.html) for details. - -#### Example use: - -```python -import geopandas as gpd -from shapely.geometry import Point -from kedro_datasets.geopandas import GeoJSONDataset - -data = gpd.GeoDataFrame( - {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}, - geometry=[Point(1, 1), Point(2, 4)], -) -dataset = GeoJSONDataset(filepath="test.geojson") -dataset.save(data) -reloaded = dataset.load() -assert data.equals(reloaded) -``` - -#### Example catalog.yml: - -```yaml -example_geojson_data: - type: geopandas.GeoJSONDataset - filepath: data/08_reporting/test.geojson -``` - -Contributed by (Luis Blanche)[https://github.com/lblanche]. diff --git a/kedro-datasets/kedro_datasets/geopandas/__init__.py b/kedro-datasets/kedro_datasets/geopandas/__init__.py index d4843aa68..444dd8d72 100644 --- a/kedro-datasets/kedro_datasets/geopandas/__init__.py +++ b/kedro-datasets/kedro_datasets/geopandas/__init__.py @@ -1,12 +1,12 @@ -"""``GeoJSONDataset`` is an ``AbstractVersionedDataset`` to save and load GeoJSON files.""" +"""``GenericDataset`` is an ``AbstractVersionedDataset`` to save and load GeoDataFrames.""" from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -GeoJSONDataset: Any +GenericDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"geojson_dataset": ["GeoJSONDataset"]} + __name__, submod_attrs={"generic_dataset": ["GenericDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py b/kedro-datasets/kedro_datasets/geopandas/generic_dataset.py similarity index 57% rename from kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py rename to kedro-datasets/kedro_datasets/geopandas/generic_dataset.py index 322fc147c..aa2e6d4cf 100644 --- a/kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py +++ b/kedro-datasets/kedro_datasets/geopandas/generic_dataset.py @@ -1,7 +1,8 @@ -"""GeoJSONDataset loads and saves data to a local geojson file. The +"""GenericDataset loads and saves data to a local file. The underlying functionality is supported by geopandas, so it supports all allowed geopandas (pandas) options for loading and saving geosjon files. """ + from __future__ import annotations import copy @@ -18,30 +19,35 @@ get_protocol_and_path, ) +# pyogrio currently supports no alternate file handlers https://github.com/geopandas/pyogrio/issues/430 +gpd.options.io_engine = "fiona" + +NON_FILE_SYSTEM_TARGETS = ["postgis"] + -class GeoJSONDataset( +class GenericDataset( AbstractVersionedDataset[ gpd.GeoDataFrame, gpd.GeoDataFrame | dict[str, gpd.GeoDataFrame] ] ): - """``GeoJSONDataset`` loads/saves data to a GeoJSON file using an underlying filesystem + """``GenericDataset`` loads/saves data to a file using an underlying filesystem (eg: local, S3, GCS). The underlying functionality is supported by geopandas, so it supports all - allowed geopandas (pandas) options for loading and saving GeoJSON files. + allowed geopandas (pandas) options for loading and saving files. Example: .. code-block:: pycon >>> import geopandas as gpd - >>> from kedro_datasets.geopandas import GeoJSONDataset + >>> from kedro_datasets.geopandas import GenericDataset >>> from shapely.geometry import Point >>> >>> data = gpd.GeoDataFrame( ... {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}, ... geometry=[Point(1, 1), Point(2, 4)], ... ) - >>> dataset = GeoJSONDataset(filepath=tmp_path / "test.geojson", save_args=None) + >>> dataset = GenericDataset(filepath=tmp_path / "test.geojson") >>> dataset.save(data) >>> reloaded = dataset.load() >>> @@ -50,12 +56,14 @@ class GeoJSONDataset( """ DEFAULT_LOAD_ARGS: dict[str, Any] = {} - DEFAULT_SAVE_ARGS = {"driver": "GeoJSON"} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + DEFAULT_FS_ARGS: dict[str, Any] = {"open_args_save": {"mode": "wb"}} def __init__( # noqa: PLR0913 self, *, filepath: str, + file_format: str = "file", load_args: dict[str, Any] | None = None, save_args: dict[str, Any] | None = None, version: Version | None = None, @@ -63,22 +71,26 @@ def __init__( # noqa: PLR0913 fs_args: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, ) -> None: - """Creates a new instance of ``GeoJSONDataset`` pointing to a concrete GeoJSON file + """Creates a new instance of ``GenericDataset`` pointing to a concrete file on a specific filesystem fsspec. Args: - filepath: Filepath in POSIX format to a GeoJSON file prefixed with a protocol like + filepath: Filepath in POSIX format to a file prefixed with a protocol like `s3://`. If prefix is not provided `file` protocol (local filesystem) will be used. The prefix should be any protocol supported by ``fsspec``. Note: `http(s)` doesn't support versioning. - load_args: GeoPandas options for loading GeoJSON files. + file_format: String which is used to match the appropriate load/save method on a best + effort basis. For example if 'parquet' is passed in the `geopandas.read_parquet` and + `geopandas.DataFrame.to_parquet` will be identified. An error will be raised unless + at least one matching `read_{file_format}` or `to_{file_format}` method is + identified. Defaults to 'file'. + load_args: GeoPandas options for loading files. Here you can find all available arguments: https://geopandas.org/en/stable/docs/reference/api/geopandas.read_file.html - save_args: GeoPandas options for saving geojson files. + save_args: GeoPandas options for saving files. Here you can find all available arguments: https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.to_file.html - The default_save_arg driver is 'GeoJSON', all others preserved. version: If specified, should be an instance of ``kedro.io.core.Version``. If its ``load`` attribute is None, the latest version will be loaded. If its ``save`` @@ -94,6 +106,9 @@ def __init__( # noqa: PLR0913 metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. """ + + self._file_format = file_format.lower() + _fs_args = copy.deepcopy(fs_args) or {} _fs_open_args_load = _fs_args.pop("open_args_load", {}) _fs_open_args_save = _fs_args.pop("open_args_save", {}) @@ -114,28 +129,57 @@ def __init__( # noqa: PLR0913 glob_function=self._fs.glob, ) - self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS) - if load_args is not None: - self._load_args.update(load_args) - - self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS) - if save_args is not None: - self._save_args.update(save_args) + # Handle default load and save and fs arguments + self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})} + self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})} + self._fs_open_args_load = { + **self.DEFAULT_FS_ARGS.get("open_args_load", {}), + **(_fs_open_args_load or {}), + } + self._fs_open_args_save = { + **self.DEFAULT_FS_ARGS.get("open_args_save", {}), + **(_fs_open_args_save or {}), + } - _fs_open_args_save.setdefault("mode", "wb") - self._fs_open_args_load = _fs_open_args_load - self._fs_open_args_save = _fs_open_args_save + def _ensure_file_system_target(self) -> None: + # Fail fast if provided a known non-filesystem target + if self._file_format in NON_FILE_SYSTEM_TARGETS: + raise DatasetError( + f"Cannot load or save a dataset of file_format '{self._file_format}' as it " + f"does not support a filepath target/source." + ) def load(self) -> gpd.GeoDataFrame | dict[str, gpd.GeoDataFrame]: + self._ensure_file_system_target() + load_path = get_filepath_str(self._get_load_path(), self._protocol) - with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: - return gpd.read_file(fs_file, **self._load_args) + load_method = getattr(gpd, f"read_{self._file_format}", None) + if load_method: + with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: + return load_method(fs_file, **self._load_args) + raise DatasetError( + f"Unable to retrieve 'geopandas.read_{self._file_format}' method, please ensure that your " + "'file_format' parameter has been defined correctly as per the GeoPandas API " + "https://geopandas.org/en/stable/docs/reference/io.html" + ) def save(self, data: gpd.GeoDataFrame) -> None: + self._ensure_file_system_target() + save_path = get_filepath_str(self._get_save_path(), self._protocol) - with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: - data.to_file(fs_file, **self._save_args) - self.invalidate_cache() + save_method = getattr(data, f"to_{self._file_format}", None) + if save_method: + with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: + # KEY ASSUMPTION - first argument is path/buffer/io + save_method(fs_file, **self._save_args) + self.invalidate_cache() + else: + raise DatasetError( + f"Unable to retrieve 'geopandas.DataFrame.to_{self._file_format}' method, please " + "ensure that your 'file_format' parameter has been defined correctly as " + "per the GeoPandas API " + "https://geopandas.org/en/stable/docs/reference/io.html" + ) def _exists(self) -> bool: try: @@ -147,6 +191,7 @@ def _exists(self) -> bool: def _describe(self) -> dict[str, Any]: return { "filepath": self._filepath, + "file_format": self._file_format, "protocol": self._protocol, "load_args": self._load_args, "save_args": self._save_args, diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 6d882a2a0..b357d8038 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -40,8 +40,8 @@ dask = ["kedro-datasets[dask-parquetdataset, dask-csvdataset]"] databricks-managedtabledataset = ["kedro-datasets[spark-base,pandas-base,delta-base,hdfs-base,s3fs-base]"] databricks = ["kedro-datasets[databricks-managedtabledataset]"] -geopandas-geojsondataset = ["geopandas>=0.6.0, <1.0", "pyproj~=3.0"] -geopandas = ["kedro-datasets[geopandas-geojsondataset]"] +geopandas-genericdataset = ["geopandas>=0.8.0, <2.0", "fiona >=1.8, <2.0"] +geopandas = ["kedro-datasets[geopandas-genericdataset]"] holoviews-holoviewswriter = ["holoviews>=1.13.0"] holoviews = ["kedro-datasets[holoviews-holoviewswriter]"] @@ -215,8 +215,9 @@ test = [ "deltalake>=0.10.0", "dill~=0.3.1", "filelock>=3.4.0, <4.0", + "fiona >=1.8, <2.0", "gcsfs>=2023.1, <2023.3", - "geopandas>=0.6.0, <1.0", + "geopandas>=0.8.0, <2.0", "hdfs>=2.5.8, <3.0", "holoviews>=1.13.0", "ibis-framework[duckdb,examples]", @@ -243,7 +244,6 @@ test = [ "pyarrow>=1.0; python_version < '3.11'", "pyarrow>=7.0; python_version >= '3.11'", # Adding to avoid numpy build errors "pyodbc~=5.0", - "pyproj~=3.0", "pyspark>=3.0; python_version < '3.11'", "pyspark>=3.4; python_version >= '3.11'", "pytest-cov~=3.0", diff --git a/kedro-datasets/tests/geopandas/test_geojson_dataset.py b/kedro-datasets/tests/geopandas/test_generic_dataset.py similarity index 63% rename from kedro-datasets/tests/geopandas/test_geojson_dataset.py rename to kedro-datasets/tests/geopandas/test_generic_dataset.py index 9c6cb49fe..5c4569e9c 100644 --- a/kedro-datasets/tests/geopandas/test_geojson_dataset.py +++ b/kedro-datasets/tests/geopandas/test_generic_dataset.py @@ -10,7 +10,7 @@ from s3fs import S3FileSystem from shapely.geometry import Point -from kedro_datasets.geopandas import GeoJSONDataset +from kedro_datasets.geopandas import GenericDataset @pytest.fixture(params=[None]) @@ -24,16 +24,36 @@ def save_version(request): @pytest.fixture -def filepath(tmp_path): +def filepath_geojson(tmp_path): return (tmp_path / "test.geojson").as_posix() +@pytest.fixture +def filepath_parquet(tmp_path): + return (tmp_path / "test.parquet").as_posix() + + +@pytest.fixture +def filepath_feather(tmp_path): + return (tmp_path / "test.feather").as_posix() + + +@pytest.fixture +def filepath_postgis(tmp_path): + return (tmp_path / "test.sql").as_posix() + + +@pytest.fixture +def filepath_abc(tmp_path): + return tmp_path / "test.abc" + + @pytest.fixture(params=[None]) def load_args(request): return request.param -@pytest.fixture(params=[{"driver": "GeoJSON"}]) +@pytest.fixture(params=[None]) def save_args(request): return request.param @@ -47,20 +67,77 @@ def dummy_dataframe(): @pytest.fixture -def geojson_dataset(filepath, load_args, save_args, fs_args): - return GeoJSONDataset( - filepath=filepath, load_args=load_args, save_args=save_args, fs_args=fs_args +def geojson_dataset(filepath_geojson, load_args, save_args, fs_args): + return GenericDataset( + filepath=filepath_geojson, + load_args=load_args, + save_args=save_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def parquet_dataset(filepath_parquet, load_args, save_args, fs_args): + return GenericDataset( + filepath=filepath_parquet, + file_format="parquet", + load_args=load_args, + save_args=save_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def parquet_dataset_bad_config(filepath_parquet, load_args, save_args, fs_args): + return GenericDataset( + filepath=filepath_parquet, + load_args=load_args, + save_args=save_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def feather_dataset(filepath_feather, load_args, save_args, fs_args): + return GenericDataset( + filepath=filepath_feather, + file_format="feather", + load_args=load_args, + save_args=save_args, + fs_args=fs_args, + ) + + +@pytest.fixture +def postgis_dataset(filepath_postgis, load_args, save_args, fs_args): + return GenericDataset( + filepath=filepath_postgis, + file_format="postgis", + load_args=load_args, + save_args=save_args, + fs_args=fs_args, ) @pytest.fixture -def versioned_geojson_dataset(filepath, load_version, save_version): - return GeoJSONDataset( - filepath=filepath, version=Version(load_version, save_version) +def abc_dataset(filepath_abc, load_args, save_args, fs_args): + return GenericDataset( + filepath=filepath_abc, + file_format="abc", + load_args=load_args, + save_args=save_args, + fs_args=fs_args, ) -class TestGeoJSONDataset: +@pytest.fixture +def versioned_geojson_dataset(filepath_geojson, load_version, save_version): + return GenericDataset( + filepath=filepath_geojson, version=Version(load_version, save_version) + ) + + +class TestGenericDataset: def test_save_and_load(self, geojson_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one.""" geojson_dataset.save(dummy_dataframe) @@ -72,7 +149,7 @@ def test_save_and_load(self, geojson_dataset, dummy_dataframe): @pytest.mark.parametrize("geojson_dataset", [{"index": False}], indirect=True) def test_load_missing_file(self, geojson_dataset): """Check the error while trying to load from missing source.""" - pattern = r"Failed while loading data from dataset GeoJSONDataset" + pattern = r"Failed while loading data from dataset GenericDataset" with pytest.raises(DatasetError, match=pattern): geojson_dataset.load() @@ -82,6 +159,39 @@ def test_exists(self, geojson_dataset, dummy_dataframe): geojson_dataset.save(dummy_dataframe) assert geojson_dataset.exists() + def test_load_parquet_dataset(self, parquet_dataset, dummy_dataframe): + parquet_dataset.save(dummy_dataframe) + reloaded_df = parquet_dataset.load() + assert_frame_equal(reloaded_df, dummy_dataframe) + + def test_load_feather_dataset(self, feather_dataset, dummy_dataframe): + feather_dataset.save(dummy_dataframe) + reloaded_df = feather_dataset.load() + assert_frame_equal(reloaded_df, dummy_dataframe) + + def test_bad_load( + self, parquet_dataset_bad_config, dummy_dataframe, filepath_parquet + ): + dummy_dataframe.to_parquet(filepath_parquet) + pattern = r"Failed while loading data from dataset GenericDataset(.*)" + with pytest.raises(DatasetError, match=pattern): + parquet_dataset_bad_config.load() + + def test_none_file_system_target(self, postgis_dataset, dummy_dataframe): + pattern = "Cannot load or save a dataset of file_format 'postgis' as it does not support a filepath target/source." + with pytest.raises(DatasetError, match=pattern): + postgis_dataset.save(dummy_dataframe) + + def test_unknown_file_format(self, abc_dataset, dummy_dataframe, filepath_abc): + pattern = "Unable to retrieve 'geopandas.DataFrame.to_abc' method" + with pytest.raises(DatasetError, match=pattern): + abc_dataset.save(dummy_dataframe) + + filepath_abc.write_bytes(b"") + pattern = "Unable to retrieve 'geopandas.read_abc' method" + with pytest.raises(DatasetError, match=pattern): + abc_dataset.load() + @pytest.mark.parametrize( "load_args", [{"crs": "init:4326"}, {"crs": "init:2154", "driver": "GeoJSON"}] ) @@ -118,7 +228,7 @@ def test_open_extra_args(self, geojson_dataset, fs_args): ], ) def test_protocol_usage(self, path, instance_type): - geojson_dataset = GeoJSONDataset(filepath=path) + geojson_dataset = GenericDataset(filepath=path) assert isinstance(geojson_dataset._fs, instance_type) path = path.split(PROTOCOL_DELIMITER, 1)[-1] @@ -129,18 +239,18 @@ def test_protocol_usage(self, path, instance_type): def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.geojson" - geojson_dataset = GeoJSONDataset(filepath=filepath) + geojson_dataset = GenericDataset(filepath=filepath) geojson_dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestGeoJSONDatasetVersioned: +class TestGenericDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.geojson" - ds = GeoJSONDataset(filepath=filepath) - ds_versioned = GeoJSONDataset( + ds = GenericDataset(filepath=filepath) + ds_versioned = GenericDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -149,8 +259,8 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GeoJSONDataset" in str(ds_versioned) - assert "GeoJSONDataset" in str(ds) + assert "GenericDataset" in str(ds_versioned) + assert "GenericDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) @@ -163,7 +273,7 @@ def test_save_and_load(self, versioned_geojson_dataset, dummy_dataframe): def test_no_versions(self, versioned_geojson_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GeoJSONDataset\(.+\)" + pattern = r"Did not find any versions for GenericDataset\(.+\)" with pytest.raises(DatasetError, match=pattern): versioned_geojson_dataset.load() @@ -178,7 +288,7 @@ def test_prevent_override(self, versioned_geojson_dataset, dummy_dataframe): version.""" versioned_geojson_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for GeoJSONDataset\(.+\) must not " + r"Save path \'.+\' for GenericDataset\(.+\) must not " r"exist if versioning is enabled" ) with pytest.raises(DatasetError, match=pattern): @@ -197,7 +307,7 @@ def test_save_version_warning( the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for GeoJSONDataset\(.+\)" + rf"'{load_version}' for GenericDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): versioned_geojson_dataset.save(dummy_dataframe) @@ -206,7 +316,7 @@ def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." with pytest.raises(DatasetError, match=pattern): - GeoJSONDataset( + GenericDataset( filepath="https://example/file.geojson", version=Version(None, None) )