Skip to content

Commit

Permalink
Merge pull request #139 from jeffbrennan/sort_struct_columns
Browse files Browse the repository at this point in the history
Sort struct columns
  • Loading branch information
jeffbrennan authored Nov 5, 2023
2 parents fe33bbe + 841e074 commit ad271eb
Show file tree
Hide file tree
Showing 3 changed files with 649 additions and 25 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,10 @@ Converts all the column names in a DataFrame to snake_case. It's annoying to wri
**sort_columns()**

```python
quinn.sort_columns(source_df, "asc")
quinn.sort_columns(df=source_df, sort_order="asc", sort_nested=True)
```

Sorts the DataFrame columns in alphabetical order. Wide DataFrames are easier to navigate when they're sorted alphabetically.
Sorts the DataFrame columns in alphabetical order, including nested columns if sort_nested is set to True. Wide DataFrames are easier to navigate when they're sorted alphabetically.

### DataFrame Helpers

Expand Down
165 changes: 147 additions & 18 deletions quinn/transformations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import re
from collections.abc import Callable

import pyspark.sql.functions as F # noqa: N812
from __future__ import annotations
from collections.abc import Callable
from pyspark.sql import DataFrame
from pyspark.sql.types import ArrayType, MapType, StructType

from pyspark.sql.types import ArrayType, MapType, StructField, StructType
from quinn.schema_helpers import complex_fields


Expand Down Expand Up @@ -84,30 +83,161 @@ def to_snake_case(s: str) -> str:
return s.lower().replace(" ", "_")


def sort_columns(df: DataFrame, sort_order: str) -> DataFrame:
"""Function sorts the columns of a given DataFrame based on a given sort order.
The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to
def sort_columns(
df: DataFrame, sort_order: str, sort_nested: bool = False
) -> DataFrame:
"""This function sorts the columns of a given DataFrame based on a given sort
order. The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to
ascending and descending order, respectively. If any other value is provided for
the ``sort_order`` parameter, a ``ValueError`` will be raised.
:param df: A DataFrame
:type df: pandas.DataFrame
:type df: pyspark.sql.DataFrame
:param sort_order: The order in which to sort the columns in the DataFrame
:type sort_order: str
:param sort_nested: Whether to sort nested structs or not. Defaults to false.
:type sort_nested: bool
:return: A DataFrame with the columns sorted in the chosen order
:rtype: pandas.DataFrame
:rtype: pyspark.sql.DataFrame
"""
sorted_col_names = None
if sort_order == "asc":
sorted_col_names = sorted(df.columns)
elif sort_order == "desc":
sorted_col_names = sorted(df.columns, reverse=True)
else:

def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]:
# recursively check nested fields and sort them
# https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark
# Credits: @pault for logic

def parse_fields(
fields_to_sort: list, parent_struct, is_reversed: bool
) -> list:
sorted_fields: list = sorted(
fields_to_sort,
key=lambda x: x["name"],
reverse=is_reversed,
)

results = []
for field in sorted_fields:
new_struct = StructType([StructField.fromJson(field)])
new_base_field = parent_struct.name
if base_field:
new_base_field = base_field + "." + new_base_field

results.extend(
sort_nested_cols(new_struct, is_reversed, base_field=new_base_field)
)
return results

select_cols = []
for parent_struct in sorted(schema, key=lambda x: x.name, reverse=is_reversed):
field_type = parent_struct.dataType
if isinstance(field_type, ArrayType):
array_parent = parent_struct.jsonValue()["type"]["elementType"]
base_str = f"transform({parent_struct.name}"
suffix_str = f") AS {parent_struct.name}"

# if struct in array, create mapping to struct
if array_parent["type"] == "struct":
array_parent = array_parent["fields"]
base_str = f"{base_str}, x -> struct("
suffix_str = f"){suffix_str}"

array_elements = parse_fields(array_parent, parent_struct, is_reversed)
element_names = [i.split(".")[-1] for i in array_elements]
array_elements_formatted = [f"x.{i} as {i}" for i in element_names]

# create a string representation of the sorted array
# ex: transform(phone_numbers, x -> struct(x.number as number, x.type as type)) AS phone_numbers
result = f"{base_str}{', '.join(array_elements_formatted)}{suffix_str}"

elif isinstance(field_type, StructType):
field_list = parent_struct.jsonValue()["type"]["fields"]
sub_fields = parse_fields(field_list, parent_struct, is_reversed)

# create a string representation of the sorted struct
# ex: struct(address.zip.first5, address.zip.last4) AS zip
result = f"struct({', '.join(sub_fields)}) AS {parent_struct.name}"

else:
if base_field:
result = f"{base_field}.{parent_struct.name}"
else:
result = parent_struct.name
select_cols.append(result)

return select_cols

def get_original_nullability(field: StructField, result_dict: dict) -> None:
if hasattr(field, "nullable"):
result_dict[field.name] = field.nullable
else:
result_dict[field.name] = True

if not isinstance(field.dataType, StructType) and not isinstance(
field.dataType, ArrayType
):
return

if isinstance(field.dataType, ArrayType):
result_dict[f"{field.name}_element"] = field.dataType.containsNull
children = field.dataType.elementType.fields
else:
children = field.dataType.fields
for i in children:
get_original_nullability(i, result_dict)

def fix_nullability(field: StructField, result_dict: dict) -> None:
field.nullable = result_dict[field.name]
if not isinstance(field.dataType, StructType) and not isinstance(
field.dataType, ArrayType
):
return

if isinstance(field.dataType, ArrayType):
# save the containsNull property of the ArrayType
field.dataType.containsNull = result_dict[f"{field.name}_element"]
children = field.dataType.elementType.fields
else:
children = field.dataType.fields

for i in children:
fix_nullability(i, result_dict)

if sort_order not in ["asc", "desc"]:
msg = f"['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'"
raise ValueError(
msg,
)
reverse_lookup = {
"asc": False,
"desc": True,
}

is_reversed: bool = reverse_lookup[sort_order]
top_level_sorted_df = df.select(*sorted(df.columns, reverse=is_reversed))
if not sort_nested:
return top_level_sorted_df

is_nested: bool = any(
[
isinstance(i.dataType, StructType) or isinstance(i.dataType, ArrayType)
for i in top_level_sorted_df.schema
]
)

if not is_nested:
return top_level_sorted_df

fully_sorted_schema = sort_nested_cols(top_level_sorted_df.schema, is_reversed)
output = df.selectExpr(fully_sorted_schema)
result_dict = {}
for field in df.schema:
get_original_nullability(field, result_dict)

for field in output.schema:
fix_nullability(field, result_dict)

final_df = output.sparkSession.createDataFrame(output.rdd, output.schema)
return final_df
return df.select(*sorted_col_names)


Expand Down Expand Up @@ -250,5 +380,4 @@ def explode_array(df: DataFrame, col_name: str) -> DataFrame:
]
df = df.toDF(*sanitized_columns) # noqa: PD901

return df

return df
Loading

0 comments on commit ad271eb

Please sign in to comment.