Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(datasets): don't create connection until need #281

Merged
merged 17 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 60 additions & 53 deletions kedro-datasets/kedro_datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""``SQLDataset`` to load and save data to a SQL backend."""
from __future__ import annotations

import copy
import datetime as dt
import re
import warnings
from pathlib import PurePosixPath
from typing import Any, Dict, NoReturn, Optional
from typing import TYPE_CHECKING, Any, NoReturn

import fsspec
import pandas as pd
Expand All @@ -15,6 +17,9 @@
from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._io import AbstractDataset, DatasetError

if TYPE_CHECKING:
from sqlalchemy.engine.base import Engine

__all__ = ["SQLTableDataset", "SQLQueryDataset"]

KNOWN_PIP_INSTALL = {
Expand All @@ -33,7 +38,7 @@
"""


def _find_known_drivers(module_import_error: ImportError) -> Optional[str]:
def _find_known_drivers(module_import_error: ImportError) -> str | None:
"""Looks up known keywords in a ``ModuleNotFoundError`` so that it can
provide better guideline for the user.

Expand Down Expand Up @@ -148,19 +153,18 @@ class SQLTableDataset(AbstractDataset[pd.DataFrame, pd.DataFrame]):

"""

DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False}
# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}
DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {"index": False}

engines: dict[str, Engine] = {}
deepyaman marked this conversation as resolved.
Show resolved Hide resolved

def __init__( # noqa: PLR0913
self,
table_name: str,
credentials: Dict[str, Any],
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
metadata: Dict[str, Any] = None,
credentials: dict[str, Any],
load_args: dict[str, Any] = None,
save_args: dict[str, Any] = None,
metadata: dict[str, Any] = None,
) -> None:
"""Creates a new ``SQLTableDataset``.

Expand Down Expand Up @@ -215,29 +219,33 @@ def __init__( # noqa: PLR0913
self._save_args["name"] = table_name

self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

self.metadata = metadata

@classmethod
def create_connection(cls, connection_str: str) -> None:
def create_connection(cls, connection_str: str) -> Engine:
"""Given a connection string, create singleton connection
to be used across all instances of ``SQLTableDataset`` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return
if connection_str not in cls.engines:
try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
cls.engines[connection_str] = engine

cls.engines[connection_str] = engine
return cls.engines[connection_str]

def _describe(self) -> Dict[str, Any]:
@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
return self.create_connection(self._connection_str)

def _describe(self) -> dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
save_args = copy.deepcopy(self._save_args)
del load_args["table_name"]
Expand All @@ -249,16 +257,13 @@ def _describe(self) -> Dict[str, Any]:
}

def _load(self) -> pd.DataFrame:
engine = self.engines[self._connection_str] # type:ignore
return pd.read_sql_table(con=engine, **self._load_args)
return pd.read_sql_table(con=self.engine, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
engine = self.engines[self._connection_str] # type: ignore
data.to_sql(con=engine, **self._save_args)
data.to_sql(con=self.engine, **self._save_args)

def _exists(self) -> bool:
engine = self.engines[self._connection_str] # type: ignore
insp = inspect(engine)
insp = inspect(self.engine)
schema = self._load_args.get("schema", None)
return insp.has_table(self._load_args["table_name"], schema)

Expand Down Expand Up @@ -372,19 +377,17 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]):
date: "%Y-%m-%d %H:%M:%S.%f0 %z"
"""

# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}
engines: dict[str, Engine] = {}
deepyaman marked this conversation as resolved.
Show resolved Hide resolved

def __init__( # noqa: PLR0913
self,
sql: str = None,
credentials: Dict[str, Any] = None,
load_args: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
credentials: dict[str, Any] = None,
load_args: dict[str, Any] = None,
fs_args: dict[str, Any] = None,
filepath: str = None,
execution_options: Optional[Dict[str, Any]] = None,
metadata: Dict[str, Any] = None,
execution_options: dict[str, Any] | None = None,
metadata: dict[str, Any] = None,
) -> None:
"""Creates a new ``SQLQueryDataset``.

Expand Down Expand Up @@ -440,7 +443,7 @@ def __init__( # noqa: PLR0913
"provide a SQLAlchemy connection string."
)

default_load_args: Dict[str, Any] = {}
default_load_args: dict[str, Any] = {}

