Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support raw container in the map task #1547

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
name: str,
image: str,
command: List[str],
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
inputs: Optional[Union[Dict[str, Tuple[Type, Any]], OrderedDict[str, Type]]] = None,
metadata: Optional[TaskMetadata] = None,
arguments: Optional[List[str]] = None,
outputs: Optional[Dict[str, Type]] = None,
Expand Down
23 changes: 17 additions & 6 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set

from flytekit import ContainerTask
from flytekit.configuration import SerializationSettings
from flytekit.core import tracker
from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin
Expand All @@ -33,7 +34,7 @@ class MapPythonTask(PythonTask):

def __init__(
self,
python_function_task: typing.Union[PythonFunctionTask, functools.partial],
python_function_task: typing.Union[PythonFunctionTask, ContainerTask, functools.partial],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
bound_inputs: Optional[Set[str]] = None,
Expand Down Expand Up @@ -63,8 +64,8 @@ def __init__(
else:
actual_task = python_function_task

if not isinstance(actual_task, PythonFunctionTask):
raise ValueError("Map tasks can only compose of Python Functon Tasks currently")
if not isinstance(actual_task, (PythonFunctionTask, ContainerTask)):
raise ValueError("Map tasks can only compose of Python Function or Container Tasks currently")

if len(actual_task.python_interface.outputs.keys()) > 1:
raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs")
Expand All @@ -75,9 +76,13 @@ def __init__(

collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs)
self._run_task: PythonFunctionTask = actual_task
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest()
name = f"{mod}.map_{f}_{h}"

if isinstance(actual_task, ContainerTask):
name = f"raw_container_task.mapper_{actual_task.name}_{h}"
else:
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
name = f"{mod}.map_{f}_{h}"

self._cmd_prefix: typing.Optional[typing.List[str]] = None
self._max_concurrency: typing.Optional[int] = concurrency
Expand Down Expand Up @@ -142,14 +147,20 @@ def prepare_target(self):
self._run_task.reset_command_fn()

def get_container(self, settings: SerializationSettings) -> Container:
if isinstance(self._run_task, ContainerTask):
return self._run_task.get_container(settings)
with self.prepare_target():
return self._run_task.get_container(settings)

def get_k8s_pod(self, settings: SerializationSettings) -> K8sPod:
if isinstance(self._run_task, ContainerTask):
return self._run_task.get_k8s_pod(settings)
with self.prepare_target():
return self._run_task.get_k8s_pod(settings)

def get_sql(self, settings: SerializationSettings) -> Sql:
if isinstance(self._run_task, ContainerTask):
return self._run_task.get_sql(settings)
with self.prepare_target():
return self._run_task.get_sql(settings)

Expand Down Expand Up @@ -270,7 +281,7 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
task_function: typing.Union[PythonFunctionTask, functools.partial],
task_function: typing.Union[PythonFunctionTask, functools.partial, ContainerTask],
concurrency: int = 0,
min_success_ratio: float = 1.0,
**kwargs,
Expand Down
5 changes: 4 additions & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def get_serializable_task(

if settings.should_fast_serialize():
# This handles container tasks.
if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask)):
if container and (
isinstance(entity, PythonAutoContainerTask)
or (isinstance(entity, MapPythonTask) and isinstance(entity.run_task, PythonAutoContainerTask))
):
# For fast registration, we'll need to muck with the command, but on
# ly for certain kinds of tasks. Specifically,
# tasks that rely on user code defined in the container. This should be encapsulated by the auto container
Expand Down
35 changes: 34 additions & 1 deletion tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import flytekit.configuration
from flytekit import LaunchPlan, map_task
from flytekit import ContainerTask, LaunchPlan, kwtypes, map_task
from flytekit.configuration import Image, ImageConfig
from flytekit.core.map_task import MapPythonTask, MapTaskResolver
from flytekit.core.task import TaskMetadata, task
Expand All @@ -25,6 +25,22 @@ def serialization_settings():
)


raw_container = ContainerTask(
name="ellipse-area-metadata-python",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs=kwtypes(a=int),
outputs=kwtypes(area=float),
image="flyte/raw-container:v1",
command=[
"python",
"test.py",
"{{.inputs.a}}",
"/var/outputs",
],
)


@task
def t1(a: int) -> str:
b = a + 2
Expand Down Expand Up @@ -106,6 +122,23 @@ def test_serialization(serialization_settings):
]


def test_serialization_with_raw_container(serialization_settings):
maptask = map_task(raw_container, metadata=TaskMetadata(retries=1))
task_spec = get_serializable(OrderedDict(), serialization_settings, maptask)

# By default all map_task tasks will have their custom fields set.
assert task_spec.template.custom["minSuccessRatio"] == 1.0
assert task_spec.template.type == "container_array"
assert task_spec.template.task_type_version == 1
assert task_spec.template.container.args is None
assert task_spec.template.container.command == [
"python",
"test.py",
"{{.inputs.a}}",
"/var/outputs",
]


@pytest.mark.parametrize(
"custom_fields_dict, expected_custom_fields",
[
Expand Down