Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(): Task links #216

Merged
merged 10 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
be passed as a parameter into any task instead of any other pipedag table reference.
- Fixed bug that caused a crash when retrieving a polars dataframe from SQL using polars >= 1
- Fix warning about `ignore_position_hashes` being printed even if the flag was not set.
- Added support for `inputs` argument for `flow.run()` allowing to pass `ExternalTableReference`
objects to the flow that override the outputs of selected tasks.

## 0.9.5 (2024-07-22)
- Fixed a bug in primary key generation when materializing pandas dataframe to postgres database
Expand Down
18 changes: 9 additions & 9 deletions docs/source/examples/raw_sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import sqlalchemy as sa

from pydiverse.pipedag import Flow, Stage, materialize
from pydiverse.pipedag.context import ConfigContext, StageLockContext
from pydiverse.pipedag.materialize.container import RawSql
from pydiverse.pipedag.container import RawSql
from tests.fixtures.instances import with_instances

"""
Expand All @@ -38,13 +38,13 @@ they can be gradually converted from text SQL to programmatically created SQL (p

@materialize(input_type=sa.Table, lazy=True)
def tsql(
name: str,
script_directory: Path,
*,
out_stage: Stage | None = None,
in_sql=None,
helper_sql=None,
depend=None,
name: str,
script_directory: Path,
*,
out_stage: Stage | None = None,
in_sql=None,
helper_sql=None,
depend=None,
):
_ = depend # only relevant for adding additional task dependency
script_path = script_directory / name
Expand All @@ -56,7 +56,7 @@ def tsql(


def raw_sql_bind_schema(
sql, prefix: str, stage: Stage | RawSql | None, *, transaction=False
sql, prefix: str, stage: Stage | RawSql | None, *, transaction=False
):
if isinstance(stage, RawSql):
stage = stage.stage
Expand Down
12 changes: 7 additions & 5 deletions src/pydiverse/pipedag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from __future__ import annotations

from .container import (
Blob,
ExternalTableReference,
RawSql,
Schema,
Table,
)
from .context import ConfigContext, StageLockContext
from .core import (
Flow,
Expand All @@ -11,11 +18,6 @@
VisualizationStyle,
)
from .materialize import (
Blob,
ExternalTableReference,
RawSql,
Schema,
Table,
input_stage_versions,
materialize,
)
Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/pipedag/backend/lock/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from pydiverse.pipedag.backend.table.sql.ddl import (
CreateSchema,
)
from pydiverse.pipedag.container import Schema
from pydiverse.pipedag.errors import LockError
from pydiverse.pipedag.materialize.container import Schema

DISABLE_DIALECT_REGISTRATION = "__DISABLE_DIALECT_REGISTRATION"

Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/pipedag/backend/table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

from pydiverse.pipedag import ConfigContext
from pydiverse.pipedag._typing import T, TableHookResolverT
from pydiverse.pipedag.container import RawSql, Table
from pydiverse.pipedag.context import RunContext, TaskContext
from pydiverse.pipedag.errors import CacheError
from pydiverse.pipedag.materialize.cache import TaskCacheInfo, lazy_table_cache_key
from pydiverse.pipedag.materialize.container import RawSql, Table
from pydiverse.pipedag.materialize.metadata import (
LazyTableMetadata,
RawSqlMetadata,
Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/pipedag/backend/table/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from pydiverse.pipedag import Stage
from pydiverse.pipedag._typing import T
from pydiverse.pipedag.backend.table.base import TableHookResolver
from pydiverse.pipedag.container import Table
from pydiverse.pipedag.context import RunContext
from pydiverse.pipedag.materialize.container import Table
from pydiverse.pipedag.materialize.core import MaterializingTask
from pydiverse.pipedag.util import Disposable

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from pydiverse.pipedag.backend.table.sql.sql import SQLTableStore
from pydiverse.pipedag.backend.table.util import DType
from pydiverse.pipedag.materialize.container import Schema
from pydiverse.pipedag.container import Schema
from pydiverse.pipedag.materialize.details import resolve_materialization_details_label

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from pydiverse.pipedag.backend.table.sql.reflection import PipedagDB2Reflection
from pydiverse.pipedag.backend.table.sql.sql import SQLTableStore
from pydiverse.pipedag.backend.table.util import DType
from pydiverse.pipedag.materialize import Table
from pydiverse.pipedag.materialize.container import Schema
from pydiverse.pipedag.container import Schema, Table
from pydiverse.pipedag.materialize.details import (
BaseMaterializationDetails,
resolve_materialization_details_label,
Expand Down
3 changes: 1 addition & 2 deletions src/pydiverse/pipedag/backend/table/sql/dialects/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from pydiverse.pipedag.backend.table.sql.reflection import PipedagMSSqlReflection
from pydiverse.pipedag.backend.table.sql.sql import SQLTableStore
from pydiverse.pipedag.backend.table.util import DType
from pydiverse.pipedag.materialize import Table
from pydiverse.pipedag.materialize.container import RawSql, Schema
from pydiverse.pipedag.container import RawSql, Schema, Table


class MSSqlTableStore(SQLTableStore):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
)
from pydiverse.pipedag.backend.table.sql.sql import SQLTableStore
from pydiverse.pipedag.backend.table.util import DType
from pydiverse.pipedag.materialize import Table
from pydiverse.pipedag.materialize.container import Schema
from pydiverse.pipedag.container import Schema, Table
from pydiverse.pipedag.materialize.details import (
BaseMaterializationDetails,
resolve_materialization_details_label,
Expand Down
3 changes: 1 addition & 2 deletions src/pydiverse/pipedag/backend/table/sql/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
DType,
PandasDTypeBackend,
)
from pydiverse.pipedag.container import ExternalTableReference, Schema, Table
from pydiverse.pipedag.context import TaskContext
from pydiverse.pipedag.materialize import Table
from pydiverse.pipedag.materialize.container import ExternalTableReference, Schema
from pydiverse.pipedag.materialize.details import resolve_materialization_details_label
from pydiverse.pipedag.util.computation_tracing import ComputationTracer

Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/pipedag/backend/table/sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
RenameTable,
split_ddl_statement,
)
from pydiverse.pipedag.container import RawSql, Schema
from pydiverse.pipedag.context import RunContext
from pydiverse.pipedag.context.context import (
CacheValidationMode,
Expand All @@ -37,7 +38,6 @@
)
from pydiverse.pipedag.context.run_context import DeferredTableStoreOp
from pydiverse.pipedag.errors import CacheError
from pydiverse.pipedag.materialize.container import RawSql, Schema
from pydiverse.pipedag.materialize.core import MaterializingTask
from pydiverse.pipedag.materialize.metadata import (
LazyTableMetadata,
Expand Down
15 changes: 14 additions & 1 deletion src/pydiverse/pipedag/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pydot
import structlog

from pydiverse.pipedag import ExternalTableReference
from pydiverse.pipedag.context import (
ConfigContext,
DAGContext,
Expand Down Expand Up @@ -257,6 +258,7 @@ def run(
disable_cache_function: bool | None = None,
ignore_task_version: bool | None = None,
ignore_position_hashes: bool = False,
inputs: dict[Task | TaskGetItem, ExternalTableReference] | None = None,
**kwargs,
) -> Result:
"""Execute the flow.
Expand Down Expand Up @@ -305,6 +307,15 @@ def run(
And for this to work, any task producing an input
for the chosen subgraph may never be used more
than once per stage.
NOTE: This is only supported for the SequentialEngine and SQLTablestore
:param inputs:
Optionally provide the outputs for a subset of tasks.
The format is expected as
dict[Task|TaskGetItem, ExternalTableReference].
Every task that is listed in this mapping
will not be executed but instead the output,
will be read from the external reference.
NOTE: This is only supported when using the SQLTablestore at the moment
:param kwargs:
Other keyword arguments that get passed on directly to the
``run()`` method of the orchestration engine. Consequently, these
Expand Down Expand Up @@ -387,7 +398,9 @@ def run(
with config, RunContextServer(subflow, trace_hook):
if orchestration_engine is None:
orchestration_engine = config.create_orchestration_engine()
result = orchestration_engine.run(subflow, ignore_position_hashes, **kwargs)
result = orchestration_engine.run(
subflow, ignore_position_hashes, inputs, **kwargs
)

visualization_url = result.visualize_url()
self.logger.info("Flow visualization", url=visualization_url)
Expand Down
1 change: 0 additions & 1 deletion src/pydiverse/pipedag/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def task_result_mapper(x):
result = self.fn(*args, **kwargs)
DominikZuercherQC marked this conversation as resolved.
Show resolved Hide resolved
else:
result = self.fn(*args, **kwargs)

return result, task_context

def __compute_position_hash(self) -> str:
Expand Down
8 changes: 7 additions & 1 deletion src/pydiverse/pipedag/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from pydiverse.pipedag import ExternalTableReference, Task
from pydiverse.pipedag.core.task import TaskGetItem
from pydiverse.pipedag.util import Disposable

if TYPE_CHECKING:
Expand All @@ -14,7 +16,11 @@ class OrchestrationEngine(Disposable, ABC):

@abstractmethod
def run(
self, flow: Subflow, ignore_position_hashes: bool = False, **kwargs
self,
flow: Subflow,
ignore_position_hashes: bool = False,
inputs: dict[Task | TaskGetItem, ExternalTableReference] | None = None,
**kwargs,
) -> Result:
"""Execute a flow

Expand Down
29 changes: 25 additions & 4 deletions src/pydiverse/pipedag/engine/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import structlog

from pydiverse.pipedag import ExternalTableReference, Table
from pydiverse.pipedag.context import ConfigContext, RunContext
from pydiverse.pipedag.core import Result
from pydiverse.pipedag.core.task import TaskGetItem
from pydiverse.pipedag.engine.base import OrchestrationEngine
from pydiverse.pipedag.util import requires

Expand Down Expand Up @@ -45,7 +47,15 @@ def __init__(self, **dask_compute_kwargs):

self.dask_compute_kwargs.update(dask_compute_kwargs)

def run(self, flow: Subflow, ignore_position_hashes: bool = False, **run_kwargs):
def run(
self,
flow: Subflow,
ignore_position_hashes: bool = False,
inputs: dict[Task | TaskGetItem, ExternalTableReference] | None = None,
**run_kwargs,
):
inputs = inputs if inputs is not None else {}
_ = ignore_position_hashes
run_context = RunContext.get()
config_context = ConfigContext.get()

Expand All @@ -72,13 +82,24 @@ def run(parent_futures, **kwargs):
return dask.delayed(run, pure=False)

for task in flow.get_tasks():
task_inputs = {
**{
in_id: Table(inputs[in_t])
for in_id, in_t in task.input_tasks.items()
if in_t in inputs
},
**{
in_id: results[in_t]
for in_id, in_t in task.input_tasks.items()
if in_t not in inputs
},
}

results[task] = bind_run(task)(
parent_futures=[
results[parent] for parent in flow.get_parent_tasks(task)
],
inputs={
in_id: results[in_t] for in_id, in_t in task.input_tasks.items()
},
inputs=task_inputs,
run_context=run_context,
config_context=config_context,
)
Expand Down
Loading