self._load_args = (
{**default_load_args, **load_args}
Expand Down Expand Up @@ -470,24 +473,29 @@ def __init__( # noqa: PLR0913
self.adapt_mssql_date_params()

@classmethod
def create_connection(cls, connection_str: str) -> None:
def create_connection(cls, connection_str: str) -> Engine:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLQueryDataset` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return
if connection_str not in cls.engines:
try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
cls.engines[connection_str] = engine

cls.engines[connection_str] = engine
return cls.engines[connection_str]

def _describe(self) -> Dict[str, Any]:
@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
return self.create_connection(self._connection_str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
return self.create_connection(self._connection_str)
@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
if self._conneciton_str not in cls.engines:
self.create_connection(self._connection_str)
return self.engines[self._connection_str]

The current refactoring works perfectly fine, but I feel that create_connection shouldn't return a connection and this is more natural. (the if conneciton_str not in cls.engines should be also removed).


def _describe(self) -> dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
return {
"sql": str(load_args.pop("sql", None)),
Expand All @@ -498,16 +506,15 @@ def _describe(self) -> Dict[str, Any]:

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
engine = self.engines[self._connection_str].execution_options(
**self._execution_options
) # type: ignore

if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()

return pd.read_sql_query(con=engine, **load_args)
return pd.read_sql_query(
con=self.engine.execution_options(**self._execution_options), **load_args
)

def _save(self, data: None) -> NoReturn:
raise DatasetError("'save' is not supported on SQLQueryDataset")
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ version = {attr = "kedro_datasets.__version__"}
fail_under = 100
show_missing = true
omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*"]
exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
exclude_lines = ["pragma: no cover", "raise NotImplementedError", "if TYPE_CHECKING:"]

[tool.pytest.ini_options]
addopts = """
Expand Down
31 changes: 25 additions & 6 deletions kedro-datasets/tests/pandas/test_sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def test_driver_missing(self, mocker):
side_effect=ImportError("No module named 'mysqldb'"),
)
with pytest.raises(DatasetError, match=ERROR_PREFIX + "mysqlclient"):
SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION})
SQLTableDataset(
table_name=TABLE_NAME, credentials={"con": CONNECTION}
).exists()

def test_unknown_sql(self):
"""Check the error when unknown sql dialect is provided;
Expand All @@ -116,7 +118,9 @@ def test_unknown_sql(self):
"""
pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy"
with pytest.raises(DatasetError, match=pattern):
SQLTableDataset(table_name=TABLE_NAME, credentials={"con": FAKE_CONN_STR})
SQLTableDataset(
table_name=TABLE_NAME, credentials={"con": FAKE_CONN_STR}
).exists()

def test_unknown_module(self, mocker):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
Expand All @@ -127,7 +131,9 @@ def test_unknown_module(self, mocker):
)
pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'"
with pytest.raises(DatasetError, match=pattern):
SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION})
SQLTableDataset(
table_name=TABLE_NAME, credentials={"con": CONNECTION}
).exists()

def test_str_representation_table(self, table_dataset):
"""Test the data set instance string representation"""
Expand Down Expand Up @@ -215,6 +221,7 @@ def test_single_connection(self, dummy_dataframe, mocker):
kwargs = {"table_name": TABLE_NAME, "credentials": {"con": CONNECTION}}

first = SQLTableDataset(**kwargs)
assert not first.exists() # Do something to create the `Engine`
unique_connection = first.engines[CONNECTION]
datasets = [SQLTableDataset(**kwargs) for _ in range(10)]

Expand All @@ -236,12 +243,18 @@ def test_create_connection_only_once(self, mocker):
(but different tables, for example) only create a connection once.
"""
mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine")
mock_inspector = mocker.patch(
"kedro_datasets.pandas.sql_dataset.inspect"
).return_value
mock_inspector.has_table.return_value = False
first = SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION})
assert not first.exists() # Do something to create the `Engine`
assert len(first.engines) == 1

second = SQLTableDataset(
table_name="other_table", credentials={"con": CONNECTION}
)
assert not second.exists() # Do something to fetch the `Engine`
assert len(second.engines) == 1
assert len(first.engines) == 1

Expand All @@ -252,11 +265,17 @@ def test_multiple_connections(self, mocker):
only create one connection per db.
"""
mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine")
mock_inspector = mocker.patch(
"kedro_datasets.pandas.sql_dataset.inspect"
).return_value
mock_inspector.has_table.return_value = False
first = SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION})
assert not first.exists() # Do something to create the `Engine`
assert len(first.engines) == 1

second_con = f"other_{CONNECTION}"
second = SQLTableDataset(table_name=TABLE_NAME, credentials={"con": second_con})
assert not second.exists() # Do something to create the `Engine`
assert len(second.engines) == 2
assert len(first.engines) == 2

Expand Down Expand Up @@ -337,7 +356,7 @@ def test_load_driver_missing(self, mocker):
"kedro_datasets.pandas.sql_dataset.create_engine", side_effect=_err
)
with pytest.raises(DatasetError, match=ERROR_PREFIX + "mysqlclient"):
SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION})
SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}).load()

def test_invalid_module(self, mocker):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
Expand All @@ -348,7 +367,7 @@ def test_invalid_module(self, mocker):
)
pattern = ERROR_PREFIX + r"Invalid module some\_module"
with pytest.raises(DatasetError, match=pattern):
SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION})
SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}).load()

def test_load_unknown_module(self, mocker):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
Expand All @@ -359,7 +378,7 @@ def test_load_unknown_module(self, mocker):
)
pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'"
with pytest.raises(DatasetError, match=pattern):
SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION})
SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}).load()

def test_load_unknown_sql(self):
"""Check the error when unknown SQL dialect is provided
Expand Down
Loading