Skip to content

Commit

Permalink
[dagster-airlift][rfc] Proxy operator launches partitioned runs (#25324)
Browse files Browse the repository at this point in the history
## Summary & Motivation
Adds a pluggable implementation to BaseAssetsOperator which handles
mapping the current airflow run to a partitioned run in Dagster. By
default, we do the same thing that we do in the sensor - we attempt to
map the logical date directly to a partition.

Important points to note:
- I make the simplifying assumption that all assets within a given task
share the same partitions definition. This makes it so that we can keep
to the "one run" constraint from a previous PR.
- There's two points of pluggability that I think make sense to expose.
The first is the method get_partition_key(context, partition_keys),
which allows users to pick a partition key from the list to use.
The second is a pluggable default implementation
translate_logical_date_to_partition_key, which takes a list of partition
key formats. This is to support TimeWindowPartitionsDefinitions that use
a custom format / cron schedule without needing to do a full
reimplementation. All they would do is override get_partition_key to
call translate_logical_date_to_partition_key with their custom format.
## How I Tested These Changes
Added a new test which takes a daily dag and constructs a daily
partitioned materialization. Might be worth testing all the other
formatting cases, as well as pluggability.
## Changelog
NOCHANGELOG
  • Loading branch information
dpeng817 authored Oct 23, 2024
1 parent 4ce255b commit 6c307d7
Show file tree
Hide file tree
Showing 15 changed files with 306 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,17 @@ def get_dag_run(self, dag_id: str, run_id: str) -> "DagRun":
metadata=response.json(),
)

def unpause_dag(self, dag_id: str) -> None:
response = self.auth_backend.get_session().patch(
f"{self.get_api_url()}/dags",
json={"is_paused": False},
params={"dag_id_pattern": dag_id},
)
if response.status_code != 200:
raise DagsterError(
f"Failed to unpause dag {dag_id}. Status code: {response.status_code}, Message: {response.text}"
)

def wait_for_run_completion(self, dag_id: str, run_id: str, timeout: int = 30) -> None:
start_time = get_current_datetime()
while get_current_datetime() - start_time < datetime.timedelta(seconds=timeout):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import time
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, Iterable, Mapping, Sequence, Tuple

import requests
Expand All @@ -12,6 +13,11 @@
from dagster_airlift.constants import DAG_ID_TAG_KEY, DAG_RUN_ID_TAG_KEY, TASK_ID_TAG_KEY

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION, VERIFICATION_QUERY
from .partition_utils import (
PARTITION_NAME_TAG,
PartitioningInformation,
translate_logical_date_to_partition_key,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,6 +68,19 @@ def filter_asset_nodes(
) -> Iterable[Mapping[str, Any]]:
"""Filters the asset nodes to only include those that should be triggered by the current task."""

def get_partition_key(
self, context: Context, partitioning_info: PartitioningInformation
) -> str:
"""Overrideable method to determine the partition key to use to trigger the dagster run.
This method will only be called if the underlying asset is partitioned.
"""
if not partitioning_info:
return None
return translate_logical_date_to_partition_key(
self.get_airflow_logical_date(context), partitioning_info
)

def get_valid_graphql_response(self, response: Response, key: str) -> Any:
response_json = response.json()
if not response_json.get("data"):
Expand Down Expand Up @@ -128,6 +147,9 @@ def get_airflow_dag_id(self, context: Context) -> str:
def get_airflow_task_id(self, context: Context) -> str:
return self.get_attribute_from_airflow_context(context, "task").task_id

def get_airflow_logical_date(self, context: Context) -> datetime:
return self.get_attribute_from_airflow_context(context, "logical_date")

def default_dagster_run_tags(self, context: Context) -> Dict[str, str]:
return {
DAG_ID_TAG_KEY: self.get_airflow_dag_id(context),
Expand Down Expand Up @@ -165,12 +187,17 @@ def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> N
job_identifier = _get_implicit_job_identifier(next(iter(filtered_asset_nodes)))
asset_key_paths = [asset_node["assetKey"]["path"] for asset_node in filtered_asset_nodes]
logger.info(f"Triggering run for {job_identifier} with assets {asset_key_paths}")
tags = self.default_dagster_run_tags(context)
partitioning_info = PartitioningInformation.from_asset_node_graphql(filtered_asset_nodes)
if partitioning_info:
tags[PARTITION_NAME_TAG] = self.get_partition_key(context, partitioning_info)
logger.info(f"Using tags {tags}")
run_id = self.launch_dagster_run(
context,
session,
dagster_url,
_build_dagster_run_execution_params(
self.default_dagster_run_tags(context),
tags,
job_identifier,
asset_key_paths=asset_key_paths,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
}
}
}
isPartitioned
partitionDefinition {
type
name
fmt
}
partitionKeys
}
}
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from datetime import (
datetime,
timezone as tz,
)
from enum import Enum
from typing import Any, Mapping, NamedTuple, Optional, Sequence

