Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier committed Oct 22, 2024
1 parent 92a787e commit d11d046
Showing 1 changed file with 149 additions and 0 deletions.
149 changes: 149 additions & 0 deletions tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,3 +969,152 @@ def t1(dc: FrozenDataclass) -> (int, float, bool, str):
assert b == 2.0
assert c == True
assert d == "hello"

def test_pure_frozen_dataclasses_with_python_types():
@dataclass(frozen=True)
class DC:
string: Optional[str] = None

@dataclass(frozen=True)
class DCWithOptional:
string: Optional[str] = None
dc: Optional[DC] = None
list_dc: Optional[List[DC]] = None
list_list_dc: Optional[List[List[DC]]] = None
dict_dc: Optional[Dict[str, DC]] = None
dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None
dict_list_dc: Optional[Dict[str, List[DC]]] = None
list_dict_dc: Optional[List[Dict[str, DC]]] = None

@task
def t1() -> DCWithOptional:
return DCWithOptional(string="a", dc=DC(string="b"),
list_dc=[DC(string="c"), DC(string="d")],
list_list_dc=[[DC(string="e"), DC(string="f")]],
list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")},
{"k": DC(string="l"), "m": DC(string="n")}],
dict_dc={"o": DC(string="p"), "q": DC(string="r")},
dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}},
dict_list_dc={"x": [DC(string="y"), DC(string="z")],
"aa": [DC(string="bb"), DC(string="cc")]},)

@task
def t2() -> DCWithOptional:
return DCWithOptional()

output = DCWithOptional(string="a", dc=DC(string="b"),
list_dc=[DC(string="c"), DC(string="d")],
list_list_dc=[[DC(string="e"), DC(string="f")]],
list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")},
{"k": DC(string="l"), "m": DC(string="n")}],
dict_dc={"o": DC(string="p"), "q": DC(string="r")},
dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}},
dict_list_dc={"x": [DC(string="y"), DC(string="z")],
"aa": [DC(string="bb"), DC(string="cc")]}, )

dc1 = t1()
dc2 = t2()

assert dc1 == output
assert dc2.string is None
assert dc2.dc is None

DataclassTransformer().assert_type(DCWithOptional, dc1)
DataclassTransformer().assert_type(DCWithOptional, dc2)

def test_pure_frozen_dataclasses_with_flyte_types(local_dummy_txt_file, local_dummy_directory):
@dataclass(frozen=True)
class FlyteTypes:
flytefile: Optional[FlyteFile] = None
flytedir: Optional[FlyteDirectory] = None
structured_dataset: Optional[StructuredDataset] = None

@dataclass(frozen=True)
class NestedFlyteTypes:
flytefile: Optional[FlyteFile] = None
flytedir: Optional[FlyteDirectory] = None
structured_dataset: Optional[StructuredDataset] = None
flyte_types: Optional[FlyteTypes] = None
list_flyte_types: Optional[List[FlyteTypes]] = None
dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None
optional_flyte_types: Optional[FlyteTypes] = None

@task
def pass_and_return_flyte_types(nested_flyte_types: NestedFlyteTypes) -> NestedFlyteTypes:
return nested_flyte_types

@task
def generate_sd() -> StructuredDataset:
return StructuredDataset(
uri="s3://my-s3-bucket/data/test_sd",
file_format="parquet")

@task
def create_local_dir(path: str) -> FlyteDirectory:
return FlyteDirectory(path=path)

@task
def create_local_dir_by_str(path: str) -> FlyteDirectory:
return path

@task
def create_local_file(path: str) -> FlyteFile:
return FlyteFile(path=path)

@task
def create_local_file_with_str(path: str) -> FlyteFile:
return path

@task
def generate_nested_flyte_types(local_file: FlyteFile, local_dir: FlyteDirectory, sd: StructuredDataset,
local_file_by_str: FlyteFile,
local_dir_by_str: FlyteDirectory, ) -> NestedFlyteTypes:
ft = FlyteTypes(
flytefile=local_file,
flytedir=local_dir,
structured_dataset=sd,
)

return NestedFlyteTypes(
flytefile=local_file,
flytedir=local_dir,
structured_dataset=sd,
flyte_types=FlyteTypes(
flytefile=local_file_by_str,
flytedir=local_dir_by_str,
structured_dataset=sd,
),
list_flyte_types=[ft, ft, ft],
dict_flyte_types={"a": ft, "b": ft, "c": ft},
)

@workflow
def nested_dc_wf(txt_path: str, dir_path: str) -> NestedFlyteTypes:
local_file = create_local_file(path=txt_path)
local_dir = create_local_dir(path=dir_path)
local_file_by_str = create_local_file_with_str(path=txt_path)
local_dir_by_str = create_local_dir_by_str(path=dir_path)
sd = generate_sd()
nested_flyte_types = generate_nested_flyte_types(
local_file=local_file,
local_dir=local_dir,
local_file_by_str=local_file_by_str,
local_dir_by_str=local_dir_by_str,
sd=sd
)
old_flyte_types = pass_and_return_flyte_types(nested_flyte_types=nested_flyte_types)
return pass_and_return_flyte_types(nested_flyte_types=old_flyte_types)

@task
def get_empty_nested_type() -> NestedFlyteTypes:
return NestedFlyteTypes()

@workflow
def empty_nested_dc_wf() -> NestedFlyteTypes:
return get_empty_nested_type()

nested_flyte_types = nested_dc_wf(txt_path=local_dummy_txt_file, dir_path=local_dummy_directory)
DataclassTransformer().assert_type(NestedFlyteTypes, nested_flyte_types)

empty_nested_flyte_types = empty_nested_dc_wf()
DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types)

0 comments on commit d11d046

Please sign in to comment.