From 173493e2109f29705b116dfb7e7a6e67c93c32c2 Mon Sep 17 00:00:00 2001 From: SemyonSinchenko Date: Thu, 16 Nov 2023 19:51:51 +0000 Subject: [PATCH 1/3] Fix tests and linter On branch feature/fix-tests Changes to be committed: modified: pyproject.toml modified: quinn/dataframe_validator.py modified: quinn/schema_helpers.py modified: quinn/split_columns.py modified: quinn/transformations.py modified: tests/test_split_columns.py --- pyproject.toml | 13 ++++++++---- quinn/dataframe_validator.py | 2 +- quinn/schema_helpers.py | 5 ++--- quinn/split_columns.py | 3 +-- quinn/transformations.py | 38 +++++++++++++++++------------------- tests/test_split_columns.py | 4 +++- 6 files changed, 34 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 93170dd1..6da00031 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,10 +66,15 @@ select = ["ALL"] line-length = 150 ignore = [ "D100", - "D203", # Ignore blank line before summary of class - "D213", # Ignore multiline summary second line - "T201", # Allow print() in code. - "D401", # Docstrings should be in imperative modes + "D203", # Ignore blank line before summary of class + "D213", # Ignore multiline summary second line + "T201", # Allow print() in code. + "D401", # Docstrings should be in imperative modes + "D404", # Boring thing about how to write docsrings + "FBT001", # Boolean positional arg is OK + "FBT002", # Boolean default arg value is OK + "D205", # It is broken + "TCH003", # I have no idea what is it about ] extend-exclude = ["tests", "docs"] diff --git a/quinn/dataframe_validator.py b/quinn/dataframe_validator.py index b591a9ec..450b6d10 100644 --- a/quinn/dataframe_validator.py +++ b/quinn/dataframe_validator.py @@ -41,7 +41,7 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) - def validate_schema( df: DataFrame, required_schema: StructType, - ignore_nullable: bool = False, # noqa: FBT001,FBT002 + ignore_nullable: bool = False, ) -> None: """Function that validate if a given DataFrame has a given StructType as its schema. diff --git a/quinn/schema_helpers.py b/quinn/schema_helpers.py index b3bc6ab6..6fe382d3 100644 --- a/quinn/schema_helpers.py +++ b/quinn/schema_helpers.py @@ -4,7 +4,6 @@ from pyspark.sql import SparkSession from pyspark.sql import types as T # noqa: N812 -from typing import Union def print_schema_as_code(dtype: T.DataType) -> str: @@ -56,7 +55,7 @@ def print_schema_as_code(dtype: T.DataType) -> str: def _repr_column(column: T.StructField) -> str: res = [] - if isinstance(column.dataType, (T.ArrayType, T.MapType, T.StructType)): + if isinstance(column.dataType, (T.ArrayType | T.MapType | T.StructType)): res.append(f'StructField(\n\t"{column.name}",') for line in print_schema_as_code(column.dataType).split("\n"): res.append("\n\t") @@ -166,6 +165,6 @@ def complex_fields(schema: T.StructType) -> dict[str, object]: return { field.name: field.dataType for field in schema.fields - if isinstance(field.dataType, (T.ArrayType, T.StructType, T.MapType)) + if isinstance(field.dataType, (T.ArrayType | T.StructType | T.MapType)) } diff --git a/quinn/split_columns.py b/quinn/split_columns.py index 96b5e5c5..43f6d982 100644 --- a/quinn/split_columns.py +++ b/quinn/split_columns.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING -from typing import Optional from pyspark.sql.functions import length, split, trim, udf, when from pyspark.sql.types import IntegerType @@ -16,7 +15,7 @@ def split_col( # noqa: PLR0913 delimiter: str, new_col_names: list[str], mode: str = "permissive", - default: Optional[str] = None, + default: str | None = None, ) -> DataFrame: """Splits the given column based on the delimiter and creates new columns with the split values. diff --git a/quinn/transformations.py b/quinn/transformations.py index 7c5c7803..fbda2eee 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -1,9 +1,12 @@ from __future__ import annotations + import re -import pyspark.sql.functions as F # noqa: N812 from collections.abc import Callable + from pyspark.sql import DataFrame +from pyspark.sql import functions as F # noqa: N812 from pyspark.sql.types import ArrayType, MapType, StructField, StructType + from quinn.schema_helpers import complex_fields @@ -83,8 +86,8 @@ def to_snake_case(s: str) -> str: return s.lower().replace(" ", "_") -def sort_columns( - df: DataFrame, sort_order: str, sort_nested: bool = False +def sort_columns( # noqa: C901,PLR0915 + df: DataFrame, sort_order: str, sort_nested: bool = False, ) -> DataFrame: """This function sorts the columns of a given DataFrame based on a given sort order. The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to @@ -101,13 +104,13 @@ def sort_columns( :rtype: pyspark.sql.DataFrame """ - def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]: + def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]: # noqa: ANN001 # recursively check nested fields and sort them # https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark # Credits: @pault for logic def parse_fields( - fields_to_sort: list, parent_struct, is_reversed: bool + fields_to_sort: list, parent_struct, is_reversed: bool, # noqa: ANN001 ) -> list: sorted_fields: list = sorted( fields_to_sort, @@ -123,7 +126,7 @@ def parse_fields( new_base_field = base_field + "." + new_base_field results.extend( - sort_nested_cols(new_struct, is_reversed, base_field=new_base_field) + sort_nested_cols(new_struct, is_reversed, base_field=new_base_field), ) return results @@ -157,11 +160,10 @@ def parse_fields( # ex: struct(address.zip.first5, address.zip.last4) AS zip result = f"struct({', '.join(sub_fields)}) AS {parent_struct.name}" + elif base_field: + result = f"{base_field}.{parent_struct.name}" else: - if base_field: - result = f"{base_field}.{parent_struct.name}" - else: - result = parent_struct.name + result = parent_struct.name select_cols.append(result) return select_cols @@ -173,7 +175,7 @@ def get_original_nullability(field: StructField, result_dict: dict) -> None: result_dict[field.name] = True if not isinstance(field.dataType, StructType) and not isinstance( - field.dataType, ArrayType + field.dataType, ArrayType, ): return @@ -188,7 +190,7 @@ def get_original_nullability(field: StructField, result_dict: dict) -> None: def fix_nullability(field: StructField, result_dict: dict) -> None: field.nullable = result_dict[field.name] if not isinstance(field.dataType, StructType) and not isinstance( - field.dataType, ArrayType + field.dataType, ArrayType, ): return @@ -218,10 +220,8 @@ def fix_nullability(field: StructField, result_dict: dict) -> None: return top_level_sorted_df is_nested: bool = any( - [ - isinstance(i.dataType, StructType) or isinstance(i.dataType, ArrayType) + isinstance(i.dataType, StructType | ArrayType) for i in top_level_sorted_df.schema - ] ) if not is_nested: @@ -236,9 +236,7 @@ def fix_nullability(field: StructField, result_dict: dict) -> None: for field in output.schema: fix_nullability(field, result_dict) - final_df = output.sparkSession.createDataFrame(output.rdd, output.schema) - return final_df - return df.select(*sorted_col_names) + return output.sparkSession.createDataFrame(output.rdd, output.schema) def flatten_struct(df: DataFrame, col_name: str, separator: str = ":") -> DataFrame: @@ -287,7 +285,7 @@ def flatten_dataframe( df: DataFrame, separator: str = ":", replace_char: str = "_", - sanitized_columns: bool = False, # noqa: FBT001, FBT002 + sanitized_columns: bool = False, ) -> DataFrame: """Flattens the complex columns in the DataFrame. @@ -356,7 +354,7 @@ def explode_array(df: DataFrame, col_name: str) -> DataFrame: :rtype: DataFrame """ return df.select( - "*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name) + "*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name), ).drop( col_name, ) diff --git a/tests/test_split_columns.py b/tests/test_split_columns.py index 6e0c3d19..e976df1d 100644 --- a/tests/test_split_columns.py +++ b/tests/test_split_columns.py @@ -3,6 +3,8 @@ import chispa import pytest +from pyspark.errors.exceptions.captured import PythonException + @auto_inject_fixtures("spark") def test_split_columns(spark): @@ -50,5 +52,5 @@ def test_split_columns_strict(spark): delimiter="XX", new_col_names=["student_first_name", "student_middle_name", "student_last_name"], mode="strict", default="hi") - with pytest.raises(IndexError): + with pytest.raises(PythonException): df2.show() From 98e854ce8bf67ee9a9377ca75924f44d110289df Mon Sep 17 00:00:00 2001 From: SemyonSinchenko Date: Thu, 16 Nov 2023 19:55:50 +0000 Subject: [PATCH 2/3] Fix linter v2 On branch feature/fix-tests Changes to be committed: modified: pyproject.toml modified: quinn/extensions/column_ext.py modified: quinn/schema_helpers.py modified: quinn/split_columns.py --- pyproject.toml | 19 ++++++++++--------- quinn/extensions/column_ext.py | 2 +- quinn/schema_helpers.py | 9 ++++----- quinn/split_columns.py | 4 ++-- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6da00031..b90b0cd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,15 +66,16 @@ select = ["ALL"] line-length = 150 ignore = [ "D100", - "D203", # Ignore blank line before summary of class - "D213", # Ignore multiline summary second line - "T201", # Allow print() in code. - "D401", # Docstrings should be in imperative modes - "D404", # Boring thing about how to write docsrings - "FBT001", # Boolean positional arg is OK - "FBT002", # Boolean default arg value is OK - "D205", # It is broken - "TCH003", # I have no idea what is it about + "D203", # Ignore blank line before summary of class + "D213", # Ignore multiline summary second line + "T201", # Allow print() in code. + "D401", # Docstrings should be in imperative modes + "D404", # Boring thing about how to write docsrings + "FBT001", # Boolean positional arg is OK + "FBT002", # Boolean default arg value is OK + "D205", # It is broken + "TCH003", # I have no idea what is it about + "PLC1901", # Strange thing ] extend-exclude = ["tests", "docs"] diff --git a/quinn/extensions/column_ext.py b/quinn/extensions/column_ext.py index 99b96982..f44ea006 100644 --- a/quinn/extensions/column_ext.py +++ b/quinn/extensions/column_ext.py @@ -63,7 +63,7 @@ def isNullOrBlank(self: Column) -> Column: blank characters, or ``False`` otherwise. :rtype: Column """ - return (self.isNull()) | (trim(self) == "") # noqa: PLC1901 + return (self.isNull()) | (trim(self) == "") def isNotIn(self: Column, _list: list[Any]) -> Column: diff --git a/quinn/schema_helpers.py b/quinn/schema_helpers.py index 6fe382d3..9542a377 100644 --- a/quinn/schema_helpers.py +++ b/quinn/schema_helpers.py @@ -42,12 +42,11 @@ def print_schema_as_code(dtype: T.DataType) -> str: elif isinstance(dtype, T.DecimalType): res.append(f"DecimalType({dtype.precision}, {dtype.scale})") - else: + elif str(dtype).endswith("()"): # PySpark 3.3+ - if str(dtype).endswith("()"): # noqa: PLR5501 - res.append(str(dtype)) - else: - res.append(f"{dtype}()") + res.append(str(dtype)) + else: + res.append(f"{dtype}()") return "".join(res) diff --git a/quinn/split_columns.py b/quinn/split_columns.py index 43f6d982..7b7186ba 100644 --- a/quinn/split_columns.py +++ b/quinn/split_columns.py @@ -68,7 +68,7 @@ def _num_delimiter(col_value1: str) -> int: # If the length of split_value is same as new_col_names, check if any of the split values is None or empty string elif any( # noqa: RET506 - x is None or x.strip() == "" for x in split_value[: len(new_col_names)] # noqa: PLC1901 + x is None or x.strip() == "" for x in split_value[: len(new_col_names)] ): msg = "Null or empty values are not accepted for columns in strict mode" raise ValueError( @@ -93,7 +93,7 @@ def _num_delimiter(col_value1: str) -> int: if mode == "strict": # Create an array of select expressions to create new columns from the split values select_exprs = [ - when(split_col_expr.getItem(i) != "", split_col_expr.getItem(i)).alias( # noqa: PLC1901 + when(split_col_expr.getItem(i) != "", split_col_expr.getItem(i)).alias( new_col_names[i], ) for i in range(len(new_col_names)) From 5dd0f111ba2c916ec3d15146322e99ab67649e78 Mon Sep 17 00:00:00 2001 From: SemyonSinchenko Date: Fri, 17 Nov 2023 15:45:02 +0000 Subject: [PATCH 3/3] Fix x | y On branch feature/fix-tests Changes to be committed: modified: pyproject.toml modified: quinn/split_columns.py --- pyproject.toml | 1 + quinn/split_columns.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b90b0cd0..ff72e82c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ ignore = [ "D205", # It is broken "TCH003", # I have no idea what is it about "PLC1901", # Strange thing + "UP007", # Not supported in py3.6 ] extend-exclude = ["tests", "docs"] diff --git a/quinn/split_columns.py b/quinn/split_columns.py index 7b7186ba..1342b642 100644 --- a/quinn/split_columns.py +++ b/quinn/split_columns.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from pyspark.sql.functions import length, split, trim, udf, when from pyspark.sql.types import IntegerType @@ -15,7 +15,7 @@ def split_col( # noqa: PLR0913 delimiter: str, new_col_names: list[str], mode: str = "permissive", - default: str | None = None, + default: Optional[str] = None, ) -> DataFrame: """Splits the given column based on the delimiter and creates new columns with the split values.