Skip to content

Commit

Permalink
perf(datasets): delay Engine creation until need
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
deepyaman committed Jul 25, 2023
1 parent 9cb6a16 commit 4129822
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions kedro-datasets/kedro_datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""``SQLDataSet`` to load and save data to a SQL backend."""
from __future__ import annotations

import copy
import datetime as dt
import re
from pathlib import PurePosixPath
from typing import Any, Dict, NoReturn, Optional
from typing import TYPE_CHECKING, Any, Dict, NoReturn, Optional

import fsspec
import pandas as pd
Expand All @@ -15,6 +16,9 @@
from .._io import AbstractDataset as AbstractDataSet
from .._io import DatasetError as DataSetError

if TYPE_CHECKING:
from sqlalchemy.engine.base import Engine

__all__ = ["SQLTableDataSet", "SQLQueryDataSet"]

KNOWN_PIP_INSTALL = {
Expand Down Expand Up @@ -221,22 +225,27 @@ def __init__(
self.metadata = metadata

@classmethod
def create_connection(cls, connection_str: str) -> None:
def create_connection(cls, connection_str: str) -> Engine:
"""Given a connection string, create singleton connection
to be used across all instances of ``SQLTableDataSet`` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return
if connection_str not in cls.engines:
try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
cls.engines[connection_str] = engine

cls.engines[connection_str] = engine
return cls.engines[connection_str]

@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
return self.create_connection(self._connection_str)

def _describe(self) -> Dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
Expand All @@ -250,16 +259,13 @@ def _describe(self) -> Dict[str, Any]:
}

def _load(self) -> pd.DataFrame:
engine = self.engines[self._connection_str] # type:ignore
return pd.read_sql_table(con=engine, **self._load_args)
return pd.read_sql_table(con=self.engine, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
engine = self.engines[self._connection_str] # type: ignore
data.to_sql(con=engine, **self._save_args)
data.to_sql(con=self.engine, **self._save_args)

def _exists(self) -> bool:
engine = self.engines[self._connection_str] # type: ignore
insp = inspect(engine)
insp = inspect(self.engine)
schema = self._load_args.get("schema", None)
return insp.has_table(self._load_args["table_name"], schema)

Expand Down

0 comments on commit 4129822

Please sign in to comment.