diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index e095c2b2e..e565d674c 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,6 +1,7 @@ # Upcoming Release ## Major features and improvements * Moved `PartitionedDataSet` and `IncrementalDataSet` from the core Kedro repo to `kedro-datasets` and renamed to `PartitionedDataset` and `IncrementalDataset`. +* Delayed backend connection for `pandas.SQLTableDataset`, `pandas.SQLQueryDataset`, and `snowflake.SnowparkTableDataset`. In practice, this means that a dataset's connection details aren't used (or validated) until the dataset is accessed. On the plus side, the cost of connection isn't incurred regardless of when or whether the dataset is used. ## Bug fixes and other changes * Fix erroneous warning when using an cloud protocol file path with SparkDataSet on Databricks. diff --git a/kedro-datasets/kedro_datasets/__init__.py b/kedro-datasets/kedro_datasets/__init__.py index 60aa4afb2..6449f33c7 100644 --- a/kedro-datasets/kedro_datasets/__init__.py +++ b/kedro-datasets/kedro_datasets/__init__.py @@ -9,7 +9,7 @@ try: # Custom `KedroDeprecationWarning` class was added in Kedro 0.18.14. from kedro import KedroDeprecationWarning -except ImportError: +except ImportError: # pragma: no cover class KedroDeprecationWarning(DeprecationWarning): """Custom class for warnings about deprecated Kedro features.""" diff --git a/kedro-datasets/kedro_datasets/pandas/sql_dataset.py b/kedro-datasets/kedro_datasets/pandas/sql_dataset.py index 5bad6e98b..8343723b0 100644 --- a/kedro-datasets/kedro_datasets/pandas/sql_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/sql_dataset.py @@ -1,10 +1,12 @@ """``SQLDataset`` to load and save data to a SQL backend.""" +from __future__ import annotations + import copy import datetime as dt import re import warnings from pathlib import PurePosixPath -from typing import Any, Dict, NoReturn, Optional +from typing import Any, NoReturn import fsspec import pandas as pd @@ -33,7 +35,7 @@ """ -def _find_known_drivers(module_import_error: ImportError) -> Optional[str]: +def _find_known_drivers(module_import_error: ImportError) -> str | None: """Looks up known keywords in a ``ModuleNotFoundError`` so that it can provide better guideline for the user. @@ -145,19 +147,19 @@ class SQLTableDataset(AbstractDataset[pd.DataFrame, pd.DataFrame]): """ - DEFAULT_LOAD_ARGS: Dict[str, Any] = {} - DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False} + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {"index": False} # using Any because of Sphinx but it should be # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine - engines: Dict[str, Any] = {} + engines: dict[str, Any] = {} def __init__( # noqa: PLR0913 self, table_name: str, - credentials: Dict[str, Any], - load_args: Dict[str, Any] = None, - save_args: Dict[str, Any] = None, - metadata: Dict[str, Any] = None, + credentials: dict[str, Any], + load_args: dict[str, Any] = None, + save_args: dict[str, Any] = None, + metadata: dict[str, Any] = None, ) -> None: """Creates a new ``SQLTableDataset``. @@ -212,7 +214,6 @@ def __init__( # noqa: PLR0913 self._save_args["name"] = table_name self._connection_str = credentials["con"] - self.create_connection(self._connection_str) self.metadata = metadata @@ -222,9 +223,6 @@ def create_connection(cls, connection_str: str) -> None: to be used across all instances of ``SQLTableDataset`` that need to connect to the same source. """ - if connection_str in cls.engines: - return - try: engine = create_engine(connection_str) except ImportError as import_error: @@ -234,7 +232,17 @@ def create_connection(cls, connection_str: str) -> None: cls.engines[connection_str] = engine - def _describe(self) -> Dict[str, Any]: + @property + def engine(self): + """The ``Engine`` object for the dataset's connection string.""" + cls = type(self) + + if self._connection_str not in cls.engines: + self.create_connection(self._connection_str) + + return cls.engines[self._connection_str] + + def _describe(self) -> dict[str, Any]: load_args = copy.deepcopy(self._load_args) save_args = copy.deepcopy(self._save_args) del load_args["table_name"] @@ -246,16 +254,13 @@ def _describe(self) -> Dict[str, Any]: } def _load(self) -> pd.DataFrame: - engine = self.engines[self._connection_str] # type:ignore - return pd.read_sql_table(con=engine, **self._load_args) + return pd.read_sql_table(con=self.engine, **self._load_args) def _save(self, data: pd.DataFrame) -> None: - engine = self.engines[self._connection_str] # type: ignore - data.to_sql(con=engine, **self._save_args) + data.to_sql(con=self.engine, **self._save_args) def _exists(self) -> bool: - engine = self.engines[self._connection_str] # type: ignore - insp = inspect(engine) + insp = inspect(self.engine) schema = self._load_args.get("schema", None) return insp.has_table(self._load_args["table_name"], schema) @@ -273,7 +278,6 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]): It does not support save method so it is a read only data set. To save data to a SQL server use ``SQLTableDataset``. - Example usage for the `YAML API `_: @@ -375,17 +379,17 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]): # using Any because of Sphinx but it should be # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine - engines: Dict[str, Any] = {} + engines: dict[str, Any] = {} def __init__( # noqa: PLR0913 self, sql: str = None, - credentials: Dict[str, Any] = None, - load_args: Dict[str, Any] = None, - fs_args: Dict[str, Any] = None, + credentials: dict[str, Any] = None, + load_args: dict[str, Any] = None, + fs_args: dict[str, Any] = None, filepath: str = None, - execution_options: Optional[Dict[str, Any]] = None, - metadata: Dict[str, Any] = None, + execution_options: dict[str, Any] | None = None, + metadata: dict[str, Any] = None, ) -> None: """Creates a new ``SQLQueryDataset``. @@ -441,7 +445,7 @@ def __init__( # noqa: PLR0913 "provide a SQLAlchemy connection string." ) - default_load_args: Dict[str, Any] = {} + default_load_args: dict[str, Any] = {} self._load_args = ( {**default_load_args, **load_args} @@ -466,7 +470,6 @@ def __init__( # noqa: PLR0913 self._filepath = path self._connection_str = credentials["con"] self._execution_options = execution_options or {} - self.create_connection(self._connection_str) if "mssql" in self._connection_str: self.adapt_mssql_date_params() @@ -476,9 +479,6 @@ def create_connection(cls, connection_str: str) -> None: to be used across all instances of `SQLQueryDataset` that need to connect to the same source. """ - if connection_str in cls.engines: - return - try: engine = create_engine(connection_str) except ImportError as import_error: @@ -488,7 +488,17 @@ def create_connection(cls, connection_str: str) -> None: cls.engines[connection_str] = engine - def _describe(self) -> Dict[str, Any]: + @property + def engine(self): + """The ``Engine`` object for the dataset's connection string.""" + cls = type(self) + + if self._connection_str not in cls.engines: + self.create_connection(self._connection_str) + + return cls.engines[self._connection_str] + + def _describe(self) -> dict[str, Any]: load_args = copy.deepcopy(self._load_args) return { "sql": str(load_args.pop("sql", None)), @@ -499,16 +509,15 @@ def _describe(self) -> Dict[str, Any]: def _load(self) -> pd.DataFrame: load_args = copy.deepcopy(self._load_args) - engine = self.engines[self._connection_str].execution_options( - **self._execution_options - ) # type: ignore if self._filepath: load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol) with self._fs.open(load_path, mode="r") as fs_file: load_args["sql"] = fs_file.read() - return pd.read_sql_query(con=engine, **load_args) + return pd.read_sql_query( + con=self.engine.execution_options(**self._execution_options), **load_args + ) def _save(self, data: None) -> NoReturn: raise DatasetError("'save' is not supported on SQLQueryDataset") diff --git a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py index 6fbfa60a0..df2eab564 100644 --- a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py +++ b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py @@ -172,7 +172,6 @@ def __init__( # noqa: PLR0913 {"database": self._database, "schema": self._schema} ) self._connection_parameters = connection_parameters - self._session = self._get_session(self._connection_parameters) self.metadata = metadata @@ -207,10 +206,14 @@ def _get_session(connection_parameters) -> sp.Session: logger.debug("Trying to reuse active snowpark session...") session = sp.context.get_active_session() except sp.exceptions.SnowparkSessionException: - logger.debug("No active snowpark session found. Creating") + logger.debug("No active snowpark session found. Creating...") session = sp.Session.builder.configs(connection_parameters).create() return session + @property + def _session(self) -> sp.Session: + return self._get_session(self._connection_parameters) + def _load(self) -> sp.DataFrame: table_name = [ self._database, diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 8cabcf1f2..51416af42 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -32,7 +32,7 @@ version = {attr = "kedro_datasets.__version__"} fail_under = 100 show_missing = true omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*", "kedro_datasets/__init__.py", "kedro_datasets/databricks/*"] -exclude_lines = ["pragma: no cover", "raise NotImplementedError"] +exclude_lines = ["pragma: no cover", "raise NotImplementedError", "if TYPE_CHECKING:"] [tool.pytest.ini_options] addopts = """ diff --git a/kedro-datasets/tests/pandas/test_sql_dataset.py b/kedro-datasets/tests/pandas/test_sql_dataset.py index a90cff0b7..ecea7ed83 100644 --- a/kedro-datasets/tests/pandas/test_sql_dataset.py +++ b/kedro-datasets/tests/pandas/test_sql_dataset.py @@ -107,7 +107,9 @@ def test_driver_missing(self, mocker): side_effect=ImportError("No module named 'mysqldb'"), ) with pytest.raises(DatasetError, match=ERROR_PREFIX + "mysqlclient"): - SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION}) + SQLTableDataset( + table_name=TABLE_NAME, credentials={"con": CONNECTION} + ).exists() def test_unknown_sql(self): """Check the error when unknown sql dialect is provided; @@ -116,7 +118,9 @@ def test_unknown_sql(self): """ pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy" with pytest.raises(DatasetError, match=pattern): - SQLTableDataset(table_name=TABLE_NAME, credentials={"con": FAKE_CONN_STR}) + SQLTableDataset( + table_name=TABLE_NAME, credentials={"con": FAKE_CONN_STR} + ).exists() def test_unknown_module(self, mocker): """Test that if an unknown module/driver is encountered by SQLAlchemy @@ -127,7 +131,9 @@ def test_unknown_module(self, mocker): ) pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'" with pytest.raises(DatasetError, match=pattern): - SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION}) + SQLTableDataset( + table_name=TABLE_NAME, credentials={"con": CONNECTION} + ).exists() def test_str_representation_table(self, table_dataset): """Test the data set instance string representation""" @@ -215,6 +221,7 @@ def test_single_connection(self, dummy_dataframe, mocker): kwargs = {"table_name": TABLE_NAME, "credentials": {"con": CONNECTION}} first = SQLTableDataset(**kwargs) + assert not first.exists() # Do something to create the `Engine` unique_connection = first.engines[CONNECTION] datasets = [SQLTableDataset(**kwargs) for _ in range(10)] @@ -236,12 +243,18 @@ def test_create_connection_only_once(self, mocker): (but different tables, for example) only create a connection once. """ mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine") + mock_inspector = mocker.patch( + "kedro_datasets.pandas.sql_dataset.inspect" + ).return_value + mock_inspector.has_table.return_value = False first = SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION}) + assert not first.exists() # Do something to create the `Engine` assert len(first.engines) == 1 second = SQLTableDataset( table_name="other_table", credentials={"con": CONNECTION} ) + assert not second.exists() # Do something to fetch the `Engine` assert len(second.engines) == 1 assert len(first.engines) == 1 @@ -252,11 +265,17 @@ def test_multiple_connections(self, mocker): only create one connection per db. """ mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine") + mock_inspector = mocker.patch( + "kedro_datasets.pandas.sql_dataset.inspect" + ).return_value + mock_inspector.has_table.return_value = False first = SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION}) + assert not first.exists() # Do something to create the `Engine` assert len(first.engines) == 1 second_con = f"other_{CONNECTION}" second = SQLTableDataset(table_name=TABLE_NAME, credentials={"con": second_con}) + assert not second.exists() # Do something to create the `Engine` assert len(second.engines) == 2 assert len(first.engines) == 2 @@ -337,7 +356,7 @@ def test_load_driver_missing(self, mocker): "kedro_datasets.pandas.sql_dataset.create_engine", side_effect=_err ) with pytest.raises(DatasetError, match=ERROR_PREFIX + "mysqlclient"): - SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) + SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}).load() def test_invalid_module(self, mocker): """Test that if an unknown module/driver is encountered by SQLAlchemy @@ -348,7 +367,7 @@ def test_invalid_module(self, mocker): ) pattern = ERROR_PREFIX + r"Invalid module some\_module" with pytest.raises(DatasetError, match=pattern): - SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) + SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}).load() def test_load_unknown_module(self, mocker): """Test that if an unknown module/driver is encountered by SQLAlchemy @@ -359,14 +378,14 @@ def test_load_unknown_module(self, mocker): ) pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'" with pytest.raises(DatasetError, match=pattern): - SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) + SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}).load() def test_load_unknown_sql(self): """Check the error when unknown SQL dialect is provided in the connection string""" pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy" with pytest.raises(DatasetError, match=pattern): - SQLQueryDataset(sql=SQL_QUERY, credentials={"con": FAKE_CONN_STR}) + SQLQueryDataset(sql=SQL_QUERY, credentials={"con": FAKE_CONN_STR}).load() def test_save_error(self, query_dataset, dummy_dataframe): """Check the error when trying to save to the data set""" @@ -409,11 +428,13 @@ def test_create_connection_only_once(self, mocker): """ mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine") first = SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) + first.load() # Do something to create the `Engine` assert len(first.engines) == 1 # second engine has identical params to the first one # => no new engine should be created second = SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) + second.load() # Do something to fetch the `Engine` mock_engine.assert_called_once_with(CONNECTION) assert second.engines == first.engines assert len(first.engines) == 1 @@ -425,6 +446,7 @@ def test_create_connection_only_once(self, mocker): credentials={"con": CONNECTION}, execution_options=EXECUTION_OPTIONS, ) + third.load() # Do something to fetch the `Engine` assert mock_engine.call_count == 1 assert third.engines == first.engines assert len(first.engines) == 1 @@ -434,6 +456,7 @@ def test_create_connection_only_once(self, mocker): fourth = SQLQueryDataset( sql=SQL_QUERY, credentials={"con": "an other connection string"} ) + fourth.load() # Do something to create the `Engine` assert mock_engine.call_count == 2 assert fourth.engines == first.engines assert len(first.engines) == 2 @@ -447,6 +470,7 @@ def test_adapt_mssql_date_params_called(self, mocker): ) mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine") ds = SQLQueryDataset(sql=SQL_QUERY, credentials={"con": MSSQL_CONNECTION}) + ds.load() # Do something to create the `Engine` mock_engine.assert_called_once_with(MSSQL_CONNECTION) assert mock_adapt_mssql_date_params.call_count == 1 assert len(ds.engines) == 1