Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Mismatching while Serializing Dataclass with Union #2859

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,31 @@
"""
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset

# Handle Optional
if UnionTransformer.is_optional_type(python_type):
if python_val is None:
return None
return self._make_dataclass_serializable(python_val, get_args(python_type)[0])

def get_expected_type(python_val: T, types: tuple) -> Type[T | None]:
if len(set(types) & {FlyteFile, FlyteDirectory, StructuredDataset}) > 1:
raise ValueError("Cannot have two Flyte types in a Union type")

for t in types:
try:
trans = TypeEngine.get_transformer(t) # type: ignore
if trans:
trans.assert_type(t, python_val)
return t
except Exception:
continue
return type(None)

Check warning on line 696 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L696

Added line #L696 was not covered by tests

# Get the expected type in the Union type

Check warning on line 698 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L698

Added line #L698 was not covered by tests
expected_type = type(None)
if python_val is not None:
expected_type = get_expected_type(python_val, get_args(python_type)) # type: ignore

Check warning on line 702 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L701-L702

Added lines #L701 - L702 were not covered by tests
return self._make_dataclass_serializable(python_val, expected_type)

if hasattr(python_type, "__origin__") and get_origin(python_type) is list:
if python_val is None:
Expand Down
14 changes: 14 additions & 0 deletions tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,3 +1118,17 @@ def empty_nested_dc_wf() -> NestedFlyteTypes:

empty_nested_flyte_types = empty_nested_dc_wf()
DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types)

def test_dataclass_serialize_with_multiple_dataclass_union():
@dataclass
class A():
x: int

@dataclass
class B():
x: FlyteFile

b = B(x="s3://my-bucket/my-file")
res = DataclassTransformer()._make_dataclass_serializable(b, Union[None, A, B])

assert res.x.path == "s3://my-bucket/my-file"
16 changes: 16 additions & 0 deletions tests/flytekit/unit/core/test_flytetypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass
from flytekit.types.file import FlyteFile
from flytekit.types.structured.structured_dataset import StructuredDataset
from flytekit.core.type_engine import DataclassTransformer
from typing import Union
import pytest

def test_dataclass_union_with_multiple_flytetypes_error():
@dataclass
class DC():
x: Union[None, FlyteFile, StructuredDataset]


dc = DC(x="s3://my-bucket/my-file")
with pytest.raises(ValueError, match="Cannot have two Flyte types in a Union type"):
DataclassTransformer()._make_dataclass_serializable(dc, DC)
4 changes: 4 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,7 @@ class TestFileStruct(DataClassJsonMixin):
b: typing.Optional[FlyteFile]
b_prime: typing.Optional[FlyteFile]
c: typing.Union[FlyteFile, None]
c_prime: typing.Union[None, FlyteFile]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this

Suggested change
c_prime: typing.Union[None, FlyteFile]
c_prime: typing.Union[None, StructuredDataset, int, FlyteFile]

Copy link
Contributor

@wild-endeavor wild-endeavor Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually can you write one more unit test for me please? (and add it under test_dataclass.py this file is getting too big).

@dataclass
class A():
  x: int

@dataclass
class B():
   x: FlyteFile

then call _make_dataclass_serializable on Union[None, A, B] where b = B(x="s3://tmp) or something.

d: typing.List[FlyteFile]
e: typing.List[typing.Optional[FlyteFile]]
e_prime: typing.List[typing.Optional[FlyteFile]]
Expand All @@ -989,6 +990,7 @@ class TestFileStruct(DataClassJsonMixin):
b=f1,
b_prime=None,
c=f1,
c_prime=f1,
d=[f1],
e=[f1],
e_prime=[None],
Expand All @@ -1011,6 +1013,7 @@ class TestFileStruct(DataClassJsonMixin):
assert dict_obj["b"]["path"] == remote_path
assert dict_obj["b_prime"] is None
assert dict_obj["c"]["path"] == remote_path
assert dict_obj["c_prime"]["path"] == remote_path
assert dict_obj["d"][0]["path"] == remote_path
assert dict_obj["e"][0]["path"] == remote_path
assert dict_obj["e_prime"][0] is None
Expand All @@ -1028,6 +1031,7 @@ class TestFileStruct(DataClassJsonMixin):
assert o.b.remote_path == ot.b.remote_source
assert ot.b_prime is None
assert o.c.remote_path == ot.c.remote_source
assert o.c_prime.remote_path == ot.c_prime.remote_source
assert o.d[0].remote_path == ot.d[0].remote_source
assert o.e[0].remote_path == ot.e[0].remote_source
assert o.e_prime == [None]
Expand Down
Loading