Skip to content

Commit

Permalink
add support for toggling data mode for array node (#2940)
Browse files Browse the repository at this point in the history
* add support for toggling data mode for array node

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* cleanup

Signed-off-by: Paul Dittamo <[email protected]>

* Bump flyteidl lower-bound to 1.14.1

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add import of FlyteLaunchPlan back

Signed-off-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
pvditt and eapolinario authored Dec 23, 2024
1 parent e9a7da1 commit bc0e8c0
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 12 deletions.
27 changes: 17 additions & 10 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.task import ReferenceTask
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
Expand All @@ -34,8 +35,7 @@
class ArrayNode:
def __init__(
self,
target: Union[LaunchPlan, "FlyteLaunchPlan"],
execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE,
target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"],
bindings: Optional[List[_literal_models.Binding]] = None,
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
Expand All @@ -51,17 +51,17 @@ def __init__(
:param min_successes: The minimum number of successful executions. If set, this takes precedence over
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions.
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying node
"""
from flytekit.remote import FlyteLaunchPlan

self.target = target
self._concurrency = concurrency
self._execution_mode = execution_mode
self.id = target.name
self._bindings = bindings or []
self.metadata = metadata
self._data_mode = None
self._execution_mode = None

if min_successes is not None:
self._min_successes = min_successes
Expand Down Expand Up @@ -92,9 +92,12 @@ def __init__(
else:
raise ValueError("No interface found for the target entity.")

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
if isinstance(target, (LaunchPlan, FlyteLaunchPlan)):
self._data_mode = _core_workflow.ArrayNode.SINGLE_INPUT_FILE
self._execution_mode = _core_workflow.ArrayNode.FULL_STATE
elif isinstance(target, ReferenceTask):
self._data_mode = _core_workflow.ArrayNode.INDIVIDUAL_INPUT_FILES
self._execution_mode = _core_workflow.ArrayNode.MINIMAL_STATE
else:
raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}")

Expand Down Expand Up @@ -133,6 +136,10 @@ def upstream_nodes(self) -> List[Node]:
def flyte_entity(self) -> Any:
return self.target

@property
def data_mode(self) -> _core_workflow.ArrayNode.DataMode:
return self._data_mode

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
if self._remote_interface:
raise ValueError("Mapping over remote entities is not supported in local execution.")
Expand Down Expand Up @@ -254,7 +261,7 @@ def __call__(self, *args, **kwargs):


def array_node(
target: Union[LaunchPlan, "FlyteLaunchPlan"],
target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
min_successes: Optional[int] = None,
Expand All @@ -275,8 +282,8 @@ def array_node(
"""
from flytekit.remote import FlyteLaunchPlan

if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan):
raise ValueError("Only LaunchPlans are supported for now.")
if not isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)):
raise ValueError("Only LaunchPlans and ReferenceTasks are supported for now.")

node = ArrayNode(
target=target,
Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.task import ReferenceTask
from flytekit.core.type_engine import TypeEngine
from flytekit.core.utils import timeit
from flytekit.loggers import logger
Expand Down Expand Up @@ -390,7 +391,7 @@ def map_task(
"""
from flytekit.remote import FlyteLaunchPlan

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)):
return array_node(
target=target,
concurrency=concurrency,
Expand Down
3 changes: 3 additions & 0 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __init__(
min_success_ratio=None,
execution_mode=None,
is_original_sub_node_interface=False,
data_mode=None,
) -> None:
"""
TODO: docstring
Expand All @@ -401,6 +402,7 @@ def __init__(
self._min_success_ratio = min_success_ratio
self._execution_mode = execution_mode
self._is_original_sub_node_interface = is_original_sub_node_interface
self._data_mode = data_mode

@property
def node(self) -> "Node":
Expand All @@ -414,6 +416,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode:
min_success_ratio=self._min_success_ratio,
execution_mode=self._execution_mode,
is_original_sub_node_interface=BoolValue(value=self._is_original_sub_node_interface),
data_mode=self._data_mode,
)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def get_serializable_array_node(
min_success_ratio=array_node.min_success_ratio,
execution_mode=array_node.execution_mode,
is_original_sub_node_interface=array_node.is_original_sub_node_interface,
data_mode=array_node.data_mode,
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.13.9",
"flyteidl>=1.14.1",
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down

0 comments on commit bc0e8c0

Please sign in to comment.