Skip to content

Commit

Permalink
update testing framework and docstrings
Browse files Browse the repository at this point in the history
Signed-off-by: tdhooghe <[email protected]>
  • Loading branch information
tdhooghe committed Oct 21, 2024
1 parent f31b8c3 commit 3c09845
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 201 deletions.
67 changes: 44 additions & 23 deletions kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from typing import Any, Union
from typing import Any

import pandas as pd
import snowflake.snowpark as sp
Expand Down Expand Up @@ -111,9 +111,11 @@ def __init__( # noqa: PLR0913
load_args: dict[str, Any] | None = None,
save_args: dict[str, Any] | None = None,
credentials: dict[str, Any] | None = None,
session: sp.Session | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``SnowparkTableDataset``.
"""
Creates a new instance of ``SnowparkTableDataset``.
Args:
table_name: The table name to load or save data to.
Expand Down Expand Up @@ -162,6 +164,7 @@ def __init__( # noqa: PLR0913
self._table_name = table_name
self._database = database
self._schema = schema
self.__session = session or self._get_session(credentials) # for testing

connection_parameters = credentials
connection_parameters.update(
Expand All @@ -180,7 +183,8 @@ def _describe(self) -> dict[str, Any]:

@staticmethod
def _get_session(connection_parameters) -> sp.Session:
"""Given a connection string, create singleton connection
"""
Given a connection string, create singleton connection
to be used across all instances of `SnowparkTableDataset` that
need to connect to the same source.
connection_parameters is a dictionary of any values
Expand Down Expand Up @@ -208,35 +212,52 @@ def _get_session(connection_parameters) -> sp.Session:

@property
def _session(self) -> sp.Session:
return self._get_session(self._connection_parameters)
"""
Retrieve or create a session.
Returns:
sp.Session: The current session associated with the object.
"""
if not self.__session:
self.__session = self._get_session(self._connection_parameters)
return self.__session

def load(self) -> sp.DataFrame:
table_name: list = [
self._database,
self._schema,
self._table_name,
]
"""
Load data from a specified database table.
Returns:
sp.DataFrame: The loaded data as a Snowpark DataFrame.
"""
table_name = [self._database, self._schema, self._table_name]
sp_df = self._session.table(".".join(table_name))
return sp_df

def save(self, data: Union[pd.DataFrame, sp.DataFrame) -> None:
def save(self, data: pd.DataFrame | sp.DataFrame) -> None:
"""
Save data to a specified database table.
Args:
data (pd.DataFrame | sp.DataFrame): The data to save.
"""
if isinstance(data, pd.DataFrame):
data = self._session.create_dataframe(data)

table_name = ".".join([self._database, self._schema, self._table_name])
table_name = [self._database, self._schema, self._table_name]
data.write.save_as_table(table_name, **self._save_args)

def _exists(self) -> bool:
session = self._session
query = "SELECT COUNT(*) FROM {database}.INFORMATION_SCHEMA.TABLES \
WHERE TABLE_SCHEMA = '{schema}' \
AND TABLE_NAME = '{table_name}'"
rows = session.sql(
query.format(
database=self._database,
schema=self._schema,
table_name=self._table_name,
)
).collect()
return rows[0][0] == 1
"""
Check if a specified table exists in the database.
Returns:
bool: True if the table exists, False otherwise.
"""
try:
self._session.table(
f"{self._database}.{self._schema}.{self._table_name}"
).show()
return True
except Exception as e:
logger.debug(f"Table {self._table_name} does not exist: {e}")
return False
25 changes: 0 additions & 25 deletions kedro-datasets/tests/snowflake/conftest.py

This file was deleted.

Loading

0 comments on commit 3c09845

Please sign in to comment.