From a9ea6a88253d0f1c6c1b5a714745d3c314fda6c8 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Wed, 4 Oct 2023 22:05:03 -0400 Subject: [PATCH 01/25] first test attempt - broken --- tests/test_transformations.py | 50 ++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index e71dd275..f5fae7de 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -1,5 +1,5 @@ import pytest -from pyspark.sql.types import StructType, StructField, StringType +from pyspark.sql.types import StructType, StructField, StringType, ArrayType, IntegerType import quinn from tests.conftest import auto_inject_fixtures @@ -222,3 +222,51 @@ def it_throws_an_error_if_the_sort_order_is_invalid(spark): excinfo.value.args[0] == "['asc', 'desc'] are the only valid sort orders and you entered a sort order of 'cats'" ) + +def test_sort_struct(spark): + # create a schema including an array of structs + unsorted_fields = StructType( + [ + StructField("b", IntegerType()), + StructField("a", ArrayType(StructType([ + StructField("d", IntegerType()), + StructField("e", IntegerType()), + StructField("c", IntegerType()), + ]))), + ] + ) + + sorted_fields = StructType( + [ + StructField("a", ArrayType(StructType([ + StructField("d", IntegerType()), + StructField("e", IntegerType()), + StructField("c", IntegerType()), + ]))), + StructField("b", IntegerType()), + ] + ) + + unsorted_data = [ + (1, [(1, 2, 3), (4, 5, 6)]), + (2, [(7, 8, 9), (10, 11, 12)]), + ] + + sorted_data = [ + ([(1, 2, 3), (4, 5, 6)], 1), + ([(7, 8, 9), (10, 11, 12)], 2), + ] + + unsorted_df = spark.createDataFrame(unsorted_data, unsorted_fields) + expected_df = spark.createDataFrame(sorted_data, sorted_fields) + + sorted_df = quinn.sort_columns(unsorted_df, 'asc') + + # TODO: doesn't work b/c of nested structs + chispa.schema_comparer.assert_schema_equality(sorted_df, expected_df) + + +# create a local spark session +from pyspark.sql import SparkSession +spark = SparkSession.builder.appName('abc').getOrCreate() +test_sort_struct(spark) \ No newline at end of file From 3491909f3e0b678cd30776877d1275fa040fba78 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 12:11:54 -0400 Subject: [PATCH 02/25] create test data in function, add double nested column --- tests/test_transformations.py | 106 +++++++++++++++++++++++----------- 1 file changed, 73 insertions(+), 33 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index f5fae7de..8e651a27 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -224,43 +224,83 @@ def it_throws_an_error_if_the_sort_order_is_invalid(spark): ) def test_sort_struct(spark): - # create a schema including an array of structs - unsorted_fields = StructType( - [ - StructField("b", IntegerType()), - StructField("a", ArrayType(StructType([ - StructField("d", IntegerType()), - StructField("e", IntegerType()), - StructField("c", IntegerType()), - ]))), - ] - ) - - sorted_fields = StructType( - [ - StructField("a", ArrayType(StructType([ - StructField("d", IntegerType()), - StructField("e", IntegerType()), - StructField("c", IntegerType()), - ]))), - StructField("b", IntegerType()), - ] - ) + def _create_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + unsorted_fields = StructType( + [ + StructField("b", IntegerType()), + StructField( + "c", + ArrayType( + ArrayType( + StructType( + [ + StructField("g", IntegerType()), + StructField("f", IntegerType()), + ] + ) + ) + ), + ), + StructField( + "a", + ArrayType( + StructType( + [ + StructField("d", IntegerType()), + StructField("e", IntegerType()), + StructField("c", IntegerType()), + ] + ) + ), + ), + ] + ) + sorted_fields = StructType( + [ + StructField( + "a", + ArrayType( + StructType( + [ + StructField("c", IntegerType()), + StructField("e", IntegerType()), + StructField("d", IntegerType()), + ] + ) + ), + ), + StructField("b", IntegerType()), + StructField( + "c", + ArrayType( + ArrayType( + StructType( + [ + StructField("f", IntegerType()), + StructField("g", IntegerType()), + ] + ) + ) + ), + ), + ] + ) - unsorted_data = [ - (1, [(1, 2, 3), (4, 5, 6)]), - (2, [(7, 8, 9), (10, 11, 12)]), - ] + col_a = [(2, 3, 4)] + col_b = 1 + col_c = [[(5, 6)]] - sorted_data = [ - ([(1, 2, 3), (4, 5, 6)], 1), - ([(7, 8, 9), (10, 11, 12)], 2), - ] + unsorted_data = [ + (col_b, col_c, col_a), + ] + sorted_data = [ + (col_a, col_b, col_c), + ] - unsorted_df = spark.createDataFrame(unsorted_data, unsorted_fields) - expected_df = spark.createDataFrame(sorted_data, sorted_fields) + unsorted_df = spark.createDataFrame(unsorted_data, unsorted_fields) + expected_df = spark.createDataFrame(sorted_data, sorted_fields) - sorted_df = quinn.sort_columns(unsorted_df, 'asc') + return unsorted_df, expected_df # TODO: doesn't work b/c of nested structs chispa.schema_comparer.assert_schema_equality(sorted_df, expected_df) From 135156cc2d3ecbeaff6a63327f0dc68763cf70d7 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 12:43:46 -0400 Subject: [PATCH 03/25] add simple test to ensure original functionality still works --- tests/test_transformations.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 8e651a27..9fcf980e 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -302,6 +302,37 @@ def _create_test_dataframes() -> tuple[(DataFrame, DataFrame)]: return unsorted_df, expected_df + def _create_simple_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + unsorted_fields = StructType( + [ + StructField("b", IntegerType()), + StructField("c", IntegerType()), + StructField("a", IntegerType()), + ] + ) + sorted_fields = StructType( + [ + StructField("a", IntegerType()), + StructField("b", IntegerType()), + StructField("c", IntegerType()), + ] + ) + + col_a = 1 + col_b = 2 + col_c = 3 + + unsorted_data = [ + (col_b, col_c, col_a), + ] + sorted_data = [ + (col_a, col_b, col_c), + ] + + unsorted_df = spark.createDataFrame(unsorted_data, unsorted_fields) + expected_df = spark.createDataFrame(sorted_data, sorted_fields) + + return unsorted_df, expected_df # TODO: doesn't work b/c of nested structs chispa.schema_comparer.assert_schema_equality(sorted_df, expected_df) From 782c6e875d7fc89587705a4dd77306088aa0de65 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 15:08:20 -0400 Subject: [PATCH 04/25] add working implementation for flat and nested StructType schemas --- quinn/transformations.py | 91 ++++++++++++++++++++++++++++++++++------ 1 file changed, 79 insertions(+), 12 deletions(-) diff --git a/quinn/transformations.py b/quinn/transformations.py index 2f32f779..ab6186f8 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -2,6 +2,7 @@ import pyspark.sql.functions as F from pyspark.sql import DataFrame +from pyspark.sql.types import ArrayType, StructField, StructType def with_columns_renamed(fun: Callable[[str], str]) -> Callable[[DataFrame], DataFrame]: @@ -94,21 +95,87 @@ def sort_columns(df: DataFrame, sort_order: str) -> DataFrame: the ``sort_order`` parameter, a ``ValueError`` will be raised. :param df: A DataFrame - :type df: pandas.DataFrame + :type df: pyspark.sql.DataFrame :param sort_order: The order in which to sort the columns in the DataFrame :type sort_order: str :return: A DataFrame with the columns sorted in the chosen order - :rtype: pandas.DataFrame + :rtype: pyspark.sql.DataFrame """ - sorted_col_names = None - if sort_order == "asc": - sorted_col_names = sorted(df.columns) - elif sort_order == "desc": - sorted_col_names = sorted(df.columns, reverse=True) - else: - raise ValueError( - "['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format( - sort_order=sort_order + + def parse_sort_order(sort_order: str) -> bool: + if sort_order not in ["asc", "desc"]: + raise ValueError( + "['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format( + sort_order=sort_order + ) ) + reverse_lookup = { + "asc": False, + "desc": True, + } + return reverse_lookup[sort_order] + + def sort_top_level_cols(schema, is_reversed) -> dict: + # sort top level columns + top_sorted_fields: list = sorted( + schema.fields, key=lambda x: x.name, reverse=is_reversed + ) + + is_nested: bool = any( + [ + isinstance(i.dataType, StructType) or isinstance(i.dataType, ArrayType) + for i in top_sorted_fields + ] ) - return df.select(*sorted_col_names) + + output = { + "schema": top_sorted_fields, + "is_nested": is_nested, + } + + return output + + def sort_nested_cols(schema, is_reversed, baseField="") -> list: + # TODO: get working with ArrayType + # recursively check nested fields and sort them + # https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark + # Credits: @pault for logic + + select_cols = [] + for structField in sorted(schema, key=lambda x: x.name, reverse=is_reversed): + if isinstance(structField.dataType, StructType): + subFields = [] + for fld in sorted( + structField.jsonValue()["type"]["fields"], + key=lambda x: x["name"], + reverse=is_reversed, + ): + newStruct = StructType([StructField.fromJson(fld)]) + newBaseField = structField.name + if baseField: + newBaseField = baseField + "." + newBaseField + subFields.extend( + sort_nested_cols(newStruct, is_reversed, baseField=newBaseField) + ) + + select_cols.append( + "struct(" + ",".join(subFields) + ") AS {}".format(structField.name) + ) + else: + if baseField: + select_cols.append(baseField + "." + structField.name) + else: + select_cols.append(structField.name) + return select_cols + + is_reversed: bool = parse_sort_order(sort_order) + top_sorted_schema_results: dict = sort_top_level_cols(df.schema, is_reversed) + if not top_sorted_schema_results["is_nested"]: + columns: list = [i.name for i in top_sorted_schema_results["schema"]] + return df.select(*columns) + + fully_sorted_schema = sort_nested_cols( + top_sorted_schema_results["schema"], is_reversed + ) + + return df.selectExpr(fully_sorted_schema) From 983d6a593ffdf4dc7617da945e6fb392d3f1b4b2 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 15:16:36 -0400 Subject: [PATCH 05/25] separate tests into flat and nested struct, confirm tests work --- tests/test_transformations.py | 185 ++++++++++++++++++++-------------- 1 file changed, 108 insertions(+), 77 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 9fcf980e..8c988d32 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -1,9 +1,17 @@ import pytest -from pyspark.sql.types import StructType, StructField, StringType, ArrayType, IntegerType +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + ArrayType, + IntegerType, +) import quinn +from pyspark.sql import DataFrame from tests.conftest import auto_inject_fixtures import chispa +import chispa.schema_comparer @auto_inject_fixtures("spark") @@ -223,72 +231,27 @@ def it_throws_an_error_if_the_sort_order_is_invalid(spark): == "['asc', 'desc'] are the only valid sort orders and you entered a sort order of 'cats'" ) -def test_sort_struct(spark): - def _create_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + +def test_sort_struct_flat(spark): + def _get_simple_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_fields = StructType( [ StructField("b", IntegerType()), - StructField( - "c", - ArrayType( - ArrayType( - StructType( - [ - StructField("g", IntegerType()), - StructField("f", IntegerType()), - ] - ) - ) - ), - ), - StructField( - "a", - ArrayType( - StructType( - [ - StructField("d", IntegerType()), - StructField("e", IntegerType()), - StructField("c", IntegerType()), - ] - ) - ), - ), + StructField("c", IntegerType()), + StructField("a", IntegerType()), ] ) sorted_fields = StructType( [ - StructField( - "a", - ArrayType( - StructType( - [ - StructField("c", IntegerType()), - StructField("e", IntegerType()), - StructField("d", IntegerType()), - ] - ) - ), - ), + StructField("a", IntegerType()), StructField("b", IntegerType()), - StructField( - "c", - ArrayType( - ArrayType( - StructType( - [ - StructField("f", IntegerType()), - StructField("g", IntegerType()), - ] - ) - ) - ), - ), + StructField("c", IntegerType()), ] ) - col_a = [(2, 3, 4)] - col_b = 1 - col_c = [[(5, 6)]] + col_a = 1 + col_b = 2 + col_c = 3 unsorted_data = [ (col_b, col_c, col_a), @@ -302,42 +265,110 @@ def _create_test_dataframes() -> tuple[(DataFrame, DataFrame)]: return unsorted_df, expected_df - def _create_simple_test_dataframes() -> tuple[(DataFrame, DataFrame)]: - unsorted_fields = StructType( + unsorted_df, expected_df = _get_simple_test_dataframes() + + unsorted_df.printSchema() + sorted_df = quinn.sort_columns(unsorted_df, "asc") + sorted_df.printSchema() + expected_df.printSchema() + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable=True + ) + + +def test_sort_struct_nested(spark): + def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + unsorted_schema = StructType( [ - StructField("b", IntegerType()), - StructField("c", IntegerType()), - StructField("a", IntegerType()), + StructField("_id", StringType(), nullable=False), + StructField("first_name", StringType(), nullable=False), + StructField( + "address", + StructType( + [ + StructField( + "zip", + StructType( + [ + StructField( + "last4", IntegerType(), nullable=True + ), + StructField( + "first5", IntegerType(), nullable=True + ), + ] + ), + nullable=True, + ), + StructField("city", StringType(), nullable=True), + ] + ), + nullable=True, + ), ] ) - sorted_fields = StructType( + + sorted_schema = StructType( [ - StructField("a", IntegerType()), - StructField("b", IntegerType()), - StructField("c", IntegerType()), + StructField("_id", StringType(), nullable=False), + StructField( + "address", + StructType( + [ + StructField("city", StringType(), nullable=True), + StructField( + "zip", + StructType( + [ + StructField( + "first5", IntegerType(), nullable=True + ), + StructField( + "last4", IntegerType(), nullable=True + ), + ] + ), + nullable=True, + ), + ] + ), + nullable=True, + ), + StructField("first_name", StringType(), nullable=False), ] ) - col_a = 1 - col_b = 2 - col_c = 3 + _id = "12345" + city = "Fake City" + zip_first5 = 54321 + zip_last4 = 12345 + first_name = "John" unsorted_data = [ - (col_b, col_c, col_a), + (_id, first_name, (((zip_last4, zip_first5)), city)), ] sorted_data = [ - (col_a, col_b, col_c), + (_id, ((city, (zip_first5, zip_last4))), first_name), ] - unsorted_df = spark.createDataFrame(unsorted_data, unsorted_fields) - expected_df = spark.createDataFrame(sorted_data, sorted_fields) + unsorted_df = spark.createDataFrame(unsorted_data, unsorted_schema) + expected_df = spark.createDataFrame(sorted_data, sorted_schema) return unsorted_df, expected_df - # TODO: doesn't work b/c of nested structs - chispa.schema_comparer.assert_schema_equality(sorted_df, expected_df) + + unsorted_df, expected_df = _get_test_dataframes() + + unsorted_df.printSchema() + sorted_df = quinn.sort_columns(unsorted_df, "asc") + sorted_df.printSchema() + + # TODO: work on assert_schema_equality to handle nested structs + assert True + # assert_schema_equality(sorted_df, expected_df) # create a local spark session -from pyspark.sql import SparkSession -spark = SparkSession.builder.appName('abc').getOrCreate() -test_sort_struct(spark) \ No newline at end of file +# from pyspark.sql import SparkSession + +# spark = SparkSession.builder.appName("abc").getOrCreate() +# test_sort_struct(spark) From 60a2400df9c4c3f4514091acad1a201846703e94 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 15:30:53 -0400 Subject: [PATCH 06/25] add test to assert correct descending sort behavior --- tests/test_transformations.py | 65 +++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 19 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 8c988d32..c678b2c0 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -232,8 +232,12 @@ def it_throws_an_error_if_the_sort_order_is_invalid(spark): ) -def test_sort_struct_flat(spark): - def _get_simple_test_dataframes() -> tuple[(DataFrame, DataFrame)]: +def _test_sort_struct_flat(spark, sort_order: str): + def _get_simple_test_dataframes(sort_order) -> tuple[(DataFrame, DataFrame)]: + col_a = 1 + col_b = 2 + col_c = 3 + unsorted_fields = StructType( [ StructField("b", IntegerType()), @@ -241,34 +245,49 @@ def _get_simple_test_dataframes() -> tuple[(DataFrame, DataFrame)]: StructField("a", IntegerType()), ] ) - sorted_fields = StructType( - [ - StructField("a", IntegerType()), - StructField("b", IntegerType()), - StructField("c", IntegerType()), - ] - ) - - col_a = 1 - col_b = 2 - col_c = 3 - unsorted_data = [ (col_b, col_c, col_a), ] - sorted_data = [ - (col_a, col_b, col_c), - ] + if sort_order == "asc": + sorted_fields = StructType( + [ + StructField("a", IntegerType()), + StructField("b", IntegerType()), + StructField("c", IntegerType()), + ] + ) + + sorted_data = [ + (col_a, col_b, col_c), + ] + elif sort_order == "desc": + sorted_fields = StructType( + [ + StructField("c", IntegerType()), + StructField("b", IntegerType()), + StructField("a", IntegerType()), + ] + ) + + sorted_data = [ + (col_c, col_b, col_a), + ] + else: + raise ValueError( + "['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format( + sort_order=sort_order + ) + ) unsorted_df = spark.createDataFrame(unsorted_data, unsorted_fields) expected_df = spark.createDataFrame(sorted_data, sorted_fields) return unsorted_df, expected_df - unsorted_df, expected_df = _get_simple_test_dataframes() + unsorted_df, expected_df = _get_simple_test_dataframes(sort_order=sort_order) unsorted_df.printSchema() - sorted_df = quinn.sort_columns(unsorted_df, "asc") + sorted_df = quinn.sort_columns(unsorted_df, sort_order) sorted_df.printSchema() expected_df.printSchema() chispa.schema_comparer.assert_schema_equality( @@ -276,6 +295,14 @@ def _get_simple_test_dataframes() -> tuple[(DataFrame, DataFrame)]: ) +def test_sort_struct_flat(spark): + _test_sort_struct_flat(spark, "asc") + + +def test_sort_struct_flat_desc(spark): + _test_sort_struct_flat(spark, "desc") + + def test_sort_struct_nested(spark): def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_schema = StructType( From 1fe19e1356de9e241fd22e1a7067138929205a5d Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 15:38:13 -0400 Subject: [PATCH 07/25] fix assert_schema_equality call - actually doesn't need to be updated --- tests/test_transformations.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index c678b2c0..03cf2c73 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -389,13 +389,12 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: sorted_df = quinn.sort_columns(unsorted_df, "asc") sorted_df.printSchema() - # TODO: work on assert_schema_equality to handle nested structs - assert True - # assert_schema_equality(sorted_df, expected_df) + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable=True + ) -# create a local spark session # from pyspark.sql import SparkSession # spark = SparkSession.builder.appName("abc").getOrCreate() -# test_sort_struct(spark) +# test_sort_struct_nested(spark) From ba2b2e5215f188dd5074d9373d099c33a91daf4d Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 15:56:14 -0400 Subject: [PATCH 08/25] add descending check for nested struct --- tests/test_transformations.py | 95 +++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 03cf2c73..21db2cd4 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -394,6 +394,101 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: ) +def test_sort_struct_nested_desc(spark): + def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + unsorted_schema = StructType( + [ + StructField("_id", StringType(), nullable=False), + StructField("first_name", StringType(), nullable=False), + StructField( + "address", + StructType( + [ + StructField( + "zip", + StructType( + [ + StructField( + "last4", IntegerType(), nullable=True + ), + StructField( + "first5", IntegerType(), nullable=True + ), + ] + ), + nullable=True, + ), + StructField("city", StringType(), nullable=True), + ] + ), + nullable=True, + ), + ] + ) + + sorted_schema = StructType( + [ + StructField("first_name", StringType(), nullable=False), + StructField( + "address", + StructType( + [ + StructField( + "zip", + StructType( + [ + StructField( + "last4", IntegerType(), nullable=True + ), + StructField( + "first5", IntegerType(), nullable=True + ), + ] + ), + nullable=True, + ), + StructField("city", StringType(), nullable=True), + ] + ), + nullable=True, + ), + StructField("_id", StringType(), nullable=False), + ] + ) + + _id = "12345" + city = "Fake City" + zip_first5 = 54321 + zip_last4 = 12345 + first_name = "John" + + unsorted_data = [ + (_id, first_name, (((zip_last4, zip_first5)), city)), + ] + sorted_data = [ + ( + first_name, + ((zip_first5, zip_last4), city), + _id, + ), + ] + + unsorted_df = spark.createDataFrame(unsorted_data, unsorted_schema) + expected_df = spark.createDataFrame(sorted_data, sorted_schema) + + return unsorted_df, expected_df + + unsorted_df, expected_df = _get_test_dataframes() + + unsorted_df.printSchema() + sorted_df = quinn.sort_columns(unsorted_df, "desc") + sorted_df.printSchema() + + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable=True + ) + + # from pyspark.sql import SparkSession # spark = SparkSession.builder.appName("abc").getOrCreate() From 59a022799b05e6e65330d41b86c5ba615c76d013 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 21:30:18 -0400 Subject: [PATCH 09/25] add working arraytype implementation --- quinn/transformations.py | 36 +++++++++- tests/test_transformations.py | 129 +++++++++++++++++++++++++++++++++- 2 files changed, 163 insertions(+), 2 deletions(-) diff --git a/quinn/transformations.py b/quinn/transformations.py index ab6186f8..d4b382e2 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -143,7 +143,41 @@ def sort_nested_cols(schema, is_reversed, baseField="") -> list: select_cols = [] for structField in sorted(schema, key=lambda x: x.name, reverse=is_reversed): - if isinstance(structField.dataType, StructType): + field_type = structField.dataType + if isinstance(field_type, ArrayType): + array_elements = [] + array_parent = structField.jsonValue()["type"]["elementType"] + + base_str = f"transform({structField.name}" + suffix_str = f") AS {structField.name}" + if array_parent["type"] == "struct": + array_parent = array_parent["fields"] + + base_str = f"{base_str}, x -> struct(" + suffix_str = f"){suffix_str}" + + sorted_fields: list = sorted( + array_parent, + key=lambda x: x["name"], + reverse=is_reversed, + ) + for fld in sorted_fields: + print(fld) + newStruct = StructType([StructField.fromJson(fld)]) + newBaseField = structField.name + if baseField: + newBaseField = baseField + "." + newBaseField + array_elements.extend( + sort_nested_cols(newStruct, is_reversed, baseField=newBaseField) + ) + + element_names = [i.split(".")[-1] for i in array_elements] + array_elements_formatted = [f"x.{i} as {i}" for i in element_names] + select_cols.append( + base_str + ", ".join(array_elements_formatted) + suffix_str + ) + + elif isinstance(field_type, StructType): subFields = [] for fld in sorted( structField.jsonValue()["type"]["fields"], diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 21db2cd4..d3d75cca 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -489,7 +489,134 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: ) +def test_sort_struct_nested_with_arraytypes(spark): + def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + unsorted_schema = StructType( + [ + StructField( + "address", + StructType( + [ + StructField( + "zip", + StructType( + [ + StructField( + "first5", IntegerType(), nullable=True + ), + StructField( + "last4", IntegerType(), nullable=True + ), + ] + ), + nullable=False, + ), + StructField("city", StringType(), nullable=True), + ] + ), + nullable=False, + ), + StructField( + "phone_numbers", + ArrayType( + StructType( + [ + StructField("type", StringType(), nullable=True), + StructField("number", StringType(), nullable=True), + ] + ) + ), + nullable=True, + ), + StructField("_id", StringType(), nullable=True), + StructField("first_name", StringType(), nullable=True), + ] + ) + + sorted_schema = StructType( + [ + StructField("_id", StringType(), nullable=True), + StructField( + "address", + StructType( + [ + StructField("city", StringType(), nullable=True), + StructField( + "zip", + StructType( + [ + StructField( + "first5", IntegerType(), nullable=True + ), + StructField( + "last4", IntegerType(), nullable=True + ), + ] + ), + nullable=False, + ), + ] + ), + nullable=False, + ), + StructField("first_name", StringType(), nullable=True), + StructField( + "phone_numbers", + ArrayType( + StructType( + [ + StructField("number", StringType(), nullable=True), + StructField("type", StringType(), nullable=True), + ] + ) + ), + nullable=True, + ), + ] + ) + + _id = "12345" + city = "Fake City" + zip_first5 = 54321 + zip_last4 = 12345 + first_name = "John" + phone_type = "home" + phone_number = "555-555-5555" + + unsorted_data = [ + ( + (((zip_last4, zip_first5)), city), + [(phone_type, phone_number)], + _id, + first_name, + ), + ] + sorted_data = [ + ( + _id, + (city, ((zip_last4, zip_first5))), + first_name, + [(phone_type, phone_number)], + ), + ] + + unsorted_df = spark.createDataFrame(unsorted_data, unsorted_schema) + expected_df = spark.createDataFrame(sorted_data, sorted_schema) + + return unsorted_df, expected_df + + unsorted_df, expected_df = _get_test_dataframes() + + unsorted_df.printSchema() + sorted_df = quinn.sort_columns(unsorted_df, "asc") + sorted_df.printSchema() + + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable=True + ) + + # from pyspark.sql import SparkSession # spark = SparkSession.builder.appName("abc").getOrCreate() -# test_sort_struct_nested(spark) +# test_sort_struct_nested_with_arraytypes(spark) From f865265f99f6cdc4e381a052a777db9e456bba12 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 21:59:56 -0400 Subject: [PATCH 10/25] split up logic into smaller functions --- quinn/transformations.py | 122 +++++++++++++++++++++++---------------- 1 file changed, 71 insertions(+), 51 deletions(-) diff --git a/quinn/transformations.py b/quinn/transformations.py index d4b382e2..1929a0de 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -135,71 +135,91 @@ def sort_top_level_cols(schema, is_reversed) -> dict: return output - def sort_nested_cols(schema, is_reversed, baseField="") -> list: + def sort_nested_cols(schema, is_reversed, base_field="") -> list: # TODO: get working with ArrayType # recursively check nested fields and sort them # https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark # Credits: @pault for logic - select_cols = [] - for structField in sorted(schema, key=lambda x: x.name, reverse=is_reversed): - field_type = structField.dataType - if isinstance(field_type, ArrayType): - array_elements = [] - array_parent = structField.jsonValue()["type"]["elementType"] - - base_str = f"transform({structField.name}" - suffix_str = f") AS {structField.name}" - if array_parent["type"] == "struct": - array_parent = array_parent["fields"] + def parse_fields( + fields_to_sort: list, parent_struct, is_reversed: bool + ) -> list: + sorted_fields: list = sorted( + fields_to_sort, + key=lambda x: x["name"], + reverse=is_reversed, + ) - base_str = f"{base_str}, x -> struct(" - suffix_str = f"){suffix_str}" + results = [] + for field in sorted_fields: + new_struct = StructType([StructField.fromJson(field)]) + new_base_field = parent_struct.name + if base_field: + new_base_field = base_field + "." + new_base_field - sorted_fields: list = sorted( - array_parent, - key=lambda x: x["name"], - reverse=is_reversed, + results.extend( + sort_nested_cols(new_struct, is_reversed, base_field=new_base_field) ) - for fld in sorted_fields: - print(fld) - newStruct = StructType([StructField.fromJson(fld)]) - newBaseField = structField.name - if baseField: - newBaseField = baseField + "." + newBaseField - array_elements.extend( - sort_nested_cols(newStruct, is_reversed, baseField=newBaseField) - ) - - element_names = [i.split(".")[-1] for i in array_elements] + return results + + def handle_array_type(parent_struct: StructField, is_reversed: bool) -> str: + def format_array_selection( + elements: list, base_str: str, suffix_str: str + ) -> str: + element_names = [i.split(".")[-1] for i in elements] array_elements_formatted = [f"x.{i} as {i}" for i in element_names] - select_cols.append( - base_str + ", ".join(array_elements_formatted) + suffix_str + + output = ( + f"{base_str} {', '.join(array_elements_formatted)} {suffix_str}" ) + return output + + array_parent = parent_struct.jsonValue()["type"]["elementType"] + + base_str = f"transform({parent_struct.name}" + suffix_str = f") AS {parent_struct.name}" + + # if struct in array, create mapping to struct + # TODO: prob doesn't work with additional levels of nesting + if array_parent["type"] == "struct": + array_parent = array_parent["fields"] + + base_str = f"{base_str}, x -> struct(" + suffix_str = f"){suffix_str}" + + array_elements = parse_fields(array_parent, parent_struct, is_reversed) + formatted_array_selection = format_array_selection( + array_elements, base_str, suffix_str + ) + return formatted_array_selection + + def handle_struct_type(parent_struct: StructField, is_reversed: bool) -> str: + def format_struct_selection(elements: list, struct_name: str) -> str: + output: str = f"struct( {', '.join(elements)} ) AS {struct_name}" + return output + + field_list = parent_struct.jsonValue()["type"]["fields"] + sub_fields = parse_fields(field_list, parent_struct, is_reversed) + formatted_sub_fields = format_struct_selection( + sub_fields, parent_struct.name + ) + return formatted_sub_fields + + select_cols = [] + for parent_struct in sorted(schema, key=lambda x: x.name, reverse=is_reversed): + field_type = parent_struct.dataType + if isinstance(field_type, ArrayType): + result = handle_array_type(parent_struct, is_reversed) elif isinstance(field_type, StructType): - subFields = [] - for fld in sorted( - structField.jsonValue()["type"]["fields"], - key=lambda x: x["name"], - reverse=is_reversed, - ): - newStruct = StructType([StructField.fromJson(fld)]) - newBaseField = structField.name - if baseField: - newBaseField = baseField + "." + newBaseField - subFields.extend( - sort_nested_cols(newStruct, is_reversed, baseField=newBaseField) - ) - - select_cols.append( - "struct(" + ",".join(subFields) + ") AS {}".format(structField.name) - ) + result = handle_struct_type(parent_struct, is_reversed) else: - if baseField: - select_cols.append(baseField + "." + structField.name) + if base_field: + result = f"{base_field}.{parent_struct.name}" else: - select_cols.append(structField.name) + result = parent_struct.name + select_cols.append(result) + return select_cols is_reversed: bool = parse_sort_order(sort_order) From c2c0616bf4115ccda916347e106579e945519a0f Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 22:05:03 -0400 Subject: [PATCH 11/25] add bool flag to skip nested sorting by default --- quinn/transformations.py | 13 +++++++++++-- tests/test_transformations.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/quinn/transformations.py b/quinn/transformations.py index 1929a0de..149c4b44 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -88,7 +88,9 @@ def to_snake_case(s: str) -> str: return s.lower().replace(" ", "_") -def sort_columns(df: DataFrame, sort_order: str) -> DataFrame: +def sort_columns( + df: DataFrame, sort_order: str, sort_nested_structs: bool = False +) -> DataFrame: """This function sorts the columns of a given DataFrame based on a given sort order. The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to ascending and descending order, respectively. If any other value is provided for @@ -98,6 +100,8 @@ def sort_columns(df: DataFrame, sort_order: str) -> DataFrame: :type df: pyspark.sql.DataFrame :param sort_order: The order in which to sort the columns in the DataFrame :type sort_order: str + :param sort_nested_structs: Whether to sort nested structs or not. Defaults to false. + :type sort_nested_structs: bool :return: A DataFrame with the columns sorted in the chosen order :rtype: pyspark.sql.DataFrame """ @@ -224,7 +228,12 @@ def format_struct_selection(elements: list, struct_name: str) -> str: is_reversed: bool = parse_sort_order(sort_order) top_sorted_schema_results: dict = sort_top_level_cols(df.schema, is_reversed) - if not top_sorted_schema_results["is_nested"]: + + skip_nested_sorting = ( + not top_sorted_schema_results["is_nested"] or not sort_nested_structs + ) + + if skip_nested_sorting: columns: list = [i.name for i in top_sorted_schema_results["schema"]] return df.select(*columns) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index d3d75cca..e2b67173 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -386,7 +386,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_df, expected_df = _get_test_dataframes() unsorted_df.printSchema() - sorted_df = quinn.sort_columns(unsorted_df, "asc") + sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) sorted_df.printSchema() chispa.schema_comparer.assert_schema_equality( @@ -608,7 +608,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_df, expected_df = _get_test_dataframes() unsorted_df.printSchema() - sorted_df = quinn.sort_columns(unsorted_df, "asc") + sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) sorted_df.printSchema() chispa.schema_comparer.assert_schema_equality( From bf7bd8e4fd9c4931f77b12dff14580df756d9e33 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 22:33:03 -0400 Subject: [PATCH 12/25] deduplicate testing code --- tests/test_transformations.py | 387 +++++++++++++--------------------- 1 file changed, 151 insertions(+), 236 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index e2b67173..b30cdc71 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -303,91 +303,86 @@ def test_sort_struct_flat_desc(spark): _test_sort_struct_flat(spark, "desc") -def test_sort_struct_nested(spark): - def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: - unsorted_schema = StructType( - [ - StructField("_id", StringType(), nullable=False), - StructField("first_name", StringType(), nullable=False), - StructField( - "address", - StructType( - [ - StructField( - "zip", - StructType( - [ - StructField( - "last4", IntegerType(), nullable=True - ), - StructField( - "first5", IntegerType(), nullable=True - ), - ] - ), - nullable=True, - ), - StructField("city", StringType(), nullable=True), - ] - ), - nullable=True, - ), - ] - ) - - sorted_schema = StructType( - [ - StructField("_id", StringType(), nullable=False), - StructField( - "address", - StructType( - [ - StructField("city", StringType(), nullable=True), - StructField( - "zip", - StructType( - [ - StructField( - "first5", IntegerType(), nullable=True - ), - StructField( - "last4", IntegerType(), nullable=True - ), - ] - ), - nullable=True, - ), - ] +def _get_test_dataframes_schemas() -> dict: + elements = { + "_id": (StructField("_id", StringType(), nullable=False)), + "first_name": (StructField("first_name", StringType(), nullable=False)), + "city": (StructField("city", StringType(), nullable=False)), + "last4": (StructField("last4", IntegerType(), nullable=True)), + "first5": (StructField("first5", IntegerType(), nullable=True)), + "type": (StructField("type", StringType(), nullable=True)), + "number": (StructField("number", StringType(), nullable=True)), + } + + return elements + + +def _get_test_dataframes_data() -> tuple[(str, str, int, int, str)]: + _id = "12345" + city = "Fake City" + zip_first5 = 54321 + zip_last4 = 12345 + first_name = "John" + + return _id, city, zip_first5, zip_last4, first_name + + +def _get_unsorted_nested_struct_fields(elements: dict): + unsorted_fields = [ + elements["_id"], + elements["first_name"], + StructField( + "address", + StructType( + [ + StructField( + "zip", + StructType([elements["last4"], elements["first5"]]), + nullable=True, ), - nullable=True, - ), - StructField("first_name", StringType(), nullable=False), - ] - ) + elements["city"], + ] + ), + nullable=True, + ), + ] + return unsorted_fields - _id = "12345" - city = "Fake City" - zip_first5 = 54321 - zip_last4 = 12345 - first_name = "John" - unsorted_data = [ - (_id, first_name, (((zip_last4, zip_first5)), city)), - ] - sorted_data = [ - (_id, ((city, (zip_first5, zip_last4))), first_name), +def test_sort_struct_nested(spark): + def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + elements = _get_test_dataframes_schemas() + unsorted_fields = _get_unsorted_nested_struct_fields(elements) + sorted_fields = [ + elements["_id"], + StructField( + "address", + StructType( + [ + elements["city"], + StructField( + "zip", + StructType([elements["first5"], elements["last4"]]), + nullable=True, + ), + ] + ), + nullable=True, + ), + elements["first_name"], ] - unsorted_df = spark.createDataFrame(unsorted_data, unsorted_schema) - expected_df = spark.createDataFrame(sorted_data, sorted_schema) + _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() + unsorted_data = [(_id, first_name, (((zip_last4, zip_first5)), city))] + sorted_data = [(_id, ((city, (zip_first5, zip_last4))), first_name)] + + unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) + expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) return unsorted_df, expected_df unsorted_df, expected_df = _get_test_dataframes() - - unsorted_df.printSchema() sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) - sorted_df.printSchema() chispa.schema_comparer.assert_schema_equality( sorted_df.schema, expected_df.schema, ignore_nullable=True @@ -396,190 +391,113 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: def test_sort_struct_nested_desc(spark): def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: - unsorted_schema = StructType( - [ - StructField("_id", StringType(), nullable=False), - StructField("first_name", StringType(), nullable=False), - StructField( - "address", - StructType( - [ - StructField( - "zip", - StructType( - [ - StructField( - "last4", IntegerType(), nullable=True - ), - StructField( - "first5", IntegerType(), nullable=True - ), - ] - ), - nullable=True, - ), - StructField("city", StringType(), nullable=True), - ] - ), - nullable=True, + elements = _get_test_dataframes_schemas() + unsorted_fields = _get_unsorted_nested_struct_fields(elements) + + sorted_fields = [ + elements["first_name"], + StructField( + "address", + StructType( + [ + StructField( + "zip", + StructType([elements["last4"], elements["first5"]]), + nullable=True, + ), + elements["city"], + ] ), - ] - ) - - sorted_schema = StructType( - [ - StructField("first_name", StringType(), nullable=False), - StructField( - "address", - StructType( - [ - StructField( - "zip", - StructType( - [ - StructField( - "last4", IntegerType(), nullable=True - ), - StructField( - "first5", IntegerType(), nullable=True - ), - ] - ), - nullable=True, - ), - StructField("city", StringType(), nullable=True), - ] - ), - nullable=True, - ), - StructField("_id", StringType(), nullable=False), - ] - ) + nullable=True, + ), + elements["_id"], + ] - _id = "12345" - city = "Fake City" - zip_first5 = 54321 - zip_last4 = 12345 - first_name = "John" + _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() - unsorted_data = [ - (_id, first_name, (((zip_last4, zip_first5)), city)), - ] + unsorted_data = [(_id, first_name, (((zip_last4, zip_first5)), city))] sorted_data = [ ( first_name, ((zip_first5, zip_last4), city), _id, - ), + ) ] - unsorted_df = spark.createDataFrame(unsorted_data, unsorted_schema) - expected_df = spark.createDataFrame(sorted_data, sorted_schema) + unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) + expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) return unsorted_df, expected_df unsorted_df, expected_df = _get_test_dataframes() - - unsorted_df.printSchema() sorted_df = quinn.sort_columns(unsorted_df, "desc") - sorted_df.printSchema() chispa.schema_comparer.assert_schema_equality( sorted_df.schema, expected_df.schema, ignore_nullable=True ) -def test_sort_struct_nested_with_arraytypes(spark): - def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: - unsorted_schema = StructType( - [ - StructField( - "address", - StructType( - [ - StructField( - "zip", - StructType( - [ - StructField( - "first5", IntegerType(), nullable=True - ), - StructField( - "last4", IntegerType(), nullable=True - ), - ] - ), - nullable=False, - ), - StructField("city", StringType(), nullable=True), - ] - ), - nullable=False, - ), - StructField( - "phone_numbers", - ArrayType( +def _get_unsorted_nested_array_fields(elements: dict) -> list: + unsorted_fields = [ + StructField( + "address", + StructType( + [ + StructField( + "zip", StructType( [ - StructField("type", StringType(), nullable=True), - StructField("number", StringType(), nullable=True), + elements["first5"], + elements["last4"], ] - ) + ), + nullable=False, ), - nullable=True, - ), - StructField("_id", StringType(), nullable=True), - StructField("first_name", StringType(), nullable=True), - ] - ) + ] + ), + elements["city"], + ), + StructField( + "phone_numbers", + ArrayType(StructType([elements["type"], elements["number"]])), + nullable=True, + ), + elements["_id"], + elements["first_name"], + ] + return unsorted_fields - sorted_schema = StructType( - [ - StructField("_id", StringType(), nullable=True), - StructField( - "address", - StructType( - [ - StructField("city", StringType(), nullable=True), - StructField( - "zip", - StructType( - [ - StructField( - "first5", IntegerType(), nullable=True - ), - StructField( - "last4", IntegerType(), nullable=True - ), - ] - ), - nullable=False, - ), - ] - ), - nullable=False, - ), - StructField("first_name", StringType(), nullable=True), - StructField( - "phone_numbers", - ArrayType( - StructType( - [ - StructField("number", StringType(), nullable=True), - StructField("type", StringType(), nullable=True), - ] - ) - ), - nullable=True, + +def test_sort_struct_nested_with_arraytypes(spark): + def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + elements = _get_test_dataframes_schemas() + unsorted_fields = _get_unsorted_nested_array_fields(elements) + + sorted_fields = [ + elements["_id"], + StructField( + "address", + StructType( + [ + elements["city"], + StructField( + "zip", + StructType([elements["first5"], elements["last4"]]), + nullable=False, + ), + ] ), - ] - ) + nullable=False, + ), + elements["first_name"], + StructField( + "phone_numbers", + ArrayType(StructType([elements["type"], elements["number"]])), + nullable=True, + ), + ] - _id = "12345" - city = "Fake City" - zip_first5 = 54321 - zip_last4 = 12345 - first_name = "John" + _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() phone_type = "home" phone_number = "555-555-5555" @@ -600,16 +518,13 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: ), ] - unsorted_df = spark.createDataFrame(unsorted_data, unsorted_schema) - expected_df = spark.createDataFrame(sorted_data, sorted_schema) + unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) + expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) return unsorted_df, expected_df unsorted_df, expected_df = _get_test_dataframes() - - unsorted_df.printSchema() sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) - sorted_df.printSchema() chispa.schema_comparer.assert_schema_equality( sorted_df.schema, expected_df.schema, ignore_nullable=True From 212820c9dc2155240354c3ac412d9e54f6d0893a Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 22:51:52 -0400 Subject: [PATCH 13/25] fix array test --- tests/test_transformations.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index b30cdc71..949b431e 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -447,15 +447,15 @@ def _get_unsorted_nested_array_fields(elements: dict) -> list: "zip", StructType( [ - elements["first5"], elements["last4"], + elements["first5"], ] ), nullable=False, ), + elements["city"], ] ), - elements["city"], ), StructField( "phone_numbers", @@ -492,7 +492,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: elements["first_name"], StructField( "phone_numbers", - ArrayType(StructType([elements["type"], elements["number"]])), + ArrayType(StructType([elements["number"], elements["type"]])), nullable=True, ), ] @@ -503,7 +503,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_data = [ ( - (((zip_last4, zip_first5)), city), + ((zip_last4, zip_first5), city), [(phone_type, phone_number)], _id, first_name, @@ -512,7 +512,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: sorted_data = [ ( _id, - (city, ((zip_last4, zip_first5))), + (city, (zip_last4, zip_first5)), first_name, [(phone_type, phone_number)], ), From 636a3e723e85965925929edf1dba2884439e7ae3 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 23:07:06 -0400 Subject: [PATCH 14/25] test_sort_struct_nested_with_arraytypes_desc --- tests/test_transformations.py | 83 ++++++++++++++++++++++++++++++----- 1 file changed, 71 insertions(+), 12 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 949b431e..c5ae0e99 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -384,9 +384,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_df, expected_df = _get_test_dataframes() sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable=True - ) + chispa.schema_comparer.assert_schema_equality(sorted_df.schema, expected_df.schema) def test_sort_struct_nested_desc(spark): @@ -432,9 +430,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_df, expected_df = _get_test_dataframes() sorted_df = quinn.sort_columns(unsorted_df, "desc") - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable=True - ) + chispa.schema_comparer.assert_schema_equality(sorted_df.schema, expected_df.schema) def _get_unsorted_nested_array_fields(elements: dict) -> list: @@ -456,6 +452,7 @@ def _get_unsorted_nested_array_fields(elements: dict) -> list: elements["city"], ] ), + nullable=False, ), StructField( "phone_numbers", @@ -526,12 +523,74 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_df, expected_df = _get_test_dataframes() sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable=True - ) + expected_df.printSchema() + sorted_df.printSchema() + + chispa.schema_comparer.assert_schema_equality(sorted_df.schema, expected_df.schema) + + +def test_sort_struct_nested_with_arraytypes_desc(spark): + def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + elements = _get_test_dataframes_schemas() + unsorted_fields = _get_unsorted_nested_array_fields(elements) + + sorted_fields = [ + StructField( + "phone_numbers", + ArrayType(StructType([elements["number"], elements["type"]])), + nullable=True, + ), + elements["first_name"], + StructField( + "address", + StructType( + [ + elements["city"], + StructField( + "zip", + StructType([elements["first5"], elements["last4"]]), + nullable=False, + ), + ] + ), + nullable=False, + ), + elements["_id"], + ] + + _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() + phone_type = "home" + phone_number = "555-555-5555" + + unsorted_data = [ + ( + ((zip_last4, zip_first5), city), + [(phone_type, phone_number)], + _id, + first_name, + ), + ] + sorted_data = [ + ( + [(phone_type, phone_number)], + first_name, + (city, (zip_last4, zip_first5)), + _id, + ), + ] + + unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) + expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) + + return unsorted_df, expected_df + + unsorted_df, expected_df = _get_test_dataframes() + sorted_df = quinn.sort_columns(unsorted_df, "desc", sort_nested_structs=True) + + chispa.schema_comparer.assert_schema_equality(sorted_df.schema, expected_df.schema) -# from pyspark.sql import SparkSession +from pyspark.sql import SparkSession -# spark = SparkSession.builder.appName("abc").getOrCreate() -# test_sort_struct_nested_with_arraytypes(spark) +spark = SparkSession.builder.appName("abc").getOrCreate() +test_sort_struct_nested_with_arraytypes(spark) From a5fdb54393b0f6a1d86f86a97e1ccb04841add8b Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 23:32:04 -0400 Subject: [PATCH 15/25] add ignore nullable T/F tests on schema validation --- tests/test_transformations.py | 57 ++++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index c5ae0e99..e93ae2c0 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -349,7 +349,7 @@ def _get_unsorted_nested_struct_fields(elements: dict): return unsorted_fields -def test_sort_struct_nested(spark): +def _test_sort_struct_nested(spark, ignore_nullable: bool): def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: elements = _get_test_dataframes_schemas() unsorted_fields = _get_unsorted_nested_struct_fields(elements) @@ -384,10 +384,12 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_df, expected_df = _get_test_dataframes() sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) - chispa.schema_comparer.assert_schema_equality(sorted_df.schema, expected_df.schema) + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable + ) -def test_sort_struct_nested_desc(spark): +def _test_sort_struct_nested_desc(spark, ignore_nullable: bool): def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: elements = _get_test_dataframes_schemas() unsorted_fields = _get_unsorted_nested_struct_fields(elements) @@ -430,7 +432,9 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_df, expected_df = _get_test_dataframes() sorted_df = quinn.sort_columns(unsorted_df, "desc") - chispa.schema_comparer.assert_schema_equality(sorted_df.schema, expected_df.schema) + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable + ) def _get_unsorted_nested_array_fields(elements: dict) -> list: @@ -465,7 +469,7 @@ def _get_unsorted_nested_array_fields(elements: dict) -> list: return unsorted_fields -def test_sort_struct_nested_with_arraytypes(spark): +def _test_sort_struct_nested_with_arraytypes(spark, ignore_nullable: bool): def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: elements = _get_test_dataframes_schemas() unsorted_fields = _get_unsorted_nested_array_fields(elements) @@ -526,10 +530,12 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: expected_df.printSchema() sorted_df.printSchema() - chispa.schema_comparer.assert_schema_equality(sorted_df.schema, expected_df.schema) + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable + ) -def test_sort_struct_nested_with_arraytypes_desc(spark): +def _test_sort_struct_nested_with_arraytypes_desc(spark, ignore_nullable: bool): def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: elements = _get_test_dataframes_schemas() unsorted_fields = _get_unsorted_nested_array_fields(elements) @@ -587,10 +593,39 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_df, expected_df = _get_test_dataframes() sorted_df = quinn.sort_columns(unsorted_df, "desc", sort_nested_structs=True) - chispa.schema_comparer.assert_schema_equality(sorted_df.schema, expected_df.schema) + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable + ) + + +def test_sort_struct_nested(spark): + _test_sort_struct_nested(spark, True) + + +def test_sort_struct_nested_desc(spark): + _test_sort_struct_nested_desc(spark, True) + + +def test_sort_struct_nested_with_arraytypes(spark): + _test_sort_struct_nested_with_arraytypes(spark, True) + + +def test_sort_struct_nested_with_arraytypes_desc(spark): + _test_sort_struct_nested_with_arraytypes_desc(spark, True) + + +# broken nullable tests below ============================ +def test_sort_struct_nested_nullable(spark): + _test_sort_struct_nested(spark, False) + + +def test_sort_struct_nested_nullable_desc(spark): + _test_sort_struct_nested_desc(spark, False) + +def test_sort_struct_nested_with_arraytypes_nullable(spark): + _test_sort_struct_nested_with_arraytypes(spark, False) -from pyspark.sql import SparkSession -spark = SparkSession.builder.appName("abc").getOrCreate() -test_sort_struct_nested_with_arraytypes(spark) +def test_sort_struct_nested_with_arraytypes_nullable_desc(spark): + _test_sort_struct_nested_with_arraytypes_desc(spark, False) From ef502325952bc740ed0d6f53bc36d4e0aefc5e3d Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Sat, 7 Oct 2023 23:41:55 -0400 Subject: [PATCH 16/25] fix arraytypes_desc ignore nullable == True test --- tests/test_transformations.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index e93ae2c0..c19650c4 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -543,7 +543,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: sorted_fields = [ StructField( "phone_numbers", - ArrayType(StructType([elements["number"], elements["type"]])), + ArrayType(StructType([elements["type"], elements["number"]])), nullable=True, ), elements["first_name"], @@ -551,12 +551,12 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: "address", StructType( [ - elements["city"], StructField( "zip", - StructType([elements["first5"], elements["last4"]]), + StructType([elements["last4"], elements["first5"]]), nullable=False, ), + elements["city"], ] ), nullable=False, @@ -580,7 +580,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: ( [(phone_type, phone_number)], first_name, - (city, (zip_last4, zip_first5)), + ((zip_last4, zip_first5), city), _id, ), ] @@ -629,3 +629,9 @@ def test_sort_struct_nested_with_arraytypes_nullable(spark): def test_sort_struct_nested_with_arraytypes_nullable_desc(spark): _test_sort_struct_nested_with_arraytypes_desc(spark, False) + + +# from pyspark.sql import SparkSession + +# spark = SparkSession.builder.getOrCreate() +# test_sort_struct_nested_with_arraytypes_desc(spark) From 2434e671399853109ca8a8c9f893bab606b6948a Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Tue, 10 Oct 2023 19:53:45 -0400 Subject: [PATCH 17/25] add nullability fixes --- quinn/transformations.py | 55 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/quinn/transformations.py b/quinn/transformations.py index 149c4b44..bbba4d67 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -139,7 +139,7 @@ def sort_top_level_cols(schema, is_reversed) -> dict: return output - def sort_nested_cols(schema, is_reversed, base_field="") -> list: + def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]: # TODO: get working with ArrayType # recursively check nested fields and sort them # https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark @@ -226,13 +226,53 @@ def format_struct_selection(elements: list, struct_name: str) -> str: return select_cols + def get_original_nullability(field: StructField, result_dict: dict) -> None: + def assign_nullability(field: StructField, result_dict: dict) -> dict: + try: + result_dict[field.name] = field.nullable + except AttributeError: + result_dict[field.name] = True + + return result_dict + + result_dict = assign_nullability(field, result_dict) + if not isinstance(field.dataType, StructType) and not isinstance( + field.dataType, ArrayType + ): + return + + if isinstance(field.dataType, ArrayType): + result_dict[f"{field.name}_element"] = field.dataType.containsNull + children = field.dataType.elementType.fields + else: + children = field.dataType.fields + for i in children: + get_original_nullability(i, result_dict) + + def fix_nullability(field: StructField, result_dict: dict) -> None: + field.nullable = result_dict[field.name] + if not isinstance(field.dataType, StructType) and not isinstance( + field.dataType, ArrayType + ): + return + + if isinstance(field.dataType, ArrayType): + # save the containsNull property of the ArrayType + field.dataType.containsNull = result_dict[f"{field.name}_element"] + children = field.dataType.elementType.fields + else: + children = field.dataType.fields + + for i in children: + fix_nullability(i, result_dict) + is_reversed: bool = parse_sort_order(sort_order) top_sorted_schema_results: dict = sort_top_level_cols(df.schema, is_reversed) - skip_nested_sorting = ( not top_sorted_schema_results["is_nested"] or not sort_nested_structs ) + # fast exit if no nested structs or if user doesn't want to sort them if skip_nested_sorting: columns: list = [i.name for i in top_sorted_schema_results["schema"]] return df.select(*columns) @@ -241,4 +281,13 @@ def format_struct_selection(elements: list, struct_name: str) -> str: top_sorted_schema_results["schema"], is_reversed ) - return df.selectExpr(fully_sorted_schema) + output = df.selectExpr(fully_sorted_schema) + result_dict = {} + for field in df.schema: + get_original_nullability(field, result_dict) + + for field in output.schema: + fix_nullability(field, result_dict) + + final_df = output.sparkSession.createDataFrame(output.rdd, output.schema) + return final_df From ee7c081f4cad27120ff1428b86aceb830f957ebc Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Tue, 10 Oct 2023 19:54:41 -0400 Subject: [PATCH 18/25] clean up tests --- tests/test_transformations.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index c19650c4..e7ac1bcb 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -285,11 +285,8 @@ def _get_simple_test_dataframes(sort_order) -> tuple[(DataFrame, DataFrame)]: return unsorted_df, expected_df unsorted_df, expected_df = _get_simple_test_dataframes(sort_order=sort_order) - - unsorted_df.printSchema() sorted_df = quinn.sort_columns(unsorted_df, sort_order) - sorted_df.printSchema() - expected_df.printSchema() + chispa.schema_comparer.assert_schema_equality( sorted_df.schema, expected_df.schema, ignore_nullable=True ) @@ -373,8 +370,17 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: ] _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() - unsorted_data = [(_id, first_name, (((zip_last4, zip_first5)), city))] - sorted_data = [(_id, ((city, (zip_first5, zip_last4))), first_name)] + unsorted_data = [ + (_id, first_name, (((zip_last4, zip_first5)), city)), + (_id, first_name, (((None, zip_first5)), city)), + (_id, first_name, (None)), + ] + + sorted_data = [ + (_id, ((city, (zip_first5, zip_last4))), first_name), + (_id, ((city, (zip_first5, None))), first_name), + (_id, (None), first_name), + ] unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) @@ -509,6 +515,8 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: _id, first_name, ), + (((zip_last4, zip_first5), city), [(phone_type, None)], _id, first_name), + (((None, None), city), None, _id, first_name), ] sorted_data = [ ( @@ -517,8 +525,9 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: first_name, [(phone_type, phone_number)], ), + (_id, (city, (zip_last4, zip_first5)), first_name, [(phone_type, None)]), + (_id, (city, (None, None)), first_name, None), ] - unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) @@ -527,9 +536,6 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: unsorted_df, expected_df = _get_test_dataframes() sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) - expected_df.printSchema() - sorted_df.printSchema() - chispa.schema_comparer.assert_schema_equality( sorted_df.schema, expected_df.schema, ignore_nullable ) @@ -614,7 +620,6 @@ def test_sort_struct_nested_with_arraytypes_desc(spark): _test_sort_struct_nested_with_arraytypes_desc(spark, True) -# broken nullable tests below ============================ def test_sort_struct_nested_nullable(spark): _test_sort_struct_nested(spark, False) @@ -629,9 +634,3 @@ def test_sort_struct_nested_with_arraytypes_nullable(spark): def test_sort_struct_nested_with_arraytypes_nullable_desc(spark): _test_sort_struct_nested_with_arraytypes_desc(spark, False) - - -# from pyspark.sql import SparkSession - -# spark = SparkSession.builder.getOrCreate() -# test_sort_struct_nested_with_arraytypes_desc(spark) From b8ca06c8e9e4f44697f0fb2419c13cadb100e831 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Tue, 10 Oct 2023 19:56:54 -0400 Subject: [PATCH 19/25] rename param --- quinn/transformations.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/quinn/transformations.py b/quinn/transformations.py index bbba4d67..5a8666da 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -89,7 +89,7 @@ def to_snake_case(s: str) -> str: def sort_columns( - df: DataFrame, sort_order: str, sort_nested_structs: bool = False + df: DataFrame, sort_order: str, sort_nested: bool = False ) -> DataFrame: """This function sorts the columns of a given DataFrame based on a given sort order. The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to @@ -268,9 +268,7 @@ def fix_nullability(field: StructField, result_dict: dict) -> None: is_reversed: bool = parse_sort_order(sort_order) top_sorted_schema_results: dict = sort_top_level_cols(df.schema, is_reversed) - skip_nested_sorting = ( - not top_sorted_schema_results["is_nested"] or not sort_nested_structs - ) + skip_nested_sorting = not top_sorted_schema_results["is_nested"] or not sort_nested # fast exit if no nested structs or if user doesn't want to sort them if skip_nested_sorting: From cdb01c33fd41d6d36cb050127f69c8f6349244fb Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Tue, 10 Oct 2023 19:58:09 -0400 Subject: [PATCH 20/25] update readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 59a8563e..b01910c4 100644 --- a/README.md +++ b/README.md @@ -230,10 +230,10 @@ Converts all the column names in a DataFrame to snake_case. It's annoying to wri **sort_columns()** ```python -quinn.sort_columns(source_df, "asc") +quinn.sort_columns(df=source_df, sort_order="asc", sort_nested=True) ``` -Sorts the DataFrame columns in alphabetical order. Wide DataFrames are easier to navigate when they're sorted alphabetically. +Sorts the DataFrame columns in alphabetical order, including nested columns if sort_nested is set to True. Wide DataFrames are easier to navigate when they're sorted alphabetically. ### DataFrame Helpers From e819c1416e8ba7f090096b3406e1acb64341a612 Mon Sep 17 00:00:00 2001 From: Jeff Brennan Date: Tue, 10 Oct 2023 20:01:57 -0400 Subject: [PATCH 21/25] update sort param name --- tests/test_transformations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index e7ac1bcb..04c83c61 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -388,7 +388,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: return unsorted_df, expected_df unsorted_df, expected_df = _get_test_dataframes() - sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) + sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested=True) chispa.schema_comparer.assert_schema_equality( sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable @@ -534,7 +534,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: return unsorted_df, expected_df unsorted_df, expected_df = _get_test_dataframes() - sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested_structs=True) + sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested=True) chispa.schema_comparer.assert_schema_equality( sorted_df.schema, expected_df.schema, ignore_nullable @@ -597,7 +597,7 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: return unsorted_df, expected_df unsorted_df, expected_df = _get_test_dataframes() - sorted_df = quinn.sort_columns(unsorted_df, "desc", sort_nested_structs=True) + sorted_df = quinn.sort_columns(unsorted_df, "desc", sort_nested=True) chispa.schema_comparer.assert_schema_equality( sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable From 9e75fb6cbb3a6da52896f6c48a685a7761aa2dc9 Mon Sep 17 00:00:00 2001 From: jeffbrennan Date: Sat, 14 Oct 2023 10:47:07 -0400 Subject: [PATCH 22/25] ensure backwards compatability for sort_nested=False --- quinn/transformations.py | 41 ++++++++++++++++------------------------ 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/quinn/transformations.py b/quinn/transformations.py index 5a8666da..832391c2 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -119,25 +119,10 @@ def parse_sort_order(sort_order: str) -> bool: } return reverse_lookup[sort_order] - def sort_top_level_cols(schema, is_reversed) -> dict: + def sort_top_level_cols(df, is_reversed) -> DataFrame: # sort top level columns - top_sorted_fields: list = sorted( - schema.fields, key=lambda x: x.name, reverse=is_reversed - ) - - is_nested: bool = any( - [ - isinstance(i.dataType, StructType) or isinstance(i.dataType, ArrayType) - for i in top_sorted_fields - ] - ) - - output = { - "schema": top_sorted_fields, - "is_nested": is_nested, - } - - return output + sorted_col_names = sorted(df.columns, reverse=is_reversed) + return df.select(*sorted_col_names) def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]: # TODO: get working with ArrayType @@ -267,16 +252,22 @@ def fix_nullability(field: StructField, result_dict: dict) -> None: fix_nullability(i, result_dict) is_reversed: bool = parse_sort_order(sort_order) - top_sorted_schema_results: dict = sort_top_level_cols(df.schema, is_reversed) - skip_nested_sorting = not top_sorted_schema_results["is_nested"] or not sort_nested + top_level_sorted_df = sort_top_level_cols(df, is_reversed) + if not sort_nested: + return top_level_sorted_df + + is_nested: bool = any( + [ + isinstance(i.dataType, StructType) or isinstance(i.dataType, ArrayType) + for i in top_level_sorted_df.schema + ] + ) - # fast exit if no nested structs or if user doesn't want to sort them - if skip_nested_sorting: - columns: list = [i.name for i in top_sorted_schema_results["schema"]] - return df.select(*columns) + if not is_nested: + return top_level_sorted_df fully_sorted_schema = sort_nested_cols( - top_sorted_schema_results["schema"], is_reversed + top_level_sorted_df.schema, is_reversed ) output = df.selectExpr(fully_sorted_schema) From 25a0b1c36c45e82d6b81e118110ba34c1830128b Mon Sep 17 00:00:00 2001 From: jeffbrennan Date: Sat, 14 Oct 2023 10:48:30 -0400 Subject: [PATCH 23/25] clean up function comments/docstring --- quinn/transformations.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/quinn/transformations.py b/quinn/transformations.py index 832391c2..27cf205b 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -100,8 +100,8 @@ def sort_columns( :type df: pyspark.sql.DataFrame :param sort_order: The order in which to sort the columns in the DataFrame :type sort_order: str - :param sort_nested_structs: Whether to sort nested structs or not. Defaults to false. - :type sort_nested_structs: bool + :param sort_nested: Whether to sort nested structs or not. Defaults to false. + :type sort_nested: bool :return: A DataFrame with the columns sorted in the chosen order :rtype: pyspark.sql.DataFrame """ @@ -125,7 +125,6 @@ def sort_top_level_cols(df, is_reversed) -> DataFrame: return df.select(*sorted_col_names) def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]: - # TODO: get working with ArrayType # recursively check nested fields and sort them # https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark # Credits: @pault for logic @@ -169,7 +168,6 @@ def format_array_selection( suffix_str = f") AS {parent_struct.name}" # if struct in array, create mapping to struct - # TODO: prob doesn't work with additional levels of nesting if array_parent["type"] == "struct": array_parent = array_parent["fields"] From 2ff0af04513eb91bcb84ead4aa7697841558f6cd Mon Sep 17 00:00:00 2001 From: jeffbrennan Date: Sun, 22 Oct 2023 18:49:19 -0400 Subject: [PATCH 24/25] add non-working example --- tests/test_transformations.py | 99 +++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 04c83c61..200c2f9c 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -604,6 +604,99 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: ) +def _test_sort_struct_nested_in_arraytypes(spark, ignore_nullable: bool): + def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: + elements = _get_test_dataframes_schemas() + unsorted_fields = _get_unsorted_nested_array_fields(elements) + + # extensions = StructType( + # [ + # StructField("extension_code", StringType(), nullable=True), + # StructField( + # "extension_numbers", + # StructType( + # [ + # StructField("extension_number_one", IntegerType()), + # StructField("extension_number_two", IntegerType()), + # ] + # ), + # ), + # ] + # ) + + sorted_fields = [ + StructField( + "phone_numbers", + ArrayType(StructType([elements["type"], elements["number"]])), + ), + StructField( + "extensions", + ArrayType( + StructType( + [ + StructField("extension_number_one", IntegerType()), + StructField("extension_number_two", IntegerType()), + ] + ), + StructField("extension_code", StringType(), nullable=True), + ), + ), + elements["first_name"], + StructField( + "address", + StructType( + [ + StructField( + "zip", + StructType([elements["last4"], elements["first5"]]), + nullable=False, + ), + elements["city"], + ] + ), + nullable=False, + ), + elements["_id"], + ] + + _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() + phone_type = "home" + phone_number = "555-555-5555" + extension_code = "test" + extension_number_one = 1 + extension_number_two = 2 + + unsorted_data = [ + ( + ((zip_last4, zip_first5), city), + [(phone_type, phone_number)], + _id, + first_name, + ), + ] + sorted_data = [ + ( + [(phone_type, phone_number)], + [(extension_number_one, extension_number_two), extension_code], + first_name, + ((zip_last4, zip_first5), city), + _id, + ), + ] + + expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) + unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) + + return unsorted_df, expected_df + + unsorted_df, expected_df = _get_test_dataframes() + sorted_df = quinn.sort_columns(unsorted_df, "desc", sort_nested=True) + + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable + ) + + def test_sort_struct_nested(spark): _test_sort_struct_nested(spark, True) @@ -634,3 +727,9 @@ def test_sort_struct_nested_with_arraytypes_nullable(spark): def test_sort_struct_nested_with_arraytypes_nullable_desc(spark): _test_sort_struct_nested_with_arraytypes_desc(spark, False) + + +from pyspark.sql import SparkSession + +spark = SparkSession.builder.getOrCreate() +_test_sort_struct_nested_with_arraytypes(spark, False) From 53d4738579048aa450cdb27be6b66120d7fa83e9 Mon Sep 17 00:00:00 2001 From: jeffbrennan Date: Mon, 30 Oct 2023 18:40:59 -0400 Subject: [PATCH 25/25] refactor - reorganize nested functions, add comments --- quinn/transformations.py | 119 ++++++++++++++------------------------- 1 file changed, 43 insertions(+), 76 deletions(-) diff --git a/quinn/transformations.py b/quinn/transformations.py index 27cf205b..3149cbce 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Callable import pyspark.sql.functions as F @@ -106,24 +107,6 @@ def sort_columns( :rtype: pyspark.sql.DataFrame """ - def parse_sort_order(sort_order: str) -> bool: - if sort_order not in ["asc", "desc"]: - raise ValueError( - "['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format( - sort_order=sort_order - ) - ) - reverse_lookup = { - "asc": False, - "desc": True, - } - return reverse_lookup[sort_order] - - def sort_top_level_cols(df, is_reversed) -> DataFrame: - # sort top level columns - sorted_col_names = sorted(df.columns, reverse=is_reversed) - return df.select(*sorted_col_names) - def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]: # recursively check nested fields and sort them # https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark @@ -150,56 +133,36 @@ def parse_fields( ) return results - def handle_array_type(parent_struct: StructField, is_reversed: bool) -> str: - def format_array_selection( - elements: list, base_str: str, suffix_str: str - ) -> str: - element_names = [i.split(".")[-1] for i in elements] - array_elements_formatted = [f"x.{i} as {i}" for i in element_names] - - output = ( - f"{base_str} {', '.join(array_elements_formatted)} {suffix_str}" - ) - return output - - array_parent = parent_struct.jsonValue()["type"]["elementType"] - - base_str = f"transform({parent_struct.name}" - suffix_str = f") AS {parent_struct.name}" - - # if struct in array, create mapping to struct - if array_parent["type"] == "struct": - array_parent = array_parent["fields"] - - base_str = f"{base_str}, x -> struct(" - suffix_str = f"){suffix_str}" - - array_elements = parse_fields(array_parent, parent_struct, is_reversed) - formatted_array_selection = format_array_selection( - array_elements, base_str, suffix_str - ) - return formatted_array_selection - - def handle_struct_type(parent_struct: StructField, is_reversed: bool) -> str: - def format_struct_selection(elements: list, struct_name: str) -> str: - output: str = f"struct( {', '.join(elements)} ) AS {struct_name}" - return output - - field_list = parent_struct.jsonValue()["type"]["fields"] - sub_fields = parse_fields(field_list, parent_struct, is_reversed) - formatted_sub_fields = format_struct_selection( - sub_fields, parent_struct.name - ) - return formatted_sub_fields - select_cols = [] for parent_struct in sorted(schema, key=lambda x: x.name, reverse=is_reversed): field_type = parent_struct.dataType if isinstance(field_type, ArrayType): - result = handle_array_type(parent_struct, is_reversed) + array_parent = parent_struct.jsonValue()["type"]["elementType"] + base_str = f"transform({parent_struct.name}" + suffix_str = f") AS {parent_struct.name}" + + # if struct in array, create mapping to struct + if array_parent["type"] == "struct": + array_parent = array_parent["fields"] + base_str = f"{base_str}, x -> struct(" + suffix_str = f"){suffix_str}" + + array_elements = parse_fields(array_parent, parent_struct, is_reversed) + element_names = [i.split(".")[-1] for i in array_elements] + array_elements_formatted = [f"x.{i} as {i}" for i in element_names] + + # create a string representation of the sorted array + # ex: transform(phone_numbers, x -> struct(x.number as number, x.type as type)) AS phone_numbers + result = f"{base_str}{', '.join(array_elements_formatted)}{suffix_str}" elif isinstance(field_type, StructType): - result = handle_struct_type(parent_struct, is_reversed) + field_list = parent_struct.jsonValue()["type"]["fields"] + sub_fields = parse_fields(field_list, parent_struct, is_reversed) + + # create a string representation of the sorted struct + # ex: struct(address.zip.first5, address.zip.last4) AS zip + result = f"struct({', '.join(sub_fields)}) AS {parent_struct.name}" + else: if base_field: result = f"{base_field}.{parent_struct.name}" @@ -210,15 +173,11 @@ def format_struct_selection(elements: list, struct_name: str) -> str: return select_cols def get_original_nullability(field: StructField, result_dict: dict) -> None: - def assign_nullability(field: StructField, result_dict: dict) -> dict: - try: - result_dict[field.name] = field.nullable - except AttributeError: - result_dict[field.name] = True - - return result_dict + if hasattr(field, "nullable"): + result_dict[field.name] = field.nullable + else: + result_dict[field.name] = True - result_dict = assign_nullability(field, result_dict) if not isinstance(field.dataType, StructType) and not isinstance( field.dataType, ArrayType ): @@ -249,8 +208,19 @@ def fix_nullability(field: StructField, result_dict: dict) -> None: for i in children: fix_nullability(i, result_dict) - is_reversed: bool = parse_sort_order(sort_order) - top_level_sorted_df = sort_top_level_cols(df, is_reversed) + if sort_order not in ["asc", "desc"]: + raise ValueError( + "['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format( + sort_order=sort_order + ) + ) + reverse_lookup = { + "asc": False, + "desc": True, + } + + is_reversed: bool = reverse_lookup[sort_order] + top_level_sorted_df = df.select(*sorted(df.columns, reverse=is_reversed)) if not sort_nested: return top_level_sorted_df @@ -264,10 +234,7 @@ def fix_nullability(field: StructField, result_dict: dict) -> None: if not is_nested: return top_level_sorted_df - fully_sorted_schema = sort_nested_cols( - top_level_sorted_df.schema, is_reversed - ) - + fully_sorted_schema = sort_nested_cols(top_level_sorted_df.schema, is_reversed) output = df.selectExpr(fully_sorted_schema) result_dict = {} for field in df.schema: