From b3aa546ca164a6ab948b6cc05c9727f0c557f0c1 Mon Sep 17 00:00:00 2001 From: emcake <3726783+emcake@users.noreply.github.com> Date: Mon, 11 Dec 2023 17:43:11 +0000 Subject: [PATCH 1/6] merge using partition filters --- .../src/delta_datafusion/mod.rs | 32 ++ crates/deltalake-core/src/operations/merge.rs | 402 +++++++++++++++++- 2 files changed, 414 insertions(+), 20 deletions(-) diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 973d575904..4e69bb092c 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -65,6 +65,7 @@ use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use futures::TryStreamExt; use log::error; use object_store::ObjectMeta; use serde::{Deserialize, Serialize}; @@ -1013,6 +1014,37 @@ pub(crate) fn logical_expr_to_physical_expr( create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } +pub(crate) async fn execute_plan_to_batch( + state: &SessionState, + plan: Arc, +) -> DeltaResult { + let data = futures::future::try_join_all( + (0..plan.output_partitioning().partition_count()) + .into_iter() + .map(|p| { + let plan_copy = plan.clone(); + let task_context = state.task_ctx().clone(); + async move { + let batch_stream = plan_copy.execute(p, task_context)?; + + let schema = batch_stream.schema(); + + let batches = batch_stream.try_collect::>().await?; + + DataFusionResult::<_>::Ok(arrow::compute::concat_batches( + &schema, + batches.iter(), + )?) + } + }), + ) + .await?; + + let batch = arrow::compute::concat_batches(&plan.schema(), data.iter())?; + + Ok(batch) +} + /// Responsible for checking batches of data conform to table's invariants. #[derive(Clone)] pub struct DeltaDataChecker { diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index 433e9cda43..5f0ea410e5 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -50,12 +50,16 @@ use datafusion::{ }, prelude::{DataFrame, SessionContext}, }; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; +use datafusion_expr::expr::Placeholder; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, + BinaryExpr, Distinct, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + UserDefinedLogicalNode, UNNAMED_TABLE, }; use futures::future::BoxFuture; +use itertools::Itertools; use parquet::file::properties::WriterProperties; use serde::Serialize; use serde_json::Value; @@ -65,7 +69,9 @@ use super::transaction::{commit, PROTOCOL}; use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression}; use crate::delta_datafusion::logical::MetricObserver; use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec}; -use crate::delta_datafusion::{register_store, DeltaScanConfig, DeltaTableProvider}; +use crate::delta_datafusion::{ + execute_plan_to_batch, register_store, DeltaScanConfig, DeltaTableProvider, +}; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; use crate::operations::write::write_execution_plan; @@ -646,6 +652,243 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner { } } +/// Takes the predicate provided and does two things: +/// +/// 1. for any relations between a source column and a target column, if the target column is a +/// partition column, then replace source with a placeholder matching the name of the partition +/// columns +/// +/// 2. for any other relation with a source column, remove them. +/// +/// For example, for the predicate: +/// +/// `source.date = target.date and source.id = target.id and frob > 42` +/// +/// where `date` is a partition column, would result in the expr: +/// +/// `$date = target.date and frob > 42` +/// +/// This leaves us with a predicate that we can push into delta scan after expanding it out to +/// a conjunction between the disinct partitions in the source input. +/// +/// TODO: A futher improvement here might be for non-partition columns to be replaced with min/max +/// checks, so the above example could become: +/// +/// `$date = target.date and target.id between 12345 and 99999 and frob > 42` +fn generalize_filter( + predicate: Expr, + partition_columns: &Vec, + source_name: &TableReference, + target_name: &TableReference, + placeholders: &mut HashMap, +) -> Option { + fn references_table(expr: &Expr, table: &TableReference) -> Option { + match expr { + Expr::Alias(alias) => references_table(&alias.expr, table), + Expr::Column(col) => col.relation.as_ref().and_then(|rel| { + if rel == table { + Some(col.name.to_owned()) + } else { + None + } + }), + Expr::Negative(neg) => references_table(&*neg, table), + Expr::Cast(cast) => references_table(&cast.expr, table), + Expr::TryCast(try_cast) => references_table(&try_cast.expr, table), + Expr::ScalarFunction(func) => { + if func.args.len() == 1 { + references_table(&func.args[0], table) + } else { + None + } + } + Expr::ScalarUDF(udf) => { + if udf.args.len() == 1 { + references_table(&udf.args[0], table) + } else { + None + } + } + _ => None, + } + } + + match predicate { + Expr::BinaryExpr(binary) => { + if let Some(_) = references_table(&binary.right, source_name) { + if let Some(left_target) = references_table(&binary.left, target_name) { + if partition_columns.contains(&left_target) { + let placeholder_name = format!("{left_target}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + let replaced = Expr::BinaryExpr(BinaryExpr { + left: binary.left, + op: binary.op, + right: placeholder.into(), + }); + + placeholders.insert(placeholder_name, *binary.right); + + return Some(replaced); + } + } + return None; + } + if let Some(_) = references_table(&binary.left, source_name) { + if let Some(right_target) = references_table(&binary.right, target_name) { + if partition_columns.contains(&right_target) { + let placeholder_name = format!("{right_target}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + let replaced = Expr::BinaryExpr(BinaryExpr { + right: binary.right, + op: binary.op, + left: placeholder.into(), + }); + + placeholders.insert(placeholder_name, *binary.left); + + return Some(replaced); + } + } + return None; + } + + let left = generalize_filter( + *binary.left, + partition_columns, + source_name, + target_name, + placeholders, + ); + let right = generalize_filter( + *binary.right, + partition_columns, + source_name, + target_name, + placeholders, + ); + + match (left, right) { + (None, None) => None, + (None, Some(r)) => Some(r), + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Expr::BinaryExpr(BinaryExpr { + left: l.into(), + op: binary.op, + right: r.into(), + }) + .into(), + } + } + other => Some(other), + } +} + +fn replace_placeholders(expr: Expr, placeholders: &HashMap) -> Expr { + expr.transform(&|expr| match expr { + Expr::Placeholder(Placeholder { id, .. }) => { + let value = placeholders[&id].clone(); + // Replace the placeholder with the value + Ok(Transformed::Yes(Expr::Literal(value))) + } + _ => Ok(Transformed::No(expr)), + }) + .unwrap() +} + +async fn try_construct_early_filter( + join_predicate: Expr, + table_snapshot: &DeltaTableState, + session_state: &SessionState, + source: &LogicalPlan, + source_name: &TableReference<'_>, + target_name: &TableReference<'_>, +) -> Option { + println!("checkit"); + let table_metadata = table_snapshot.metadata()?; + + let partition_columns = &table_metadata.partition_columns; + + if partition_columns.is_empty() { + return None; + } + + let mut placeholders = HashMap::default(); + + println!("generalize {join_predicate}"); + + let filter = generalize_filter( + join_predicate, + partition_columns, + source_name, + target_name, + &mut placeholders, + )?; + + println!("{filter:?}"); + + if placeholders.is_empty() { + // if we haven't recognised any partition-based predicates in the join predicate, return our reduced filter + Some(filter) + } else { + // if we have some recognised partitions, then discover the distinct set of partitions in the source data and + // make a new filter, which expands out the placeholders for each distinct partition (and then OR these together) + let distinct_partitions = LogicalPlan::Distinct(Distinct { + input: LogicalPlan::Projection( + Projection::try_new( + placeholders + .into_iter() + .map(|(alias, expr)| expr.alias(alias)) + .collect_vec(), + source.clone().into(), + ) + .unwrap(), + ) + .into(), + }); + + let execution_plan = session_state + .create_physical_plan(&distinct_partitions) + .await + .unwrap(); + + let items = execute_plan_to_batch(session_state, execution_plan) + .await + .unwrap(); + + let placeholder_names = items + .schema() + .fields() + .iter() + .map(|f| f.name().to_owned()) + .collect_vec(); + + (0..items.num_rows()) + .into_iter() + .map(|i| { + let replacements = placeholder_names + .iter() + .map(|placeholder| { + let col = items.column_by_name(placeholder).unwrap(); + let value = ScalarValue::try_from_array(col, i).unwrap(); + (placeholder.to_owned(), value) + }) + .collect(); + replace_placeholders(filter.clone(), &replacements) + }) + .reduce(Expr::or) + .unwrap() + .into() + } +} + #[allow(clippy::too_many_arguments)] async fn execute( predicate: Expression, @@ -687,9 +930,12 @@ async fn execute( }; // This is only done to provide the source columns with a correct table reference. Just renaming the columns does not work - let source = - LogicalPlanBuilder::scan(source_name, provider_as_source(source.into_view()), None)? - .build()?; + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + )? + .build()?; let source = LogicalPlan::Extension(Extension { node: Arc::new(MetricObserver { @@ -698,17 +944,70 @@ async fn execute( }), }); - let source = DataFrame::new(state.clone(), source); - let source = source.with_column(SOURCE_COLUMN, lit(true))?; - let target_provider = Arc::new(DeltaTableProvider::try_new( snapshot.clone(), log_store.clone(), DeltaScanConfig::default(), )?); + let target_provider = provider_as_source(target_provider); - let target = LogicalPlanBuilder::scan(target_name, target_provider, None)?.build()?; + let target = LogicalPlanBuilder::scan(target_name.clone(), target_provider, None)?.build()?; + + let source_schema = source.schema(); + let target_schema = target.schema(); + let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; + let predicate = match predicate { + Expression::DataFusion(expr) => expr, + Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, + }; + + let state = state.with_query_planner(Arc::new(MergePlanner {})); + + let (target, files) = { + // Attempt to construct an early filter that we can apply to the Add action list and the delta scan. + // In the case where there are partition columns in the join predicate, we can scan the source table + // to get the distinct list of partitions affected and constrain the search to those. + + if !not_match_source_operations.is_empty() { + // It's only worth trying to create an early filter where there are no `when_not_matched_source` operators, since + // that implies a full scan + (target, snapshot.files().iter().collect_vec()) + } else { + if let Some(filter) = try_construct_early_filter( + predicate.clone(), + snapshot, + &state, + &source, + &source_name, + &target_name, + ) + .await + { + let file_filter = filter + .clone() + .transform(&|expr| match expr { + Expr::Column(c) => Ok(Transformed::Yes(Expr::Column(Column { + relation: None, // the file filter won't be looking at columns like `target.partition`, it'll just be `partition` + name: c.name, + }))), + expr => Ok(Transformed::No(expr)), + }) + .unwrap(); + let files = snapshot + .files_matching_predicate(&[file_filter])? + .collect_vec(); + + let new_target = LogicalPlan::Filter(Filter::try_new(filter, target.into())?); + (new_target, files) + } else { + (target, snapshot.files().iter().collect_vec()) + } + } + }; + + let source = DataFrame::new(state.clone(), source); + let source = source.with_column(SOURCE_COLUMN, lit(true))?; // TODO: This is here to prevent predicate pushdowns. In the future we can replace this node to allow pushdowns depending on which operations are being used. let target = LogicalPlan::Extension(Extension { @@ -720,14 +1019,6 @@ async fn execute( let target = DataFrame::new(state.clone(), target); let target = target.with_column(TARGET_COLUMN, lit(true))?; - let source_schema = source.schema(); - let target_schema = target.schema(); - let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; - let predicate = match predicate { - Expression::DataFusion(expr) => expr, - Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, - }; - let join = source.join(target, JoinType::Full, &[], &[], Some(predicate.clone()))?; let join_schema_df = join.schema().to_owned(); @@ -990,7 +1281,6 @@ async fn execute( let project = filtered.select(write_projection)?; let optimized = &project.into_optimized_plan()?; - let state = state.with_query_planner(Arc::new(MergePlanner {})); let write = state.create_physical_plan(optimized).await?; let err = || DeltaTableError::Generic("Unable to locate expected metric node".into()); @@ -1025,7 +1315,7 @@ async fn execute( let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); metrics.num_target_files_added = actions.len(); - for action in snapshot.files() { + for action in files { metrics.num_target_files_removed += 1; actions.push(Action::Remove(Remove { path: action.path.clone(), @@ -1493,7 +1783,7 @@ mod tests { #[tokio::test] async fn test_merge_partitions() { - /* Validate the join predicate works with partition columns */ + /* Validate the join predicate works with table partitions */ let schema = get_arrow_schema(&None); let table = setup_table(Some(vec!["modified"])).await; @@ -1581,6 +1871,78 @@ mod tests { assert_batches_sorted_eq!(&expected, &actual); } + #[tokio::test] + async fn test_merge_partitions_skipping() { + /* Validate the join predicate can be used for skipping partitions */ + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["id"])).await; + + let table = write_data(table, &schema).await; + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 4); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![999, 999, 999])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("value", col("source.value")) + .update("modified", col("source.modified")) + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + }) + .unwrap() + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert!(table.get_file_uris().count() >= 3); + assert_eq!(metrics.num_target_files_added, 3); + assert_eq!(metrics.num_target_files_removed, 2); + assert_eq!(metrics.num_target_rows_copied, 0); + assert_eq!(metrics.num_target_rows_updated, 2); + assert_eq!(metrics.num_target_rows_inserted, 1); + assert_eq!(metrics.num_target_rows_deleted, 0); + assert_eq!(metrics.num_output_rows, 3); + assert_eq!(metrics.num_source_rows, 3); + + let expected = vec![ + "+-------+------------+----+", + "| value | modified | id |", + "+-------+------------+----+", + "| 1 | 2021-02-01 | A |", + "| 100 | 2021-02-02 | D |", + "| 999 | 2023-07-04 | B |", + "| 999 | 2023-07-04 | C |", + "| 999 | 2023-07-04 | X |", + "+-------+------------+----+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_merge_delete_matched() { // Validate behaviours of match delete From 2e423640f845d70ec3ba6558b2e7a2148e640ce9 Mon Sep 17 00:00:00 2001 From: emcake <3726783+emcake@users.noreply.github.com> Date: Mon, 11 Dec 2023 22:26:59 +0000 Subject: [PATCH 2/6] clippy lints --- .../src/delta_datafusion/mod.rs | 36 +++++------ crates/deltalake-core/src/operations/merge.rs | 63 +++++++++---------- 2 files changed, 45 insertions(+), 54 deletions(-) diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 4e69bb092c..979091b824 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -1018,27 +1018,21 @@ pub(crate) async fn execute_plan_to_batch( state: &SessionState, plan: Arc, ) -> DeltaResult { - let data = futures::future::try_join_all( - (0..plan.output_partitioning().partition_count()) - .into_iter() - .map(|p| { - let plan_copy = plan.clone(); - let task_context = state.task_ctx().clone(); - async move { - let batch_stream = plan_copy.execute(p, task_context)?; - - let schema = batch_stream.schema(); - - let batches = batch_stream.try_collect::>().await?; - - DataFusionResult::<_>::Ok(arrow::compute::concat_batches( - &schema, - batches.iter(), - )?) - } - }), - ) - .await?; + let data = + futures::future::try_join_all((0..plan.output_partitioning().partition_count()).map(|p| { + let plan_copy = plan.clone(); + let task_context = state.task_ctx().clone(); + async move { + let batch_stream = plan_copy.execute(p, task_context)?; + + let schema = batch_stream.schema(); + + let batches = batch_stream.try_collect::>().await?; + + DataFusionResult::<_>::Ok(arrow::compute::concat_batches(&schema, batches.iter())?) + } + })) + .await?; let batch = arrow::compute::concat_batches(&plan.schema(), data.iter())?; diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index 5f0ea410e5..c6d226bbd6 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -692,7 +692,7 @@ fn generalize_filter( None } }), - Expr::Negative(neg) => references_table(&*neg, table), + Expr::Negative(neg) => references_table(neg, table), Expr::Cast(cast) => references_table(&cast.expr, table), Expr::TryCast(try_cast) => references_table(&try_cast.expr, table), Expr::ScalarFunction(func) => { @@ -715,7 +715,7 @@ fn generalize_filter( match predicate { Expr::BinaryExpr(binary) => { - if let Some(_) = references_table(&binary.right, source_name) { + if references_table(&binary.right, source_name).is_some() { if let Some(left_target) = references_table(&binary.left, target_name) { if partition_columns.contains(&left_target) { let placeholder_name = format!("{left_target}_{}", placeholders.len()); @@ -737,7 +737,7 @@ fn generalize_filter( } return None; } - if let Some(_) = references_table(&binary.left, source_name) { + if references_table(&binary.left, source_name).is_some() { if let Some(right_target) = references_table(&binary.right, target_name) { if partition_columns.contains(&right_target) { let placeholder_name = format!("{right_target}_{}", placeholders.len()); @@ -871,7 +871,6 @@ async fn try_construct_early_filter( .collect_vec(); (0..items.num_rows()) - .into_iter() .map(|i| { let replacements = placeholder_names .iter() @@ -973,36 +972,34 @@ async fn execute( // It's only worth trying to create an early filter where there are no `when_not_matched_source` operators, since // that implies a full scan (target, snapshot.files().iter().collect_vec()) + } else if let Some(filter) = try_construct_early_filter( + predicate.clone(), + snapshot, + &state, + &source, + &source_name, + &target_name, + ) + .await + { + let file_filter = filter + .clone() + .transform(&|expr| match expr { + Expr::Column(c) => Ok(Transformed::Yes(Expr::Column(Column { + relation: None, // the file filter won't be looking at columns like `target.partition`, it'll just be `partition` + name: c.name, + }))), + expr => Ok(Transformed::No(expr)), + }) + .unwrap(); + let files = snapshot + .files_matching_predicate(&[file_filter])? + .collect_vec(); + + let new_target = LogicalPlan::Filter(Filter::try_new(filter, target.into())?); + (new_target, files) } else { - if let Some(filter) = try_construct_early_filter( - predicate.clone(), - snapshot, - &state, - &source, - &source_name, - &target_name, - ) - .await - { - let file_filter = filter - .clone() - .transform(&|expr| match expr { - Expr::Column(c) => Ok(Transformed::Yes(Expr::Column(Column { - relation: None, // the file filter won't be looking at columns like `target.partition`, it'll just be `partition` - name: c.name, - }))), - expr => Ok(Transformed::No(expr)), - }) - .unwrap(); - let files = snapshot - .files_matching_predicate(&[file_filter])? - .collect_vec(); - - let new_target = LogicalPlan::Filter(Filter::try_new(filter, target.into())?); - (new_target, files) - } else { - (target, snapshot.files().iter().collect_vec()) - } + (target, snapshot.files().iter().collect_vec()) } }; From 5e49b0ce0f74c8b30a7b3b1b91da3634597ed94c Mon Sep 17 00:00:00 2001 From: emcake <3726783+emcake@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:15:52 +0000 Subject: [PATCH 3/6] add tests for generalize_filter and try_construct_early_filter --- crates/deltalake-core/src/operations/merge.rs | 213 +++++++++++++++++- 1 file changed, 208 insertions(+), 5 deletions(-) diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index c6d226bbd6..168bd31cca 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -811,7 +811,6 @@ async fn try_construct_early_filter( source_name: &TableReference<'_>, target_name: &TableReference<'_>, ) -> Option { - println!("checkit"); let table_metadata = table_snapshot.metadata()?; let partition_columns = &table_metadata.partition_columns; @@ -822,8 +821,6 @@ async fn try_construct_early_filter( let mut placeholders = HashMap::default(); - println!("generalize {join_predicate}"); - let filter = generalize_filter( join_predicate, partition_columns, @@ -832,8 +829,6 @@ async fn try_construct_early_filter( &mut placeholders, )?; - println!("{filter:?}"); - if placeholders.is_empty() { // if we haven't recognised any partition-based predicates in the join predicate, return our reduced filter Some(filter) @@ -1436,6 +1431,8 @@ impl std::future::IntoFuture for MergeBuilder { #[cfg(test)] mod tests { + use crate::operations::merge::generalize_filter; + use crate::operations::merge::try_construct_early_filter; use crate::operations::DeltaOps; use crate::protocol::*; use crate::writer::test_utils::datafusion::get_data; @@ -1447,11 +1444,20 @@ mod tests { use arrow::datatypes::Schema as ArrowSchema; use arrow::record_batch::RecordBatch; use datafusion::assert_batches_sorted_eq; + use datafusion::datasource::provider_as_source; use datafusion::prelude::DataFrame; use datafusion::prelude::SessionContext; + use datafusion_common::Column; + use datafusion_common::ScalarValue; + use datafusion_common::TableReference; use datafusion_expr::col; + use datafusion_expr::expr::Placeholder; use datafusion_expr::lit; + use datafusion_expr::Expr; + use datafusion_expr::LogicalPlanBuilder; use serde_json::json; + use std::collections::HashMap; + use std::ops::Neg; use std::sync::Arc; use super::MergeMetrics; @@ -2278,4 +2284,201 @@ mod tests { let actual = get_data(&table).await; assert_batches_sorted_eq!(&expected, &actual); } + + #[tokio::test] + async fn test_generalize_filter_with_partitions() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_generalize_filter_with_partitions_captures_expression() { + // Check that when generalizing the filter, the placeholder map captures the expression needed to make the statement the same + // when the distinct values are substitiuted in + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .neg() + .eq(col(Column::new(target.clone().into(), "id"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + + assert_eq!(placeholders.len(), 1); + + let placeholder_expr = &placeholders["id_0"]; + + let expected_placeholder = col(Column::new(source.clone().into(), "id")).neg(); + + assert_eq!(placeholder_expr, &expected_placeholder); + } + + #[tokio::test] + async fn test_generalize_filter_keeps_static_target_references() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_generalize_filter_removes_source_references() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(source.clone().into(), "id")).eq(lit("C"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_partitions_expands() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["id"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_file_uris().count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })); + + let pred = try_construct_early_filter( + join_predicate, + &table.state, + &ctx.state(), + &source, + &source_name, + &target_name, + ) + .await; + + assert!(pred.is_some()); + + let expected_pred = ["C", "X", "B"] + .into_iter() + .map(|id| { + lit(ScalarValue::Utf8(id.to_owned().into())).eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })) + }) + .reduce(Expr::or) + .unwrap(); + + assert_eq!(pred.unwrap(), expected_pred); + } } From 5e3f8fa11ea90321b5e3505623bcb04881eee804 Mon Sep 17 00:00:00 2001 From: emcake <3726783+emcake@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:27:09 +0000 Subject: [PATCH 4/6] post-merge fix --- crates/deltalake-core/src/operations/merge.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index 663596895b..3ac79ba509 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -70,7 +70,8 @@ use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression} use crate::delta_datafusion::logical::MetricObserver; use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec}; use crate::delta_datafusion::{ - register_store, DeltaColumn, DeltaScanConfig, DeltaSessionConfig, DeltaTableProvider, + execute_plan_to_batch, register_store, DeltaColumn, DeltaScanConfig, DeltaSessionConfig, + DeltaTableProvider, }; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; From 922a4a940edf0f8b3d8b168a875db4ae19015ff4 Mon Sep 17 00:00:00 2001 From: emcake <3726783+emcake@users.noreply.github.com> Date: Tue, 19 Dec 2023 12:31:47 +0000 Subject: [PATCH 5/6] better error handling on filter construction --- crates/deltalake-core/src/operations/merge.rs | 120 +++++++++--------- 1 file changed, 63 insertions(+), 57 deletions(-) diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index 3ac79ba509..37011c278e 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -815,76 +815,81 @@ async fn try_construct_early_filter( source: &LogicalPlan, source_name: &TableReference<'_>, target_name: &TableReference<'_>, -) -> Option { - let table_metadata = table_snapshot.metadata()?; +) -> DeltaResult> { + let table_metadata = table_snapshot.metadata(); + + if table_metadata.is_none() { + return Ok(None); + } + + let table_metadata = table_metadata.unwrap(); let partition_columns = &table_metadata.partition_columns; if partition_columns.is_empty() { - return None; + return Ok(None); } let mut placeholders = HashMap::default(); - let filter = generalize_filter( + match generalize_filter( join_predicate, partition_columns, source_name, target_name, &mut placeholders, - )?; - - if placeholders.is_empty() { - // if we haven't recognised any partition-based predicates in the join predicate, return our reduced filter - Some(filter) - } else { - // if we have some recognised partitions, then discover the distinct set of partitions in the source data and - // make a new filter, which expands out the placeholders for each distinct partition (and then OR these together) - let distinct_partitions = LogicalPlan::Distinct(Distinct { - input: LogicalPlan::Projection( - Projection::try_new( - placeholders - .into_iter() - .map(|(alias, expr)| expr.alias(alias)) - .collect_vec(), - source.clone().into(), - ) - .unwrap(), - ) - .into(), - }); - - let execution_plan = session_state - .create_physical_plan(&distinct_partitions) - .await - .unwrap(); - - let items = execute_plan_to_batch(session_state, execution_plan) - .await - .unwrap(); - - let placeholder_names = items - .schema() - .fields() - .iter() - .map(|f| f.name().to_owned()) - .collect_vec(); - - (0..items.num_rows()) - .map(|i| { - let replacements = placeholder_names + ) { + None => Ok(None), + Some(filter) => { + if placeholders.is_empty() { + // if we haven't recognised any partition-based predicates in the join predicate, return our reduced filter + Ok(Some(filter)) + } else { + // if we have some recognised partitions, then discover the distinct set of partitions in the source data and + // make a new filter, which expands out the placeholders for each distinct partition (and then OR these together) + let distinct_partitions = LogicalPlan::Distinct(Distinct { + input: LogicalPlan::Projection(Projection::try_new( + placeholders + .into_iter() + .map(|(alias, expr)| expr.alias(alias)) + .collect_vec(), + source.clone().into(), + )?) + .into(), + }); + + let execution_plan = session_state + .create_physical_plan(&distinct_partitions) + .await?; + + let items = execute_plan_to_batch(session_state, execution_plan).await?; + + let placeholder_names = items + .schema() + .fields() .iter() - .map(|placeholder| { - let col = items.column_by_name(placeholder).unwrap(); - let value = ScalarValue::try_from_array(col, i).unwrap(); - (placeholder.to_owned(), value) + .map(|f| f.name().to_owned()) + .collect_vec(); + + let expr = (0..items.num_rows()) + .map(|i| { + let replacements = placeholder_names + .iter() + .map(|placeholder| { + let col = items.column_by_name(placeholder).unwrap(); + let value = ScalarValue::try_from_array(col, i)?; + DeltaResult::Ok((placeholder.to_owned(), value)) + }) + .try_collect()?; + Ok(replace_placeholders(filter.clone(), &replacements)) }) - .collect(); - replace_placeholders(filter.clone(), &replacements) - }) - .reduce(Expr::or) - .unwrap() - .into() + .collect::>>()? + .into_iter() + .reduce(Expr::or); + + Ok(expr) + } + } } } @@ -980,7 +985,7 @@ async fn execute( &source_name, &target_name, ) - .await + .await? { let file_filter = filter .clone() @@ -2559,7 +2564,8 @@ mod tests { &source_name, &target_name, ) - .await; + .await + .unwrap(); assert!(pred.is_some()); From ba10c9a0d0fe8845f7ec2a680830a4f362a86571 Mon Sep 17 00:00:00 2001 From: emcake <3726783+emcake@users.noreply.github.com> Date: Tue, 19 Dec 2023 22:54:37 +0000 Subject: [PATCH 6/6] stability when comparing predicate construction --- crates/deltalake-core/src/operations/merge.rs | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index 37011c278e..0f0da1c21f 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -1474,6 +1474,7 @@ mod tests { use datafusion_expr::lit; use datafusion_expr::Expr; use datafusion_expr::LogicalPlanBuilder; + use datafusion_expr::Operator; use serde_json::json; use std::collections::HashMap; use std::ops::Neg; @@ -2569,17 +2570,43 @@ mod tests { assert!(pred.is_some()); - let expected_pred = ["C", "X", "B"] - .into_iter() - .map(|id| { - lit(ScalarValue::Utf8(id.to_owned().into())).eq(col(Column { - relation: Some(target_name.clone()), - name: "id".to_owned(), - })) - }) - .reduce(Expr::or) - .unwrap(); + let split_pred = { + fn split(expr: Expr, parts: &mut Vec<(String, String)>) { + match expr { + Expr::BinaryExpr(ex) if ex.op == Operator::Or => { + split(*ex.left, parts); + split(*ex.right, parts); + } + Expr::BinaryExpr(ex) if ex.op == Operator::Eq => { + let col = match *ex.right { + Expr::Column(col) => col.name, + ex => panic!("expected column in pred, got {ex}!"), + }; + + let value = match *ex.left { + Expr::Literal(ScalarValue::Utf8(Some(value))) => value, + ex => panic!("expected value in predicate, got {ex}!"), + }; + + parts.push((col, value)) + } + + expr => panic!("expected either = or OR, got {expr}"), + } + } + + let mut parts = vec![]; + split(pred.unwrap(), &mut parts); + parts.sort(); + parts + }; + + let expected_pred_parts = [ + ("id".to_owned(), "B".to_owned()), + ("id".to_owned(), "C".to_owned()), + ("id".to_owned(), "X".to_owned()), + ]; - assert_eq!(pred.unwrap(), expected_pred); + assert_eq!(split_pred, expected_pred_parts); } }