Skip to content

Commit

Permalink
perf(datasets): don't create connection until need (kedro-org#281)
Browse files Browse the repository at this point in the history
* perf(datasets): delay `Engine` creation until need

Signed-off-by: Deepyaman Datta <[email protected]>

* chore: don't check coverage in TYPE_CHECKING block

Signed-off-by: Deepyaman Datta <[email protected]>

* perf(datasets): don't connect in `__init__` method

Signed-off-by: Deepyaman Datta <[email protected]>

* test(datasets): fix tests to touch `create_engine`

Signed-off-by: Deepyaman Datta <[email protected]>

* perf(datasets): don't connect in `__init__` method

Signed-off-by: Deepyaman Datta <[email protected]>

* style(datasets): exec Ruff on sql_dataset.py 🐶

Signed-off-by: Deepyaman Datta <[email protected]>

* Undo changes to `engines` values type (for Sphinx)

Signed-off-by: Deepyaman Datta <[email protected]>

* Patch Sphinx build by removing `Engine` references

* perf(datasets): don't connect in `__init__` method

Signed-off-by: Deepyaman Datta <[email protected]>

* chore(datasets): don't require coverage for import

* chore(datasets): del unused `TYPE_CHECKING` import

* docs(datasets): document lazy connection in README

* perf(datasets): remove create in `SQLQueryDataset`

Signed-off-by: Deepyaman Datta <[email protected]>

* refactor(datasets): do not return the created conn

Signed-off-by: Deepyaman Datta <[email protected]>

---------

Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
deepyaman authored and tgoelles committed Jun 6, 2024
1 parent 2cd79bb commit 4214432
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 48 deletions.
1 change: 1 addition & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Major features and improvements
* Moved `PartitionedDataSet` and `IncrementalDataSet` from the core Kedro repo to `kedro-datasets` and renamed to `PartitionedDataset` and `IncrementalDataset`.
* Delayed backend connection for `pandas.SQLTableDataset`, `pandas.SQLQueryDataset`, and `snowflake.SnowparkTableDataset`. In practice, this means that a dataset's connection details aren't used (or validated) until the dataset is accessed. On the plus side, the cost of connection isn't incurred regardless of when or whether the dataset is used.
* Added xarray.GeoTiffDataset to handle GeoTIFF files.

## Bug fixes and other changes
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
try:
# Custom `KedroDeprecationWarning` class was added in Kedro 0.18.14.
from kedro import KedroDeprecationWarning
except ImportError:
except ImportError: # pragma: no cover

class KedroDeprecationWarning(DeprecationWarning):
"""Custom class for warnings about deprecated Kedro features."""
Expand Down
83 changes: 46 additions & 37 deletions kedro-datasets/kedro_datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""``SQLDataset`` to load and save data to a SQL backend."""
from __future__ import annotations

import copy
import datetime as dt
import re
import warnings
from pathlib import PurePosixPath
from typing import Any, Dict, NoReturn, Optional
from typing import Any, NoReturn

import fsspec
import pandas as pd
Expand Down Expand Up @@ -33,7 +35,7 @@
"""


def _find_known_drivers(module_import_error: ImportError) -> Optional[str]:
def _find_known_drivers(module_import_error: ImportError) -> str | None:
"""Looks up known keywords in a ``ModuleNotFoundError`` so that it can
provide better guideline for the user.
Expand Down Expand Up @@ -145,19 +147,19 @@ class SQLTableDataset(AbstractDataset[pd.DataFrame, pd.DataFrame]):
"""

DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False}
DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {"index": False}
# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}
engines: dict[str, Any] = {}

