Skip to content

Commit

Permalink
Implement query builder
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Ke <[email protected]>
  • Loading branch information
PeterKeDer committed Nov 2, 2024
1 parent a999c92 commit 94dcbc1
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 4 deletions.
7 changes: 5 additions & 2 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,9 +824,12 @@ impl TableProvider for DeltaTableProvider {

fn supports_filters_pushdown(
&self,
_filter: &[&Expr],
filter: &[&Expr],
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
Ok(vec![TableProviderFilterPushDown::Inexact])
Ok(filter
.iter()
.map(|_| TableProviderFilterPushDown::Inexact)
.collect())
}

fn statistics(&self) -> Option<Statistics> {
Expand Down
1 change: 1 addition & 0 deletions python/deltalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ._internal import __version__ as __version__
from ._internal import rust_core_version as rust_core_version
from .data_catalog import DataCatalog as DataCatalog
from .query import QueryBuilder
from .schema import DataType as DataType
from .schema import Field as Field
from .schema import Schema as Schema
Expand Down
5 changes: 5 additions & 0 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,11 @@ class DeltaFileSystemHandler:
) -> ObjectOutputStream:
"""Open an output stream for sequential writing."""

class PyQueryBuilder:
def __init__(self) -> None: ...
def register(self, table_name: str, delta_table: RawDeltaTable): ...
def execute(self, sql: str) -> List[pyarrow.RecordBatch]: ...

class DeltaDataChecker:
def __init__(self, invariants: List[Tuple[str, str]]) -> None: ...
def check_batch(self, batch: pyarrow.RecordBatch) -> None: ...
Expand Down
25 changes: 25 additions & 0 deletions python/deltalake/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from typing import List

import pyarrow

from deltalake._internal import PyQueryBuilder
from deltalake.table import DeltaTable


class QueryBuilder:
def __init__(self) -> None:
self._query_builder = PyQueryBuilder()

def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder:
"""Add a table to the query builder."""
self._query_builder.register(
table_name=table_name,
delta_table=delta_table._table,
)
return self

def execute(self, sql: str) -> List[pyarrow.RecordBatch]:
"""Execute the query and return a list of record batches."""
return self._query_builder.execute(sql)
8 changes: 8 additions & 0 deletions python/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow_schema::ArrowError;
use deltalake::datafusion::error::DataFusionError;
use deltalake::protocol::ProtocolError;
use deltalake::{errors::DeltaTableError, ObjectStoreError};
use pyo3::exceptions::{
Expand Down Expand Up @@ -79,6 +80,10 @@ fn checkpoint_to_py(err: ProtocolError) -> PyErr {
}
}

fn datafusion_to_py(err: DataFusionError) -> PyErr {
DeltaError::new_err(err.to_string())
}

#[derive(thiserror::Error, Debug)]
pub enum PythonError {
#[error("Error in delta table")]
Expand All @@ -89,6 +94,8 @@ pub enum PythonError {
Arrow(#[from] ArrowError),
#[error("Error in checkpoint")]
Protocol(#[from] ProtocolError),
#[error("Error in data fusion")]
DataFusion(#[from] DataFusionError),
}

impl From<PythonError> for pyo3::PyErr {
Expand All @@ -98,6 +105,7 @@ impl From<PythonError> for pyo3::PyErr {
PythonError::ObjectStore(err) => object_store_to_py(err),
PythonError::Arrow(err) => arrow_to_py(err),
PythonError::Protocol(err) => checkpoint_to_py(err),
PythonError::DataFusion(err) => datafusion_to_py(err),
}
}
}
14 changes: 12 additions & 2 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ mod error;
mod features;
mod filesystem;
mod merge;
mod query;
mod schema;
mod utils;

use std::collections::{HashMap, HashSet};
use std::future::IntoFuture;
use std::str::FromStr;
use std::sync::Arc;
use std::time;
use std::time::{SystemTime, UNIX_EPOCH};

Expand All @@ -17,12 +19,18 @@ use delta_kernel::expressions::Scalar;
use delta_kernel::schema::StructField;
use deltalake::arrow::compute::concat_batches;
use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
use deltalake::arrow::pyarrow::ToPyArrow;
use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator};
use deltalake::arrow::{self, datatypes::Schema as ArrowSchema};
use deltalake::checkpoints::{cleanup_metadata, create_checkpoint};
use deltalake::datafusion::datasource::provider_as_source;
use deltalake::datafusion::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
use deltalake::datafusion::physical_plan::ExecutionPlan;
use deltalake::datafusion::prelude::SessionContext;
use deltalake::delta_datafusion::DeltaDataChecker;
use deltalake::datafusion::prelude::{DataFrame, SessionContext};
use deltalake::delta_datafusion::{
DataFusionMixins, DeltaDataChecker, DeltaScanConfigBuilder, DeltaSessionConfig,
DeltaTableProvider,
};
use deltalake::errors::DeltaTableError;
use deltalake::kernel::{
scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, Transaction,
Expand Down Expand Up @@ -66,6 +74,7 @@ use crate::error::PythonError;
use crate::features::TableFeatures;
use crate::filesystem::FsConfig;
use crate::merge::PyMergeBuilder;
use crate::query::PyQueryBuilder;
use crate::schema::{schema_to_pyobject, Field};
use crate::utils::rt;

Expand Down Expand Up @@ -2069,6 +2078,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
)?)?;
m.add_class::<RawDeltaTable>()?;
m.add_class::<PyMergeBuilder>()?;
m.add_class::<PyQueryBuilder>()?;
m.add_class::<RawDeltaTableMetaData>()?;
m.add_class::<PyDeltaDataChecker>()?;
m.add_class::<PyTransaction>()?;
Expand Down
59 changes: 59 additions & 0 deletions python/src/query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use std::sync::Arc;

use deltalake::{
arrow::pyarrow::ToPyArrow,
datafusion::prelude::SessionContext,
delta_datafusion::{DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider},
};
use pyo3::prelude::*;

use crate::{error::PythonError, utils::rt, RawDeltaTable};

#[pyclass(module = "deltalake._internal")]
pub(crate) struct PyQueryBuilder {
_ctx: SessionContext,
}

#[pymethods]
impl PyQueryBuilder {
#[new]
pub fn new() -> Self {
let config = DeltaSessionConfig::default().into();
let _ctx = SessionContext::new_with_config(config);

PyQueryBuilder { _ctx }
}

pub fn register(&self, table_name: &str, delta_table: &RawDeltaTable) -> PyResult<()> {
let snapshot = delta_table._table.snapshot().map_err(PythonError::from)?;
let log_store = delta_table._table.log_store();

let scan_config = DeltaScanConfigBuilder::default()
.with_parquet_pushdown(false)
.build(snapshot)
.map_err(PythonError::from)?;

let provider = Arc::new(
DeltaTableProvider::try_new(snapshot.clone(), log_store, scan_config)
.map_err(PythonError::from)?,
);

self._ctx
.register_table(table_name, provider)
.map_err(PythonError::from)?;

Ok(())
}

pub fn execute(&self, py: Python, sql: &str) -> PyResult<PyObject> {
let batches = py.allow_threads(|| {
rt().block_on(async {
let df = self._ctx.sql(sql).await?;
df.collect().await
})
.map_err(PythonError::from)
})?;

batches.to_pyarrow(py)
}
}
54 changes: 54 additions & 0 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from deltalake._util import encode_partition_value
from deltalake.exceptions import DeltaProtocolError
from deltalake.query import QueryBuilder
from deltalake.table import ProtocolVersions
from deltalake.writer import write_deltalake

Expand Down Expand Up @@ -946,3 +947,56 @@ def test_is_deltatable_with_storage_opts():
"DELTA_DYNAMO_TABLE_NAME": "custom_table_name",
}
assert DeltaTable.is_deltatable(table_path, storage_options=storage_options)


def test_read_query_builder():
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
dt = DeltaTable(table_path)
expected = {
"value": ["4", "5", "6", "7"],
"year": ["2021", "2021", "2021", "2021"],
"month": ["4", "12", "12", "12"],
"day": ["5", "4", "20", "20"],
}
actual = pa.Table.from_batches(
QueryBuilder()
.register("tbl", dt)
.execute("SELECT * FROM tbl WHERE year >= 2021 ORDER BY value")
).to_pydict()
assert expected == actual


def test_read_query_builder_join_multiple_tables(tmp_path):
table_path = "../crates/test/tests/data/delta-0.8.0-date"
dt1 = DeltaTable(table_path)

write_deltalake(
tmp_path,
pa.table(
{
"date": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-12-31"],
"value": ["a", "b", "c", "d"],
}
),
)
dt2 = DeltaTable(tmp_path)

expected = {
"date": ["2021-01-01", "2021-01-02", "2021-01-03"],
"dayOfYear": [1, 2, 3],
"value": ["a", "b", "c"],
}
actual = pa.Table.from_batches(
QueryBuilder()
.register("tbl1", dt1)
.register("tbl2", dt2)
.execute(
"""
SELECT tbl2.date, tbl1.dayOfYear, tbl2.value
FROM tbl1
INNER JOIN tbl2 ON tbl1.date = tbl2.date
ORDER BY tbl1.date
"""
)
).to_pydict()
assert expected == actual

0 comments on commit 94dcbc1

Please sign in to comment.