diff --git a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py index 61e9e115a..a8cb36d59 100644 --- a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py +++ b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import Any, Union +from typing import Any import pandas as pd import snowflake.snowpark as sp @@ -111,9 +111,11 @@ def __init__( # noqa: PLR0913 load_args: dict[str, Any] | None = None, save_args: dict[str, Any] | None = None, credentials: dict[str, Any] | None = None, + session: sp.Session | None = None, metadata: dict[str, Any] | None = None, ) -> None: - """Creates a new instance of ``SnowparkTableDataset``. + """ + Creates a new instance of ``SnowparkTableDataset``. Args: table_name: The table name to load or save data to. @@ -162,6 +164,7 @@ def __init__( # noqa: PLR0913 self._table_name = table_name self._database = database self._schema = schema + self.__session = session or self._get_session(credentials) # for testing connection_parameters = credentials connection_parameters.update( @@ -180,7 +183,8 @@ def _describe(self) -> dict[str, Any]: @staticmethod def _get_session(connection_parameters) -> sp.Session: - """Given a connection string, create singleton connection + """ + Given a connection string, create singleton connection to be used across all instances of `SnowparkTableDataset` that need to connect to the same source. connection_parameters is a dictionary of any values @@ -208,35 +212,52 @@ def _get_session(connection_parameters) -> sp.Session: @property def _session(self) -> sp.Session: - return self._get_session(self._connection_parameters) + """ + Retrieve or create a session. + + Returns: + sp.Session: The current session associated with the object. + """ + if not self.__session: + self.__session = self._get_session(self._connection_parameters) + return self.__session def load(self) -> sp.DataFrame: - table_name: list = [ - self._database, - self._schema, - self._table_name, - ] + """ + Load data from a specified database table. + Returns: + sp.DataFrame: The loaded data as a Snowpark DataFrame. + """ + table_name = [self._database, self._schema, self._table_name] sp_df = self._session.table(".".join(table_name)) return sp_df - def save(self, data: Union[pd.DataFrame, sp.DataFrame) -> None: + def save(self, data: pd.DataFrame | sp.DataFrame) -> None: + """ + Save data to a specified database table. + + Args: + data (pd.DataFrame | sp.DataFrame): The data to save. + """ if isinstance(data, pd.DataFrame): data = self._session.create_dataframe(data) - table_name = ".".join([self._database, self._schema, self._table_name]) + table_name = [self._database, self._schema, self._table_name] data.write.save_as_table(table_name, **self._save_args) def _exists(self) -> bool: - session = self._session - query = "SELECT COUNT(*) FROM {database}.INFORMATION_SCHEMA.TABLES \ - WHERE TABLE_SCHEMA = '{schema}' \ - AND TABLE_NAME = '{table_name}'" - rows = session.sql( - query.format( - database=self._database, - schema=self._schema, - table_name=self._table_name, - ) - ).collect() - return rows[0][0] == 1 + """ + Check if a specified table exists in the database. + + Returns: + bool: True if the table exists, False otherwise. + """ + try: + self._session.table( + f"{self._database}.{self._schema}.{self._table_name}" + ).show() + return True + except Exception as e: + logger.debug(f"Table {self._table_name} does not exist: {e}") + return False diff --git a/kedro-datasets/tests/snowflake/conftest.py b/kedro-datasets/tests/snowflake/conftest.py deleted file mode 100644 index 704ecf482..000000000 --- a/kedro-datasets/tests/snowflake/conftest.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -We disable execution of tests that require real Snowflake instance -to run by default. Providing -m snowflake option explicitly to -pytest will make these and only these tests run -""" - -import pytest - - -def pytest_collection_modifyitems(config, items): - markers_arg = config.getoption("-m") - - # Naive implementation to handle basic marker expressions - # Will not work if someone will (ever) run pytest with complex marker - # expressions like "-m spark and not (snowflake or pandas)" - if ( - "snowflake" in markers_arg.lower() - and "not snowflake" not in markers_arg.lower() - ): - return - - skip_snowflake = pytest.mark.skip(reason="need -m snowflake option to run") - for item in items: - if "snowflake" in item.keywords: - item.add_marker(skip_snowflake) diff --git a/kedro-datasets/tests/snowflake/test_snowpark_dataset.py b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py index 5dc4d2ee7..af92f20bd 100644 --- a/kedro-datasets/tests/snowflake/test_snowpark_dataset.py +++ b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py @@ -1,127 +1,111 @@ import datetime -import os import pandas as pd import pytest -import snowflake.snowpark as sp -from kedro.io.core import DatasetError - -from kedro_datasets.snowflake import SnowparkTableDataset as spds - - -def get_connection(): - account = os.getenv("SNOWSQL_ACCOUNT") - warehouse = os.getenv("SNOWSQL_WAREHOUSE") - database = os.getenv("SNOWSQL_DATABASE") - role = os.getenv("SNOWSQL_ROLE") - user = os.getenv("SNOWSQL_USER") - schema = os.getenv("SNOWSQL_SCHEMA") - password = os.getenv("SNOWSQL_PWD") - - if not ( - account and warehouse and database and role and user and schema and password - ): - raise DatasetError( - "Snowflake connection environment variables provided not in full" - ) - - conn = { - "account": account, - "warehouse": warehouse, - "database": database, - "role": role, - "user": user, - "schema": schema, - "password": password, - } - return conn - - -def sf_setup_db(sf_session): - # For table exists test - run_query(sf_session, 'CREATE TABLE KEDRO_PYTEST_TESTEXISTS ("name" VARCHAR)') - - # For load test - query = 'CREATE TABLE KEDRO_PYTEST_TESTLOAD ("name" VARCHAR\ - , "age" INTEGER\ - , "bday" date\ - , "height" float\ - , "insert_dttm" timestamp)' - run_query(sf_session, query) - - query = "INSERT INTO KEDRO_PYTEST_TESTLOAD VALUES ('John'\ - , 23\ - , to_date('1999-12-02','YYYY-MM-DD')\ - , 6.5\ - , to_timestamp_ntz('2022-12-02 13:20:01',\ - 'YYYY-MM-DD hh24:mi:ss'))" - run_query(sf_session, query) - - query = "INSERT INTO KEDRO_PYTEST_TESTLOAD VALUES ('Jane'\ - , 41\ - , to_date('1981-01-03','YYYY-MM-DD')\ - , 5.7\ - , to_timestamp_ntz('2022-12-02 13:21:11',\ - 'YYYY-MM-DD hh24:mi:ss'))" - run_query(sf_session, query) - - -def sf_db_cleanup(sf_session): - run_query(sf_session, "DROP TABLE IF EXISTS KEDRO_PYTEST_TESTEXISTS") - run_query(sf_session, "DROP TABLE IF EXISTS KEDRO_PYTEST_TESTLOAD") - run_query(sf_session, "DROP TABLE IF EXISTS KEDRO_PYTEST_TESTSAVE") - - -def run_query(session, query): - df = session.sql(query) - df.collect() - return df - - -def df_equals_ignore_dtype(df1, df2): - # Pytest will show respective stdout only if test fails - # this will help to debug what was exactly not matching right away - - c1 = df1.to_pandas().values.tolist() - c2 = df2.to_pandas().values.tolist() - - print(c1) - print("--- comparing to ---") - print(c2) - - for i, row in enumerate(c1): - for j, column in enumerate(row): - if not column == c2[i][j]: - print(f"{column} not equal to {c2[i][j]}") - return False - return True +from snowflake.snowpark import DataFrame +from snowflake.snowpark.session import Session +from snowflake.snowpark.types import ( + DateType, + FloatType, + IntegerType, + StringType, + StructField, + StructType, + TimestampType, +) + +from kedro_datasets.snowflake.snowpark_dataset import SnowparkTableDataset + +# Example dummy configuration for testing +DUMMY_CREDENTIALS = { + "account": "DUMMY_ACCOUNT", + "warehouse": "DUMMY_WAREHOUSE", + "database": "DUMMY_DATABASE", + "schema": "DUMMY_SCHEMA", + "user": "DUMMY_USER", + "password": "DUMMY_PASSWORD", +} + + +@pytest.fixture(scope="module") +def local_snowpark_session() -> Session: + """ + Creates a local Snowflake session for testing purposes. + See https://docs.snowflake.com/en/developer-guide/snowpark/python/testing-locally + + Returns: + Session: Snowflake session object configured for local testing. + """ + return Session.builder.config("local_testing", True).create() + + +@pytest.fixture(scope="module") +def snowflake_dataset(local_snowpark_session: Session) -> SnowparkTableDataset: + """ + Provides a SnowparkTableDataset fixture for testing. + + Args: + snowflake_session (Session): The Snowflake session used for this dataset. + + Returns: + SnowparkTableDataset: Dataset configuration for a Snowflake table. + """ + return SnowparkTableDataset( + table_name="DUMMY_TABLE", + credentials=DUMMY_CREDENTIALS, + session=local_snowpark_session, + save_args={"mode": "overwrite"}, + ) @pytest.fixture -def sample_sp_df(sf_session): - return sf_session.create_dataframe( +def sample_sp_df(local_snowpark_session: Session) -> DataFrame: + """ + Creates a sample Snowpark DataFrame for testing. + + Args: + snowflake_session (Session): Session to create the DataFrame. + + Returns: + snowpark.DataFrame: DataFrame with sample data and schema. + """ + return local_snowpark_session.create_dataframe( [ - [ + ( "John", 23, datetime.date(1999, 12, 2), 6.5, datetime.datetime(2022, 12, 2, 13, 20, 1), - ], - [ + ), + ( "Jane", 41, datetime.date(1981, 1, 3), 5.7, datetime.datetime(2022, 12, 2, 13, 21, 11), - ], + ), ], - schema=["name", "age", "bday", "height", "insert_dttm"], + schema=StructType( + [ + StructField("name", StringType()), + StructField("age", IntegerType()), + StructField("bday", DateType()), + StructField("height", FloatType()), + StructField("insert_dttm", TimestampType()), + ] + ), ) @pytest.fixture -def sample_pd_df(): +def sample_pd_df() -> pd.DataFrame: + """ + Creates a sample Pandas DataFrame for testing. + + Returns: + pd.DataFrame: DataFrame with sample data. + """ return pd.DataFrame( { "name": ["Alice", "Bob"], @@ -136,56 +120,80 @@ def sample_pd_df(): ) -@pytest.fixture -def sf_session(): - sf_session = sp.Session.builder.configs(get_connection()).create() - - # Running cleanup in case previous run was interrupted w/o proper cleanup - sf_db_cleanup(sf_session) - sf_setup_db(sf_session) - - yield sf_session - sf_db_cleanup(sf_session) - sf_session.close() - - class TestSnowparkTableDataset: - @pytest.mark.snowflake - def test_save(self, sample_sp_df, sf_session): - sp_df = spds(table_name="KEDRO_PYTEST_TESTSAVE", credentials=get_connection()) - sp_df._save(sample_sp_df) - sp_df_saved = sf_session.table("KEDRO_PYTEST_TESTSAVE") - assert sp_df_saved.count() == 2 - - @pytest.mark.snowflake - def test_save_with_pandas(self, sample_pd_df, sf_session): - sp_df = spds( - table_name="KEDRO_PYTEST_TESTSAVEPANDAS", credentials=get_connection() - ) - sp_df.save(sample_pd_df) - - sp_df_saved = sf_session.table("KEDRO_PYTEST_TESTSAVEPANDAS") - - # Assert the count matches - assert sp_df_saved.count() == 2 - - @pytest.mark.snowflake - def test_load(self, sample_sp_df, sf_session): - print(sf_session) - sp_df = spds( - table_name="KEDRO_PYTEST_TESTLOAD", credentials=get_connection() - )._load() - - # Ignoring dtypes as ex. age can be int8 vs int64 and pandas.compare - # fails on that - assert df_equals_ignore_dtype(sample_sp_df, sp_df) - - @pytest.mark.snowflake - def test_exists(self, sf_session): - print(sf_session) - df_e = spds(table_name="KEDRO_PYTEST_TESTEXISTS", credentials=get_connection()) - df_ne = spds( - table_name="KEDRO_PYTEST_TESTNEXISTS", credentials=get_connection() - ) - assert df_e._exists() - assert not df_ne._exists() + """Tests for the SnowparkTableDataset functionality.""" + + def test_save_with_snowpark( + self, sample_sp_df: DataFrame, snowflake_dataset: SnowparkTableDataset + ) -> None: + """Tests saving a Snowpark DataFrame to a Snowflake table. + + Args: + sample_sp_df (snowpark.DataFrame): Sample data to save. + snowflake_dataset (SnowparkTableDataset): Dataset to test. + + Asserts: + The count of the loaded DataFrame matches the saved DataFrame. + """ + snowflake_dataset.save(sample_sp_df) + loaded_df = snowflake_dataset.load() + assert loaded_df.count() == sample_sp_df.count() + + def test_save_with_pandas( + self, sample_pd_df: pd.DataFrame, snowflake_dataset: SnowparkTableDataset + ) -> None: + """ + Tests saving a Pandas DataFrame to a Snowflake table. + + Args: + sample_pd_df (pd.DataFrame): Sample data to save. + snowflake_dataset (SnowparkTableDataset): Dataset to test. + + Asserts: + The count of the loaded DataFrame matches the number of rows in the Pandas DataFrame. + """ + snowflake_dataset.save(sample_pd_df) + loaded_df = snowflake_dataset.load() + assert loaded_df.count() == len(sample_pd_df) + + def test_load( + self, snowflake_dataset: SnowparkTableDataset, sample_sp_df: DataFrame + ) -> None: + """ + Tests loading data from a Snowflake table. + + Args: + snowflake_dataset (SnowparkTableDataset): Dataset to load data from. + sample_sp_df (snowpark.DataFrame): Sample data for reference. + + Asserts: + The count of the loaded DataFrame matches the reference sample DataFrame. + """ + loaded_df = snowflake_dataset.load() + assert loaded_df.count() == sample_sp_df.count() + + def test_exists(self, snowflake_dataset: SnowparkTableDataset) -> None: + """ + Tests if a Snowflake table exists. + + Args: + snowflake_dataset (SnowparkTableDataset): Dataset to check existence. + + Asserts: + The dataset table exists in the Snowflake environment. + """ + exists = snowflake_dataset._exists() + assert exists + + def test_not_exists(self, snowflake_dataset: SnowparkTableDataset) -> None: + """ + Tests if a non-existent Snowflake table is detected. + Args: + snowflake_dataset (SnowparkTableDataset): Dataset to check existence. + + Asserts: + The dataset table does not exist in the Snowflake environment. + """ + snowflake_dataset._table_name = "NON_EXISTENT_TABLE" + exists = snowflake_dataset._exists() + assert not exists