PARTITION_NAME_TAG = "dagster/partition"


class PartitionDefinitionType(Enum):
TIME_WINDOW = "TIME_WINDOW"
STATIC = "STATIC"
MULTIPARTITIONED = "MULTIPARTITIONED"
DYNAMIC = "DYNAMIC"


class TimeWindowPartitioningInformation(NamedTuple):
fmt: str


class PartitioningInformation(NamedTuple):
partitioning_type: PartitionDefinitionType
partition_keys: Sequence[str]
# Eventually we can add more of these for different partitioning types
additional_info: Optional[TimeWindowPartitioningInformation]

@staticmethod
def from_asset_node_graphql(
asset_nodes: Sequence[Mapping[str, Any]],
) -> Optional["PartitioningInformation"]:
assets_partitioned = [_asset_is_partitioned(asset_node) for asset_node in asset_nodes]
if any(assets_partitioned) and not all(assets_partitioned):
raise Exception(
"Found some unpartitioned assets and some partitioned assets in the same task. "
"For a given task, all assets must have the same partitions definition. "
)
partition_keys_per_asset = [
set(asset_node["partitionKeys"])
for asset_node in asset_nodes
if asset_node["isPartitioned"]
]
if not all_sets_equal(partition_keys_per_asset):
raise Exception(
"Found differing partition keys across assets in this task. "
"For a given task, all assets must have the same partitions definition. "
)
# Now we can proceed with the assumption that all assets are partitioned and have the same partition keys.
# This, we only look at the first asset node.
asset_node = next(iter(asset_nodes))
if not asset_node["isPartitioned"]:
return None
partitioning_type = PartitionDefinitionType(asset_node["partitionDefinition"]["type"])
return PartitioningInformation(
partitioning_type=partitioning_type,
partition_keys=asset_node["partitionKeys"],
additional_info=_build_additional_info_for_type(asset_node, partitioning_type),
)

@property
def time_window_partitioning_info(self) -> TimeWindowPartitioningInformation:
if self.partitioning_type != PartitionDefinitionType.TIME_WINDOW:
raise Exception(
f"Partitioning type is {self.partitioning_type}, but expected {PartitionDefinitionType.TIME_WINDOW}"
)
if self.additional_info is None:
raise Exception(
f"Partitioning type is {self.partitioning_type}, but no additional info was provided."
)
return self.additional_info


def _build_additional_info_for_type(
asset_node: Mapping[str, Any], partitioning_type: PartitionDefinitionType
) -> Optional[TimeWindowPartitioningInformation]:
if partitioning_type != PartitionDefinitionType.TIME_WINDOW:
return None
return TimeWindowPartitioningInformation(fmt=asset_node["partitionDefinition"]["fmt"])


def all_sets_equal(list_of_sets):
if not list_of_sets:
return True
return len(set.union(*list_of_sets)) == len(set.intersection(*list_of_sets))


def translate_logical_date_to_partition_key(
logical_date: datetime, partitioning_info: PartitioningInformation
) -> str:
if not partitioning_info.partitioning_type == PartitionDefinitionType.TIME_WINDOW:
raise Exception(
"Only time-window partitioned assets or non-partitioned assets are supported out of the box."
)
fmt = partitioning_info.time_window_partitioning_info.fmt
partitions_and_datetimes = [
(_get_partition_datetime(partition_key, fmt), partition_key)
for partition_key in partitioning_info.partition_keys
]
matching_partition = next(
(
partition_key
for datetime, partition_key in partitions_and_datetimes
if datetime.timestamp() == logical_date.timestamp()
),
None,
)
if matching_partition is None:
raise Exception(f"No partition key found for logical date {logical_date}")
return matching_partition


def _asset_is_partitioned(asset_node: Mapping[str, Any]) -> bool:
return asset_node["isPartitioned"]


def _get_partition_datetime(partition_key: str, fmt: str) -> datetime:
try:
return _add_default_utc_timezone_if_none(datetime.strptime(partition_key, fmt))
except ValueError:
raise Exception(f"Could not parse partition key {partition_key} with format {fmt}.")


