Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict Dynamic Workflow for Interactive Mode #2849

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@
PythonAutoContainerTask,
default_notebook_task_resolver,
)
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceSpec
from flytekit.core.task import ReferenceTask
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import LiteralsResolver, TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, ReferenceWorkflow, WorkflowBase, WorkflowFailurePolicy
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.user import (
FlyteAssertion,
FlyteEntityAlreadyExistsException,
FlyteEntityNotExistException,
FlyteValueException,
Expand Down Expand Up @@ -198,6 +200,38 @@
return ""


def _get_pickled_target_dict(root_entity: typing.Union[WorkflowBase, PythonTask]) -> typing.Dict[str, typing.Any]:
"""
Get the pickled target dictionary for the entity.
:param root_entity: The entity to get the pickled target for.
:return: The pickled target dictionary.
"""
queue = [root_entity]
pickled_target_dict = {}

Check warning on line 210 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L209-L210

Added lines #L209 - L210 were not covered by tests
while queue:
entity = queue.pop()

Check warning on line 212 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L212

Added line #L212 was not covered by tests
if isinstance(entity, PythonFunctionTask):
if entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC:
raise FlyteAssertion(

Check warning on line 215 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L215

Added line #L215 was not covered by tests
f"Dynamic tasks are not supported in interactive mode. {entity.name} is a dynamic task."
)

if isinstance(entity, PythonTask):
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
if isinstance(entity, ArrayNodeMapTask):
entity._run_task.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity._run_task.name] = entity._run_task

Check warning on line 223 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L222-L223

Added lines #L222 - L223 were not covered by tests
else:
entity.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity.name] = entity

Check warning on line 226 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L225-L226

Added lines #L225 - L226 were not covered by tests
elif isinstance(entity, WorkflowBase):
for task in entity.nodes:
queue.append(task)

Check warning on line 229 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L229

Added line #L229 was not covered by tests
elif isinstance(entity, CoreNode):
queue.append(entity.flyte_entity)
return pickled_target_dict

Check warning on line 232 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L231-L232

Added lines #L231 - L232 were not covered by tests


class FlyteRemote(object):
"""Main entrypoint for programmatically accessing a Flyte remote backend.

Expand Down Expand Up @@ -2583,39 +2617,16 @@
for var, literal in lm.items():
download_literal(self.file_access, var, literal, download_to)

def _get_pickled_target_dict(self, root_entity: typing.Any) -> typing.Dict[str, typing.Any]:
"""
Get the pickled target dictionary for the entity.
:param root_entity: The entity to get the pickled target for.
:return: The pickled target dictionary.
"""
queue = [root_entity]
pickled_target_dict = {}
while queue:
entity = queue.pop()
if isinstance(entity, PythonTask):
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
if isinstance(entity, ArrayNodeMapTask):
entity._run_task.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity._run_task.name] = entity._run_task
else:
entity.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity.name] = entity
elif isinstance(entity, WorkflowBase):
for task in entity.nodes:
queue.append(task)
elif isinstance(entity, CoreNode):
queue.append(entity.flyte_entity)
return pickled_target_dict

def _pickle_and_upload_entity(self, entity: typing.Any) -> typing.Tuple[bytes, FastSerializationSettings]:
def _pickle_and_upload_entity(
self, entity: typing.Union[WorkflowBase, PythonTask]
) -> typing.Tuple[bytes, FastSerializationSettings]:
"""
Pickle the entity to the specified location. This is useful for debugging and for sharing entities across
different environments.
:param entity: The entity to pickle
"""
# get all entity tasks
pickled_dict = self._get_pickled_target_dict(entity)
pickled_dict = _get_pickled_target_dict(entity)

Check warning on line 2629 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L2629

Added line #L2629 was not covered by tests
with tempfile.TemporaryDirectory() as tmp_dir:
dest = pathlib.Path(tmp_dir, PICKLE_FILE_PATH)
with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped:
Expand Down
68 changes: 65 additions & 3 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
from collections import OrderedDict
from datetime import datetime, timedelta
from functools import partial

import mock
import pytest
Expand All @@ -15,13 +16,13 @@
from mock import ANY, MagicMock, patch

import flytekit.configuration
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task
from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task, map_task, dynamic
from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.user import FlyteEntityNotExistException
from flytekit.exceptions.user import FlyteEntityNotExistException, FlyteAssertion
from flytekit.models import common as common_models
from flytekit.models import security
from flytekit.models.admin.workflow import Workflow, WorkflowClosure
Expand All @@ -33,7 +34,7 @@
from flytekit.models.task import Task
from flytekit.remote import FlyteTask
from flytekit.remote.lazy_entity import LazyEntity
from flytekit.remote.remote import FlyteRemote, _get_git_repo_url
from flytekit.remote.remote import FlyteRemote, _get_git_repo_url, _get_pickled_target_dict
from flytekit.tools.translator import Options, get_serializable, get_serializable_launch_plan
from tests.flytekit.common.parameterizers import LIST_OF_TASK_CLOSURES

Expand Down Expand Up @@ -690,3 +691,64 @@ def test_register_wf_script_mode(compress_scripts_mock, upload_file_mock, regist
def test_fetch_active_launchplan_not_found(mock_client, remote):
mock_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found")
assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None


def test_get_pickled_target_dict():
@task
def t1() -> int:
return 1

@task
def t2(a: int) -> int:
return a + 2

@workflow
def w() -> int:
return t2(a=t1())

target_dict = _get_pickled_target_dict(w)
assert len(target_dict) == 2
assert t1.name in target_dict
assert t2.name in target_dict
assert target_dict[t1.name] == t1
assert target_dict[t2.name] == t2

def test_get_pickled_target_dict_with_map_task():
@task
def t1(x: int, y: int) -> int:
return x + y

@workflow
def w() -> int:
return map_task(partial(t1, y=2))(x=[1, 2, 3])

target_dict = _get_pickled_target_dict(w)
assert len(target_dict) == 1
assert t1.name in target_dict
assert target_dict[t1.name] == t1

def test_get_pickled_target_dict_with_dynamic():
@task
def t1(a: int) -> str:
a = a + 2
return "fast-" + str(a)

@workflow
def subwf(a: int):
t1(a=a)

@dynamic
def my_subwf(a: int) -> typing.List[str]:
s = []
for i in range(a):
s.append(t1(a=i))
subwf(a=a)
return s

@workflow
def my_wf(a: int) -> typing.List[str]:
v = my_subwf(a=a)
return v

with pytest.raises(FlyteAssertion):
_get_pickled_target_dict(my_wf)
Loading