Skip to content

Commit

Permalink
Merge pull request #148 from MrPowers/feature/fix-tests
Browse files Browse the repository at this point in the history
Fix tests and linter
  • Loading branch information
SemyonSinchenko authored Nov 18, 2023
2 parents 79262db + 5dd0f11 commit 2f2e012
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 39 deletions.
15 changes: 11 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,17 @@ select = ["ALL"]
line-length = 150
ignore = [
"D100",
"D203", # Ignore blank line before summary of class
"D213", # Ignore multiline summary second line
"T201", # Allow print() in code.
"D401", # Docstrings should be in imperative modes
"D203", # Ignore blank line before summary of class
"D213", # Ignore multiline summary second line
"T201", # Allow print() in code.
"D401", # Docstrings should be in imperative modes
"D404", # Boring thing about how to write docsrings
"FBT001", # Boolean positional arg is OK
"FBT002", # Boolean default arg value is OK
"D205", # It is broken
"TCH003", # I have no idea what is it about
"PLC1901", # Strange thing
"UP007", # Not supported in py3.6
]
extend-exclude = ["tests", "docs"]

Expand Down
2 changes: 1 addition & 1 deletion quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) -
def validate_schema(
df: DataFrame,
required_schema: StructType,
ignore_nullable: bool = False, # noqa: FBT001,FBT002
ignore_nullable: bool = False,
) -> None:
"""Function that validate if a given DataFrame has a given StructType as its schema.
Expand Down
2 changes: 1 addition & 1 deletion quinn/extensions/column_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def isNullOrBlank(self: Column) -> Column:
blank characters, or ``False`` otherwise.
:rtype: Column
"""
return (self.isNull()) | (trim(self) == "") # noqa: PLC1901
return (self.isNull()) | (trim(self) == "")


def isNotIn(self: Column, _list: list[Any]) -> Column:
Expand Down
14 changes: 6 additions & 8 deletions quinn/schema_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pyspark.sql import SparkSession
from pyspark.sql import types as T # noqa: N812
from typing import Union


def print_schema_as_code(dtype: T.DataType) -> str:
Expand Down Expand Up @@ -43,20 +42,19 @@ def print_schema_as_code(dtype: T.DataType) -> str:
elif isinstance(dtype, T.DecimalType):
res.append(f"DecimalType({dtype.precision}, {dtype.scale})")

else:
elif str(dtype).endswith("()"):
# PySpark 3.3+
if str(dtype).endswith("()"): # noqa: PLR5501
res.append(str(dtype))
else:
res.append(f"{dtype}()")
res.append(str(dtype))
else:
res.append(f"{dtype}()")

return "".join(res)


def _repr_column(column: T.StructField) -> str:
res = []

if isinstance(column.dataType, (T.ArrayType, T.MapType, T.StructType)):
if isinstance(column.dataType, (T.ArrayType | T.MapType | T.StructType)):
res.append(f'StructField(\n\t"{column.name}",')
for line in print_schema_as_code(column.dataType).split("\n"):
res.append("\n\t")
Expand Down Expand Up @@ -166,6 +164,6 @@ def complex_fields(schema: T.StructType) -> dict[str, object]:
return {
field.name: field.dataType
for field in schema.fields
if isinstance(field.dataType, (T.ArrayType, T.StructType, T.MapType))
if isinstance(field.dataType, (T.ArrayType | T.StructType | T.MapType))
}

7 changes: 3 additions & 4 deletions quinn/split_columns.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Optional
from typing import TYPE_CHECKING, Optional

from pyspark.sql.functions import length, split, trim, udf, when
from pyspark.sql.types import IntegerType
Expand Down Expand Up @@ -69,7 +68,7 @@ def _num_delimiter(col_value1: str) -> int:

