Skip to content

Commit

Permalink
Fix tests and linter
Browse files Browse the repository at this point in the history
 On branch feature/fix-tests
 Changes to be committed:
	modified:   pyproject.toml
	modified:   quinn/dataframe_validator.py
	modified:   quinn/schema_helpers.py
	modified:   quinn/split_columns.py
	modified:   quinn/transformations.py
	modified:   tests/test_split_columns.py
  • Loading branch information
SemyonSinchenko committed Nov 16, 2023
1 parent 79262db commit 173493e
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 31 deletions.
13 changes: 9 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,15 @@ select = ["ALL"]
line-length = 150
ignore = [
"D100",
"D203", # Ignore blank line before summary of class
"D213", # Ignore multiline summary second line
"T201", # Allow print() in code.
"D401", # Docstrings should be in imperative modes
"D203", # Ignore blank line before summary of class
"D213", # Ignore multiline summary second line
"T201", # Allow print() in code.
"D401", # Docstrings should be in imperative modes
"D404", # Boring thing about how to write docsrings
"FBT001", # Boolean positional arg is OK
"FBT002", # Boolean default arg value is OK
"D205", # It is broken
"TCH003", # I have no idea what is it about
]
extend-exclude = ["tests", "docs"]

Expand Down
2 changes: 1 addition & 1 deletion quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) -
def validate_schema(
df: DataFrame,
required_schema: StructType,
ignore_nullable: bool = False, # noqa: FBT001,FBT002
ignore_nullable: bool = False,
) -> None:
"""Function that validate if a given DataFrame has a given StructType as its schema.
Expand Down
5 changes: 2 additions & 3 deletions quinn/schema_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pyspark.sql import SparkSession
from pyspark.sql import types as T # noqa: N812
from typing import Union


def print_schema_as_code(dtype: T.DataType) -> str:
Expand Down Expand Up @@ -56,7 +55,7 @@ def print_schema_as_code(dtype: T.DataType) -> str:
def _repr_column(column: T.StructField) -> str:
res = []

if isinstance(column.dataType, (T.ArrayType, T.MapType, T.StructType)):
if isinstance(column.dataType, (T.ArrayType | T.MapType | T.StructType)):
res.append(f'StructField(\n\t"{column.name}",')
for line in print_schema_as_code(column.dataType).split("\n"):
res.append("\n\t")
Expand Down Expand Up @@ -166,6 +165,6 @@ def complex_fields(schema: T.StructType) -> dict[str, object]:
return {
field.name: field.dataType
for field in schema.fields
if isinstance(field.dataType, (T.ArrayType, T.StructType, T.MapType))
if isinstance(field.dataType, (T.ArrayType | T.StructType | T.MapType))
}

3 changes: 1 addition & 2 deletions quinn/split_columns.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Optional

from pyspark.sql.functions import length, split, trim, udf, when
from pyspark.sql.types import IntegerType
Expand All @@ -16,7 +15,7 @@ def split_col( # noqa: PLR0913
delimiter: str,
new_col_names: list[str],
mode: str = "permissive",
default: Optional[str] = None,
default: str | None = None,
) -> DataFrame:
"""Splits the given column based on the delimiter and creates new columns with the split values.
Expand Down
38 changes: 18 additions & 20 deletions quinn/transformations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import re
import pyspark.sql.functions as F # noqa: N812
from collections.abc import Callable

from pyspark.sql import DataFrame
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import ArrayType, MapType, StructField, StructType

from quinn.schema_helpers import complex_fields


Expand Down Expand Up @@ -83,8 +86,8 @@ def to_snake_case(s: str) -> str:
return s.lower().replace(" ", "_")


def sort_columns(
df: DataFrame, sort_order: str, sort_nested: bool = False
def sort_columns( # noqa: C901,PLR0915
df: DataFrame, sort_order: str, sort_nested: bool = False,
) -> DataFrame:
"""This function sorts the columns of a given DataFrame based on a given sort
order. The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to
Expand All @@ -101,13 +104,13 @@ def sort_columns(
:rtype: pyspark.sql.DataFrame
"""

def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]:
def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]: # noqa: ANN001
# recursively check nested fields and sort them
# https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark
# Credits: @pault for logic