def _add_default_utc_timezone_if_none(dt: datetime) -> datetime:
return dt.replace(tzinfo=tz.utc) if dt.tzinfo is None else dt
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,15 @@ def dagster_dev_cmd(dagster_defs_path: str) -> List[str]:


@pytest.fixture(name="dagster_dev")
def setup_dagster(dagster_home: str, dagster_dev_cmd: List[str]) -> Generator[Any, None, None]:
def setup_dagster(
airflow_instance: None, dagster_home: str, dagster_dev_cmd: List[str]
) -> Generator[Any, None, None]:
with stand_up_dagster(dagster_dev_cmd) as process:
yield process


@contextmanager
def stand_up_dagster(dagster_dev_cmd: List[str]) -> Generator[subprocess.Popen, None, None]:
"""Stands up a dagster instance using the dagster dev CLI. dagster_defs_path must be provided
by a fixture included in the callsite.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dev_install:
uv pip install -e ../../../dagster-airlift
uv pip install -e .

setup_local_env:
setup_local_env:
$(MAKE) wipe
mkdir -p $(AIRFLOW_HOME)
mkdir -p $(DAGSTER_HOME)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from datetime import timedelta
from pathlib import Path

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster._time import get_current_datetime_midnight
from dagster_airlift.in_airflow import proxying_to_dagster
from dagster_airlift.in_airflow.proxied_state import load_proxied_state_from_yaml


def print_hello() -> None:
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"retries": 0,
}

with DAG(
dag_id="migrated_daily_interval_dag",
default_args=default_args,
schedule="@daily",
start_date=get_current_datetime_midnight() - timedelta(days=1),
# We pause this dag upon creation to avoid running it immediately
is_paused_upon_creation=True,
) as minute_dag:
PythonOperator(task_id="my_task", python_callable=print_hello)


proxying_to_dagster(
proxied_state=load_proxied_state_from_yaml(Path(__file__).parent / "proxied_state"),
global_vars=globals(),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tasks:
- id: my_task
proxied: True
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ def print_hello() -> None:
"retries": 1,
}


with DAG(
"simple_unproxied_dag",
default_args=default_args,
schedule_interval=None,
is_paused_upon_creation=False,
) as dag:
) as the_dag:
PythonOperator(task_id="print_task", python_callable=print_hello) >> PythonOperator(
task_id="downstream_print_task", python_callable=print_hello
) # type: ignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

from dagster import Definitions, asset, define_asset_job
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.assets import AssetsDefinition
Expand Down Expand Up @@ -60,6 +62,16 @@ def multi_job__c() -> None:
job2 = define_asset_job("job2", [multi_job__b, multi_job__c])


# Partitioned assets for migrated_daily_interval_dag
@asset(
partitions_def=DailyPartitionsDefinition(
start_date=get_current_datetime_midnight() - timedelta(days=2)
)
)
def migrated_daily_interval_dag__partitioned() -> None:
print("Materialized daily_interval_dag__partitioned")


def build_mapped_defs() -> Definitions:
return build_defs_from_airflow_instance(
airflow_instance=local_airflow_instance(),
Expand Down Expand Up @@ -129,6 +141,12 @@ def build_mapped_defs() -> Definitions:
),
jobs=[job1, job2],
),
Definitions(
assets=assets_with_task_mappings(
dag_id="migrated_daily_interval_dag",
task_mappings={"my_task": [migrated_daily_interval_dag__partitioned]},
),
),
),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import subprocess
import time
from pathlib import Path
from typing import Generator

import pytest
from dagster._core.test_utils import environ
from dagster_airlift.core.airflow_instance import AirflowInstance
from dagster_airlift.test.shared_fixtures import stand_up_airflow


Expand Down Expand Up @@ -41,3 +43,19 @@ def airflow_instance_fixture(local_env: None) -> Generator[subprocess.Popen, Non
airflow_cmd=["make", "run_airflow"], env=os.environ, cwd=makefile_dir()
) as process:
yield process


def poll_for_airflow_run_existence_and_completion(
af_instance: AirflowInstance, dag_id: str, af_run_id: str, duration: int
) -> None:
start_time = time.time()
while time.time() - start_time < duration:
try:
af_instance.wait_for_run_completion(
dag_id=dag_id, run_id=af_run_id, timeout=int(time.time() - start_time)
)
return
# Run may not exist yet
except Exception:
time.sleep(0.1)
continue
Loading

0 comments on commit 6c307d7

Please sign in to comment.