Skip to content

Commit

Permalink
[dagster-airlift] Proxy partitioning working
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Oct 18, 2024
1 parent 0f05aee commit 2c055e5
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,17 @@ def trigger_dag(self, dag_id: str, logical_date: Optional[datetime.datetime] = N
)
return response.json()["dag_run_id"]

def unpause_dag(self, dag_id: str) -> None:
response = self.auth_backend.get_session().patch(
f"{self.get_api_url()}/dags",
json={"is_paused": False},
params={"dag_id_pattern": dag_id},
)
if response.status_code != 200:
raise DagsterError(
f"Failed to unpause dag {dag_id}. Status code: {response.status_code}, Message: {response.text}"
)

def get_dag_run(self, dag_id: str, run_id: str) -> "DagRun":
response = self.auth_backend.get_session().get(
f"{self.get_api_url()}/dags/{dag_id}/dagRuns/{run_id}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import os
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Mapping, Sequence, Tuple
from datetime import datetime
from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple

import requests
from airflow.models.operator import BaseOperator
Expand All @@ -12,6 +13,11 @@
from dagster_airlift.constants import DAG_ID_TAG_KEY, DAG_RUN_ID_TAG_KEY, TASK_ID_TAG_KEY

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION, VERIFICATION_QUERY
from .partition_utils import (
PARTITION_NAME_TAG,
PartitioningInformation,
translate_logical_date_to_partition_key,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,6 +68,16 @@ def filter_asset_nodes(
) -> Iterable[Mapping[str, Any]]:
"""Filters the asset nodes to only include those that should be triggered by the current task."""

def get_partition_key(
self, context: Context, partitioning_info: Optional[PartitioningInformation]
) -> Optional[str]:
"""Overrideable method to determine the partition key to use to trigger the dagster run."""
if not partitioning_info:
return None
return translate_logical_date_to_partition_key(
self.get_airflow_logical_date(context), partitioning_info
)

def get_valid_graphql_response(self, response: Response, key: str) -> Any:
response_json = response.json()
if not response_json.get("data"):
Expand Down Expand Up @@ -128,6 +144,9 @@ def get_airflow_dag_id(self, context: Context) -> str:
def get_airflow_task_id(self, context: Context) -> str:
return self.get_attribute_from_airflow_context(context, "task").task_id

def get_airflow_logical_date(self, context: Context) -> datetime:
return self.get_attribute_from_airflow_context(context, "logical_date")

def default_dagster_run_tags(self, context: Context) -> Dict[str, str]:
return {
DAG_ID_TAG_KEY: self.get_airflow_dag_id(context),
Expand Down Expand Up @@ -162,15 +181,23 @@ def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> N
"`dagster-airlift` expects that all assets mapped to a given task exist within the same code location, so that they can be executed by the same run."
)

partitioning_info = PartitioningInformation.from_asset_node_graphql(filtered_asset_nodes)
partition_key_for_run = self.get_partition_key(context, partitioning_info)
job_identifier = _get_implicit_job_identifier(next(iter(filtered_asset_nodes)))
asset_key_paths = [asset_node["assetKey"]["path"] for asset_node in filtered_asset_nodes]
logger.info(f"Triggering run for {job_identifier} with assets {asset_key_paths}")
tags = (
{**self.default_dagster_run_tags(context), PARTITION_NAME_TAG: partition_key_for_run}
if partition_key_for_run
else self.default_dagster_run_tags(context)
)
logger.info(f"Using tags {tags}")
run_id = self.launch_dagster_run(
context,
session,
dagster_url,
_build_dagster_run_execution_params(
self.default_dagster_run_tags(context),
tags,
job_identifier,
asset_key_paths=asset_key_paths,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
}
}
}
isPartitioned
partitionDefinition {
type
name
fmt
}
partitionKeys
}
}
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from datetime import (
datetime,
timezone as tz,
)
from enum import Enum
from typing import Any, Mapping, NamedTuple, Optional, Sequence

PARTITION_NAME_TAG = "dagster/partition"


class PartitionDefinitionType(Enum):
TIME_WINDOW = "TIME_WINDOW"
STATIC = "STATIC"
MULTIPARTITIONED = "MULTIPARTITIONED"
DYNAMIC = "DYNAMIC"


