Skip to content

Commit

Permalink
[dagster-airlift] operator to proxy dag
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Oct 10, 2024
1 parent 9838517 commit 6f585b9
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import json
import os
from typing import Any, Iterable, Mapping, Sequence

import requests
from airflow import DAG
from airflow.utils.context import Context

from dagster_airlift.constants import DAG_MAPPING_METADATA_KEY
from dagster_airlift.in_airflow.base_asset_operator import BaseDagsterAssetsOperator


class BaseProxyDAGToDagsterOperator(BaseDagsterAssetsOperator):
"""An operator that proxies task execution to Dagster assets with metadata that map to this task's dag ID and task ID."""

def filter_asset_nodes(
self, context: Context, asset_nodes: Sequence[Mapping[str, Any]]
) -> Iterable[Mapping[str, Any]]:
for asset_node in asset_nodes:
if matched_dag_id(asset_node, self.get_airflow_dag_id(context)):
yield asset_node


class DefaultProxyDAGToDagsterOperator(BaseProxyDAGToDagsterOperator):
"""The default task proxying operator - which opens a blank session and expects the dagster URL to be set in the environment.
The dagster url is expected to be set in the environment as DAGSTER_URL.
"""

def get_dagster_session(self, context: Context) -> requests.Session:
return requests.Session()

def get_dagster_url(self, context: Context) -> str:
return os.environ["DAGSTER_URL"]


def build_dag_level_proxied_task(dag: DAG) -> DefaultProxyDAGToDagsterOperator:
return DefaultProxyDAGToDagsterOperator(
task_id=f"DAGSTER_OVERRIDE_DAG_{dag.dag_id}",
dag=dag,
)


def matched_dag_id(asset_node: Mapping[str, Any], dag_id: str) -> bool:
json_metadata_entries = {
entry["label"]: entry["jsonString"]
for entry in asset_node["metadataEntries"]
if entry["__typename"] == "JsonMetadataEntry"
}

if mapping_entry := json_metadata_entries.get(DAG_MAPPING_METADATA_KEY):
mappings = json.loads(mapping_entry)
return any(mapping["dag_id"] == dag_id for mapping in mappings)
return False
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ def get_task_proxied_state(self, *, dag_id: str, task_id: str) -> Optional[bool]
def dag_has_proxied_state(self, dag_id: str) -> bool:
return self.get_proxied_dict_for_dag(dag_id) is not None

def dag_proxies_at_task_level(self, dag_id: str) -> bool:
"""Dags can proxy on either a task-by-task basis, or for the entire dag at once.
We use the proxied state to determine which is the case for a given dag. If the dag's proxied state
is None, then we assume the dag proxies at the task level. If the dag's proxied state is a boolean,
then we assume the dag proxies at the dag level.
"""
return self.dags[dag_id].proxied is None

def dag_proxies_at_dag_level(self, dag_id: str) -> bool:
"""Dags can proxy on either a task-by-task basis, or for the entire dag at once.
We use the proxied state to determine which is the case for a given dag. If the dag's proxied state
is None, then we assume the dag proxies at the task level. If the dag's proxied state is a boolean,
then we assume the dag proxies at the dag level.
"""
return self.dags[dag_id].proxied is not None

