Skip to content

Commit

Permalink
Update the failing functions & test-cases to throw exceptions for Spa…
Browse files Browse the repository at this point in the history
…rk-Connect <3.5.2
  • Loading branch information
nijanthanvijayakumar committed Aug 5, 2024
1 parent d12ab07 commit 55bc170
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 12 deletions.
6 changes: 6 additions & 0 deletions quinn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from pyspark.sql.functions import udf


import os
import sys
import uuid
from typing import Any

Expand Down Expand Up @@ -196,6 +198,10 @@ def array_choice(col: Column, seed: int | None = None) -> Column:
:return: random element from the given column
:rtype: Column
"""

if sys.modules["pyspark"].__version__ < "3.5.2" and os.getenv("SPARK_CONNECT_MODE_ENABLED"):
raise Exception("array_choice is not supported on Spark-Connect mode for Spark versions < 3.5.2")

index = (F.rand(seed) * F.size(col)).cast("int")
return col[index]

Expand Down
4 changes: 4 additions & 0 deletions quinn/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from __future__ import annotations

import re
import os
import sys
from collections.abc import Callable

from pyspark.sql import DataFrame, SparkSession
Expand Down Expand Up @@ -113,6 +115,8 @@ def sort_columns( # noqa: C901,PLR0915
:return: A DataFrame with the columns sorted in the chosen order
:rtype: pyspark.sql.DataFrame
"""
if sys.modules["pyspark"].__version__ < "3.5.2" and os.getenv("SPARK_CONNECT_MODE_ENABLED"):
raise Exception("sort_columns is not supported on Spark-Connect mode for Spark versions < 3.5.2")

def sort_nested_cols(
schema: StructType,
Expand Down
27 changes: 20 additions & 7 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import pytest

import pyspark.sql.functions as F
Expand Down Expand Up @@ -171,8 +173,8 @@ def it_errors_out_if_with_invalid_week_start_date():
"week_start_date", quinn.week_start_date(F.col("some_date"), "hello")
)
assert (
excinfo.value.args[0]
== "The day you entered 'hello' is not valid. Here are the valid days: [Mon,Tue,Wed,Thu,Fri,Sat,Sun]"
excinfo.value.args[0]
== "The day you entered 'hello' is not valid. Here are the valid days: [Mon,Tue,Wed,Thu,Fri,Sat,Sun]"
)


Expand Down Expand Up @@ -228,8 +230,8 @@ def it_errors_out_if_with_invalid_week_end_date():
"week_start_date", quinn.week_end_date(F.col("some_date"), "Friday")
)
assert (
excinfo.value.args[0]
== "The day you entered 'Friday' is not valid. Here are the valid days: [Mon,Tue,Wed,Thu,Fri,Sat,Sun]"
excinfo.value.args[0]
== "The day you entered 'Friday' is not valid. Here are the valid days: [Mon,Tue,Wed,Thu,Fri,Sat,Sun]"
)


Expand Down Expand Up @@ -281,13 +283,23 @@ def it_works_with_integer_values():

# TODO: Figure out how to make this test deterministic locally & on CI
def test_array_choice():
df = quinn.create_df(spark,
# Create the DataFrame so that it can be passed to the if & else blocks
df = quinn.create_df(
spark,
[(["a", "b", "c"], "c"), (["a", "b", "c", "d"], "a"), (["x"], "x"), ([None], None)],
[("letters", ArrayType(StringType(), True), True), ("expected", StringType(), True)],
)
actual_df = df.withColumn("random_letter", quinn.array_choice(F.col("letters"), 42))
# chispa.assert_column_equality(actual_df, "random_letter", "expected")

# Check if the SPARK_CONNECT_MODE_ENABLED environment variable is set and if the Spark version is less than 3.5.2.
# If so check for the exception and if not, run the test.
spark_version = spark.version
if spark_version < "3.5.2" and os.getenv("SPARK_CONNECT_MODE_ENABLED"):
with pytest.raises(Exception) as excinfo:
df.withColumn("random_letter", quinn.array_choice(F.col("letters"), 42))
assert excinfo.value.args[0] == "array_choice is not supported on Spark-Connect mode for Spark versions < 3.5.2"
else:
actual_df = df.withColumn("random_letter", quinn.array_choice(F.col("letters"), 42))
# chispa.assert_column_equality(actual_df, "random_letter", "expected")


def test_business_days_between():
Expand Down Expand Up @@ -361,6 +373,7 @@ def test_with_extra_string():
)
chispa.assert_column_equality(actual_df, "uuid5_of_s1", "expected")


def test_is_falsy():
source_df = quinn.create_df(
spark,
Expand Down
16 changes: 11 additions & 5 deletions tests/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os

import pytest
import chispa
import quinn
Expand Down Expand Up @@ -535,11 +537,15 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]:
return unsorted_df, expected_df

unsorted_df, expected_df = _get_test_dataframes()
sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested=True)

chispa.schema_comparer.assert_schema_equality(
sorted_df.schema, expected_df.schema, ignore_nullable
)
if spark.version < "3.5.2" and os.getenv("SPARK_CONNECT_MODE_ENABLED"):
with pytest.raises(Exception) as excinfo:
quinn.sort_columns(unsorted_df, "asc", sort_nested=True)
assert str(excinfo.value) == "sort_columns is not supported on Spark-Connect mode for Spark versions < 3.5.2"
else:
sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested=True)
chispa.schema_comparer.assert_schema_equality(
sorted_df.schema, expected_df.schema, ignore_nullable
)


def _test_sort_struct_nested_with_arraytypes_desc(spark, ignore_nullable: bool):
Expand Down

0 comments on commit 55bc170

Please sign in to comment.