class TimeWindowPartitioningInformation(NamedTuple):
fmt: str


class PartitioningInformation(NamedTuple):
partitioning_type: PartitionDefinitionType
partition_keys: Sequence[str]
# Eventually we can add more of these for different partitioning types
additional_info: Optional[TimeWindowPartitioningInformation]

@staticmethod
def from_asset_node_graphql(
asset_nodes: Sequence[Mapping[str, Any]],
) -> Optional["PartitioningInformation"]:
assets_partitioned = [_asset_is_partitioned(asset_node) for asset_node in asset_nodes]
if any(assets_partitioned) and not all(assets_partitioned):
raise Exception(
"Found some unpartitioned assets and some partitioned assets in the same task. "
"For a given task, all assets must have the same partitions definition. "
)
partition_keys_per_asset = [
set(asset_node["partitionKeys"])
for asset_node in asset_nodes
if asset_node["isPartitioned"]
]
if not all_sets_equal(partition_keys_per_asset):
raise Exception(
"Found differing partition keys across assets in this task. "
"For a given task, all assets must have the same partitions definition. "
)
# Now we can proceed with the assumption that all assets are partitioned and have the same partition keys.
# This, we only look at the first asset node.
asset_node = next(iter(asset_nodes))
if not asset_node["isPartitioned"]:
return None
partitioning_type = PartitionDefinitionType(asset_node["partitionDefinition"]["type"])
return PartitioningInformation(
partitioning_type=partitioning_type,
partition_keys=asset_node["partitionKeys"],
additional_info=_build_additional_info_for_type(asset_node, partitioning_type),
)

@property
def time_window_partitioning_info(self) -> TimeWindowPartitioningInformation:
if self.partitioning_type != PartitionDefinitionType.TIME_WINDOW:
raise Exception(
f"Partitioning type is {self.partitioning_type}, but expected {PartitionDefinitionType.TIME_WINDOW}"
)
if self.additional_info is None:
raise Exception(
f"Partitioning type is {self.partitioning_type}, but no additional info was provided."
)
return self.additional_info


def _build_additional_info_for_type(
asset_node: Mapping[str, Any], partitioning_type: PartitionDefinitionType
) -> Optional[TimeWindowPartitioningInformation]:
if partitioning_type != PartitionDefinitionType.TIME_WINDOW:
return None
return TimeWindowPartitioningInformation(fmt=asset_node["partitionDefinition"]["fmt"])


def all_sets_equal(list_of_sets):
if not list_of_sets:
return True
return len(set.union(*list_of_sets)) == len(set.intersection(*list_of_sets))


def translate_logical_date_to_partition_key(
logical_date: datetime, partitioning_info: PartitioningInformation
) -> str:
if not partitioning_info.partitioning_type == PartitionDefinitionType.TIME_WINDOW:
raise Exception(
"Only time-window partitioned assets or non-partitioned assets are supported out of the box."
)
fmt = partitioning_info.time_window_partitioning_info.fmt
partitions_and_datetimes = [
(_get_partition_datetime(partition_key, fmt), partition_key)
for partition_key in partitioning_info.partition_keys
]
matching_partition = next(
(
partition_key
for datetime, partition_key in partitions_and_datetimes
if datetime.timestamp() == logical_date.timestamp()
),
None,
)
if matching_partition is None:
raise Exception(f"No partition key found for logical date {logical_date}")
return matching_partition


def _asset_is_partitioned(asset_node: Mapping[str, Any]) -> bool:
return asset_node["isPartitioned"]


def _get_partition_datetime(partition_key: str, fmt: str) -> datetime:
try:
return _add_default_utc_timezone_if_none(datetime.strptime(partition_key, fmt))
except ValueError:
raise Exception(f"Could not parse partition key {partition_key} with format {fmt}.")


def _add_default_utc_timezone_if_none(dt: datetime) -> datetime:
return dt.replace(tzinfo=tz.utc) if dt.tzinfo is None else dt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from datetime import timedelta
from pathlib import Path

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster._time import get_current_datetime_midnight
from dagster_airlift.in_airflow import proxying_to_dagster
from dagster_airlift.in_airflow.proxied_state import load_proxied_state_from_yaml


def print_hello() -> None:
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"retries": 0,
}

