From 94dcbc1150d306c874da9579977063c1003241a3 Mon Sep 17 00:00:00 2001 From: Peter Ke Date: Sat, 2 Nov 2024 10:42:38 -0700 Subject: [PATCH] Implement query builder Signed-off-by: Peter Ke --- crates/core/src/delta_datafusion/mod.rs | 7 ++- python/deltalake/__init__.py | 1 + python/deltalake/_internal.pyi | 5 +++ python/deltalake/query.py | 25 +++++++++++ python/src/error.rs | 8 ++++ python/src/lib.rs | 14 +++++- python/src/query.rs | 59 +++++++++++++++++++++++++ python/tests/test_table_read.py | 54 ++++++++++++++++++++++ 8 files changed, 169 insertions(+), 4 deletions(-) create mode 100644 python/deltalake/query.py create mode 100644 python/src/query.rs diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 5fba1bd608..c945a28946 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -824,9 +824,12 @@ impl TableProvider for DeltaTableProvider { fn supports_filters_pushdown( &self, - _filter: &[&Expr], + filter: &[&Expr], ) -> DataFusionResult> { - Ok(vec![TableProviderFilterPushDown::Inexact]) + Ok(filter + .iter() + .map(|_| TableProviderFilterPushDown::Inexact) + .collect()) } fn statistics(&self) -> Option { diff --git a/python/deltalake/__init__.py b/python/deltalake/__init__.py index 43997076b2..6e96a68afe 100644 --- a/python/deltalake/__init__.py +++ b/python/deltalake/__init__.py @@ -2,6 +2,7 @@ from ._internal import __version__ as __version__ from ._internal import rust_core_version as rust_core_version from .data_catalog import DataCatalog as DataCatalog +from .query import QueryBuilder from .schema import DataType as DataType from .schema import Field as Field from .schema import Schema as Schema diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 66b5dc8f8f..3f51cf90e1 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -873,6 +873,11 @@ class DeltaFileSystemHandler: ) -> ObjectOutputStream: """Open an output stream for sequential writing.""" +class PyQueryBuilder: + def __init__(self) -> None: ... + def register(self, table_name: str, delta_table: RawDeltaTable): ... + def execute(self, sql: str) -> List[pyarrow.RecordBatch]: ... + class DeltaDataChecker: def __init__(self, invariants: List[Tuple[str, str]]) -> None: ... def check_batch(self, batch: pyarrow.RecordBatch) -> None: ... diff --git a/python/deltalake/query.py b/python/deltalake/query.py new file mode 100644 index 0000000000..884383a5c5 --- /dev/null +++ b/python/deltalake/query.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import List + +import pyarrow + +from deltalake._internal import PyQueryBuilder +from deltalake.table import DeltaTable + + +class QueryBuilder: + def __init__(self) -> None: + self._query_builder = PyQueryBuilder() + + def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder: + """Add a table to the query builder.""" + self._query_builder.register( + table_name=table_name, + delta_table=delta_table._table, + ) + return self + + def execute(self, sql: str) -> List[pyarrow.RecordBatch]: + """Execute the query and return a list of record batches.""" + return self._query_builder.execute(sql) diff --git a/python/src/error.rs b/python/src/error.rs index a54b1e60b4..b1d22fc7ca 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -1,4 +1,5 @@ use arrow_schema::ArrowError; +use deltalake::datafusion::error::DataFusionError; use deltalake::protocol::ProtocolError; use deltalake::{errors::DeltaTableError, ObjectStoreError}; use pyo3::exceptions::{ @@ -79,6 +80,10 @@ fn checkpoint_to_py(err: ProtocolError) -> PyErr { } } +fn datafusion_to_py(err: DataFusionError) -> PyErr { + DeltaError::new_err(err.to_string()) +} + #[derive(thiserror::Error, Debug)] pub enum PythonError { #[error("Error in delta table")] @@ -89,6 +94,8 @@ pub enum PythonError { Arrow(#[from] ArrowError), #[error("Error in checkpoint")] Protocol(#[from] ProtocolError), + #[error("Error in data fusion")] + DataFusion(#[from] DataFusionError), } impl From for pyo3::PyErr { @@ -98,6 +105,7 @@ impl From for pyo3::PyErr { PythonError::ObjectStore(err) => object_store_to_py(err), PythonError::Arrow(err) => arrow_to_py(err), PythonError::Protocol(err) => checkpoint_to_py(err), + PythonError::DataFusion(err) => datafusion_to_py(err), } } } diff --git a/python/src/lib.rs b/python/src/lib.rs index 005076c719..59fbf2030f 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -2,12 +2,14 @@ mod error; mod features; mod filesystem; mod merge; +mod query; mod schema; mod utils; use std::collections::{HashMap, HashSet}; use std::future::IntoFuture; use std::str::FromStr; +use std::sync::Arc; use std::time; use std::time::{SystemTime, UNIX_EPOCH}; @@ -17,12 +19,18 @@ use delta_kernel::expressions::Scalar; use delta_kernel::schema::StructField; use deltalake::arrow::compute::concat_batches; use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; +use deltalake::arrow::pyarrow::ToPyArrow; use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator}; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::checkpoints::{cleanup_metadata, create_checkpoint}; +use deltalake::datafusion::datasource::provider_as_source; +use deltalake::datafusion::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use deltalake::datafusion::physical_plan::ExecutionPlan; -use deltalake::datafusion::prelude::SessionContext; -use deltalake::delta_datafusion::DeltaDataChecker; +use deltalake::datafusion::prelude::{DataFrame, SessionContext}; +use deltalake::delta_datafusion::{ + DataFusionMixins, DeltaDataChecker, DeltaScanConfigBuilder, DeltaSessionConfig, + DeltaTableProvider, +}; use deltalake::errors::DeltaTableError; use deltalake::kernel::{ scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, Transaction, @@ -66,6 +74,7 @@ use crate::error::PythonError; use crate::features::TableFeatures; use crate::filesystem::FsConfig; use crate::merge::PyMergeBuilder; +use crate::query::PyQueryBuilder; use crate::schema::{schema_to_pyobject, Field}; use crate::utils::rt; @@ -2069,6 +2078,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { )?)?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/python/src/query.rs b/python/src/query.rs new file mode 100644 index 0000000000..6bb7ae0166 --- /dev/null +++ b/python/src/query.rs @@ -0,0 +1,59 @@ +use std::sync::Arc; + +use deltalake::{ + arrow::pyarrow::ToPyArrow, + datafusion::prelude::SessionContext, + delta_datafusion::{DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider}, +}; +use pyo3::prelude::*; + +use crate::{error::PythonError, utils::rt, RawDeltaTable}; + +#[pyclass(module = "deltalake._internal")] +pub(crate) struct PyQueryBuilder { + _ctx: SessionContext, +} + +#[pymethods] +impl PyQueryBuilder { + #[new] + pub fn new() -> Self { + let config = DeltaSessionConfig::default().into(); + let _ctx = SessionContext::new_with_config(config); + + PyQueryBuilder { _ctx } + } + + pub fn register(&self, table_name: &str, delta_table: &RawDeltaTable) -> PyResult<()> { + let snapshot = delta_table._table.snapshot().map_err(PythonError::from)?; + let log_store = delta_table._table.log_store(); + + let scan_config = DeltaScanConfigBuilder::default() + .with_parquet_pushdown(false) + .build(snapshot) + .map_err(PythonError::from)?; + + let provider = Arc::new( + DeltaTableProvider::try_new(snapshot.clone(), log_store, scan_config) + .map_err(PythonError::from)?, + ); + + self._ctx + .register_table(table_name, provider) + .map_err(PythonError::from)?; + + Ok(()) + } + + pub fn execute(&self, py: Python, sql: &str) -> PyResult { + let batches = py.allow_threads(|| { + rt().block_on(async { + let df = self._ctx.sql(sql).await?; + df.collect().await + }) + .map_err(PythonError::from) + })?; + + batches.to_pyarrow(py) + } +} diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 5ff07ed9e8..30d7f21d7f 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -9,6 +9,7 @@ from deltalake._util import encode_partition_value from deltalake.exceptions import DeltaProtocolError +from deltalake.query import QueryBuilder from deltalake.table import ProtocolVersions from deltalake.writer import write_deltalake @@ -946,3 +947,56 @@ def test_is_deltatable_with_storage_opts(): "DELTA_DYNAMO_TABLE_NAME": "custom_table_name", } assert DeltaTable.is_deltatable(table_path, storage_options=storage_options) + + +def test_read_query_builder(): + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = { + "value": ["4", "5", "6", "7"], + "year": ["2021", "2021", "2021", "2021"], + "month": ["4", "12", "12", "12"], + "day": ["5", "4", "20", "20"], + } + actual = pa.Table.from_batches( + QueryBuilder() + .register("tbl", dt) + .execute("SELECT * FROM tbl WHERE year >= 2021 ORDER BY value") + ).to_pydict() + assert expected == actual + + +def test_read_query_builder_join_multiple_tables(tmp_path): + table_path = "../crates/test/tests/data/delta-0.8.0-date" + dt1 = DeltaTable(table_path) + + write_deltalake( + tmp_path, + pa.table( + { + "date": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-12-31"], + "value": ["a", "b", "c", "d"], + } + ), + ) + dt2 = DeltaTable(tmp_path) + + expected = { + "date": ["2021-01-01", "2021-01-02", "2021-01-03"], + "dayOfYear": [1, 2, 3], + "value": ["a", "b", "c"], + } + actual = pa.Table.from_batches( + QueryBuilder() + .register("tbl1", dt1) + .register("tbl2", dt2) + .execute( + """ + SELECT tbl2.date, tbl1.dayOfYear, tbl2.value + FROM tbl1 + INNER JOIN tbl2 ON tbl1.date = tbl2.date + ORDER BY tbl1.date + """ + ) + ).to_pydict() + assert expected == actual