Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix Union type with dataclass ambiguous error and support superset comparison #5858

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

mao3267
Copy link

@mao3267 mao3267 commented Oct 18, 2024

Tracking issue

Related to #5489

Why are the changes needed?

When a function accepts a Union of two dataclasses as input, Flyte cannot distinguish which dataclass matches the user's input. This is because Flyte only compares the simple types, and both dataclasses are identified as flyte.SimpleType_STRUCT in this scenario. As a result, there will be multiple matches, causing ambiguity and leading to an error.

union_test.py

from typing import Union
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from flytekit import task, workflow

@dataclass_json 
@dataclass
class A:
    a: int
    

@dataclass_json
@dataclass
class B:
    b: int


@task
def bar() -> A:
    return A(a=1)

@task
def foo(inp: Union[A, B]):
    print(inp)

@workflow
def wf():
    v = bar()
    foo(inp=v)

What changes were proposed in this pull request?

  1. To distinguish between different dataclasses, we compare their JSON schemas generated by either marshmallow_jsonschema.JSONSchema (draft-07) or mashumaro.jsonschema.build_json_schema (draft 2020-12). To check equivalence, we compare the bytes from marshaling the json schemas if they are in the same draft version (generated by the same package). For comparing different draft versions, we are still finding a way to cover all possible types, including flyte types and nested types.
  2. We plan to support dataclass inheritance, meaning that class A and class B can be a match in the following example:
from dataclasses import dataclass
from typing import Optional

@dataclass
class A:
    a: int

@dataclass
class B(A):
    b: Optional[int]
    c: str = "Flyte"

@task
def foo() -> A:
    return A(a=1)

@task
def my_task(input: Union[int, B]):
    print(input)

@workflow
def wf():
    a = foo()
    my_task(a)
  1. Unit tests will be added for different versions of json schema.

How was this patch tested?

  1. Run an example with union input on remote (union_test.py)
  2. Run an example with two different versions of schema on remote

See Screenshot section for details

Setup process

git clone https://github.com/flyteorg/flyte.git
gh pr checkout 5858
make compile
POD_NAMESPACE=flyte ./flyte start --config flyte-single-binary-local.yaml

Screenshots

  1. Example with union input on remote (union_test.py)

image

  1. Example with two different versions of schema on remote

    union_test_diff_schema_version.py

from typing import Union
from dataclasses import dataclass
from flytekit import task, workflow
from union_test import A as jsonA

@dataclass
class A:
    a: int

@dataclass
class B:
    b: int


@task
def bar() -> jsonA:
    return jsonA(a=1)

@task
def foo(inp: Union[A, B]):
    print(inp)

@workflow
def wf():
    v = bar()
    foo(inp=v)

Check all the applicable boxes

  • I updated the documentation accordingly.
  • All new and existing tests passed.
  • All commits are signed-off.

Related PRs

None

Docs link

TODO

Copy link

codecov bot commented Oct 21, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 24.71%. Comparing base (197ae13) to head (f06cdc6).
Report is 16 commits behind head on master.

❗ There is a different number of reports uploaded between BASE (197ae13) and HEAD (f06cdc6). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (197ae13) HEAD (f06cdc6)
unittests-flyteadmin 1 0
unittests-flytepropeller 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5858      +/-   ##
==========================================
- Coverage   34.48%   24.71%   -9.78%     
==========================================
  Files        1138      529     -609     
  Lines      102742    68383   -34359     
==========================================
- Hits        35434    16901   -18533     
+ Misses      63634    50091   -13543     
+ Partials     3674     1391    -2283     
Flag Coverage Δ
unittests-datacatalog 51.58% <ø> (+0.21%) ⬆️
unittests-flyteadmin ?
unittests-flytecopilot 11.73% <ø> (-0.45%) ⬇️
unittests-flyteidl 6.89% <ø> (-0.29%) ⬇️
unittests-flyteplugins 53.62% <ø> (+0.27%) ⬆️
unittests-flytepropeller ?
unittests-flytestdlib 54.78% <ø> (-0.57%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In progress
Development

Successfully merging this pull request may close these issues.

1 participant