Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TypeTransformer] Support frozen dataclasses #2823

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.
elif dataclasses.is_dataclass(python_type):
for field in dataclasses.fields(python_type):
val = python_val.__getattribute__(field.name)
python_val.__setattr__(field.name, self._fix_structured_dataset_type(field.type, val))
object.__setattr__(python_val, field.name, self._fix_structured_dataset_type(field.type, val))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I learned it

return python_val

def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any:
Expand Down Expand Up @@ -714,7 +714,7 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t
dataclass_attributes = typing.get_type_hints(python_type)
for n, t in dataclass_attributes.items():
val = python_val.__getattribute__(n)
python_val.__setattr__(n, self._make_dataclass_serializable(val, t))
object.__setattr__(python_val, n, self._make_dataclass_serializable(val, t))
return python_val

def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
Expand Down Expand Up @@ -757,7 +757,7 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An
# Thus we will have to walk the given dataclass and typecast values to int, where expected.
for f in dataclasses.fields(dc_type):
val = getattr(dc, f.name)
setattr(dc, f.name, self._fix_val_int(f.type, val))
object.__setattr__(dc, f.name, self._fix_val_int(f.type, val))

return dc

Expand Down
167 changes: 167 additions & 0 deletions tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,3 +951,170 @@ def my_task(dc: DC) -> DC:
return dc

my_task(dc=DC())

def test_frozen_dataclass():
@dataclass(frozen=True)
class FrozenDataclass:
a: int = 1
b: float = 2.0
c: bool = True
d: str = "hello"

@task
def t1(dc: FrozenDataclass) -> (int, float, bool, str):
return dc.a, dc.b, dc.c, dc.d

a, b, c, d = t1(dc=FrozenDataclass())
assert a == 1
assert b == 2.0
assert c == True
assert d == "hello"

def test_pure_frozen_dataclasses_with_python_types():
@dataclass(frozen=True)
class DC:
string: Optional[str] = None

@dataclass(frozen=True)
class DCWithOptional:
string: Optional[str] = None
dc: Optional[DC] = None
list_dc: Optional[List[DC]] = None
list_list_dc: Optional[List[List[DC]]] = None
dict_dc: Optional[Dict[str, DC]] = None
dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None
dict_list_dc: Optional[Dict[str, List[DC]]] = None
list_dict_dc: Optional[List[Dict[str, DC]]] = None

@task
def t1() -> DCWithOptional:
return DCWithOptional(string="a", dc=DC(string="b"),
list_dc=[DC(string="c"), DC(string="d")],
list_list_dc=[[DC(string="e"), DC(string="f")]],
list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")},
{"k": DC(string="l"), "m": DC(string="n")}],
dict_dc={"o": DC(string="p"), "q": DC(string="r")},
dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}},
dict_list_dc={"x": [DC(string="y"), DC(string="z")],
"aa": [DC(string="bb"), DC(string="cc")]},)

@task
def t2() -> DCWithOptional:
return DCWithOptional()

output = DCWithOptional(string="a", dc=DC(string="b"),
list_dc=[DC(string="c"), DC(string="d")],
list_list_dc=[[DC(string="e"), DC(string="f")]],
list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")},
{"k": DC(string="l"), "m": DC(string="n")}],
dict_dc={"o": DC(string="p"), "q": DC(string="r")},
dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}},
dict_list_dc={"x": [DC(string="y"), DC(string="z")],
"aa": [DC(string="bb"), DC(string="cc")]}, )

dc1 = t1()
dc2 = t2()

assert dc1 == output
assert dc2.string is None
assert dc2.dc is None

DataclassTransformer().assert_type(DCWithOptional, dc1)
DataclassTransformer().assert_type(DCWithOptional, dc2)

def test_pure_frozen_dataclasses_with_flyte_types(local_dummy_txt_file, local_dummy_directory):
@dataclass(frozen=True)
class FlyteTypes:
flytefile: Optional[FlyteFile] = None
flytedir: Optional[FlyteDirectory] = None
structured_dataset: Optional[StructuredDataset] = None

@dataclass(frozen=True)
class NestedFlyteTypes:
flytefile: Optional[FlyteFile] = None
flytedir: Optional[FlyteDirectory] = None
structured_dataset: Optional[StructuredDataset] = None
flyte_types: Optional[FlyteTypes] = None
list_flyte_types: Optional[List[FlyteTypes]] = None
dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None
optional_flyte_types: Optional[FlyteTypes] = None

@task
def pass_and_return_flyte_types(nested_flyte_types: NestedFlyteTypes) -> NestedFlyteTypes:
return nested_flyte_types

@task
def generate_sd() -> StructuredDataset:
return StructuredDataset(
uri="s3://my-s3-bucket/data/test_sd",
file_format="parquet")

@task
def create_local_dir(path: str) -> FlyteDirectory:
return FlyteDirectory(path=path)

@task
def create_local_dir_by_str(path: str) -> FlyteDirectory:
return path

@task
def create_local_file(path: str) -> FlyteFile:
return FlyteFile(path=path)

@task
def create_local_file_with_str(path: str) -> FlyteFile:
return path

@task
def generate_nested_flyte_types(local_file: FlyteFile, local_dir: FlyteDirectory, sd: StructuredDataset,
local_file_by_str: FlyteFile,
local_dir_by_str: FlyteDirectory, ) -> NestedFlyteTypes:
ft = FlyteTypes(
flytefile=local_file,
flytedir=local_dir,
structured_dataset=sd,
)

return NestedFlyteTypes(
flytefile=local_file,
flytedir=local_dir,
structured_dataset=sd,
flyte_types=FlyteTypes(
flytefile=local_file_by_str,
flytedir=local_dir_by_str,
structured_dataset=sd,
),
list_flyte_types=[ft, ft, ft],
dict_flyte_types={"a": ft, "b": ft, "c": ft},
)

@workflow
def nested_dc_wf(txt_path: str, dir_path: str) -> NestedFlyteTypes:
local_file = create_local_file(path=txt_path)
local_dir = create_local_dir(path=dir_path)
local_file_by_str = create_local_file_with_str(path=txt_path)
local_dir_by_str = create_local_dir_by_str(path=dir_path)
sd = generate_sd()
nested_flyte_types = generate_nested_flyte_types(
local_file=local_file,
local_dir=local_dir,
local_file_by_str=local_file_by_str,
local_dir_by_str=local_dir_by_str,
sd=sd
)
old_flyte_types = pass_and_return_flyte_types(nested_flyte_types=nested_flyte_types)
return pass_and_return_flyte_types(nested_flyte_types=old_flyte_types)

@task
def get_empty_nested_type() -> NestedFlyteTypes:
return NestedFlyteTypes()

@workflow
def empty_nested_dc_wf() -> NestedFlyteTypes:
return get_empty_nested_type()

nested_flyte_types = nested_dc_wf(txt_path=local_dummy_txt_file, dir_path=local_dummy_directory)
DataclassTransformer().assert_type(NestedFlyteTypes, nested_flyte_types)

empty_nested_flyte_types = empty_nested_dc_wf()
DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types)
Loading