From 96c0e5c49213272b08750dce38bc3baab261d796 Mon Sep 17 00:00:00 2001 From: mao3267 Date: Thu, 24 Oct 2024 15:31:43 +0800 Subject: [PATCH 1/7] fix: type matching err for union in dataclass Signed-off-by: mao3267 --- flytekit/core/type_engine.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 900afa8562..6ef2ad5b11 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -679,9 +679,24 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t # Handle Optional if UnionTransformer.is_optional_type(python_type): - if python_val is None: - return None - return self._make_dataclass_serializable(python_val, get_args(python_type)[0]) + + def get_expected_type(python_val: T, types: tuple) -> Type[T | None]: + for t in types: + try: + trans = TypeEngine.get_transformer(t) # type: ignore + if trans: + trans.assert_type(t, python_val) + return t + except Exception: + continue + return type(None) + + # Get the expected type in the Union type + expected_type = type(None) + if python_val is not None: + expected_type = get_expected_type(python_val, get_args(python_type)) # type: ignore + + return self._make_dataclass_serializable(python_val, expected_type) if hasattr(python_type, "__origin__") and get_origin(python_type) is list: if python_val is None: From f647bd3a1b727082210454c3f5d1e652a29b1083 Mon Sep 17 00:00:00 2001 From: mao3267 Date: Thu, 24 Oct 2024 15:39:30 +0800 Subject: [PATCH 2/7] test: add Union[None, FlyteFile] to test optional in dataclass Signed-off-by: mao3267 --- tests/flytekit/unit/core/test_type_engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 8721a8d4db..cb641eebe4 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -967,6 +967,7 @@ class TestFileStruct(DataClassJsonMixin): b: typing.Optional[FlyteFile] b_prime: typing.Optional[FlyteFile] c: typing.Union[FlyteFile, None] + c_prime: typing.Union[None, FlyteFile] d: typing.List[FlyteFile] e: typing.List[typing.Optional[FlyteFile]] e_prime: typing.List[typing.Optional[FlyteFile]] @@ -989,6 +990,7 @@ class TestFileStruct(DataClassJsonMixin): b=f1, b_prime=None, c=f1, + c_prime=f1, d=[f1], e=[f1], e_prime=[None], @@ -1011,6 +1013,7 @@ class TestFileStruct(DataClassJsonMixin): assert dict_obj["b"]["path"] == remote_path assert dict_obj["b_prime"] is None assert dict_obj["c"]["path"] == remote_path + assert dict_obj["c_prime"]["path"] == remote_path assert dict_obj["d"][0]["path"] == remote_path assert dict_obj["e"][0]["path"] == remote_path assert dict_obj["e_prime"][0] is None @@ -1028,6 +1031,7 @@ class TestFileStruct(DataClassJsonMixin): assert o.b.remote_path == ot.b.remote_source assert ot.b_prime is None assert o.c.remote_path == ot.c.remote_source + assert o.c_prime.remote_path == ot.c_prime.remote_source assert o.d[0].remote_path == ot.d[0].remote_source assert o.e[0].remote_path == ot.e[0].remote_source assert o.e_prime == [None] From a7ea0e906c86f5e4f1bfdcd7eca86f0520853896 Mon Sep 17 00:00:00 2001 From: mao3267 Date: Fri, 25 Oct 2024 14:41:35 +0800 Subject: [PATCH 3/7] fix: add error for multiple flytetypes in union while serializing Signed-off-by: mao3267 --- flytekit/core/type_engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 6ef2ad5b11..a49c2a1443 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -676,11 +676,15 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t """ from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile + from flytekit.types.structured import StructuredDataset # Handle Optional if UnionTransformer.is_optional_type(python_type): def get_expected_type(python_val: T, types: tuple) -> Type[T | None]: + if len(set(types) & {FlyteFile, FlyteDirectory, StructuredDataset}) > 1: + raise ValueError("Cannot have two Flyte types in a Union type") + for t in types: try: trans = TypeEngine.get_transformer(t) # type: ignore From dff082955505a666cc52d9cf67095cb1ae6f54c1 Mon Sep 17 00:00:00 2001 From: mao3267 Date: Fri, 25 Oct 2024 14:42:20 +0800 Subject: [PATCH 4/7] test: serialize multiple dataclass in union Signed-off-by: mao3267 --- tests/flytekit/unit/core/test_dataclass.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/flytekit/unit/core/test_dataclass.py b/tests/flytekit/unit/core/test_dataclass.py index 58dfcd1e45..4e098c254b 100644 --- a/tests/flytekit/unit/core/test_dataclass.py +++ b/tests/flytekit/unit/core/test_dataclass.py @@ -1118,3 +1118,17 @@ def empty_nested_dc_wf() -> NestedFlyteTypes: empty_nested_flyte_types = empty_nested_dc_wf() DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types) + +def test_dataclass_serialize_with_multiple_dataclass_union(): + @dataclass + class A(): + x: int + + @dataclass + class B(): + x: FlyteFile + + b = B(x="s3://my-bucket/my-file") + res = DataclassTransformer()._make_dataclass_serializable(b, Union[None, A, B]) + + assert res.x.path == "s3://my-bucket/my-file" From b9a4051fc1ed2ea06aa4260eb6270fbb853e7f4a Mon Sep 17 00:00:00 2001 From: mao3267 Date: Fri, 25 Oct 2024 14:43:04 +0800 Subject: [PATCH 5/7] test: error while serialize union with multiple flytetypes Signed-off-by: mao3267 --- tests/flytekit/unit/core/test_flytetypes.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 tests/flytekit/unit/core/test_flytetypes.py diff --git a/tests/flytekit/unit/core/test_flytetypes.py b/tests/flytekit/unit/core/test_flytetypes.py new file mode 100644 index 0000000000..6e2d483bd2 --- /dev/null +++ b/tests/flytekit/unit/core/test_flytetypes.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from flytekit.types.file import FlyteFile +from flytekit.types.structured.structured_dataset import StructuredDataset +from flytekit.core.type_engine import DataclassTransformer +from typing import Union +import pytest + +def test_dataclass_union_with_multiple_flytetypes_error(): + @dataclass + class DC(): + x: Union[None, FlyteFile, StructuredDataset] + + + dc = DC(x="s3://my-bucket/my-file") + with pytest.raises(ValueError, match="Cannot have two Flyte types in a Union type"): + DataclassTransformer()._make_dataclass_serializable(dc, DC) + \ No newline at end of file From c4c3ce72e77da8a878d69aa5784cac91a84a2dd6 Mon Sep 17 00:00:00 2001 From: mao3267 Date: Fri, 25 Oct 2024 14:50:42 +0800 Subject: [PATCH 6/7] refactor: lint Signed-off-by: mao3267 --- tests/flytekit/unit/core/test_flytetypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/core/test_flytetypes.py b/tests/flytekit/unit/core/test_flytetypes.py index 6e2d483bd2..a8c68e7872 100644 --- a/tests/flytekit/unit/core/test_flytetypes.py +++ b/tests/flytekit/unit/core/test_flytetypes.py @@ -14,4 +14,4 @@ class DC(): dc = DC(x="s3://my-bucket/my-file") with pytest.raises(ValueError, match="Cannot have two Flyte types in a Union type"): DataclassTransformer()._make_dataclass_serializable(dc, DC) - \ No newline at end of file + From 4894ae5c2868c6b667ee31fd2973988b30171a25 Mon Sep 17 00:00:00 2001 From: mao3267 Date: Fri, 25 Oct 2024 14:53:39 +0800 Subject: [PATCH 7/7] refactor: lint Signed-off-by: mao3267 --- tests/flytekit/unit/core/test_flytetypes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/flytekit/unit/core/test_flytetypes.py b/tests/flytekit/unit/core/test_flytetypes.py index a8c68e7872..3fe7ae94cd 100644 --- a/tests/flytekit/unit/core/test_flytetypes.py +++ b/tests/flytekit/unit/core/test_flytetypes.py @@ -14,4 +14,3 @@ class DC(): dc = DC(x="s3://my-bucket/my-file") with pytest.raises(ValueError, match="Cannot have two Flyte types in a Union type"): DataclassTransformer()._make_dataclass_serializable(dc, DC) -