with DAG(
dag_id="migrated_daily_interval_dag",
default_args=default_args,
schedule="@daily",
start_date=get_current_datetime_midnight() - timedelta(days=1),
# We pause this dag upon creation to avoid running it immediately
is_paused_upon_creation=True,
) as minute_dag:
PythonOperator(task_id="my_task", python_callable=print_hello)


proxying_to_dagster(
proxied_state=load_proxied_state_from_yaml(Path(__file__).parent / "proxied_state"),
global_vars=globals(),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tasks:
- id: my_task
proxied: True
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

from dagster import Definitions, asset, define_asset_job
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.assets import AssetsDefinition
Expand Down Expand Up @@ -60,6 +62,16 @@ def multi_job__c() -> None:
job2 = define_asset_job("job2", [multi_job__b, multi_job__c])


# Partitioned assets for migrated_daily_interval_dag
@asset(
partitions_def=DailyPartitionsDefinition(
start_date=get_current_datetime_midnight() - timedelta(days=2)
)
)
def migrated_daily_interval_dag__partitioned() -> None:
print("Materialized daily_interval_dag__partitioned")


def build_mapped_defs() -> Definitions:
return build_defs_from_airflow_instance(
airflow_instance=local_airflow_instance(),
Expand Down Expand Up @@ -129,6 +141,12 @@ def build_mapped_defs() -> Definitions:
),
jobs=[job1, job2],
),
Definitions(
assets=assets_with_task_mappings(
dag_id="migrated_daily_interval_dag",
task_mappings={"my_task": [migrated_daily_interval_dag__partitioned]},
),
),
),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import subprocess
import time
from pathlib import Path
from typing import Generator

import pytest
from dagster._core.test_utils import environ
from dagster_airlift.core.airflow_instance import AirflowInstance
from dagster_airlift.test.shared_fixtures import stand_up_airflow


Expand Down Expand Up @@ -41,3 +43,19 @@ def airflow_instance_fixture(local_env: None) -> Generator[subprocess.Popen, Non
airflow_cmd=["make", "run_airflow"], env=os.environ, cwd=makefile_dir()
) as process:
yield process


def poll_for_airflow_run_existence_and_completion(
af_instance: AirflowInstance, dag_id: str, af_run_id: str, duration: int
) -> None:
start_time = time.time()
while time.time() - start_time < duration:
try:
af_instance.wait_for_run_completion(
dag_id=dag_id, run_id=af_run_id, timeout=int(time.time() - start_time)
)
return
# Run may not exist yet
except Exception:
time.sleep(0.1)
continue
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from dagster_airlift.constants import DAG_RUN_ID_TAG_KEY
from dagster_airlift.core.airflow_instance import AirflowInstance

from kitchen_sink_tests.integration_tests.conftest import makefile_dir
from kitchen_sink_tests.integration_tests.conftest import (
makefile_dir,
poll_for_airflow_run_existence_and_completion,
)


def poll_for_materialization(
Expand Down Expand Up @@ -262,3 +265,32 @@ def test_assets_multiple_jobs_same_task(
],
}
poll_for_expected_mats(af_instance, expected_mats_per_dag)


def test_partitioned_migrated(
airflow_instance: None,
dagster_dev: None,
dagster_home: str,
) -> None:
"""Test that partitioned assets are properly materialized from a proxied task."""
from kitchen_sink.dagster_defs.airflow_instance import local_airflow_instance

af_instance = local_airflow_instance()
af_instance.unpause_dag(dag_id="migrated_daily_interval_dag")
# Wait for dag run to exist
expected_logical_date = get_current_datetime_midnight() - timedelta(days=1)
expected_run_id = f"scheduled__{expected_logical_date.isoformat()}"
poll_for_airflow_run_existence_and_completion(
af_instance=af_instance,
dag_id="migrated_daily_interval_dag",
af_run_id=expected_run_id,
duration=30,
)
dagster_instance = DagsterInstance.get()
entry = poll_for_materialization(
dagster_instance=dagster_instance,
asset_key=AssetKey("migrated_daily_interval_dag__partitioned"),
)
assert entry.asset_materialization
assert entry.asset_materialization.partition
assert entry.asset_materialization.partition == expected_logical_date.strftime("%Y-%m-%d")
Loading

0 comments on commit 2c055e5

Please sign in to comment.