def get_proxied_dict_for_dag(
self, dag_id: str
) -> Optional[Dict[str, Sequence[Dict[str, Any]]]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from airflow.models import BaseOperator, Variable
from airflow.utils.session import create_session

from dagster_airlift.in_airflow.dag_proxy_operator import build_dag_level_proxied_task
from dagster_airlift.in_airflow.proxied_state import AirflowProxiedState, DagProxiedState
from dagster_airlift.in_airflow.task_proxy_operator import (
BaseProxyTaskToDagsterOperator,
Expand Down Expand Up @@ -40,7 +41,8 @@ def proxying_to_dagster(
if not logger:
logger = logging.getLogger("dagster_airlift")
logger.debug(f"Searching for dags proxied to dagster{suffix}...")
proxying_dags: List[DAG] = []
task_level_proxying_dags: List[DAG] = []
dag_level_proxying_dags: List[DAG] = []
all_dag_ids: Set[str] = set()
# Do a pass to collect dags and ensure that proxied information is set correctly.
for obj in global_vars.values():
Expand All @@ -53,23 +55,36 @@ def proxying_to_dagster(
continue
logger.debug(f"Dag with id `{dag.dag_id}` has proxied state.")
proxied_state_for_dag = proxied_state.dags[dag.dag_id]
for task_id in proxied_state_for_dag.tasks.keys():
if task_id not in dag.task_dict:
raise Exception(
f"Task with id `{task_id}` not found in dag `{dag.dag_id}`. Found tasks: {list(dag.task_dict.keys())}"
)
if not isinstance(dag.task_dict[task_id], BaseOperator):
raise Exception(
f"Task with id `{task_id}` in dag `{dag.dag_id}` is not an instance of BaseOperator. This likely means a MappedOperator was attempted, which is not yet supported by airlift."
)
proxying_dags.append(dag)
if proxied_state_for_dag.proxied is not None:
if proxied_state_for_dag.proxied is False:
logger.debug(f"Dag with id `{dag.dag_id}` is not proxied. Skipping...")
continue
dag_level_proxying_dags.append(dag)
else:
for task_id in proxied_state_for_dag.tasks.keys():
if task_id not in dag.task_dict:
raise Exception(
f"Task with id `{task_id}` not found in dag `{dag.dag_id}`. Found tasks: {list(dag.task_dict.keys())}"
)
if not isinstance(dag.task_dict[task_id], BaseOperator):
raise Exception(
f"Task with id `{task_id}` in dag `{dag.dag_id}` is not an instance of BaseOperator. This likely means a MappedOperator was attempted, which is not yet supported by airlift."
)
task_level_proxying_dags.append(dag)

if len(all_dag_ids) == 0:
raise Exception(
"No dags found in globals dictionary. Ensure that your dags are available from global context, and that the call to `proxying_to_dagster` is the last line in your dag file."
)

for dag in proxying_dags:
for dag in dag_level_proxying_dags:
logger.debug(f"Tagging dag {dag.dag_id} as proxied.")
dag.tags = [*dag.tags, "Dag overriden to Dagster"]
dag.task_dict = {}
dag.task_group.children = {}
override_task = build_dag_level_proxied_task(dag)
dag.task_dict[override_task.task_id] = override_task
for dag in task_level_proxying_dags:
logger.debug(f"Tagging dag {dag.dag_id} as proxied.")
set_proxied_state_for_dag_if_changed(dag.dag_id, proxied_state.dags[dag.dag_id], logger)
proxied_state_for_dag = proxied_state.dags[dag.dag_id]
Expand Down Expand Up @@ -110,7 +125,7 @@ def proxying_to_dagster(
original_op.dag = None
proxied_tasks.add(task_id)
logger.debug(f"Proxied tasks {proxied_tasks} in dag {dag.dag_id}.")
logging.debug(f"Proxied {len(proxying_dags)}.")
logging.debug(f"Proxied {len(task_level_proxying_dags)}.")
logging.debug(f"Completed switching proxied tasks to dagster{suffix}.")


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from datetime import datetime
from pathlib import Path

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster_airlift.in_airflow import proxying_to_dagster
from dagster_airlift.in_airflow.proxied_state import load_proxied_state_from_yaml


def print_hello() -> None:
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 1,
}


with DAG(
"overridden_dag",
default_args=default_args,
schedule_interval=None,
is_paused_upon_creation=False,
) as dag:
PythonOperator(task_id="print_task", python_callable=print_hello) << PythonOperator(
task_id="downstream_print_task", python_callable=print_hello
) # type: ignore


proxying_to_dagster(
proxied_state=load_proxied_state_from_yaml(Path(__file__).parent / "proxied_state"),
global_vars=globals(),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
proxied: True
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from dagster import Definitions, asset
from dagster_airlift.core import build_defs_from_airflow_instance, dag_defs, task_defs
from dagster_airlift.core import (
assets_with_dag_mappings,
build_defs_from_airflow_instance,
dag_defs,
task_defs,
)
from dagster_airlift.core.multiple_tasks import targeted_by_multiple_tasks

from .airflow_instance import local_airflow_instance
Expand All @@ -22,6 +27,11 @@ def asset_one() -> None:
print("Materialized asset one")


@asset(description="Asset two is materialized by an overridden dag")
def asset_two() -> None:
print("Materialized asset two")


def build_mapped_defs() -> Definitions:
return build_defs_from_airflow_instance(
airflow_instance=local_airflow_instance(),
Expand All @@ -38,6 +48,7 @@ def build_mapped_defs() -> Definitions:
{"dag_id": "daily_dag", "task_id": "asset_one_daily"},
],
),
Definitions(assets=assets_with_dag_mappings({"overridden_dag": [asset_two]})),
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,41 @@ def test_dagster_weekly_daily_materializes(
assert final_result.records[0].event_log_entry
assert dag_id_of_mat(final_result.records[0].event_log_entry) == "daily_dag"
assert dag_id_of_mat(final_result.records[1].event_log_entry) == "weekly_dag"


def test_migrated_overridden_dag_materializes(
airflow_instance: None,
dagster_dev: None,
dagster_home: str,
) -> None:
"""Test that assets are properly materialized from an overridden dag."""
from kitchen_sink.dagster_defs.airflow_instance import local_airflow_instance

af_instance = local_airflow_instance()

expected_mats_per_dag = {
"overridden_dag": [AssetKey("asset_two")],
}
for dag_id, expected_asset_keys in expected_mats_per_dag.items():
airflow_run_id = af_instance.trigger_dag(dag_id=dag_id)
af_instance.wait_for_run_completion(dag_id=dag_id, run_id=airflow_run_id, timeout=60)
dagster_instance = DagsterInstance.get()

for expected_asset_key in expected_asset_keys:
mat_event_log_entry = poll_for_materialization(dagster_instance, expected_asset_key)
assert mat_event_log_entry.asset_materialization
assert mat_event_log_entry.asset_materialization.asset_key == expected_asset_key

assert mat_event_log_entry.asset_materialization
dagster_run_id = mat_event_log_entry.run_id

# test for dag run-tag-id
dagster_run = dagster_instance.get_run_by_id(dagster_run_id)
run_ids = dagster_instance.get_run_ids()
assert dagster_run, f"Could not find dagster run {dagster_run_id} All run_ids {run_ids}"
assert (
DAG_RUN_ID_TAG_KEY in dagster_run.tags
), f"Could not find dagster run tag: dagster_run.tags {dagster_run.tags}"
assert (
dagster_run.tags[DAG_RUN_ID_TAG_KEY] == airflow_run_id
), "dagster run tag does not match dag run id"

0 comments on commit 6f585b9

Please sign in to comment.