# If the length of split_value is same as new_col_names, check if any of the split values is None or empty string
elif any( # noqa: RET506
x is None or x.strip() == "" for x in split_value[: len(new_col_names)] # noqa: PLC1901
x is None or x.strip() == "" for x in split_value[: len(new_col_names)]
):
msg = "Null or empty values are not accepted for columns in strict mode"
raise ValueError(
Expand All @@ -94,7 +93,7 @@ def _num_delimiter(col_value1: str) -> int:
if mode == "strict":
# Create an array of select expressions to create new columns from the split values
select_exprs = [
when(split_col_expr.getItem(i) != "", split_col_expr.getItem(i)).alias( # noqa: PLC1901
when(split_col_expr.getItem(i) != "", split_col_expr.getItem(i)).alias(
new_col_names[i],
)
for i in range(len(new_col_names))
Expand Down
38 changes: 18 additions & 20 deletions quinn/transformations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import re
import pyspark.sql.functions as F # noqa: N812
from collections.abc import Callable

from pyspark.sql import DataFrame
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import ArrayType, MapType, StructField, StructType

from quinn.schema_helpers import complex_fields


Expand Down Expand Up @@ -83,8 +86,8 @@ def to_snake_case(s: str) -> str:
return s.lower().replace(" ", "_")


def sort_columns(
df: DataFrame, sort_order: str, sort_nested: bool = False
def sort_columns( # noqa: C901,PLR0915
df: DataFrame, sort_order: str, sort_nested: bool = False,
) -> DataFrame:
"""This function sorts the columns of a given DataFrame based on a given sort
order. The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to
Expand All @@ -101,13 +104,13 @@ def sort_columns(
:rtype: pyspark.sql.DataFrame
"""

def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]:
def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]: # noqa: ANN001
# recursively check nested fields and sort them
# https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark
# Credits: @pault for logic

def parse_fields(
fields_to_sort: list, parent_struct, is_reversed: bool
fields_to_sort: list, parent_struct, is_reversed: bool, # noqa: ANN001
) -> list:
sorted_fields: list = sorted(
fields_to_sort,
Expand All @@ -123,7 +126,7 @@ def parse_fields(
new_base_field = base_field + "." + new_base_field

results.extend(
sort_nested_cols(new_struct, is_reversed, base_field=new_base_field)
sort_nested_cols(new_struct, is_reversed, base_field=new_base_field),
)
return results

Expand Down Expand Up @@ -157,11 +160,10 @@ def parse_fields(
# ex: struct(address.zip.first5, address.zip.last4) AS zip
result = f"struct({', '.join(sub_fields)}) AS {parent_struct.name}"

elif base_field:
result = f"{base_field}.{parent_struct.name}"
else:
if base_field:
result = f"{base_field}.{parent_struct.name}"
else:
result = parent_struct.name
result = parent_struct.name
select_cols.append(result)

return select_cols
Expand All @@ -173,7 +175,7 @@ def get_original_nullability(field: StructField, result_dict: dict) -> None:
result_dict[field.name] = True

if not isinstance(field.dataType, StructType) and not isinstance(
field.dataType, ArrayType
field.dataType, ArrayType,
):
return

Expand All @@ -188,7 +190,7 @@ def get_original_nullability(field: StructField, result_dict: dict) -> None:
def fix_nullability(field: StructField, result_dict: dict) -> None:
field.nullable = result_dict[field.name]
if not isinstance(field.dataType, StructType) and not isinstance(
field.dataType, ArrayType
field.dataType, ArrayType,
):
return

Expand Down Expand Up @@ -218,10 +220,8 @@ def fix_nullability(field: StructField, result_dict: dict) -> None:
return top_level_sorted_df

is_nested: bool = any(
[
isinstance(i.dataType, StructType) or isinstance(i.dataType, ArrayType)
isinstance(i.dataType, StructType | ArrayType)
for i in top_level_sorted_df.schema
]
)

if not is_nested:
Expand All @@ -236,9 +236,7 @@ def fix_nullability(field: StructField, result_dict: dict) -> None:
for field in output.schema:
fix_nullability(field, result_dict)

final_df = output.sparkSession.createDataFrame(output.rdd, output.schema)
return final_df
return df.select(*sorted_col_names)
return output.sparkSession.createDataFrame(output.rdd, output.schema)


def flatten_struct(df: DataFrame, col_name: str, separator: str = ":") -> DataFrame:
Expand Down Expand Up @@ -287,7 +285,7 @@ def flatten_dataframe(
df: DataFrame,
separator: str = ":",
replace_char: str = "_",
sanitized_columns: bool = False, # noqa: FBT001, FBT002
sanitized_columns: bool = False,
) -> DataFrame:
"""Flattens the complex columns in the DataFrame.
Expand Down Expand Up @@ -356,7 +354,7 @@ def explode_array(df: DataFrame, col_name: str) -> DataFrame:
:rtype: DataFrame
"""
return df.select(
"*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name)
"*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name),
).drop(
col_name,
)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_split_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import chispa
import pytest

from pyspark.errors.exceptions.captured import PythonException


@auto_inject_fixtures("spark")
def test_split_columns(spark):
Expand Down Expand Up @@ -50,5 +52,5 @@ def test_split_columns_strict(spark):
delimiter="XX",
new_col_names=["student_first_name", "student_middle_name", "student_last_name"],
mode="strict", default="hi")
with pytest.raises(IndexError):
with pytest.raises(PythonException):
df2.show()

0 comments on commit 2f2e012

Please sign in to comment.