diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index d87f4d7685..e17943f80b 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -51,6 +51,7 @@ PythonAutoContainerTask, default_notebook_task_resolver, ) +from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceSpec from flytekit.core.task import ReferenceTask from flytekit.core.tracker import extract_task_module @@ -58,6 +59,7 @@ from flytekit.core.workflow import PythonFunctionWorkflow, ReferenceWorkflow, WorkflowBase, WorkflowFailurePolicy from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.user import ( + FlyteAssertion, FlyteEntityAlreadyExistsException, FlyteEntityNotExistException, FlyteValueException, @@ -198,6 +200,38 @@ def _get_git_repo_url(source_path: str): return "" +def _get_pickled_target_dict(root_entity: typing.Union[WorkflowBase, PythonTask]) -> typing.Dict[str, typing.Any]: + """ + Get the pickled target dictionary for the entity. + :param root_entity: The entity to get the pickled target for. + :return: The pickled target dictionary. + """ + queue = [root_entity] + pickled_target_dict = {} + while queue: + entity = queue.pop() + if isinstance(entity, PythonFunctionTask): + if entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: + raise FlyteAssertion( + f"Dynamic tasks are not supported in interactive mode. {entity.name} is a dynamic task." + ) + + if isinstance(entity, PythonTask): + if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)): + if isinstance(entity, ArrayNodeMapTask): + entity._run_task.set_resolver(default_notebook_task_resolver) + pickled_target_dict[entity._run_task.name] = entity._run_task + else: + entity.set_resolver(default_notebook_task_resolver) + pickled_target_dict[entity.name] = entity + elif isinstance(entity, WorkflowBase): + for task in entity.nodes: + queue.append(task) + elif isinstance(entity, CoreNode): + queue.append(entity.flyte_entity) + return pickled_target_dict + + class FlyteRemote(object): """Main entrypoint for programmatically accessing a Flyte remote backend. @@ -2583,39 +2617,16 @@ def download( for var, literal in lm.items(): download_literal(self.file_access, var, literal, download_to) - def _get_pickled_target_dict(self, root_entity: typing.Any) -> typing.Dict[str, typing.Any]: - """ - Get the pickled target dictionary for the entity. - :param root_entity: The entity to get the pickled target for. - :return: The pickled target dictionary. - """ - queue = [root_entity] - pickled_target_dict = {} - while queue: - entity = queue.pop() - if isinstance(entity, PythonTask): - if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)): - if isinstance(entity, ArrayNodeMapTask): - entity._run_task.set_resolver(default_notebook_task_resolver) - pickled_target_dict[entity._run_task.name] = entity._run_task - else: - entity.set_resolver(default_notebook_task_resolver) - pickled_target_dict[entity.name] = entity - elif isinstance(entity, WorkflowBase): - for task in entity.nodes: - queue.append(task) - elif isinstance(entity, CoreNode): - queue.append(entity.flyte_entity) - return pickled_target_dict - - def _pickle_and_upload_entity(self, entity: typing.Any) -> typing.Tuple[bytes, FastSerializationSettings]: + def _pickle_and_upload_entity( + self, entity: typing.Union[WorkflowBase, PythonTask] + ) -> typing.Tuple[bytes, FastSerializationSettings]: """ Pickle the entity to the specified location. This is useful for debugging and for sharing entities across different environments. :param entity: The entity to pickle """ # get all entity tasks - pickled_dict = self._get_pickled_target_dict(entity) + pickled_dict = _get_pickled_target_dict(entity) with tempfile.TemporaryDirectory() as tmp_dir: dest = pathlib.Path(tmp_dir, PICKLE_FILE_PATH) with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped: diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 98b50bbc2b..4df39329c7 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -7,6 +7,7 @@ import uuid from collections import OrderedDict from datetime import datetime, timedelta +from functools import partial import mock import pytest @@ -15,13 +16,13 @@ from mock import ANY, MagicMock, patch import flytekit.configuration -from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task +from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task, map_task, dynamic from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import user as user_exceptions -from flytekit.exceptions.user import FlyteEntityNotExistException +from flytekit.exceptions.user import FlyteEntityNotExistException, FlyteAssertion from flytekit.models import common as common_models from flytekit.models import security from flytekit.models.admin.workflow import Workflow, WorkflowClosure @@ -33,7 +34,7 @@ from flytekit.models.task import Task from flytekit.remote import FlyteTask from flytekit.remote.lazy_entity import LazyEntity -from flytekit.remote.remote import FlyteRemote, _get_git_repo_url +from flytekit.remote.remote import FlyteRemote, _get_git_repo_url, _get_pickled_target_dict from flytekit.tools.translator import Options, get_serializable, get_serializable_launch_plan from tests.flytekit.common.parameterizers import LIST_OF_TASK_CLOSURES @@ -690,3 +691,64 @@ def test_register_wf_script_mode(compress_scripts_mock, upload_file_mock, regist def test_fetch_active_launchplan_not_found(mock_client, remote): mock_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found") assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None + + +def test_get_pickled_target_dict(): + @task + def t1() -> int: + return 1 + + @task + def t2(a: int) -> int: + return a + 2 + + @workflow + def w() -> int: + return t2(a=t1()) + + target_dict = _get_pickled_target_dict(w) + assert len(target_dict) == 2 + assert t1.name in target_dict + assert t2.name in target_dict + assert target_dict[t1.name] == t1 + assert target_dict[t2.name] == t2 + +def test_get_pickled_target_dict_with_map_task(): + @task + def t1(x: int, y: int) -> int: + return x + y + + @workflow + def w() -> int: + return map_task(partial(t1, y=2))(x=[1, 2, 3]) + + target_dict = _get_pickled_target_dict(w) + assert len(target_dict) == 1 + assert t1.name in target_dict + assert target_dict[t1.name] == t1 + +def test_get_pickled_target_dict_with_dynamic(): + @task + def t1(a: int) -> str: + a = a + 2 + return "fast-" + str(a) + + @workflow + def subwf(a: int): + t1(a=a) + + @dynamic + def my_subwf(a: int) -> typing.List[str]: + s = [] + for i in range(a): + s.append(t1(a=i)) + subwf(a=a) + return s + + @workflow + def my_wf(a: int) -> typing.List[str]: + v = my_subwf(a=a) + return v + + with pytest.raises(FlyteAssertion): + _get_pickled_target_dict(my_wf)