Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dagster-airlift] Support dag-level overrides in builder #25243

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
PEERED_DAG_MAPPING_METADATA_KEY = "dagster-airlift/peered-dag-mapping"
DAG_MAPPING_METADATA_KEY = "dagster-airlift/dag-mapping"
AIRFLOW_SOURCE_METADATA_KEY_PREFIX = "dagster-airlift/source"
TASK_MAPPING_METADATA_KEY = "dagster-airlift/task-mapping"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from dagster_airlift.core.utils import (
dag_handles_for_spec,
is_dag_mapped_asset_spec,
is_peered_dag_asset_spec,
is_task_mapped_asset_spec,
peered_dag_handles_for_spec,
task_handles_for_spec,
)

Expand All @@ -37,7 +39,7 @@ def dag_ids_with_mapped_asset_keys(self) -> AbstractSet[str]:
return self.mapping_info.dag_ids

@cached_property
def asset_keys_per_task_handle(self) -> Mapping[TaskHandle, AbstractSet[AssetKey]]:
def mapped_asset_keys_by_task_handle(self) -> Mapping[TaskHandle, AbstractSet[AssetKey]]:
asset_keys_per_handle = defaultdict(set)
for spec in self.mapped_defs.get_all_asset_specs():
if is_task_mapped_asset_spec(spec):
Expand All @@ -47,7 +49,7 @@ def asset_keys_per_task_handle(self) -> Mapping[TaskHandle, AbstractSet[AssetKey
return asset_keys_per_handle

@cached_property
def asset_keys_per_dag_handle(self) -> Mapping[DagHandle, AbstractSet[AssetKey]]:
def mapped_asset_keys_by_dag_handle(self) -> Mapping[DagHandle, AbstractSet[AssetKey]]:
asset_keys_per_handle = defaultdict(set)
for spec in self.mapped_defs.get_all_asset_specs():
if is_dag_mapped_asset_spec(spec):
Expand All @@ -56,5 +58,15 @@ def asset_keys_per_dag_handle(self) -> Mapping[DagHandle, AbstractSet[AssetKey]]
asset_keys_per_handle[dag_handle].add(spec.key)
return asset_keys_per_handle

@cached_property
def peered_dag_asset_keys_by_dag_handle(self) -> Mapping[DagHandle, AbstractSet[AssetKey]]:
asset_keys_per_handle = defaultdict(set)
for spec in self.mapped_defs.get_all_asset_specs():
if is_peered_dag_asset_spec(spec):
dag_handles = peered_dag_handles_for_spec(spec)
for dag_handle in dag_handles:
asset_keys_per_handle[dag_handle].add(spec.key)
return asset_keys_per_handle

def asset_keys_in_task(self, dag_id: str, task_id: str) -> AbstractSet[AssetKey]:
return self.asset_keys_per_task_handle[TaskHandle(dag_id=dag_id, task_id=task_id)]
return self.mapped_asset_keys_by_task_handle[TaskHandle(dag_id=dag_id, task_id=task_id)]
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dagster import AssetKey, JsonMetadataValue, MarkdownMetadataValue
from dagster._core.definitions.metadata.metadata_value import UrlMetadataValue

from dagster_airlift.constants import DAG_MAPPING_METADATA_KEY
from dagster_airlift.constants import PEERED_DAG_MAPPING_METADATA_KEY
from dagster_airlift.core.airflow_instance import DagInfo


Expand All @@ -13,13 +13,17 @@ def dag_description(dag_info: DagInfo) -> str:
"""


def dag_asset_metadata(dag_info: DagInfo, source_code: str) -> Mapping[str, Any]:
metadata = {
def dag_asset_metadata(dag_info: DagInfo) -> Dict[str, Any]:
return {
"Dag Info (raw)": JsonMetadataValue(dag_info.metadata),
"Dag ID": dag_info.dag_id,
"Link to DAG": UrlMetadataValue(dag_info.url),
DAG_MAPPING_METADATA_KEY: [{"dag_id": dag_info.dag_id}],
}


def peered_dag_asset_metadata(dag_info: DagInfo, source_code: str) -> Mapping[str, Any]:
metadata = dag_asset_metadata(dag_info)
metadata[PEERED_DAG_MAPPING_METADATA_KEY] = [{"dag_id": dag_info.dag_id}]
# Attempt to retrieve source code from the DAG.
metadata["Source Code"] = MarkdownMetadataValue(
f"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,32 @@ def get_timestamp_from_materialization(event: AssetEvent) -> float:
)


def materializations_for_dag_run(
def synthetic_mats_for_peered_dag_asset_keys(
dag_run: DagRun, airflow_data: AirflowDefinitionsData
) -> Sequence[AssetMaterialization]:
return [
AssetMaterialization(
asset_key=asset_key, description=dag_run.note, metadata=get_dag_run_metadata(dag_run)
)
for asset_key in airflow_data.asset_keys_per_dag_handle[DagHandle(dag_run.dag_id)]
dag_synthetic_mat(dag_run, airflow_data, asset_key)
for asset_key in airflow_data.peered_dag_asset_keys_by_dag_handle[DagHandle(dag_run.dag_id)]
]


def synthetic_mats_for_mapped_dag_asset_keys(
dag_run: DagRun, airflow_data: AirflowDefinitionsData
) -> Sequence[AssetMaterialization]:
return [
dag_synthetic_mat(dag_run, airflow_data, asset_key)
for asset_key in airflow_data.mapped_asset_keys_by_dag_handle[DagHandle(dag_run.dag_id)]
]


def dag_synthetic_mat(
dag_run: DagRun, airflow_data: AirflowDefinitionsData, asset_key: AssetKey
) -> AssetMaterialization:
return AssetMaterialization(
asset_key=asset_key, description=dag_run.note, metadata=get_dag_run_metadata(dag_run)
)


def get_dag_run_metadata(dag_run: DagRun) -> Mapping[str, Any]:
return {
**get_common_metadata(dag_run),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
DagsterUserCodeExecutionError,
user_code_error_boundary,
)
from dagster._core.storage.dagster_run import RunsFilter
from dagster._core.storage.dagster_run import DagsterRun, RunsFilter
from dagster._grpc.client import DEFAULT_SENSOR_GRPC_TIMEOUT
from dagster._record import record
from dagster._serdes import deserialize_value, serialize_value
Expand All @@ -43,8 +43,9 @@
AssetEvent,
DagsterEventTransformerFn,
get_timestamp_from_materialization,
materializations_for_dag_run,
synthetic_mats_for_mapped_asset_keys,
synthetic_mats_for_mapped_dag_asset_keys,
synthetic_mats_for_peered_dag_asset_keys,
synthetic_mats_for_task_instance,
)

Expand Down Expand Up @@ -260,13 +261,9 @@ def materializations_and_requests_from_batch_iter(
context.log.info(f"Found {len(runs)} dag runs for {airflow_data.airflow_instance.name}")
context.log.info(f"All runs {runs}")
for i, dag_run in enumerate(runs):
# TODO: add pluggability here (ignoring `event_transformer_fn` for now)

dag_mats = materializations_for_dag_run(dag_run, airflow_data)
synthetic_mats = build_synthetic_asset_materializations(
mats = build_synthetic_asset_materializations(
context, airflow_data.airflow_instance, dag_run, airflow_data
)
mats = list(dag_mats) + synthetic_mats
context.log.info(f"Found {len(mats)} materializations for {dag_run.run_id}")

all_asset_keys_materialized = {mat.asset_key for mat in mats}
Expand Down Expand Up @@ -307,40 +304,62 @@ def build_synthetic_asset_materializations(
This also currently does not support dynamic tasks in Airflow, in which case
the use should instead map at the dag-level granularity.
"""
task_instances = airflow_instance.get_task_instance_batch(
run_id=dag_run.run_id,
dag_id=dag_run.dag_id,
task_ids=[task_id for task_id in airflow_data.task_ids_in_dag(dag_run.dag_id)],
states=["success"],
)

context.log.info(f"Found {len(task_instances)} task instances for {dag_run.run_id}")
context.log.info(f"All task instances {task_instances}")

check.invariant(
len({ti.task_id for ti in task_instances}) == len(task_instances),
"Assuming one task instance per task_id for now. Dynamic Airflow tasks not supported.",
)

# https://linear.app/dagster-labs/issue/FOU-444/make-sensor-work-with-an-airflow-dag-run-that-has-more-than-1000
dagster_runs = context.instance.get_runs(
filters=RunsFilter(tags={DAG_RUN_ID_TAG_KEY: dag_run.run_id}),
limit=1000,
)
context.log.info(f"Found {len(dagster_runs)} dagster runs for {dag_run.run_id}")

context.log.info(
f"Airlift Sensor: Found dagster run ids: {[run.run_id for run in dagster_runs]}"
f" for airflow run id {dag_run.run_id} and dag id {dag_run.dag_id}"
)
synthetic_mats = []
# Peered dag-level materializations will always be emitted.
synthetic_mats.extend(synthetic_mats_for_peered_dag_asset_keys(dag_run, airflow_data))
# If there is a dagster run for this dag, we don't need to synthesize materializations for mapped dag assets.
dpeng817 marked this conversation as resolved.
Show resolved Hide resolved
if not dagster_runs:
synthetic_mats.extend(synthetic_mats_for_mapped_dag_asset_keys(dag_run, airflow_data))
synthetic_mats.extend(
get_synthetic_task_mats(
airflow_instance=airflow_instance,
dagster_runs=dagster_runs,
dag_run=dag_run,
airflow_data=airflow_data,
context=context,
)
)
return synthetic_mats

dagster_runs_by_task_id = {run.tags[TASK_ID_TAG_KEY]: run for run in dagster_runs}
task_instances_by_task_id = {ti.task_id: ti for ti in task_instances}

def get_synthetic_task_mats(
airflow_instance: AirflowInstance,
dagster_runs: Sequence[DagsterRun],
dag_run: DagRun,
airflow_data: AirflowDefinitionsData,
context: SensorEvaluationContext,
) -> List[AssetMaterialization]:
task_instances = airflow_instance.get_task_instance_batch(
run_id=dag_run.run_id,
dag_id=dag_run.dag_id,
task_ids=[task_id for task_id in airflow_data.task_ids_in_dag(dag_run.dag_id)],
states=["success"],
)
check.invariant(
len({ti.task_id for ti in task_instances}) == len(task_instances),
"Assuming one task instance per task_id for now. Dynamic Airflow tasks not supported.",
)
synthetic_mats = []

context.log.info(f"Found {len(task_instances)} task instances for {dag_run.run_id}")
context.log.info(f"All task instances {task_instances}")
dagster_runs_by_task_id = {
run.tags[TASK_ID_TAG_KEY]: run for run in dagster_runs if TASK_ID_TAG_KEY in run.tags
}
task_instances_by_task_id = {ti.task_id: ti for ti in task_instances}
for task_id, task_instance in task_instances_by_task_id.items():
# If there is no dagster_run for this task, it was not proxied.
# Therefore synthensize a materialization based on the task information.
# No dagster runs means that the computation that materializes the asset was not proxied to Dagster.
# Therefore the dags ran completely in Airflow, and we will synthesize materializations in Dagster corresponding to that data run.
if task_id not in dagster_runs_by_task_id:
context.log.info(
f"Synthesizing materialization for tasks {task_id} in dag {dag_run.dag_id} because no dagster run found."
Expand All @@ -361,7 +380,6 @@ def build_synthetic_asset_materializations(
context.log.info(
f"Dagster run found for task {task_id} in dag {dag_run.dag_id}. Run {dagster_runs_by_task_id[task_id].run_id}"
)

return synthetic_mats


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from dagster_airlift.core.airflow_instance import AirflowInstance, DagInfo
from dagster_airlift.core.dag_asset import get_leaf_assets_for_dag
from dagster_airlift.core.serialization.serialized_data import (
KeyScopedDataItem,
DagHandle,
KeyScopedDagHandles,
KeyScopedTaskHandles,
SerializedAirflowDefinitionsData,
SerializedDagData,
TaskHandle,
TaskInfo,
)
from dagster_airlift.core.utils import (
dag_handles_for_spec,
is_dag_mapped_asset_spec,
is_task_mapped_asset_spec,
spec_iterator,
task_handles_for_spec,
Expand All @@ -26,34 +30,43 @@ class AirliftMetadataMappingInfo:
asset_specs: List[AssetSpec]

@cached_property
def mapped_asset_specs(self) -> List[AssetSpec]:
def mapped_task_asset_specs(self) -> List[AssetSpec]:
return [spec for spec in self.asset_specs if is_task_mapped_asset_spec(spec)]

@cached_property
def mapped_dag_asset_specs(self) -> List[AssetSpec]:
return [spec for spec in self.asset_specs if is_dag_mapped_asset_spec(spec)]

@cached_property
def dag_ids(self) -> Set[str]:
return set(self.task_id_map.keys())
return set(self.all_mapped_asset_keys_by_dag_id.keys())

@cached_property
def task_id_map(self) -> Dict[str, Set[str]]:
"""Mapping of dag_id to set of task_ids in that dag. This only contains task ids mapped to assets in this object."""
task_id_map_data = {
dag_id: set(ta_map.keys()) for dag_id, ta_map in self.asset_key_map.items()
dag_id: set(ta_map.keys())
for dag_id, ta_map in self.asset_keys_by_mapped_task_id.items()
}
return defaultdict(set, task_id_map_data)

@cached_property
def asset_keys_per_dag_id(self) -> Dict[str, Set[AssetKey]]:
"""Mapping of dag_id to set of asset_keys in that dag. Does not include standlone dag assets."""
asset_keys_per_dag_data = {
dag_id: {
asset_key for asset_keys in task_to_asset_map.values() for asset_key in asset_keys
}
for dag_id, task_to_asset_map in self.asset_key_map.items()
}
return defaultdict(set, asset_keys_per_dag_data)
def all_mapped_asset_keys_by_dag_id(self) -> Dict[str, Set[AssetKey]]:
"""Mapping of dag_id to set of asset_keys which are materialized by that dag.

If assets within the dag are mapped to individual tasks, all of those assets will be included in this set.
If the dag itself is mapped to a set of assets, those assets will be included in this set.
"""
asset_keys_in_dag_by_id = defaultdict(set)
for dag_id, task_to_asset_map in self.asset_keys_by_mapped_task_id.items():
for asset_keys in task_to_asset_map.values():
asset_keys_in_dag_by_id[dag_id].update(asset_keys)
for dag_id, asset_keys in self.asset_keys_by_mapped_dag_id.items():
asset_keys_in_dag_by_id[dag_id].update(asset_keys)
return defaultdict(set, asset_keys_in_dag_by_id)

@cached_property
def asset_key_map(self) -> Dict[str, Dict[str, Set[AssetKey]]]:
def asset_keys_by_mapped_task_id(self) -> Dict[str, Dict[str, Set[AssetKey]]]:
"""Mapping of dag_id to task_id to set of asset_keys mapped from that task."""
asset_key_map: Dict[str, Dict[str, Set[AssetKey]]] = defaultdict(lambda: defaultdict(set))
for spec in self.asset_specs:
Expand All @@ -62,10 +75,20 @@ def asset_key_map(self) -> Dict[str, Dict[str, Set[AssetKey]]]:
asset_key_map[task_handle.dag_id][task_handle.task_id].add(spec.key)
return asset_key_map

@cached_property
def asset_keys_by_mapped_dag_id(self) -> Dict[str, Set[AssetKey]]:
"""Mapping of dag_id to set of asset_keys mapped from that dag."""
asset_key_map: Dict[str, Set[AssetKey]] = defaultdict(set)
for spec in self.asset_specs:
if is_dag_mapped_asset_spec(spec):
for dag_handle in dag_handles_for_spec(spec):
asset_key_map[dag_handle.dag_id].add(spec.key)
return asset_key_map

@cached_property
def task_handle_map(self) -> Dict[AssetKey, Set[TaskHandle]]:
task_handle_map = defaultdict(set)
for dag_id, asset_key_by_task_id in self.asset_key_map.items():
for dag_id, asset_key_by_task_id in self.asset_keys_by_mapped_task_id.items():
for task_id, asset_keys in asset_key_by_task_id.items():
for asset_key in asset_keys:
task_handle_map[asset_key].add(TaskHandle(dag_id=dag_id, task_id=task_id))
Expand Down Expand Up @@ -94,7 +117,15 @@ class FetchedAirflowData:
@cached_property
def all_mapped_tasks(self) -> Dict[AssetKey, AbstractSet[TaskHandle]]:
return {
spec.key: task_handles_for_spec(spec) for spec in self.mapping_info.mapped_asset_specs
spec.key: task_handles_for_spec(spec)
for spec in self.mapping_info.mapped_task_asset_specs
}

@cached_property
def all_mapped_dags(self) -> Dict[AssetKey, AbstractSet[DagHandle]]:
return {
spec.key: dag_handles_for_spec(spec)
for spec in self.mapping_info.mapped_dag_asset_specs
}


Expand Down Expand Up @@ -123,17 +154,21 @@ def compute_serialized_data(
fetched_airflow_data = fetch_all_airflow_data(airflow_instance, mapping_info)
return SerializedAirflowDefinitionsData(
instance_name=airflow_instance.name,
key_scoped_data_items=[
KeyScopedDataItem(asset_key=k, mapped_tasks=v)
key_scoped_task_handles=[
KeyScopedTaskHandles(asset_key=k, mapped_tasks=v)
for k, v in fetched_airflow_data.all_mapped_tasks.items()
],
key_scoped_dag_handles=[
KeyScopedDagHandles(asset_key=k, mapped_dags=v)
for k, v in fetched_airflow_data.all_mapped_dags.items()
],
dag_datas={
dag_id: SerializedDagData(
dag_id=dag_id,
dag_info=dag_info,
source_code=airflow_instance.get_dag_source_code(dag_info.metadata["file_token"]),
leaf_asset_keys=get_leaf_assets_for_dag(
asset_keys_in_dag=mapping_info.asset_keys_per_dag_id[dag_id],
asset_keys_in_dag=mapping_info.all_mapped_asset_keys_by_dag_id[dag_id],
downstreams_asset_dependency_graph=mapping_info.downstream_deps,
),
task_infos=fetched_airflow_data.task_info_map[dag_id],
Expand Down
Loading