From d11d0464110df4bac737554ec021ea86c0f9bcf9 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 22 Oct 2024 22:55:06 +0800 Subject: [PATCH] add tests Signed-off-by: Future-Outlier --- tests/flytekit/unit/core/test_dataclass.py | 149 +++++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/tests/flytekit/unit/core/test_dataclass.py b/tests/flytekit/unit/core/test_dataclass.py index 1a3ba71f67..58dfcd1e45 100644 --- a/tests/flytekit/unit/core/test_dataclass.py +++ b/tests/flytekit/unit/core/test_dataclass.py @@ -969,3 +969,152 @@ def t1(dc: FrozenDataclass) -> (int, float, bool, str): assert b == 2.0 assert c == True assert d == "hello" + +def test_pure_frozen_dataclasses_with_python_types(): + @dataclass(frozen=True) + class DC: + string: Optional[str] = None + + @dataclass(frozen=True) + class DCWithOptional: + string: Optional[str] = None + dc: Optional[DC] = None + list_dc: Optional[List[DC]] = None + list_list_dc: Optional[List[List[DC]]] = None + dict_dc: Optional[Dict[str, DC]] = None + dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None + dict_list_dc: Optional[Dict[str, List[DC]]] = None + list_dict_dc: Optional[List[Dict[str, DC]]] = None + + @task + def t1() -> DCWithOptional: + return DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]},) + + @task + def t2() -> DCWithOptional: + return DCWithOptional() + + output = DCWithOptional(string="a", dc=DC(string="b"), + list_dc=[DC(string="c"), DC(string="d")], + list_list_dc=[[DC(string="e"), DC(string="f")]], + list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")}, + {"k": DC(string="l"), "m": DC(string="n")}], + dict_dc={"o": DC(string="p"), "q": DC(string="r")}, + dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}}, + dict_list_dc={"x": [DC(string="y"), DC(string="z")], + "aa": [DC(string="bb"), DC(string="cc")]}, ) + + dc1 = t1() + dc2 = t2() + + assert dc1 == output + assert dc2.string is None + assert dc2.dc is None + + DataclassTransformer().assert_type(DCWithOptional, dc1) + DataclassTransformer().assert_type(DCWithOptional, dc2) + +def test_pure_frozen_dataclasses_with_flyte_types(local_dummy_txt_file, local_dummy_directory): + @dataclass(frozen=True) + class FlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + + @dataclass(frozen=True) + class NestedFlyteTypes: + flytefile: Optional[FlyteFile] = None + flytedir: Optional[FlyteDirectory] = None + structured_dataset: Optional[StructuredDataset] = None + flyte_types: Optional[FlyteTypes] = None + list_flyte_types: Optional[List[FlyteTypes]] = None + dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None + optional_flyte_types: Optional[FlyteTypes] = None + + @task + def pass_and_return_flyte_types(nested_flyte_types: NestedFlyteTypes) -> NestedFlyteTypes: + return nested_flyte_types + + @task + def generate_sd() -> StructuredDataset: + return StructuredDataset( + uri="s3://my-s3-bucket/data/test_sd", + file_format="parquet") + + @task + def create_local_dir(path: str) -> FlyteDirectory: + return FlyteDirectory(path=path) + + @task + def create_local_dir_by_str(path: str) -> FlyteDirectory: + return path + + @task + def create_local_file(path: str) -> FlyteFile: + return FlyteFile(path=path) + + @task + def create_local_file_with_str(path: str) -> FlyteFile: + return path + + @task + def generate_nested_flyte_types(local_file: FlyteFile, local_dir: FlyteDirectory, sd: StructuredDataset, + local_file_by_str: FlyteFile, + local_dir_by_str: FlyteDirectory, ) -> NestedFlyteTypes: + ft = FlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + ) + + return NestedFlyteTypes( + flytefile=local_file, + flytedir=local_dir, + structured_dataset=sd, + flyte_types=FlyteTypes( + flytefile=local_file_by_str, + flytedir=local_dir_by_str, + structured_dataset=sd, + ), + list_flyte_types=[ft, ft, ft], + dict_flyte_types={"a": ft, "b": ft, "c": ft}, + ) + + @workflow + def nested_dc_wf(txt_path: str, dir_path: str) -> NestedFlyteTypes: + local_file = create_local_file(path=txt_path) + local_dir = create_local_dir(path=dir_path) + local_file_by_str = create_local_file_with_str(path=txt_path) + local_dir_by_str = create_local_dir_by_str(path=dir_path) + sd = generate_sd() + nested_flyte_types = generate_nested_flyte_types( + local_file=local_file, + local_dir=local_dir, + local_file_by_str=local_file_by_str, + local_dir_by_str=local_dir_by_str, + sd=sd + ) + old_flyte_types = pass_and_return_flyte_types(nested_flyte_types=nested_flyte_types) + return pass_and_return_flyte_types(nested_flyte_types=old_flyte_types) + + @task + def get_empty_nested_type() -> NestedFlyteTypes: + return NestedFlyteTypes() + + @workflow + def empty_nested_dc_wf() -> NestedFlyteTypes: + return get_empty_nested_type() + + nested_flyte_types = nested_dc_wf(txt_path=local_dummy_txt_file, dir_path=local_dummy_directory) + DataclassTransformer().assert_type(NestedFlyteTypes, nested_flyte_types) + + empty_nested_flyte_types = empty_nested_dc_wf() + DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types)