Skip to content

Commit

Permalink
[dagster-airlift] Support dag-level overrides in builder
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Oct 11, 2024
1 parent 0c21efc commit 1b8156c
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from dagster_airlift.core.airflow_instance import AirflowInstance, DagInfo
from dagster_airlift.core.dag_asset import get_leaf_assets_for_dag
from dagster_airlift.core.serialization.serialized_data import (
KeyScopedDataItem,
DagHandle,
KeyScopedDagHandles,
KeyScopedTaskHandles,
SerializedAirflowDefinitionsData,
SerializedDagData,
TaskHandle,
TaskInfo,
)
from dagster_airlift.core.utils import (
dag_handles_for_spec,
is_dag_mapped_asset_spec,
is_task_mapped_asset_spec,
spec_iterator,
task_handles_for_spec,
Expand All @@ -26,34 +30,42 @@ class AirliftMetadataMappingInfo:
asset_specs: List[AssetSpec]

@cached_property
def mapped_asset_specs(self) -> List[AssetSpec]:
def task_mapped_asset_specs(self) -> List[AssetSpec]:
return [spec for spec in self.asset_specs if is_task_mapped_asset_spec(spec)]

@cached_property
def dag_mapped_asset_specs(self) -> List[AssetSpec]:
return [spec for spec in self.asset_specs if is_dag_mapped_asset_spec(spec)]

@cached_property
def dag_ids(self) -> Set[str]:
return set(self.task_id_map.keys())
return set(self.asset_keys_in_dag_by_id.keys())

@cached_property
def task_id_map(self) -> Dict[str, Set[str]]:
"""Mapping of dag_id to set of task_ids in that dag. This only contains task ids mapped to assets in this object."""
task_id_map_data = {
dag_id: set(ta_map.keys()) for dag_id, ta_map in self.asset_key_map.items()
dag_id: set(ta_map.keys()) for dag_id, ta_map in self.asset_keys_by_task.items()
}
return defaultdict(set, task_id_map_data)

@cached_property
def asset_keys_per_dag_id(self) -> Dict[str, Set[AssetKey]]:
"""Mapping of dag_id to set of asset_keys in that dag. Does not include standlone dag assets."""
asset_keys_per_dag_data = {
dag_id: {
asset_key for asset_keys in task_to_asset_map.values() for asset_key in asset_keys
}
for dag_id, task_to_asset_map in self.asset_key_map.items()
}
return defaultdict(set, asset_keys_per_dag_data)
def asset_keys_in_dag_by_id(self) -> Dict[str, Set[AssetKey]]:
"""Mapping of dag_id to set of asset_keys which are materialized by that dag.
If assets within the dag are mapped to individual tasks, all of those assets will be included in this set.
If the dag itself is mapped to a set of assets, those assets will be included in this set.
"""
asset_keys_in_dag_by_id = defaultdict(set)
for dag_id, task_to_asset_map in self.asset_keys_by_task.items():
for asset_keys in task_to_asset_map.values():
asset_keys_in_dag_by_id[dag_id].update(asset_keys)
for dag_id, asset_keys in self.asset_keys_by_dag.items():
asset_keys_in_dag_by_id[dag_id].update(asset_keys)
return defaultdict(set, asset_keys_in_dag_by_id)

@cached_property
def asset_key_map(self) -> Dict[str, Dict[str, Set[AssetKey]]]:
def asset_keys_by_task(self) -> Dict[str, Dict[str, Set[AssetKey]]]:
"""Mapping of dag_id to task_id to set of asset_keys mapped from that task."""
asset_key_map: Dict[str, Dict[str, Set[AssetKey]]] = defaultdict(lambda: defaultdict(set))
for spec in self.asset_specs:
Expand All @@ -62,10 +74,20 @@ def asset_key_map(self) -> Dict[str, Dict[str, Set[AssetKey]]]:
asset_key_map[task_handle.dag_id][task_handle.task_id].add(spec.key)
return asset_key_map

@cached_property
def asset_keys_by_dag(self) -> Dict[str, Set[AssetKey]]:
"""Mapping of dag_id to set of asset_keys mapped from that dag."""
asset_key_map: Dict[str, Set[AssetKey]] = defaultdict(set)
for spec in self.asset_specs:
if is_dag_mapped_asset_spec(spec):
for dag_handle in dag_handles_for_spec(spec):
asset_key_map[dag_handle.dag_id].add(spec.key)
return asset_key_map

@cached_property
def task_handle_map(self) -> Dict[AssetKey, Set[TaskHandle]]:
task_handle_map = defaultdict(set)
for dag_id, asset_key_by_task_id in self.asset_key_map.items():
for dag_id, asset_key_by_task_id in self.asset_keys_by_task.items():
for task_id, asset_keys in asset_key_by_task_id.items():
for asset_key in asset_keys:
task_handle_map[asset_key].add(TaskHandle(dag_id=dag_id, task_id=task_id))
Expand Down Expand Up @@ -94,7 +116,15 @@ class FetchedAirflowData:
@cached_property
def all_mapped_tasks(self) -> Dict[AssetKey, AbstractSet[TaskHandle]]:
return {
spec.key: task_handles_for_spec(spec) for spec in self.mapping_info.mapped_asset_specs
spec.key: task_handles_for_spec(spec)
for spec in self.mapping_info.task_mapped_asset_specs
}

@cached_property
def all_mapped_dags(self) -> Dict[AssetKey, AbstractSet[DagHandle]]:
return {
spec.key: dag_handles_for_spec(spec)
for spec in self.mapping_info.dag_mapped_asset_specs
}


Expand Down Expand Up @@ -123,17 +153,21 @@ def compute_serialized_data(
fetched_airflow_data = fetch_all_airflow_data(airflow_instance, mapping_info)
return SerializedAirflowDefinitionsData(
instance_name=airflow_instance.name,
key_scoped_data_items=[
KeyScopedDataItem(asset_key=k, mapped_tasks=v)
key_scoped_task_handles=[
KeyScopedTaskHandles(asset_key=k, mapped_tasks=v)
for k, v in fetched_airflow_data.all_mapped_tasks.items()
],
key_scoped_dag_handles=[
KeyScopedDagHandles(asset_key=k, mapped_dags=v)
for k, v in fetched_airflow_data.all_mapped_dags.items()
],
dag_datas={
dag_id: SerializedDagData(
dag_id=dag_id,
dag_info=dag_info,
source_code=airflow_instance.get_dag_source_code(dag_info.metadata["file_token"]),
leaf_asset_keys=get_leaf_assets_for_dag(
asset_keys_in_dag=mapping_info.asset_keys_per_dag_id[dag_id],
asset_keys_in_dag=mapping_info.asset_keys_in_dag_by_id[dag_id],
downstreams_asset_dependency_graph=mapping_info.downstream_deps,
),
task_infos=fetched_airflow_data.task_info_map[dag_id],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dagster_airlift.constants import AUTOMAPPED_TASK_METADATA_KEY
from dagster_airlift.core.dag_asset import dag_asset_metadata, dag_description
from dagster_airlift.core.serialization.serialized_data import (
DagHandle,
SerializedAirflowDefinitionsData,
SerializedDagData,
TaskHandle,
Expand All @@ -40,7 +41,7 @@ def metadata_for_mapped_tasks(
return task_level_metadata


def enrich_spec_with_airflow_metadata(
def enrich_spec_with_airflow_task_metadata(
spec: AssetSpec,
tasks: AbstractSet[TaskHandle],
serialized_data: SerializedAirflowDefinitionsData,
Expand All @@ -50,6 +51,24 @@ def enrich_spec_with_airflow_metadata(
)


def metadata_for_mapped_dags(
dags: AbstractSet[DagHandle], serialized_data: SerializedAirflowDefinitionsData
) -> Mapping[str, Any]:
mapped_dag = next(iter(dags))
dag_info = serialized_data.dag_datas[mapped_dag.dag_id].dag_info
return dag_asset_metadata(dag_info, serialized_data.dag_datas[mapped_dag.dag_id].source_code)


def enrich_spec_with_airflow_dag_metadata(
spec: AssetSpec,
dags: AbstractSet[DagHandle],
serialized_data: SerializedAirflowDefinitionsData,
) -> AssetSpec:
return spec._replace(
metadata={**spec.metadata, **metadata_for_mapped_dags(dags, serialized_data)},
)


def make_dag_external_asset(instance_name: str, dag_data: SerializedDagData) -> AssetsDefinition:
return external_asset_from_spec(
AssetSpec(
Expand All @@ -68,12 +87,15 @@ def get_airflow_data_to_spec_mapper(
"""Creates a mapping function s.t. if there is airflow data applicable to the asset key, transform the spec and apply the data."""

def _fn(spec: AssetSpec) -> AssetSpec:
mapped_tasks = serialized_data.all_mapped_tasks.get(spec.key)
return (
enrich_spec_with_airflow_metadata(spec, mapped_tasks, serialized_data)
if mapped_tasks
else spec
)
if spec.key in serialized_data.all_mapped_tasks:
return enrich_spec_with_airflow_task_metadata(
spec, serialized_data.all_mapped_tasks[spec.key], serialized_data
)
elif spec.key in serialized_data.all_mapped_dags:
return enrich_spec_with_airflow_dag_metadata(
spec, serialized_data.all_mapped_dags[spec.key], serialized_data
)
return spec

return _fn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,18 @@ class SerializedDagData:

@whitelist_for_serdes
@record
class KeyScopedDataItem:
class KeyScopedTaskHandles:
asset_key: AssetKey
mapped_tasks: AbstractSet[TaskHandle]


@whitelist_for_serdes
@record
class KeyScopedDagHandles:
asset_key: AssetKey
mapped_dags: AbstractSet[DagHandle]


###################################################################################################
# Serializable data that will be cached to avoid repeated calls to the Airflow API, and to avoid
# repeated scans of passed-in Definitions objects.
Expand All @@ -90,9 +97,14 @@ class KeyScopedDataItem:
@record
class SerializedAirflowDefinitionsData:
instance_name: str
key_scoped_data_items: List[KeyScopedDataItem]
key_scoped_task_handles: List[KeyScopedTaskHandles]
key_scoped_dag_handles: List[KeyScopedDagHandles]
dag_datas: Mapping[str, SerializedDagData]

@cached_property
def all_mapped_tasks(self) -> Dict[AssetKey, AbstractSet[TaskHandle]]:
return {item.asset_key: item.mapped_tasks for item in self.key_scoped_data_items}
return {item.asset_key: item.mapped_tasks for item in self.key_scoped_task_handles}

@cached_property
def all_mapped_dags(self) -> Dict[AssetKey, AbstractSet[DagHandle]]:
return {item.asset_key: item.mapped_dags for item in self.key_scoped_dag_handles}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def metadata_for_task_mapping(*, task_id: str, dag_id: str) -> dict:
return {TASK_MAPPING_METADATA_KEY: [{"dag_id": dag_id, "task_id": task_id}]}


def metadata_for_dag_mapping(*, dag_id: str) -> dict:
return {DAG_MAPPING_METADATA_KEY: [{"dag_id": dag_id}]}


def get_metadata_key(instance_name: str) -> str:
return f"{AIRFLOW_SOURCE_METADATA_KEY_PREFIX}/{instance_name}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
build_defs_from_airflow_instance as build_defs_from_airflow_instance,
)
from dagster_airlift.core.sensor.event_translation import DagsterEventTransformerFn
from dagster_airlift.core.utils import metadata_for_task_mapping
from dagster_airlift.core.utils import metadata_for_dag_mapping, metadata_for_task_mapping
from dagster_airlift.test import make_dag_run, make_instance


Expand All @@ -41,12 +41,14 @@ def fully_loaded_repo_from_airflow_asset_graph(
assets_per_task: Dict[str, Dict[str, List[Tuple[str, List[str]]]]],
additional_defs: Definitions = Definitions(),
create_runs: bool = True,
dag_level_asset_overrides: Optional[Dict[str, List[str]]] = None,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
) -> RepositoryDefinition:
defs = load_definitions_airflow_asset_graph(
assets_per_task,
additional_defs=additional_defs,
create_runs=create_runs,
dag_level_asset_overrides=dag_level_asset_overrides,
event_transformer_fn=event_transformer_fn,
)
repo_def = defs.get_repository_def()
Expand All @@ -59,6 +61,7 @@ def load_definitions_airflow_asset_graph(
additional_defs: Definitions = Definitions(),
create_runs: bool = True,
create_assets_defs: bool = True,
dag_level_asset_overrides: Optional[Dict[str, List[str]]] = None,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
) -> Definitions:
assets = []
Expand All @@ -81,6 +84,23 @@ def _asset():
assets.append(_asset)
else:
assets.append(spec)
if dag_level_asset_overrides:
for dag_id, asset_keys in dag_level_asset_overrides.items():
dag_and_task_structure[dag_id] = ["dummy_task"]
for asset in asset_keys:
spec = AssetSpec(
AssetKey.from_user_string(asset),
metadata=metadata_for_dag_mapping(dag_id=dag_id),
)
if create_assets_defs:

@multi_asset(specs=[spec], name=f"{spec.key.to_python_identifier()}_asset")
def _asset():
return None

assets.append(_asset)
else:
assets.append(spec)
runs = (
[
make_dag_run(
Expand Down Expand Up @@ -112,10 +132,14 @@ def build_and_invoke_sensor(
assets_per_task: Dict[str, Dict[str, List[Tuple[str, List[str]]]]],
instance: DagsterInstance,
additional_defs: Definitions = Definitions(),
dag_level_asset_overrides: Optional[Dict[str, List[str]]] = None,
event_transformer_fn: Optional[DagsterEventTransformerFn] = None,
) -> Tuple[SensorResult, SensorEvaluationContext]:
repo_def = fully_loaded_repo_from_airflow_asset_graph(
assets_per_task, additional_defs=additional_defs, event_transformer_fn=event_transformer_fn
assets_per_task,
additional_defs=additional_defs,
dag_level_asset_overrides=dag_level_asset_overrides,
event_transformer_fn=event_transformer_fn,
)
sensor = next(iter(repo_def.sensor_defs))
sensor_context = build_sensor_context(repository_def=repo_def, instance=instance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_build_task_mapping_info_no_mapping() -> None:
defs=Definitions(assets=[AssetSpec("asset1"), AssetSpec("asset2")])
)
assert len(spec_mapping_info.dag_ids) == 0
assert not (spec_mapping_info.asset_key_map)
assert not (spec_mapping_info.asset_keys_by_task)
assert not (spec_mapping_info.task_handle_map)


Expand All @@ -51,8 +51,8 @@ def test_build_single_task_spec() -> None:
)
assert spec_mapping_info.dag_ids == {"dag1"}
assert spec_mapping_info.task_id_map == {"dag1": {"task1"}}
assert spec_mapping_info.asset_keys_per_dag_id == {"dag1": {ak("asset1")}}
assert spec_mapping_info.asset_key_map == {"dag1": {"task1": {ak("asset1")}}}
assert spec_mapping_info.asset_keys_in_dag_by_id == {"dag1": {ak("asset1")}}
assert spec_mapping_info.asset_keys_by_task == {"dag1": {"task1": {ak("asset1")}}}
assert spec_mapping_info.task_handle_map == {
ak("asset1"): set([TaskHandle(dag_id="dag1", task_id="task1")])
}
Expand All @@ -72,11 +72,11 @@ def test_task_with_multiple_assets() -> None:

assert spec_mapping_info.dag_ids == {"dag1", "dag2"}
assert spec_mapping_info.task_id_map == {"dag1": {"task1"}, "dag2": {"task1"}}
assert spec_mapping_info.asset_keys_per_dag_id == {
assert spec_mapping_info.asset_keys_in_dag_by_id == {
"dag1": {ak("asset1"), ak("asset2"), ak("asset3")},
"dag2": {ak("asset4")},
}
assert spec_mapping_info.asset_key_map == {
assert spec_mapping_info.asset_keys_by_task == {
"dag1": {"task1": {ak("asset1"), ak("asset2"), ak("asset3")}},
"dag2": {"task1": {ak("asset4")}},
}
Expand Down Expand Up @@ -106,12 +106,12 @@ def test_map_multiple_tasks_to_single_asset() -> None:

assert spec_mapping_info.dag_ids == {"dag1", "dag2"}
assert spec_mapping_info.task_id_map == {"dag1": {"task1"}, "dag2": {"task1"}}
assert spec_mapping_info.asset_keys_per_dag_id == {
assert spec_mapping_info.asset_keys_in_dag_by_id == {
"dag1": {ak("asset1")},
"dag2": {ak("asset1")},
}

assert spec_mapping_info.asset_key_map == {
assert spec_mapping_info.asset_keys_by_task == {
"dag1": {"task1": {ak("asset1")}},
"dag2": {"task1": {ak("asset1")}},
}
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_produce_fetched_airflow_data() -> None:
mapping_info=mapping_info,
)

assert len(fetched_airflow_data.mapping_info.mapped_asset_specs) == 1
assert len(fetched_airflow_data.mapping_info.task_mapped_asset_specs) == 1
assert len(fetched_airflow_data.mapping_info.asset_specs) == 2
assert fetched_airflow_data.mapping_info.downstream_deps == {ak("asset1"): {ak("asset2")}}

Expand Down
Loading

0 comments on commit 1b8156c

Please sign in to comment.