diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index fe94686f6..c9a0eb0b0 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -2,13 +2,17 @@ ## Major features and improvements +- Added functionality to save Pandas DataFrame directly to Snowflake, facilitating seemless `.csv` ingestion +- Added Python 3.9, 3.10 and 3.11 support for SnowflakeTableDataset - Added the following new **experimental** datasets: | Type | Description | Location | | --------------------------------- | ------------------------------------------------------ | ---------------------------------------- | | `databricks.ExternalTableDataset` | A dataset for accessing external tables in Databricks. | `kedro_datasets_experimental.databricks` | + ## Bug fixes and other changes +- Implemented Snowflake's (local testing framework)[https://docs.snowflake.com/en/developer-guide/snowpark/python/testing-locally] for testing purposes ## Breaking Changes @@ -16,6 +20,7 @@ Many thanks to the following Kedroids for contributing PRs to this release: +- [Thomas d'Hooghe](https://github.com/tdhooghe) - [Minura Punchihewa](https://github.com/MinuraPunchihewa) # Release 5.1.0 diff --git a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py index f47cbf2cb..f2e215d3e 100644 --- a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py +++ b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py @@ -1,20 +1,24 @@ -"""``AbstractDataset`` implementation to access Snowflake using Snowpark dataframes -""" +"""``AbstractDataset`` implementation to access Snowflake using Snowpark dataframes""" + from __future__ import annotations import logging -from typing import Any +from typing import Any, cast -import snowflake.snowpark as sp +import pandas as pd from kedro.io.core import AbstractDataset, DatasetError +from snowflake.snowpark import DataFrame, Session +from snowflake.snowpark import context as sp_context +from snowflake.snowpark import exceptions as sp_exceptions logger = logging.getLogger(__name__) class SnowparkTableDataset(AbstractDataset): - """``SnowparkTableDataset`` loads and saves Snowpark dataframes. + """``SnowparkTableDataset`` loads and saves Snowpark DataFrames. - As of Mar-2023, the snowpark connector only works with Python 3.8. + As of October 2024, the Snowpark connector works with Python 3.9, 3.10, and 3.11. + Python 3.12 is not supported yet. Example usage for the `YAML API None: - """Creates a new instance of ``SnowparkTableDataset``. + """ + Creates a new instance of ``SnowparkTableDataset``. Args: table_name: The table name to load or save data to. @@ -154,6 +160,7 @@ def __init__( # noqa: PLR0913 "'schema' must be provided by credentials or dataset." ) schema = credentials["schema"] + # Handle default load and save arguments self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})} self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})} @@ -167,6 +174,7 @@ def __init__( # noqa: PLR0913 {"database": self._database, "schema": self._schema} ) self._connection_parameters = connection_parameters + self._session = session self.metadata = metadata @@ -178,8 +186,9 @@ def _describe(self) -> dict[str, Any]: } @staticmethod - def _get_session(connection_parameters) -> sp.Session: - """Given a connection string, create singleton connection + def _get_session(connection_parameters) -> Session: + """ + Given a connection string, create singleton connection to be used across all instances of `SnowparkTableDataset` that need to connect to the same source. connection_parameters is a dictionary of any values @@ -199,45 +208,96 @@ def _get_session(connection_parameters) -> sp.Session: """ try: logger.debug("Trying to reuse active snowpark session...") - session = sp.context.get_active_session() - except sp.exceptions.SnowparkSessionException: + session = sp_context.get_active_session() + except sp_exceptions.SnowparkSessionException: logger.debug("No active snowpark session found. Creating...") - session = sp.Session.builder.configs(connection_parameters).create() + session = Session.builder.configs(connection_parameters).create() return session @property - def _session(self) -> sp.Session: - return self._get_session(self._connection_parameters) + def session(self) -> Session: + """ + Retrieve or create a session. + Returns: + Session: The current session associated with the object. + """ + if not self._session: + self._session = self._get_session(self._connection_parameters) + return self._session - def load(self) -> sp.DataFrame: - table_name: list = [ - self._database, - self._schema, - self._table_name, - ] + def load(self) -> DataFrame: + """ + Load data from a specified database table. - sp_df = self._session.table(".".join(table_name)) - return sp_df + Returns: + DataFrame: The loaded data as a Snowpark DataFrame. + """ + if self._session is None: + raise DatasetError( + "No active session. Please initialise a Snowpark session before loading data." + ) + return self._session.table(self._validate_and_get_table_name()) + + def save(self, data: pd.DataFrame | DataFrame) -> None: + """ + Check if the data is a Snowpark DataFrame or a Pandas DataFrame, + convert it to a Snowpark DataFrame if needed, and save it to the specified table. - def save(self, data: sp.DataFrame) -> None: - table_name = [ - self._database, - self._schema, - self._table_name, - ] + Args: + data (pd.DataFrame | DataFrame): The data to save. + """ + if self._session is None: + raise DatasetError( + "No active session. Please initialise a Snowpark session before loading data." + ) + if isinstance(data, pd.DataFrame): + snowpark_df = self._session.create_dataframe(data) + elif isinstance(data, DataFrame): + snowpark_df = data + else: + raise DatasetError( + f"Data of type {type(data)} is not supported for saving." + ) - data.write.save_as_table(table_name, **self._save_args) + snowpark_df.write.save_as_table( + self._validate_and_get_table_name(), **self._save_args + ) def _exists(self) -> bool: - session = self._session - query = "SELECT COUNT(*) FROM {database}.INFORMATION_SCHEMA.TABLES \ - WHERE TABLE_SCHEMA = '{schema}' \ - AND TABLE_NAME = '{table_name}'" - rows = session.sql( - query.format( - database=self._database, - schema=self._schema, - table_name=self._table_name, + """ + Check if a specified table exists in the database. + + Returns: + bool: True if the table exists, False otherwise. + """ + if self._session is None: + raise DatasetError( + "No active session. Please initialise a Snowpark session before loading data." ) - ).collect() - return rows[0][0] == 1 + try: + self._session.table( + f"{self._database}.{self._schema}.{self._table_name}" + ).show() + return True + except Exception as e: + logger.debug(f"Table {self._table_name} does not exist: {e}") + return False + + def _validate_and_get_table_name(self) -> str: + """ + Validate that all parts of the table name are not None and join them into a string. + + Args: + parts (list[str | None]): The list containing database, schema, and table name. + + Returns: + str: The joined table name in the format 'database.schema.table'. + + Raises: + ValueError: If any part of the table name is None. + """ + parts: list[str | None] = [self._database, self._schema, self._table_name] + if any(part is None or part == "" for part in parts): + raise DatasetError("Database, schema or table name cannot be None or empty") + parts_str = cast(list[str], parts) # make linting happy + return ".".join(parts_str) diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 839a8439e..de2e071ab 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -138,7 +138,7 @@ polars = [ redis-pickledataset = ["redis~=4.1"] redis = ["kedro-datasets[redis-pickledataset]"] -snowflake-snowparktabledataset = ["snowflake-snowpark-python~=1.0"] +snowflake-snowparktabledataset = ["snowflake-snowpark-python>=1.23"] snowflake = ["kedro-datasets[snowflake-snowparktabledataset]"] spark-deltatabledataset = ["kedro-datasets[spark-base,hdfs-base,s3fs-base,delta-base]"] @@ -205,7 +205,7 @@ test = [ "adlfs~=2023.1", "behave==1.2.6", "biopython~=1.73", - "cloudpickle<=2.0.0", + "cloudpickle~=2.2.1", "compress-pickle[lz4]~=2.1.0", "coverage>=7.2.0", "dask[complete]>=2021.10", @@ -250,7 +250,7 @@ test = [ "requests-mock~=1.6", "requests~=2.20", "s3fs>=2021.04", - "snowflake-snowpark-python~=1.0; python_version < '3.11'", + "snowflake-snowpark-python>=1.23; python_version < '3.12'", "scikit-learn>=1.0.2,<2", "scipy>=1.7.3", "packaging", @@ -320,7 +320,7 @@ version = {attr = "kedro_datasets.__version__"} fail_under = 100 show_missing = true # temporarily ignore kedro_datasets/__init__.py in coverage report -omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*", "kedro_datasets/__init__.py", "kedro_datasets/conftest.py"] +omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/tensorflow/*", "kedro_datasets/snowflake/*", "kedro_datasets/__init__.py", "kedro_datasets/conftest.py"] exclude_also = ["raise NotImplementedError", "if TYPE_CHECKING:"] [tool.pytest.ini_options] diff --git a/kedro-datasets/tests/snowflake/README.md b/kedro-datasets/tests/snowflake/README.md index 69fde3fd9..bd14c5de5 100644 --- a/kedro-datasets/tests/snowflake/README.md +++ b/kedro-datasets/tests/snowflake/README.md @@ -1,34 +1,4 @@ -# Snowpark connector testing +# Snowpark Testing Omitted +As of October 2024, the Snowpark connector is compatible with Python versions 3.9, 3.10, and 3.11. Python 3.12 is not supported yet. -Execution of automated tests for Snowpark connector requires real Snowflake instance access. Therefore tests located in this folder are **disabled** by default from pytest execution scope using [conftest.py](conftest.py). - -[Makefile](/Makefile) provides separate argument ``test-snowflake-only`` to run only tests related to Snowpark connector. To run tests one need to provide Snowflake connection parameters via environment variables: -* SNOWSQL_ACCOUNT - Snowflake account name with region. Ex `ab12345.eu-central-2` -* SNOWSQL_WAREHOUSE - Snowflake virtual warehouse to use -* SNOWSQL_DATABASE - Database to use -* SNOWSQL_SCHEMA - Schema to use when creating tables for tests -* SNOWSQL_ROLE - Role to use for connection -* SNOWSQL_USER - Username to use for connection -* SNOWSQL_PWD - Plain password to use for connection - -All environment variables need to be provided for tests to run. - -Here is example shell command to run snowpark tests via make utility: -```bash -SNOWSQL_ACCOUNT='ab12345.eu-central-2' SNOWSQL_WAREHOUSE='DEV_WH' SNOWSQL_DATABASE='DEV_DB' SNOWSQL_ROLE='DEV_ROLE' SNOWSQL_USER='DEV_USER' SNOWSQL_SCHEMA='DATA' SNOWSQL_PWD='supersecret' make test-snowflake-only -``` - -Currently running tests supports only simple username & password authentication and not SSO/MFA. - -As of Mar-2023, the snowpark connector only works with Python 3.8. - -## Snowflake permissions required -Credentials provided via environment variables should have following permissions granted to run tests successfully: -* Create tables in a given schema -* Drop tables in a given schema -* Insert rows into tables in a given schema -* Query tables in a given schema -* Query `INFORMATION_SCHEMA.TABLES` of respective database - -## Extending tests -Contributors adding new tests should add `@pytest.mark.snowflake` decorator to each test. Exclusion of Snowpark-related pytests from overall execution scope in [conftest.py](conftest.py) works based on markers. +Currently, the build process of kedro-datasets does not support testing different Python versions for each dataset. Additionally, each dataset test is required to have 100% coverage. Due to these constraints, the kedro-datasets/snowflake folder is excluded from pytest's coverage report. diff --git a/kedro-datasets/tests/snowflake/conftest.py b/kedro-datasets/tests/snowflake/conftest.py deleted file mode 100644 index 704ecf482..000000000 --- a/kedro-datasets/tests/snowflake/conftest.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -We disable execution of tests that require real Snowflake instance -to run by default. Providing -m snowflake option explicitly to -pytest will make these and only these tests run -""" - -import pytest - - -def pytest_collection_modifyitems(config, items): - markers_arg = config.getoption("-m") - - # Naive implementation to handle basic marker expressions - # Will not work if someone will (ever) run pytest with complex marker - # expressions like "-m spark and not (snowflake or pandas)" - if ( - "snowflake" in markers_arg.lower() - and "not snowflake" not in markers_arg.lower() - ): - return - - skip_snowflake = pytest.mark.skip(reason="need -m snowflake option to run") - for item in items: - if "snowflake" in item.keywords: - item.add_marker(skip_snowflake) diff --git a/kedro-datasets/tests/snowflake/test_snowpark_dataset.py b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py index 09e3cfb9e..365f0aafe 100644 --- a/kedro-datasets/tests/snowflake/test_snowpark_dataset.py +++ b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py @@ -1,166 +1,402 @@ +# ruff: noqa: E402 import datetime -import os +from unittest.mock import MagicMock, patch import pytest -from kedro.io.core import DatasetError -try: - import snowflake.snowpark as sp +snowflake = pytest.importorskip("snowflake") - from kedro_datasets.snowflake import SnowparkTableDataset as spds -except ImportError: - pass # this is only for test discovery to succeed on Python <> 3.8 +import pandas as pd +import pytest +from kedro.io.core import DatasetError +from snowflake.snowpark import DataFrame, Session +from snowflake.snowpark.types import ( + DateType, + FloatType, + IntegerType, + StringType, + StructField, + StructType, + TimestampType, +) + +from kedro_datasets.snowflake.snowpark_dataset import SnowparkTableDataset + +# example dummy configuration for local testing +DUMMY_CREDENTIALS = { + "account": "DUMMY_ACCOUNT", + "warehouse": "DUMMY_WAREHOUSE", + "database": "DUMMY_DATABASE", + "schema": "DUMMY_SCHEMA", + "user": "DUMMY_USER", + "password": "DUMMY_PASSWORD", +} + + +@pytest.fixture(scope="module") +def local_snowpark_session() -> Session: + """ + Creates a local Snowflake session for testing purposes. + See + + Returns: + Session: Snowflake session object configured for local testing. + """ + return Session.builder.config("local_testing", True).create() + + +@pytest.fixture(scope="module") +def snowflake_dataset(local_snowpark_session: Session) -> SnowparkTableDataset: + """ + Provides a SnowparkTableDataset fixture for testing. + + Args: + snowflake_session (Session): The Snowflake session used for this dataset. + + Returns: + SnowparkTableDataset: Dataset configuration for a Snowflake table. + """ + return SnowparkTableDataset( + table_name="DUMMY_TABLE", + credentials=DUMMY_CREDENTIALS, + session=local_snowpark_session, + save_args={"mode": "overwrite"}, + ) -def get_connection(): - account = os.getenv("SNOWSQL_ACCOUNT") - warehouse = os.getenv("SNOWSQL_WAREHOUSE") - database = os.getenv("SNOWSQL_DATABASE") - role = os.getenv("SNOWSQL_ROLE") - user = os.getenv("SNOWSQL_USER") - schema = os.getenv("SNOWSQL_SCHEMA") - password = os.getenv("SNOWSQL_PWD") +@pytest.fixture(scope="module") +def sample_sp_df(local_snowpark_session: Session) -> DataFrame: + """ + Creates a sample Snowpark DataFrame for testing. - if not ( - account and warehouse and database and role and user and schema and password - ): - raise DatasetError( - "Snowflake connection environment variables provided not in full" - ) + Args: + snowflake_session (Session): Session to create the DataFrame. - conn = { - "account": account, - "warehouse": warehouse, - "database": database, - "role": role, - "user": user, - "schema": schema, - "password": password, - } - return conn - - -def sf_setup_db(sf_session): - # For table exists test - run_query(sf_session, 'CREATE TABLE KEDRO_PYTEST_TESTEXISTS ("name" VARCHAR)') - - # For load test - query = 'CREATE TABLE KEDRO_PYTEST_TESTLOAD ("name" VARCHAR\ - , "age" INTEGER\ - , "bday" date\ - , "height" float\ - , "insert_dttm" timestamp)' - run_query(sf_session, query) - - query = "INSERT INTO KEDRO_PYTEST_TESTLOAD VALUES ('John'\ - , 23\ - , to_date('1999-12-02','YYYY-MM-DD')\ - , 6.5\ - , to_timestamp_ntz('2022-12-02 13:20:01',\ - 'YYYY-MM-DD hh24:mi:ss'))" - run_query(sf_session, query) - - query = "INSERT INTO KEDRO_PYTEST_TESTLOAD VALUES ('Jane'\ - , 41\ - , to_date('1981-01-03','YYYY-MM-DD')\ - , 5.7\ - , to_timestamp_ntz('2022-12-02 13:21:11',\ - 'YYYY-MM-DD hh24:mi:ss'))" - run_query(sf_session, query) - - -def sf_db_cleanup(sf_session): - run_query(sf_session, "DROP TABLE IF EXISTS KEDRO_PYTEST_TESTEXISTS") - run_query(sf_session, "DROP TABLE IF EXISTS KEDRO_PYTEST_TESTLOAD") - run_query(sf_session, "DROP TABLE IF EXISTS KEDRO_PYTEST_TESTSAVE") - - -def run_query(session, query): - df = session.sql(query) - df.collect() - return df - - -def df_equals_ignore_dtype(df1, df2): - # Pytest will show respective stdout only if test fails - # this will help to debug what was exactly not matching right away - - c1 = df1.to_pandas().values.tolist() - c2 = df2.to_pandas().values.tolist() - - print(c1) - print("--- comparing to ---") - print(c2) - - for i, row in enumerate(c1): - for j, column in enumerate(row): - if not column == c2[i][j]: - print(f"{column} not equal to {c2[i][j]}") - return False - return True - - -@pytest.fixture -def sample_sp_df(sf_session): - return sf_session.create_dataframe( + Returns: + snowpark.DataFrame: DataFrame with sample data and schema. + """ + return local_snowpark_session.create_dataframe( [ - [ + ( "John", 23, datetime.date(1999, 12, 2), 6.5, datetime.datetime(2022, 12, 2, 13, 20, 1), - ], - [ + ), + ( "Jane", 41, datetime.date(1981, 1, 3), 5.7, datetime.datetime(2022, 12, 2, 13, 21, 11), - ], + ), ], - schema=["name", "age", "bday", "height", "insert_dttm"], + schema=StructType( + [ + StructField("name", StringType()), + StructField("age", IntegerType()), + StructField("bday", DateType()), + StructField("height", FloatType()), + StructField("insert_dttm", TimestampType()), + ] + ), ) -@pytest.fixture -def sf_session(): - sf_session = sp.Session.builder.configs(get_connection()).create() +@pytest.fixture(scope="module") +def sample_pd_df() -> pd.DataFrame: + """ + Creates a sample Pandas DataFrame for testing. + + Returns: + pd.DataFrame: DataFrame with sample data. + """ + return pd.DataFrame( + { + "name": ["Alice", "Bob"], + "age": [30, 40], + "bday": [datetime.date(1993, 1, 1), datetime.date(1983, 2, 2)], + "height": [5.5, 6.0], + "insert_dttm": [ + datetime.datetime(2023, 1, 1, 10, 0), + datetime.datetime(2023, 1, 1, 12, 0), + ], + } + ) - # Running cleanup in case previous run was interrupted w/o proper cleanup - sf_db_cleanup(sf_session) - sf_setup_db(sf_session) - yield sf_session - sf_db_cleanup(sf_session) - sf_session.close() +class TestSnowparkTableDataset: + """Tests for the SnowparkTableDataset functionality.""" + + def test_save_with_snowpark( + self, sample_sp_df: DataFrame, snowflake_dataset: SnowparkTableDataset + ) -> None: + """Tests saving a Snowpark DataFrame to a Snowflake table. + + Args: + sample_sp_df (snowpark.DataFrame): Sample data to save. + snowflake_dataset (SnowparkTableDataset): Dataset to test. + + Asserts: + The count of the loaded DataFrame matches the saved DataFrame. + """ + snowflake_dataset.save(sample_sp_df) + loaded_df = snowflake_dataset.load() + assert loaded_df.count() == sample_sp_df.count() + + def test_save_with_pandas( + self, sample_pd_df: pd.DataFrame, snowflake_dataset: SnowparkTableDataset + ) -> None: + """ + Tests saving a Pandas DataFrame to a Snowflake table. + + Args: + sample_pd_df (pd.DataFrame): Sample data to save. + snowflake_dataset (SnowparkTableDataset): Dataset to test. + + Asserts: + The count of the loaded DataFrame matches the number of rows in the Pandas DataFrame. + """ + snowflake_dataset.save(sample_pd_df) + loaded_df = snowflake_dataset.load() + assert loaded_df.count() == len(sample_pd_df) + + def test_save_invalid_data(self, snowflake_dataset): + """ + Test the `save` method of `SnowparkTableDataset` with invalid data. + + This test ensures that the `save` method raises a `DatasetError` when provided with + data that is neither a Pandas DataFrame nor a Snowpark DataFrame. + + Args: + snowflake_dataset (SnowparkTableDataset): Instance of the dataset being tested. + + Asserts: + A `DatasetError` is raised with the appropriate error message. + """ + invalid_data = {"name": "Alice", "age": 30} + + with pytest.raises( + DatasetError, + match="Data of type is not supported for saving.", + ): + snowflake_dataset.save(invalid_data) + + def test_load( + self, snowflake_dataset: SnowparkTableDataset, sample_sp_df: DataFrame + ) -> None: + """ + Tests loading data from a Snowflake table. + + Args: + snowflake_dataset (SnowparkTableDataset): Dataset to load data from. + sample_sp_df (snowpark.DataFrame): Sample data for reference. + + Asserts: + The count of the loaded DataFrame matches the reference sample DataFrame. + """ + loaded_df = snowflake_dataset.load() + assert loaded_df.count() == sample_sp_df.count() + + def test_exists(self, snowflake_dataset: SnowparkTableDataset) -> None: + """ + Tests if a Snowflake table exists. + + Args: + snowflake_dataset (SnowparkTableDataset): Dataset to check existence. + + Asserts: + The dataset table exists in the Snowflake environment. + """ + exists = snowflake_dataset._exists() + assert exists + + def test_not_exists(self, snowflake_dataset: SnowparkTableDataset) -> None: + """ + Tests if a non-existent Snowflake table is detected. + Args: + snowflake_dataset (SnowparkTableDataset): Dataset to check existence. + + Asserts: + The dataset table does not exist in the Snowflake environment. + """ + snowflake_dataset._table_name = "NON_EXISTENT_TABLE" + exists = snowflake_dataset._exists() + assert not exists + + def test_get_session(self, snowflake_dataset: SnowparkTableDataset) -> None: + """ + Tests getting the Snowflake session from the dataset. + + Args: + snowflake_dataset (SnowparkTableDataset): Dataset to get the session from. + + Asserts: + The session is the same as the one used to create the dataset. + """ + assert ( + snowflake_dataset._get_session(snowflake_dataset._connection_parameters) + == snowflake_dataset._session + ) + def test_missing_table_name(self): + with pytest.raises( + DatasetError, match="'table_name' argument cannot be empty." + ): + SnowparkTableDataset(table_name="", credentials=DUMMY_CREDENTIALS) + + def test_missing_credentials(self): + with pytest.raises( + DatasetError, match="'credentials' argument cannot be empty." + ): + SnowparkTableDataset(table_name="weather_data", credentials=None) + + def test_missing_database_in_both_parameters_and_credentials(self): + credentials = DUMMY_CREDENTIALS.copy() + credentials.pop("database") + with pytest.raises( + DatasetError, match="'database' must be provided by credentials or dataset." + ): + SnowparkTableDataset(table_name="DUMMY_TABLE", credentials=credentials) + + def test_missing_schema_in_both_parameters_and_credentials(self): + credentials = DUMMY_CREDENTIALS.copy() + credentials.pop("schema") + with pytest.raises( + DatasetError, match="'schema' must be provided by credentials or dataset." + ): + SnowparkTableDataset(table_name="DUMMY_TABLE", credentials=credentials) + + def test_validate_and_get_table_name_success(self, snowflake_dataset): + """ + Test that the `_validate_and_get_table_name` method returns the correct table name. + + This test calls the `_validate_and_get_table_name` method with a valid table name + and verifies that the method returns the correct table name. + + Args: + self: The test case instance. + snowflake_dataset: The dataset instance being tested. + + Asserts: + The method returns the correct table name. + """ + snowflake_dataset._table_name = "DUMMY_TABLE" + expected_table_name = "DUMMY_DATABASE.DUMMY_SCHEMA.DUMMY_TABLE" + + table_name = snowflake_dataset._validate_and_get_table_name() + assert table_name == expected_table_name + + @pytest.mark.parametrize( + "table_name, database, schema", + [ + ("", "DUMMY_DATABASE", "DUMMY_SCHEMA"), # Invalid table name (empty string) + ("DUMMY_TABLE", "", "DUMMY_SCHEMA"), # Invalid database (empty string) + ("DUMMY_TABLE", "DUMMY_DATABASE", ""), # Invalid schema (empty string) + (None, "DUMMY_DATABASE", "DUMMY_SCHEMA"), # Invalid table name (None) + ("DUMMY_TABLE", None, "DUMMY_SCHEMA"), # Invalid database (None) + ("DUMMY_TABLE", "DUMMY_DATABASE", None), # Invalid schema (None) + ], + ) + def test_validate_and_get_table_name_error( + self, snowflake_dataset, table_name, database, schema + ): + """ + Test that the `_validate_and_get_table_name` method raises an error for invalid table name, database, and schema. + + This test calls the `_validate_and_get_table_name` method with invalid table name, database, and schema + and verifies that the method raises a `DatasetError`. + + Args: + self: The test case instance. + snowflake_dataset: The dataset instance being tested. + table_name: The table name to test. + database: The database name to test. + schema: The schema name to test. + + Asserts: + A `DatasetError` is raised. + """ + snowflake_dataset._table_name = table_name + snowflake_dataset._database = database + snowflake_dataset._schema = schema + + with pytest.raises( + DatasetError, match="Database, schema or table name cannot be None or empty" + ): + snowflake_dataset._validate_and_get_table_name() + + def test_get_session_existing_session(self, mocker, snowflake_dataset): + """ + Test that `snowflake_dataset._get_session` returns the existing active session. + + Args: + mocker: A fixture for mocking objects. + snowflake_dataset: An instance of the Snowflake dataset. + + Asserts: + The session returned by `_get_session` is the same as the active session. + The `get_active_session` method is called exactly once. + """ + mock_active_session = MagicMock() + mock_get_active_session = mocker.patch( + "snowflake.snowpark.context.get_active_session", + return_value=mock_active_session, + ) -class TestSnowparkTableDataset: - @pytest.mark.snowflake - def test_save(self, sample_sp_df, sf_session): - sp_df = spds(table_name="KEDRO_PYTEST_TESTSAVE", credentials=get_connection()) - sp_df._save(sample_sp_df) - sp_df_saved = sf_session.table("KEDRO_PYTEST_TESTSAVE") - assert sp_df_saved.count() == 2 - - @pytest.mark.snowflake - def test_load(self, sample_sp_df, sf_session): - print(sf_session) - sp_df = spds( - table_name="KEDRO_PYTEST_TESTLOAD", credentials=get_connection() - )._load() - - # Ignoring dtypes as ex. age can be int8 vs int64 and pandas.compare - # fails on that - assert df_equals_ignore_dtype(sample_sp_df, sp_df) - - @pytest.mark.snowflake - def test_exists(self, sf_session): - print(sf_session) - df_e = spds(table_name="KEDRO_PYTEST_TESTEXISTS", credentials=get_connection()) - df_ne = spds( - table_name="KEDRO_PYTEST_TESTNEXISTS", credentials=get_connection() + session = snowflake_dataset._get_session( + snowflake_dataset._connection_parameters + ) + + assert session == mock_active_session + mock_get_active_session.assert_called_once() + + @patch("snowflake.snowpark.Session.builder") + def test_get_session_no_existing_session(self, mock_builder, snowflake_dataset): + """ + Test the `_get_session` method of `SnowparkTableDataset` when there is no existing session. + + This test ensures that the `_get_session` method correctly initializes a new session + using the Snowflake Snowpark `Session.builder` when there is no existing session. + + Args: + mock_builder (MagicMock): Mocked `Session.builder` object. + snowflake_dataset (SnowparkTableDataset): Instance of the dataset being tested. + + Steps: + 1. Close the existing session and set the private session attribute to `None`. + 2. Mock the `builder`, `configs`, and `create` methods to simulate session creation. + 3. Call the `_get_session` method with the dataset's connection parameters. + 4. Assert that the `configs` method was called once with the correct parameters. + 5. Assert that the `create` method was called once. + 6. Assert that the returned session is the mocked `create` instance. + + Asserts: + - `mock_builder.configs` is called once with the dataset's connection parameters. + - `mock_configs_instance.create` is called once. + - The returned session is the `mock_create_instance` object. + """ + snowflake_dataset._session.close() + snowflake_dataset._SnowparkTableDataset__session = ( + None # Accessing the mangled private attribute + ) + + # mock the builder, configs, and create methods since we cannot create a real session + mock_configs_instance = MagicMock() + mock_create_instance = MagicMock() + mock_builder.configs.return_value = mock_configs_instance + mock_configs_instance.create.return_value = mock_create_instance + + session = snowflake_dataset._get_session( + snowflake_dataset._connection_parameters ) - assert df_e._exists() - assert not df_ne._exists() + + # assert that each part of the chain was called correctly + mock_builder.configs.assert_called_once_with( + snowflake_dataset._connection_parameters + ) + mock_configs_instance.create.assert_called_once() + + # Assert the returned session is the mock_create_instance object + assert session == mock_create_instance