Skip to content

Commit

Permalink
[dagster-airlift] Multi code locations working
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Oct 18, 2024
1 parent 5d3e709 commit 0c97247
Show file tree
Hide file tree
Showing 21 changed files with 263 additions and 126 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Iterable, Iterator, Optional
from typing import Callable, Iterable, Iterator, Optional

from dagster import (
AssetsDefinition,
Expand All @@ -20,13 +20,16 @@
DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
build_airflow_polling_sensor_defs,
)
from dagster_airlift.core.serialization.compute import compute_serialized_data
from dagster_airlift.core.serialization.compute import DagSelectorFn, compute_serialized_data
from dagster_airlift.core.serialization.defs_construction import (
construct_automapped_dag_assets_defs,
construct_dag_assets_defs,
get_airflow_data_to_spec_mapper,
)
from dagster_airlift.core.serialization.serialized_data import SerializedAirflowDefinitionsData
from dagster_airlift.core.serialization.serialized_data import (
DagInfo,
SerializedAirflowDefinitionsData,
)
from dagster_airlift.core.utils import get_metadata_key


Expand All @@ -35,14 +38,17 @@ class AirflowInstanceDefsLoader(StateBackedDefinitionsLoader[SerializedAirflowDe
airflow_instance: AirflowInstance
explicit_defs: Definitions
sensor_minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS
dag_selector_fn: Optional[Callable[[DagInfo], bool]] = None

@property
def defs_key(self) -> str:
return get_metadata_key(self.airflow_instance.name)

def fetch_state(self) -> SerializedAirflowDefinitionsData:
return compute_serialized_data(
airflow_instance=self.airflow_instance, defs=self.explicit_defs
airflow_instance=self.airflow_instance,
defs=self.explicit_defs,
dag_selector_fn=self.dag_selector_fn,
)

def defs_from_state(
Expand All @@ -58,10 +64,12 @@ def build_airflow_mapped_defs(
*,
airflow_instance: AirflowInstance,
defs: Optional[Definitions] = None,
dag_selector_fn: Optional[DagSelectorFn] = None,
) -> Definitions:
return AirflowInstanceDefsLoader(
airflow_instance=airflow_instance,
explicit_defs=defs or Definitions(),
dag_selector_fn=dag_selector_fn,
).build_defs()


Expand All @@ -72,8 +80,11 @@ def build_defs_from_airflow_instance(
defs: Optional[Definitions] = None,
sensor_minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
event_transformer_fn: DagsterEventTransformerFn = default_event_transformer,
dag_selector_fn: Optional[DagSelectorFn] = None,
) -> Definitions:
mapped_defs = build_airflow_mapped_defs(airflow_instance=airflow_instance, defs=defs)
mapped_defs = build_airflow_mapped_defs(
airflow_instance=airflow_instance, defs=defs, dag_selector_fn=dag_selector_fn
)
return Definitions.merge(
mapped_defs,
build_airflow_polling_sensor_defs(
Expand All @@ -97,7 +108,7 @@ def defs_key(self) -> str:

def fetch_state(self) -> SerializedAirflowDefinitionsData:
return compute_serialized_data(
airflow_instance=self.airflow_instance, defs=self.explicit_defs
airflow_instance=self.airflow_instance, defs=self.explicit_defs, dag_selector_fn=None
)

def defs_from_state(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from functools import cached_property
from typing import AbstractSet, Dict, List, Set
from typing import AbstractSet, Callable, Dict, List, Optional, Set

from dagster import AssetKey, AssetSpec, Definitions
from dagster._record import record
Expand All @@ -24,6 +24,8 @@
task_handles_for_spec,
)

DagSelectorFn = Callable[[DagInfo], bool]


@record
class AirliftMetadataMappingInfo:
Expand Down Expand Up @@ -130,9 +132,15 @@ def all_mapped_dags(self) -> Dict[AssetKey, AbstractSet[DagHandle]]:


def fetch_all_airflow_data(
airflow_instance: AirflowInstance, mapping_info: AirliftMetadataMappingInfo
airflow_instance: AirflowInstance,
mapping_info: AirliftMetadataMappingInfo,
dag_selector_fn: Optional[DagSelectorFn],
) -> FetchedAirflowData:
dag_infos = {dag.dag_id: dag for dag in airflow_instance.list_dags()}
dag_infos = {
dag.dag_id: dag
for dag in airflow_instance.list_dags()
if dag_selector_fn is None or dag_selector_fn(dag)
}
task_info_map = defaultdict(dict)
for dag_id in dag_infos:
task_info_map[dag_id] = {
Expand All @@ -148,10 +156,10 @@ def fetch_all_airflow_data(


def compute_serialized_data(
airflow_instance: AirflowInstance, defs: Definitions
airflow_instance: AirflowInstance, defs: Definitions, dag_selector_fn: Optional[DagSelectorFn]
) -> "SerializedAirflowDefinitionsData":
mapping_info = build_airlift_metadata_mapping_info(defs)
fetched_airflow_data = fetch_all_airflow_data(airflow_instance, mapping_info)
fetched_airflow_data = fetch_all_airflow_data(airflow_instance, mapping_info, dag_selector_fn)
return SerializedAirflowDefinitionsData(
instance_name=airflow_instance.name,
key_scoped_task_handles=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def test_produce_fetched_airflow_data() -> None:
fetched_airflow_data = fetch_all_airflow_data(
airflow_instance=instance,
mapping_info=mapping_info,
dag_selector_fn=None,
)

assert len(fetched_airflow_data.mapping_info.mapped_task_asset_specs) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ run_dagster_automapped:
run_observation_defs:
dagster dev -m kitchen_sink.dagster_defs.observation_defs -p 3333

# Command to point at a workspace.yaml
run_dagster_multi_code_locations:
dagster dev -w $(MAKEFILE_DIR)/kitchen_sink/dagster_multi_code_locations/workspace.yaml -p 3333

wipe: ## Wipe out all the files created by the Makefile
rm -rf $(AIRFLOW_HOME) $(DAGSTER_HOME)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from datetime import datetime
from pathlib import Path

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster_airlift.in_airflow import proxying_to_dagster
from dagster_airlift.in_airflow.proxied_state import load_proxied_state_from_yaml


def print_hello() -> None:
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 0,
}

with DAG(
"dag_first_code_location",
default_args=default_args,
schedule_interval=None,
is_paused_upon_creation=False,
) as first_dag:
PythonOperator(task_id="task", python_callable=print_hello)

with DAG(
"dag_second_code_location",
default_args=default_args,
schedule_interval=None,
is_paused_upon_creation=False,
) as second_dag:
PythonOperator(task_id="task", python_callable=print_hello)


proxying_to_dagster(
proxied_state=load_proxied_state_from_yaml(Path(__file__).parent / "proxied_state"),
global_vars=globals(),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tasks:
- id: task
proxied: False
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tasks:
- id: task
proxied: False
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dagster_airlift.core import dag_defs, task_defs
from dagster_airlift.core.load_defs import build_full_automapped_dags_from_airflow_instance

from .airflow_instance import local_airflow_instance
from ..airflow_instance import local_airflow_instance


@asset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from dagster_airlift.core.multiple_tasks import targeted_by_multiple_tasks

from .airflow_instance import local_airflow_instance
from ..airflow_instance import local_airflow_instance


def make_print_asset(key: str) -> AssetsDefinition:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
task_defs,
)

from .airflow_instance import local_airflow_instance
from ..airflow_instance import local_airflow_instance


def observations_from_materializations(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dagster import AssetSpec, Definitions
from dagster_airlift.core import assets_with_task_mappings, build_defs_from_airflow_instance

from kitchen_sink.airflow_instance import local_airflow_instance

defs = build_defs_from_airflow_instance(
airflow_instance=local_airflow_instance(),
defs=Definitions(
assets=assets_with_task_mappings(
dag_id="dag_first_code_location",
task_mappings={
"task": [AssetSpec(key="dag_first_code_location__asset")],
},
),
),
dag_selector_fn=lambda dag_info: dag_info.dag_id == "dag_first_code_location",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dagster import AssetSpec, Definitions
from dagster_airlift.core import assets_with_task_mappings, build_defs_from_airflow_instance

from kitchen_sink.airflow_instance import local_airflow_instance

defs = build_defs_from_airflow_instance(
airflow_instance=local_airflow_instance(),
defs=Definitions(
assets=assets_with_task_mappings(
dag_id="dag_second_code_location",
task_mappings={
"task": [AssetSpec(key="dag_second_code_location__asset")],
},
),
),
dag_selector_fn=lambda dag_info: dag_info.dag_id == "dag_second_code_location",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
load_from:
- python_module:
module_name: kitchen_sink.dagster_multi_code_locations.first_dag_defs
location_name: first_dag_location
- python_module:
module_name: kitchen_sink.dagster_multi_code_locations.second_dag_defs
location_name: second_dag_location
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import os
import subprocess
import time
from datetime import timedelta
from pathlib import Path
from typing import Generator
from typing import Generator, List, Mapping, NamedTuple, Sequence, Union

import pytest
from dagster import AssetKey, DagsterInstance
from dagster._core.events.log import EventLogEntry
from dagster._core.test_utils import environ
from dagster._time import get_current_datetime
from dagster_airlift.constants import DAG_RUN_ID_TAG_KEY
from dagster_airlift.core.airflow_instance import AirflowInstance
from dagster_airlift.test.shared_fixtures import stand_up_airflow

Expand Down Expand Up @@ -37,6 +42,11 @@ def airflow_home_fixture(local_env: None) -> Path:
return Path(os.environ["AIRFLOW_HOME"])


@pytest.fixture(name="dagster_home")
def dagster_home_fixture(local_env: None) -> str:
return os.environ["DAGSTER_HOME"]


@pytest.fixture(name="airflow_instance")
def airflow_instance_fixture(local_env: None) -> Generator[subprocess.Popen, None, None]:
with stand_up_airflow(
Expand All @@ -59,3 +69,76 @@ def poll_for_airflow_run_existence_and_completion(
except Exception:
time.sleep(0.1)
continue


class ExpectedMat(NamedTuple):
asset_key: AssetKey
runs_in_dagster: bool


def poll_for_expected_mats(
af_instance: AirflowInstance,
expected_mats_per_dag: Mapping[str, Sequence[Union[ExpectedMat, AssetKey]]],
) -> None:
resolved_expected_mats_per_dag: Mapping[str, List[ExpectedMat]] = {
dag_id: [
expected_mat
if isinstance(expected_mat, ExpectedMat)
else ExpectedMat(expected_mat, True)
for expected_mat in expected_mats
]
for dag_id, expected_mats in expected_mats_per_dag.items()
}
for dag_id, expected_mats in resolved_expected_mats_per_dag.items():
airflow_run_id = af_instance.trigger_dag(dag_id=dag_id)
af_instance.wait_for_run_completion(dag_id=dag_id, run_id=airflow_run_id, timeout=60)
dagster_instance = DagsterInstance.get()

dag_asset_key = AssetKey([af_instance.name, "dag", dag_id])
assert poll_for_materialization(dagster_instance, dag_asset_key)

for expected_mat in expected_mats:
mat_event_log_entry = poll_for_materialization(dagster_instance, expected_mat.asset_key)
assert mat_event_log_entry.asset_materialization
assert mat_event_log_entry.asset_materialization.asset_key == expected_mat.asset_key

assert mat_event_log_entry.asset_materialization
dagster_run_id = mat_event_log_entry.run_id

all_materializations = dagster_instance.fetch_materializations(
records_filter=expected_mat.asset_key, limit=10
)

assert all_materializations

if expected_mat.runs_in_dagster:
assert dagster_run_id
dagster_run = dagster_instance.get_run_by_id(dagster_run_id)
assert dagster_run
run_ids = dagster_instance.get_run_ids()
assert (
dagster_run
), f"Could not find dagster run {dagster_run_id} All run_ids {run_ids}"
assert (
DAG_RUN_ID_TAG_KEY in dagster_run.tags
), f"Could not find dagster run tag: dagster_run.tags {dagster_run.tags}"
assert (
dagster_run.tags[DAG_RUN_ID_TAG_KEY] == airflow_run_id
), "dagster run tag does not match dag run id"


def poll_for_materialization(
dagster_instance: DagsterInstance,
asset_key: AssetKey,
) -> EventLogEntry:
start_time = get_current_datetime()
while get_current_datetime() - start_time < timedelta(seconds=30):
asset_materialization = dagster_instance.get_latest_materialization_event(
asset_key=asset_key
)

time.sleep(0.1)
if asset_materialization:
return asset_materialization

raise Exception(f"Timeout waiting for materialization event on {asset_key}")
Loading

0 comments on commit 0c97247

Please sign in to comment.