def __init__( # noqa: PLR0913
self,
table_name: str,
credentials: Dict[str, Any],
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
metadata: Dict[str, Any] = None,
credentials: dict[str, Any],
load_args: dict[str, Any] = None,
save_args: dict[str, Any] = None,
metadata: dict[str, Any] = None,
) -> None:
"""Creates a new ``SQLTableDataset``.
Expand Down Expand Up @@ -212,7 +214,6 @@ def __init__( # noqa: PLR0913
self._save_args["name"] = table_name

self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

self.metadata = metadata

Expand All @@ -222,9 +223,6 @@ def create_connection(cls, connection_str: str) -> None:
to be used across all instances of ``SQLTableDataset`` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
Expand All @@ -234,7 +232,17 @@ def create_connection(cls, connection_str: str) -> None:

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
cls = type(self)

if self._connection_str not in cls.engines:
self.create_connection(self._connection_str)

return cls.engines[self._connection_str]

def _describe(self) -> dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
save_args = copy.deepcopy(self._save_args)
del load_args["table_name"]
Expand All @@ -246,16 +254,13 @@ def _describe(self) -> Dict[str, Any]:
}

def _load(self) -> pd.DataFrame:
engine = self.engines[self._connection_str] # type:ignore
return pd.read_sql_table(con=engine, **self._load_args)
return pd.read_sql_table(con=self.engine, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
engine = self.engines[self._connection_str] # type: ignore
data.to_sql(con=engine, **self._save_args)
data.to_sql(con=self.engine, **self._save_args)

def _exists(self) -> bool:
engine = self.engines[self._connection_str] # type: ignore
insp = inspect(engine)
insp = inspect(self.engine)
schema = self._load_args.get("schema", None)
return insp.has_table(self._load_args["table_name"], schema)

Expand All @@ -273,7 +278,6 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]):
It does not support save method so it is a read only data set.
To save data to a SQL server use ``SQLTableDataset``.
Example usage for the
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog_yaml_examples.html>`_:
Expand Down Expand Up @@ -375,17 +379,17 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]):

# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}
engines: dict[str, Any] = {}

def __init__( # noqa: PLR0913
self,
sql: str = None,
credentials: Dict[str, Any] = None,
load_args: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
credentials: dict[str, Any] = None,
load_args: dict[str, Any] = None,
fs_args: dict[str, Any] = None,
filepath: str = None,
execution_options: Optional[Dict[str, Any]] = None,
metadata: Dict[str, Any] = None,
execution_options: dict[str, Any] | None = None,
metadata: dict[str, Any] = None,
) -> None:
"""Creates a new ``SQLQueryDataset``.
Expand Down Expand Up @@ -441,7 +445,7 @@ def __init__( # noqa: PLR0913
"provide a SQLAlchemy connection string."
)

default_load_args: Dict[str, Any] = {}
default_load_args: dict[str, Any] = {}

self._load_args = (
{**default_load_args, **load_args}
Expand All @@ -466,7 +470,6 @@ def __init__( # noqa: PLR0913
self._filepath = path
self._connection_str = credentials["con"]
self._execution_options = execution_options or {}
self.create_connection(self._connection_str)
if "mssql" in self._connection_str:
self.adapt_mssql_date_params()

Expand All @@ -476,9 +479,6 @@ def create_connection(cls, connection_str: str) -> None:
to be used across all instances of `SQLQueryDataset` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
Expand All @@ -488,7 +488,17 @@ def create_connection(cls, connection_str: str) -> None:

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
cls = type(self)

if self._connection_str not in cls.engines:
self.create_connection(self._connection_str)

return cls.engines[self._connection_str]

def _describe(self) -> dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
return {
"sql": str(load_args.pop("sql", None)),
Expand All @@ -499,16 +509,15 @@ def _describe(self) -> Dict[str, Any]:

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
engine = self.engines[self._connection_str].execution_options(
**self._execution_options
) # type: ignore

if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()

return pd.read_sql_query(con=engine, **load_args)
return pd.read_sql_query(
con=self.engine.execution_options(**self._execution_options), **load_args
)

def _save(self, data: None) -> NoReturn:
raise DatasetError("'save' is not supported on SQLQueryDataset")
Expand Down
7 changes: 5 additions & 2 deletions kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def __init__( # noqa: PLR0913
{"database": self._database, "schema": self._schema}
)
self._connection_parameters = connection_parameters
self._session = self._get_session(self._connection_parameters)

self.metadata = metadata

Expand Down Expand Up @@ -207,10 +206,14 @@ def _get_session(connection_parameters) -> sp.Session:
logger.debug("Trying to reuse active snowpark session...")
session = sp.context.get_active_session()
except sp.exceptions.SnowparkSessionException:
logger.debug("No active snowpark session found. Creating")
logger.debug("No active snowpark session found. Creating...")
session = sp.Session.builder.configs(connection_parameters).create()
return session

@property
def _session(self) -> sp.Session:
return self._get_session(self._connection_parameters)

def _load(self) -> sp.DataFrame:
table_name = [
self._database,
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ version = {attr = "kedro_datasets.__version__"}
fail_under = 100
show_missing = true
omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*", "kedro_datasets/__init__.py", "kedro_datasets/databricks/*"]
exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
exclude_lines = ["pragma: no cover", "raise NotImplementedError", "if TYPE_CHECKING:"]

[tool.pytest.ini_options]
addopts = """
Expand Down
Loading

0 comments on commit 4214432

Please sign in to comment.