diff --git a/kedro-datasets/kedro_datasets/pandas/sql_dataset.py b/kedro-datasets/kedro_datasets/pandas/sql_dataset.py index 588dc2eee..ad41b6911 100644 --- a/kedro-datasets/kedro_datasets/pandas/sql_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/sql_dataset.py @@ -1,10 +1,11 @@ """``SQLDataSet`` to load and save data to a SQL backend.""" +from __future__ import annotations import copy import datetime as dt import re from pathlib import PurePosixPath -from typing import Any, Dict, NoReturn, Optional +from typing import TYPE_CHECKING, Any, Dict, NoReturn, Optional import fsspec import pandas as pd @@ -15,6 +16,9 @@ from .._io import AbstractDataset as AbstractDataSet from .._io import DatasetError as DataSetError +if TYPE_CHECKING: + from sqlalchemy.engine.base import Engine + __all__ = ["SQLTableDataSet", "SQLQueryDataSet"] KNOWN_PIP_INSTALL = { @@ -221,22 +225,27 @@ def __init__( self.metadata = metadata @classmethod - def create_connection(cls, connection_str: str) -> None: + def create_connection(cls, connection_str: str) -> Engine: """Given a connection string, create singleton connection to be used across all instances of ``SQLTableDataSet`` that need to connect to the same source. """ - if connection_str in cls.engines: - return + if connection_str not in cls.engines: + try: + engine = create_engine(connection_str) + except ImportError as import_error: + raise _get_missing_module_error(import_error) from import_error + except NoSuchModuleError as exc: + raise _get_sql_alchemy_missing_error() from exc - try: - engine = create_engine(connection_str) - except ImportError as import_error: - raise _get_missing_module_error(import_error) from import_error - except NoSuchModuleError as exc: - raise _get_sql_alchemy_missing_error() from exc + cls.engines[connection_str] = engine - cls.engines[connection_str] = engine + return cls.engines[connection_str] + + @property + def engine(self): + """The ``Engine`` object for the dataset's connection string.""" + return self.create_connection(self._connection_str) def _describe(self) -> Dict[str, Any]: load_args = copy.deepcopy(self._load_args) @@ -250,16 +259,13 @@ def _describe(self) -> Dict[str, Any]: } def _load(self) -> pd.DataFrame: - engine = self.engines[self._connection_str] # type:ignore - return pd.read_sql_table(con=engine, **self._load_args) + return pd.read_sql_table(con=self.engine, **self._load_args) def _save(self, data: pd.DataFrame) -> None: - engine = self.engines[self._connection_str] # type: ignore - data.to_sql(con=engine, **self._save_args) + data.to_sql(con=self.engine, **self._save_args) def _exists(self) -> bool: - engine = self.engines[self._connection_str] # type: ignore - insp = inspect(engine) + insp = inspect(self.engine) schema = self._load_args.get("schema", None) return insp.has_table(self._load_args["table_name"], schema)