Skip to content

Commit

Permalink
Async tasks and eager revamp (#2927)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Dec 9, 2024
1 parent 11fbd9a commit 276c464
Show file tree
Hide file tree
Showing 40 changed files with 1,620 additions and 1,256 deletions.
38 changes: 0 additions & 38 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,44 +162,6 @@ jobs:
fail_ci_if_error: false
files: coverage.xml

test-hypothesis:
needs:
- detect-python-versions
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v3
with:
# This path is specific to Ubuntu
path: ~/.cache/pip
# Look to see if there is a cache hit for the corresponding requirements files
key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.in', 'requirements.in')) }}
- name: Install dependencies
run: |
pip install uv
make setup-global-uv
uv pip freeze
- name: Test with coverage
env:
FLYTEKIT_HYPOTHESIS_PROFILE: ci
run: |
make unit_test_hypothesis
- name: Codecov
uses: codecov/[email protected]
with:
fail_ci_if_error: false
files: coverage.xml

test-serialization:
needs:
- detect-python-versions
Expand Down
4 changes: 0 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ unit_test:
# Run serial tests without any parallelism
$(PYTEST) -m "serial" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models --ignore=tests/flytekit/unit/extend ${CODECOV_OPTS}

.PHONY: unit_test_hypothesis
unit_test_hypothesis:
$(PYTEST_AND_OPTS) -m "hypothesis" tests/flytekit/unit/experimental ${CODECOV_OPTS}

.PHONY: unit_test_extras
unit_test_extras:
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python $(PYTEST_AND_OPTS) tests/flytekit/unit/extras tests/flytekit/unit/extend ${CODECOV_OPTS}
Expand Down
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@
from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference
from flytekit.core.resources import Resources
from flytekit.core.schedule import CronSchedule, FixedRate
from flytekit.core.task import Secret, reference_task, task
from flytekit.core.task import Secret, eager, reference_task, task
from flytekit.core.type_engine import BatchSize
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
Expand Down
5 changes: 0 additions & 5 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import contextlib
import datetime
import inspect
import os
import pathlib
import signal
Expand Down Expand Up @@ -177,10 +176,6 @@ def _dispatch_execute(
# Step2
# Invoke task - dispatch_execute
outputs = task_def.dispatch_execute(ctx, idl_input_literals)
if inspect.iscoroutine(outputs):
# Handle eager-mode (async) tasks
logger.info("Output is a coroutine")
outputs = _get_working_loop().run_until_complete(outputs)

# Step3a
if isinstance(outputs, VoidPromise):
Expand Down
17 changes: 15 additions & 2 deletions flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
```
"""

import os
from typing import Optional, Protocol, runtime_checkable

from click import Group
Expand Down Expand Up @@ -59,10 +60,22 @@ def get_remote(
config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None
) -> FlyteRemote:
"""Get FlyteRemote object for CLI session."""

cfg_file = get_config_file(config)

# The assumption here (if there's no config file that means we want sandbox) is too broad.
# todo: can improve this in the future, rather than just checking one env var, auto() with
# nothing configured should probably not return sandbox but can consider
if cfg_file is None:
cfg_obj = Config.for_sandbox()
logger.info("No config files found, creating remote with sandbox config")
# We really are just looking for endpoint, client_id, and client_secret. These correspond to the env vars
# FLYTE_PLATFORM_URL, FLYTE_CREDENTIALS_CLIENT_ID, FLYTE_CREDENTIALS_CLIENT_SECRET
# auto() should pick these up.
if "FLYTE_PLATFORM_URL" in os.environ:
cfg_obj = Config.auto(None)
logger.warning(f"Auto-created config object to pick up env vars {cfg_obj}")
else:
cfg_obj = Config.for_sandbox()
logger.info("No config files found, creating remote with sandbox config")
else: # pragma: no cover
cfg_obj = Config.auto(config)
logger.debug(f"Creating remote with config {cfg_obj}" + (f" with file {config}" if config else ""))
Expand Down
29 changes: 2 additions & 27 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import asyncio
import collections
import datetime
import inspect
import warnings
from abc import abstractmethod
from base64 import b64encode
Expand Down Expand Up @@ -142,6 +141,7 @@ class TaskMetadata(object):
retries: int = 0
timeout: Optional[Union[datetime.timedelta, int]] = None
pod_template_name: Optional[str] = None
is_eager: bool = False

def __post_init__(self):
if self.timeout:
Expand Down Expand Up @@ -181,6 +181,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
cache_serializable=self.cache_serialize,
pod_template_name=self.pod_template_name,
cache_ignore_input_vars=self.cache_ignore_input_vars,
is_eager=self.is_eager,
)


Expand Down Expand Up @@ -340,9 +341,6 @@ def local_execute(
# if one is changed and not the other.
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)

if inspect.iscoroutine(outputs_literal_map):
return outputs_literal_map

outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
Expand Down Expand Up @@ -759,29 +757,6 @@ def dispatch_execute(
raise
raise FlyteUserRuntimeException(e) from e

if inspect.iscoroutine(native_outputs):
# If native outputs is a coroutine, then this is an eager workflow.
if exec_ctx.execution_state:
if exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION:
# Just return task outputs as a coroutine if the eager workflow is being executed locally,
# outside of a workflow. This preserves the expectation that the eager workflow is an async
# function.
return native_outputs
elif exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
# If executed inside of a workflow being executed locally, then run the coroutine to get the
# actual results.
return asyncio.run(
self._async_execute(
native_inputs,
native_outputs,
ctx,
exec_ctx,
new_user_params,
)
)

return self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params)

# Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
# bubbled up to be handled at the callee layer.
native_outputs = self.post_execute(new_user_params, native_outputs)
Expand Down
11 changes: 11 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
# Set this environment variable to true to force the task to return non-zero exit code on failure.
FLYTE_FAIL_ON_ERROR = "FLYTE_FAIL_ON_ERROR"

# Executions launched by the current eager task will be tagged with this key:current_eager_exec_name
EAGER_TAG_KEY = "eager-exec"

# Executions launched by the current eager task will be tagged with this key:root_eager_exec_name, only relevant
# for nested eager tasks. This is how you identify the root execution.
EAGER_TAG_ROOT_KEY = "eager-root-exec"

# The environment variable that will be set to the root eager task execution name. This is how you pass down the
# root eager execution.
EAGER_ROOT_ENV_NAME = "_F_EE_ROOT"

# This is a special key used to store metadata about the cache key in a literal type.
CACHE_KEY_METADATA = "cache-key-metadata"

Expand Down
42 changes: 41 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging as _logging
import os
import pathlib
import signal
import tempfile
import traceback
import typing
Expand All @@ -24,6 +25,7 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from types import FrameType
from typing import Generator, List, Optional, Union

from flytekit.configuration import Config, SecretsConfig, SerializationSettings
Expand All @@ -37,8 +39,10 @@
from flytekit.models.core import identifier as _identifier

if typing.TYPE_CHECKING:
from flytekit import Deck
from flytekit.clients import friendly as friendly_client # noqa
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.core.worker_queue import Controller
from flytekit.deck.deck import Deck

# TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin

Expand Down Expand Up @@ -526,6 +530,10 @@ class Mode(Enum):
# This is the mode that is used to indicate a dynamic task
DYNAMIC_TASK_EXECUTION = 4

EAGER_EXECUTION = 5

EAGER_LOCAL_EXECUTION = 6

mode: Optional[ExecutionState.Mode]
working_dir: Union[os.PathLike, str]
engine_dir: Optional[Union[os.PathLike, str]]
Expand Down Expand Up @@ -586,6 +594,7 @@ def is_local_execution(self) -> bool:
return (
self.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION
or self.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
or self.mode == ExecutionState.Mode.EAGER_LOCAL_EXECUTION
)


Expand Down Expand Up @@ -663,6 +672,7 @@ class FlyteContext(object):
in_a_condition: bool = False
origin_stackframe: Optional[traceback.FrameSummary] = None
output_metadata_tracker: Optional[OutputMetadataTracker] = None
worker_queue: Optional[Controller] = None

@property
def user_space_params(self) -> Optional[ExecutionParameters]:
Expand All @@ -689,6 +699,7 @@ def new_builder(self) -> Builder:
execution_state=self.execution_state,
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
worker_queue=self.worker_queue,
)

def enter_conditional_section(self) -> Builder:
Expand All @@ -713,6 +724,12 @@ def with_serialization_settings(self, ss: SerializationSettings) -> Builder:
def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> Builder:
return self.new_builder().with_output_metadata_tracker(t)

def with_worker_queue(self, wq: Controller) -> Builder:
return self.new_builder().with_worker_queue(wq)

def with_client(self, c: SynchronousFlyteClient) -> Builder:
return self.new_builder().with_client(c)

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down Expand Up @@ -774,6 +791,7 @@ class Builder(object):
serialization_settings: Optional[SerializationSettings] = None
in_a_condition: bool = False
output_metadata_tracker: Optional[OutputMetadataTracker] = None
worker_queue: Optional[Controller] = None

def build(self) -> FlyteContext:
return FlyteContext(
Expand All @@ -785,6 +803,7 @@ def build(self) -> FlyteContext:
serialization_settings=self.serialization_settings,
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
worker_queue=self.worker_queue,
)

def enter_conditional_section(self) -> FlyteContext.Builder:
Expand Down Expand Up @@ -833,6 +852,14 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> FlyteContext
self.output_metadata_tracker = t
return self

def with_worker_queue(self, wq: Controller) -> FlyteContext.Builder:
self.worker_queue = wq
return self

def with_client(self, c: SynchronousFlyteClient) -> FlyteContext.Builder:
self.flyte_client = c
return self

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down Expand Up @@ -871,6 +898,12 @@ class FlyteContextManager(object):
FlyteContextManager.pop_context()
"""

signal_handlers: typing.List[typing.Callable[[int, FrameType], typing.Any]] = []

@staticmethod
def add_signal_handler(handler: typing.Callable[[int, FrameType], typing.Any]):
FlyteContextManager.signal_handlers.append(handler)

@staticmethod
def get_origin_stackframe(limit=2) -> traceback.FrameSummary:
ss = traceback.extract_stack(limit=limit + 1)
Expand Down Expand Up @@ -954,6 +987,13 @@ def initialize():
user_space_path = os.path.join(cfg.local_sandbox_path, "user_space")
pathlib.Path(user_space_path).mkdir(parents=True, exist_ok=True)

def main_signal_handler(signum: int, frame: FrameType):
for handler in FlyteContextManager.signal_handlers:
handler(signum, frame)
exit(1)

signal.signal(signal.SIGINT, main_signal_handler)

# Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
# are already acquainted with
default_context = FlyteContext(file_access=default_local_file_access_provider)
Expand Down
4 changes: 0 additions & 4 deletions flytekit/core/options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
from dataclasses import dataclass
from typing import Callable, Optional

from flytekit.models import common as common_models
from flytekit.models import security
Expand Down Expand Up @@ -35,9 +34,6 @@ class Options(object):
notifications: typing.Optional[typing.List[common_models.Notification]] = None
disable_notifications: typing.Optional[bool] = None
overwrite_cache: typing.Optional[bool] = None
file_uploader: Optional[Callable] = (
None # This is used by the translator to upload task files, like pickled code etc
)

@classmethod
def default_from(
Expand Down
Loading

0 comments on commit 276c464

Please sign in to comment.