def parse_fields(
fields_to_sort: list, parent_struct, is_reversed: bool
fields_to_sort: list, parent_struct, is_reversed: bool, # noqa: ANN001
) -> list:
sorted_fields: list = sorted(
fields_to_sort,
Expand All @@ -123,7 +126,7 @@ def parse_fields(
new_base_field = base_field + "." + new_base_field

results.extend(
sort_nested_cols(new_struct, is_reversed, base_field=new_base_field)
sort_nested_cols(new_struct, is_reversed, base_field=new_base_field),
)
return results

Expand Down Expand Up @@ -157,11 +160,10 @@ def parse_fields(
# ex: struct(address.zip.first5, address.zip.last4) AS zip
result = f"struct({', '.join(sub_fields)}) AS {parent_struct.name}"

elif base_field:
result = f"{base_field}.{parent_struct.name}"
else:
if base_field:
result = f"{base_field}.{parent_struct.name}"
else:
result = parent_struct.name
result = parent_struct.name
select_cols.append(result)

return select_cols
Expand All @@ -173,7 +175,7 @@ def get_original_nullability(field: StructField, result_dict: dict) -> None:
result_dict[field.name] = True

if not isinstance(field.dataType, StructType) and not isinstance(
field.dataType, ArrayType
field.dataType, ArrayType,
):
return

Expand All @@ -188,7 +190,7 @@ def get_original_nullability(field: StructField, result_dict: dict) -> None:
def fix_nullability(field: StructField, result_dict: dict) -> None:
field.nullable = result_dict[field.name]
if not isinstance(field.dataType, StructType) and not isinstance(
field.dataType, ArrayType
field.dataType, ArrayType,
):
return

Expand Down Expand Up @@ -218,10 +220,8 @@ def fix_nullability(field: StructField, result_dict: dict) -> None:
return top_level_sorted_df

is_nested: bool = any(
[
isinstance(i.dataType, StructType) or isinstance(i.dataType, ArrayType)
isinstance(i.dataType, StructType | ArrayType)
for i in top_level_sorted_df.schema
]
)

if not is_nested:
Expand All @@ -236,9 +236,7 @@ def fix_nullability(field: StructField, result_dict: dict) -> None:
for field in output.schema:
fix_nullability(field, result_dict)

final_df = output.sparkSession.createDataFrame(output.rdd, output.schema)
return final_df
return df.select(*sorted_col_names)
return output.sparkSession.createDataFrame(output.rdd, output.schema)


def flatten_struct(df: DataFrame, col_name: str, separator: str = ":") -> DataFrame:
Expand Down Expand Up @@ -287,7 +285,7 @@ def flatten_dataframe(
df: DataFrame,
separator: str = ":",
replace_char: str = "_",
sanitized_columns: bool = False, # noqa: FBT001, FBT002
sanitized_columns: bool = False,
) -> DataFrame:
"""Flattens the complex columns in the DataFrame.
Expand Down Expand Up @@ -356,7 +354,7 @@ def explode_array(df: DataFrame, col_name: str) -> DataFrame:
:rtype: DataFrame
"""
return df.select(
"*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name)
"*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name),
).drop(
col_name,
)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_split_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import chispa
import pytest

from pyspark.errors.exceptions.captured import PythonException


@auto_inject_fixtures("spark")
def test_split_columns(spark):
Expand Down Expand Up @@ -50,5 +52,5 @@ def test_split_columns_strict(spark):
delimiter="XX",
new_col_names=["student_first_name", "student_middle_name", "student_last_name"],
mode="strict", default="hi")
with pytest.raises(IndexError):
with pytest.raises(PythonException):
df2.show()

0 comments on commit 173493e

Please sign in to comment.