diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 87199eb415..459a895546 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -663,7 +663,7 @@ def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing. elif dataclasses.is_dataclass(python_type): for field in dataclasses.fields(python_type): val = python_val.__getattribute__(field.name) - python_val.__setattr__(field.name, self._fix_structured_dataset_type(field.type, val)) + object.__setattr__(python_val, field.name, self._fix_structured_dataset_type(field.type, val)) return python_val def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any: @@ -714,7 +714,7 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t dataclass_attributes = typing.get_type_hints(python_type) for n, t in dataclass_attributes.items(): val = python_val.__getattribute__(n) - python_val.__setattr__(n, self._make_dataclass_serializable(val, t)) + object.__setattr__(python_val, n, self._make_dataclass_serializable(val, t)) return python_val def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: @@ -757,7 +757,7 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An # Thus we will have to walk the given dataclass and typecast values to int, where expected. for f in dataclasses.fields(dc_type): val = getattr(dc, f.name) - setattr(dc, f.name, self._fix_val_int(f.type, val)) + object.__setattr__(dc, f.name, self._fix_val_int(f.type, val)) return dc diff --git a/tests/flytekit/unit/core/test_dataclass.py b/tests/flytekit/unit/core/test_dataclass.py index 33343e7c8c..58dfcd1e45 100644 --- a/tests/flytekit/unit/core/test_dataclass.py +++ b/tests/flytekit/unit/core/test_dataclass.py @@ -951,3 +951,170 @@ def my_task(dc: DC) -> DC: return dc my_task(dc=DC()) + +def test_frozen_dataclass(): + @dataclass(frozen=True) + class FrozenDataclass: + a: int = 1 + b: float = 2.0 + c: bool = True + d: str = "hello" + + @task + def t1(dc: FrozenDataclass) -> (int, float, bool, str): + return dc.a, dc.b, dc.c, dc.d + + a, b, c, d = t1(dc=FrozenDataclass()) + assert a == 1 + assert b == 2.0 + assert c == True + assert d == "hello" + +def test_pure_frozen_dataclasses_with_python_types(): + @dataclass(frozen=True) + class DC: + string: Optional[str] = None + + @dataclass(frozen=True) + class DCWithOptional: + string: Optional[str] = None + dc: Optional[DC] = None + list_dc: Optional[List[DC]] = None + list_list_dc: Optional[List[List[DC]]] = None + dict_dc: Optional[Dict[str, DC]] = None + dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None + dict_list_dc: Optional[Dict[str, List[DC]]] = None + list_dict_dc: Optional[List[Dict[str, DC]]] = None + + @task + def t1() -> DCWithOptional: + return DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]},) + + @task + def t2() -> DCWithOptional: + return DCWithOptional() + + output = DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]}, ) + + dc1 = t1() + dc2 = t2() + + assert dc1 == output + assert dc2.string is None + assert dc2.dc is None + + DataclassTransformer().assert_type(DCWithOptional, dc1) + DataclassTransformer().assert_type(DCWithOptional, dc2) + +def test_pure_frozen_dataclasses_with_flyte_types(local_dummy_txt_file, local_dummy_directory): + @dataclass(frozen=True) + class FlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + + @dataclass(frozen=True) + class NestedFlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + flyte_types: Optional[FlyteTypes] = None + list_flyte_types: Optional[List[FlyteTypes]] = None + dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None + optional_flyte_types: Optional[FlyteTypes] = None + + @task + def pass_and_return_flyte_types(nested_flyte_types: NestedFlyteTypes) -> NestedFlyteTypes: + return nested_flyte_types + + @task + def generate_sd() -> StructuredDataset: + return StructuredDataset( + uri="s3://my-s3-bucket/data/test_sd", + file_format="parquet") + + @task + def create_local_dir(path: str) -> FlyteDirectory: + return FlyteDirectory(path=path) + + @task + def create_local_dir_by_str(path: str) -> FlyteDirectory: + return path + + @task + def create_local_file(path: str) -> FlyteFile: + return FlyteFile(path=path) + + @task + def create_local_file_with_str(path: str) -> FlyteFile: + return path + + @task + def generate_nested_flyte_types(local_file: FlyteFile, local_dir: FlyteDirectory, sd: StructuredDataset, + local_file_by_str: FlyteFile, + local_dir_by_str: FlyteDirectory, ) -> NestedFlyteTypes: + ft = FlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + ) + + return NestedFlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + flyte_types=FlyteTypes( + flytefile=local_file_by_str, + flytedir=local_dir_by_str, + structured_dataset=sd, + ), + list_flyte_types=[ft, ft, ft], + dict_flyte_types={"a": ft, "b": ft, "c": ft}, + ) + + @workflow + def nested_dc_wf(txt_path: str, dir_path: str) -> NestedFlyteTypes: + local_file = create_local_file(path=txt_path) + local_dir = create_local_dir(path=dir_path) + local_file_by_str = create_local_file_with_str(path=txt_path) + local_dir_by_str = create_local_dir_by_str(path=dir_path) + sd = generate_sd() + nested_flyte_types = generate_nested_flyte_types( + local_file=local_file, + local_dir=local_dir, + local_file_by_str=local_file_by_str, + local_dir_by_str=local_dir_by_str, + sd=sd + ) + old_flyte_types = pass_and_return_flyte_types(nested_flyte_types=nested_flyte_types) + return pass_and_return_flyte_types(nested_flyte_types=old_flyte_types) + + @task + def get_empty_nested_type() -> NestedFlyteTypes: + return NestedFlyteTypes() + + @workflow + def empty_nested_dc_wf() -> NestedFlyteTypes: + return get_empty_nested_type() + + nested_flyte_types = nested_dc_wf(txt_path=local_dummy_txt_file, dir_path=local_dummy_directory) + DataclassTransformer().assert_type(NestedFlyteTypes, nested_flyte_types) + + empty_nested_flyte_types = empty_nested_dc_wf() + DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types)