diff --git a/README.md b/README.md index 6121110fb0..1913b6305e 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ of features outlined in the Delta [protocol][protocol] is also [tracked](#protoc | Version 2 | Column Invariants | ![done] | | Version 3 | Enforce `delta.checkpoint.writeStatsAsJson` | [![open]][writer-rs] | | Version 3 | Enforce `delta.checkpoint.writeStatsAsStruct` | [![open]][writer-rs] | -| Version 3 | CHECK constraints | [![open]][writer-rs] | +| Version 3 | CHECK constraints | [![semi-done]][check-constraints] | | Version 4 | Change Data Feed | | | Version 4 | Generated Columns | | | Version 5 | Column Mapping | | @@ -185,5 +185,6 @@ of features outlined in the Delta [protocol][protocol] is also [tracked](#protoc [merge-py]: https://github.com/delta-io/delta-rs/issues/1357 [merge-rs]: https://github.com/delta-io/delta-rs/issues/850 [writer-rs]: https://github.com/delta-io/delta-rs/issues/851 +[check-constraints]: https://github.com/delta-io/delta-rs/issues/1881 [onelake-rs]: https://github.com/delta-io/delta-rs/issues/1418 [protocol]: https://github.com/delta-io/delta/blob/master/PROTOCOL.md diff --git a/crates/deltalake-core/src/delta_datafusion/expr.rs b/crates/deltalake-core/src/delta_datafusion/expr.rs index f9275832a1..49cdae4387 100644 --- a/crates/deltalake-core/src/delta_datafusion/expr.rs +++ b/crates/deltalake-core/src/delta_datafusion/expr.rs @@ -347,9 +347,10 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> { mod test { use arrow_schema::DataType as ArrowDataType; use datafusion::prelude::SessionContext; - use datafusion_common::{DFSchema, ScalarValue}; + use datafusion_common::{Column, DFSchema, ScalarValue}; use datafusion_expr::{col, decode, lit, substring, Cast, Expr, ExprSchemable}; + use crate::delta_datafusion::DeltaSessionContext; use crate::kernel::{DataType, PrimitiveType, StructField, StructType}; use crate::{DeltaOps, DeltaTable}; @@ -388,6 +389,11 @@ mod test { DataType::Primitive(PrimitiveType::Integer), true, ), + StructField::new( + "Value3".to_string(), + DataType::Primitive(PrimitiveType::Integer), + true, + ), StructField::new( "modified".to_string(), DataType::Primitive(PrimitiveType::String), @@ -442,7 +448,10 @@ mod test { }), "arrow_cast(1, 'Int32')".to_string() ), - simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()), + simple!( + Expr::Column(Column::from_qualified_name_ignore_case("Value3")).eq(lit(3_i64)), + "Value3 = 3".to_string() + ), simple!(col("active").is_true(), "active IS TRUE".to_string()), simple!(col("active"), "active".to_string()), simple!(col("active").eq(lit(true)), "active = true".to_string()), @@ -536,7 +545,7 @@ mod test { ), ]; - let session = SessionContext::new(); + let session: SessionContext = DeltaSessionContext::default().into(); for test in tests { let actual = fmt_expr_to_sql(&test.expr).unwrap(); diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 41a38b1d0e..4314945680 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -68,6 +68,7 @@ use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_sql::planner::ParserOptions; +use futures::TryStreamExt; use itertools::Itertools; use log::error; @@ -1034,6 +1035,31 @@ pub(crate) fn logical_expr_to_physical_expr( create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } +pub(crate) async fn execute_plan_to_batch( + state: &SessionState, + plan: Arc, +) -> DeltaResult { + let data = + futures::future::try_join_all((0..plan.output_partitioning().partition_count()).map(|p| { + let plan_copy = plan.clone(); + let task_context = state.task_ctx().clone(); + async move { + let batch_stream = plan_copy.execute(p, task_context)?; + + let schema = batch_stream.schema(); + + let batches = batch_stream.try_collect::>().await?; + + DataFusionResult::<_>::Ok(arrow::compute::concat_batches(&schema, batches.iter())?) + } + })) + .await?; + + let batch = arrow::compute::concat_batches(&plan.schema(), data.iter())?; + + Ok(batch) +} + /// Responsible for checking batches of data conform to table's invariants. #[derive(Clone)] pub struct DeltaDataChecker { @@ -1048,7 +1074,7 @@ impl DeltaDataChecker { Self { invariants, constraints: vec![], - ctx: SessionContext::new(), + ctx: DeltaSessionContext::default().into(), } } @@ -1057,10 +1083,16 @@ impl DeltaDataChecker { Self { constraints, invariants: vec![], - ctx: SessionContext::new(), + ctx: DeltaSessionContext::default().into(), } } + /// Specify the Datafusion context + pub fn with_session_context(mut self, context: SessionContext) -> Self { + self.ctx = context; + self + } + /// Create a new DeltaDataChecker pub fn new(snapshot: &DeltaTableState) -> Self { let metadata = snapshot.metadata(); @@ -1074,7 +1106,7 @@ impl DeltaDataChecker { Self { invariants, constraints, - ctx: SessionContext::new(), + ctx: DeltaSessionContext::default().into(), } } diff --git a/crates/deltalake-core/src/kernel/actions/mod.rs b/crates/deltalake-core/src/kernel/actions/mod.rs index 637d520c41..97d32943a8 100644 --- a/crates/deltalake-core/src/kernel/actions/mod.rs +++ b/crates/deltalake-core/src/kernel/actions/mod.rs @@ -8,7 +8,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; pub(crate) mod schemas; -mod serde_path; +pub(crate) mod serde_path; pub(crate) mod types; pub use types::*; diff --git a/crates/deltalake-core/src/kernel/actions/serde_path.rs b/crates/deltalake-core/src/kernel/actions/serde_path.rs index 9868523e81..ae647fa54c 100644 --- a/crates/deltalake-core/src/kernel/actions/serde_path.rs +++ b/crates/deltalake-core/src/kernel/actions/serde_path.rs @@ -54,7 +54,7 @@ fn encode_path(path: &str) -> String { percent_encode(path.as_bytes(), INVALID).to_string() } -fn decode_path(path: &str) -> Result { +pub fn decode_path(path: &str) -> Result { Ok(percent_decode_str(path).decode_utf8()?.to_string()) } diff --git a/crates/deltalake-core/src/operations/cast.rs b/crates/deltalake-core/src/operations/cast.rs index d6f712ec70..e697c06d54 100644 --- a/crates/deltalake-core/src/operations/cast.rs +++ b/crates/deltalake-core/src/operations/cast.rs @@ -17,6 +17,7 @@ fn cast_record_batch_columns( .iter() .map(|f| { let col = batch.column_by_name(f.name()).unwrap(); + if let (DataType::Struct(_), DataType::Struct(child_fields)) = (col.data_type(), f.data_type()) { @@ -28,7 +29,7 @@ fn cast_record_batch_columns( child_columns.clone(), None, )) as ArrayRef) - } else if !col.data_type().equals_datatype(f.data_type()) { + } else if is_cast_required(col.data_type(), f.data_type()) { cast_with_options(col, f.data_type(), cast_options) } else { Ok(col.clone()) @@ -37,6 +38,16 @@ fn cast_record_batch_columns( .collect::, _>>() } +fn is_cast_required(a: &DataType, b: &DataType) -> bool { + match (a, b) { + (DataType::List(a_item), DataType::List(b_item)) => { + // If list item name is not the default('item') the list must be casted + !a.equals_datatype(b) || a_item.name() != b_item.name() + } + (_, _) => !a.equals_datatype(b), + } +} + /// Cast recordbatch to a new target_schema, by casting each column array pub fn cast_record_batch( batch: &RecordBatch, @@ -51,3 +62,80 @@ pub fn cast_record_batch( let columns = cast_record_batch_columns(batch, target_schema.fields(), &cast_options)?; Ok(RecordBatch::try_new(target_schema, columns)?) } + +#[cfg(test)] +mod tests { + use crate::operations::cast::{cast_record_batch, is_cast_required}; + use arrow::array::ArrayData; + use arrow_array::{Array, ArrayRef, ListArray, RecordBatch}; + use arrow_buffer::Buffer; + use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; + use std::sync::Arc; + + #[test] + fn test_cast_record_batch_with_list_non_default_item() { + let array = Arc::new(make_list_array()) as ArrayRef; + let source_schema = Schema::new(vec![Field::new( + "list_column", + array.data_type().clone(), + false, + )]); + let record_batch = RecordBatch::try_new(Arc::new(source_schema), vec![array]).unwrap(); + + let fields = Fields::from(vec![Field::new_list( + "list_column", + Field::new("item", DataType::Int8, false), + false, + )]); + let target_schema = Arc::new(Schema::new(fields)) as SchemaRef; + + let result = cast_record_batch(&record_batch, target_schema, false); + + let schema = result.unwrap().schema(); + let field = schema.column_with_name("list_column").unwrap().1; + if let DataType::List(list_item) = field.data_type() { + assert_eq!(list_item.name(), "item"); + } else { + panic!("Not a list"); + } + } + + fn make_list_array() -> ListArray { + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); + + let list_data_type = DataType::List(Arc::new(Field::new("element", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + ListArray::from(list_data) + } + + #[test] + fn test_is_cast_required_with_list() { + let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false))); + let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false))); + + assert!(!is_cast_required(&field1, &field2)); + } + + #[test] + fn test_is_cast_required_with_list_non_default_item() { + let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false))); + let field2 = DataType::List(FieldRef::from(Field::new( + "element", + DataType::Int32, + false, + ))); + + assert!(is_cast_required(&field1, &field2)); + } +} diff --git a/crates/deltalake-core/src/operations/constraints.rs b/crates/deltalake-core/src/operations/constraints.rs index 889e668b1a..ed5888bd13 100644 --- a/crates/deltalake-core/src/operations/constraints.rs +++ b/crates/deltalake-core/src/operations/constraints.rs @@ -8,11 +8,15 @@ use datafusion::execution::context::SessionState; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; +use datafusion_common::ToDFSchema; use futures::future::BoxFuture; use futures::StreamExt; use serde_json::json; -use crate::delta_datafusion::{register_store, DeltaDataChecker, DeltaScanBuilder}; +use crate::delta_datafusion::expr::fmt_expr_to_sql; +use crate::delta_datafusion::{ + register_store, DeltaDataChecker, DeltaScanBuilder, DeltaSessionContext, +}; use crate::kernel::{Action, CommitInfo, IsolationLevel, Metadata, Protocol}; use crate::logstore::LogStoreRef; use crate::operations::datafusion_utils::Expression; @@ -23,6 +27,8 @@ use crate::table::Constraint; use crate::DeltaTable; use crate::{DeltaResult, DeltaTableError}; +use super::datafusion_utils::into_expr; + /// Build a constraint to add to a table pub struct ConstraintBuilder { snapshot: DeltaTableState, @@ -47,10 +53,10 @@ impl ConstraintBuilder { /// Specify the constraint to be added pub fn with_constraint, E: Into>( mut self, - column: S, + name: S, expression: E, ) -> Self { - self.name = Some(column.into()); + self.name = Some(name.into()); self.expr = Some(expression.into()); self } @@ -75,15 +81,10 @@ impl std::future::IntoFuture for ConstraintBuilder { Some(v) => v, None => return Err(DeltaTableError::Generic("No name provided".to_string())), }; - let expr = match this.expr { - Some(Expression::String(s)) => s, - Some(Expression::DataFusion(e)) => e.to_string(), - None => { - return Err(DeltaTableError::Generic( - "No expression provided".to_string(), - )) - } - }; + + let expr = this + .expr + .ok_or_else(|| DeltaTableError::Generic("No Expresion provided".to_string()))?; let mut metadata = this .snapshot @@ -94,23 +95,29 @@ impl std::future::IntoFuture for ConstraintBuilder { if metadata.configuration.contains_key(&configuration_key) { return Err(DeltaTableError::Generic(format!( - "Constraint with name: {} already exists, expr: {}", - name, expr + "Constraint with name: {} already exists", + name ))); } let state = this.state.unwrap_or_else(|| { - let session = SessionContext::new(); + let session: SessionContext = DeltaSessionContext::default().into(); register_store(this.log_store.clone(), session.runtime_env()); session.state() }); - // Checker built here with the one time constraint to check. - let checker = DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr)]); let scan = DeltaScanBuilder::new(&this.snapshot, this.log_store.clone(), &state) .build() .await?; + let schema = scan.schema().to_dfschema()?; + let expr = into_expr(expr, &schema, &state)?; + let expr_str = fmt_expr_to_sql(&expr)?; + + // Checker built here with the one time constraint to check. + let checker = + DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr_str)]); + let plan: Arc = Arc::new(scan); let mut tasks = vec![]; for p in 0..plan.output_partitioning().partition_count() { @@ -140,9 +147,10 @@ impl std::future::IntoFuture for ConstraintBuilder { // We have validated the table passes it's constraints, now to add the constraint to // the table. - metadata - .configuration - .insert(format!("delta.constraints.{}", name), Some(expr.clone())); + metadata.configuration.insert( + format!("delta.constraints.{}", name), + Some(expr_str.clone()), + ); let old_protocol = this.snapshot.protocol(); let protocol = Protocol { @@ -162,12 +170,12 @@ impl std::future::IntoFuture for ConstraintBuilder { let operational_parameters = HashMap::from_iter([ ("name".to_string(), json!(&name)), - ("expr".to_string(), json!(&expr)), + ("expr".to_string(), json!(&expr_str)), ]); let operations = DeltaOperation::AddConstraint { name: name.clone(), - expr: expr.clone(), + expr: expr_str.clone(), }; let commit_info = CommitInfo { @@ -208,11 +216,37 @@ mod tests { use std::sync::Arc; use arrow_array::{Array, Int32Array, RecordBatch, StringArray}; + use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema}; + use datafusion_expr::{col, lit}; use crate::writer::test_utils::{create_bare_table, get_arrow_schema, get_record_batch}; - use crate::{DeltaOps, DeltaResult}; + use crate::{DeltaOps, DeltaResult, DeltaTable}; + + fn get_constraint(table: &DeltaTable, name: &str) -> String { + table + .metadata() + .unwrap() + .configuration + .get(name) + .unwrap() + .clone() + .unwrap() + } + + async fn get_constraint_op_params(table: &mut DeltaTable) -> String { + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[commit_info.len() - 1]; + last_commit + .operation_parameters + .as_ref() + .unwrap() + .get("expr") + .unwrap() + .as_str() + .unwrap() + .to_owned() + } - #[cfg(feature = "datafusion")] #[tokio::test] async fn add_constraint_with_invalid_data() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -225,12 +259,10 @@ mod tests { .add_constraint() .with_constraint("id", "value > 5") .await; - dbg!(&constraint); assert!(constraint.is_err()); Ok(()) } - #[cfg(feature = "datafusion")] #[tokio::test] async fn add_valid_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -239,18 +271,89 @@ mod tests { .await?; let table = DeltaOps(write); - let constraint = table + let mut table = table .add_constraint() - .with_constraint("id", "value < 1000") - .await; - dbg!(&constraint); - assert!(constraint.is_ok()); - let version = constraint?.version(); + .with_constraint("id", "value < 1000") + .await?; + let version = table.version(); + assert_eq!(version, 1); + + let expected_expr = "value < 1000"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + assert_eq!( + get_constraint(&table, "delta.constraints.id"), + expected_expr + ); + Ok(()) + } + + #[tokio::test] + async fn add_constraint_datafusion() -> DeltaResult<()> { + // Add constraint by providing a datafusion expression. + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let mut table = table + .add_constraint() + .with_constraint("valid_values", col("value").lt(lit(1000))) + .await?; + let version = table.version(); assert_eq!(version, 1); + + let expected_expr = "value < 1000"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + assert_eq!( + get_constraint(&table, "delta.constraints.valid_values"), + expected_expr + ); + + Ok(()) + } + + #[tokio::test] + async fn test_constraint_case_sensitive() -> DeltaResult<()> { + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("Id", ArrowDataType::Utf8, true), + Field::new("vAlue", ArrowDataType::Int32, true), + Field::new("mOdifieD", ArrowDataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&arrow_schema.clone()), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + + let table = DeltaOps::new_in_memory().write(vec![batch]).await.unwrap(); + + let mut table = DeltaOps(table) + .add_constraint() + .with_constraint("valid_values", "vAlue < 1000") + .await?; + let version = table.version(); + assert_eq!(version, 1); + + let expected_expr = "vAlue < 1000"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + assert_eq!( + get_constraint(&table, "delta.constraints.valid_values"), + expected_expr + ); + Ok(()) } - #[cfg(feature = "datafusion")] #[tokio::test] async fn add_conflicting_named_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -269,12 +372,10 @@ mod tests { .add_constraint() .with_constraint("id", "value < 10") .await; - dbg!(&second_constraint); assert!(second_constraint.is_err()); Ok(()) } - #[cfg(feature = "datafusion")] #[tokio::test] async fn write_data_that_violates_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -294,7 +395,6 @@ mod tests { ]; let batch = RecordBatch::try_new(get_arrow_schema(&None), invalid_values)?; let err = table.write(vec![batch]).await; - dbg!(&err); assert!(err.is_err()); Ok(()) } diff --git a/crates/deltalake-core/src/operations/merge/mod.rs b/crates/deltalake-core/src/operations/merge/mod.rs index 08af4d65ef..7cb752dc21 100644 --- a/crates/deltalake-core/src/operations/merge/mod.rs +++ b/crates/deltalake-core/src/operations/merge/mod.rs @@ -51,12 +51,16 @@ use datafusion::{ }, prelude::{DataFrame, SessionContext}, }; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; +use datafusion_expr::expr::Placeholder; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, + BinaryExpr, Distinct, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + UserDefinedLogicalNode, UNNAMED_TABLE, }; use futures::future::BoxFuture; +use itertools::Itertools; use parquet::file::properties::WriterProperties; use serde::Serialize; use serde_json::Value; @@ -69,7 +73,8 @@ use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression} use crate::delta_datafusion::logical::MetricObserver; use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec}; use crate::delta_datafusion::{ - register_store, DeltaColumn, DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider, + execute_plan_to_batch, register_store, DeltaColumn, DeltaScanConfigBuilder, DeltaSessionConfig, + DeltaTableProvider, }; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; @@ -668,6 +673,242 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner { } } +/// Takes the predicate provided and does two things: +/// +/// 1. for any relations between a source column and a target column, if the target column is a +/// partition column, then replace source with a placeholder matching the name of the partition +/// columns +/// +/// 2. for any other relation with a source column, remove them. +/// +/// For example, for the predicate: +/// +/// `source.date = target.date and source.id = target.id and frob > 42` +/// +/// where `date` is a partition column, would result in the expr: +/// +/// `$date = target.date and frob > 42` +/// +/// This leaves us with a predicate that we can push into delta scan after expanding it out to +/// a conjunction between the disinct partitions in the source input. +/// +/// TODO: A futher improvement here might be for non-partition columns to be replaced with min/max +/// checks, so the above example could become: +/// +/// `$date = target.date and target.id between 12345 and 99999 and frob > 42` +fn generalize_filter( + predicate: Expr, + partition_columns: &Vec, + source_name: &TableReference, + target_name: &TableReference, + placeholders: &mut HashMap, +) -> Option { + fn references_table(expr: &Expr, table: &TableReference) -> Option { + match expr { + Expr::Alias(alias) => references_table(&alias.expr, table), + Expr::Column(col) => col.relation.as_ref().and_then(|rel| { + if rel == table { + Some(col.name.to_owned()) + } else { + None + } + }), + Expr::Negative(neg) => references_table(neg, table), + Expr::Cast(cast) => references_table(&cast.expr, table), + Expr::TryCast(try_cast) => references_table(&try_cast.expr, table), + Expr::ScalarFunction(func) => { + if func.args.len() == 1 { + references_table(&func.args[0], table) + } else { + None + } + } + Expr::ScalarUDF(udf) => { + if udf.args.len() == 1 { + references_table(&udf.args[0], table) + } else { + None + } + } + _ => None, + } + } + + match predicate { + Expr::BinaryExpr(binary) => { + if references_table(&binary.right, source_name).is_some() { + if let Some(left_target) = references_table(&binary.left, target_name) { + if partition_columns.contains(&left_target) { + let placeholder_name = format!("{left_target}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + let replaced = Expr::BinaryExpr(BinaryExpr { + left: binary.left, + op: binary.op, + right: placeholder.into(), + }); + + placeholders.insert(placeholder_name, *binary.right); + + return Some(replaced); + } + } + return None; + } + if references_table(&binary.left, source_name).is_some() { + if let Some(right_target) = references_table(&binary.right, target_name) { + if partition_columns.contains(&right_target) { + let placeholder_name = format!("{right_target}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + let replaced = Expr::BinaryExpr(BinaryExpr { + right: binary.right, + op: binary.op, + left: placeholder.into(), + }); + + placeholders.insert(placeholder_name, *binary.left); + + return Some(replaced); + } + } + return None; + } + + let left = generalize_filter( + *binary.left, + partition_columns, + source_name, + target_name, + placeholders, + ); + let right = generalize_filter( + *binary.right, + partition_columns, + source_name, + target_name, + placeholders, + ); + + match (left, right) { + (None, None) => None, + (None, Some(r)) => Some(r), + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Expr::BinaryExpr(BinaryExpr { + left: l.into(), + op: binary.op, + right: r.into(), + }) + .into(), + } + } + other => Some(other), + } +} + +fn replace_placeholders(expr: Expr, placeholders: &HashMap) -> Expr { + expr.transform(&|expr| match expr { + Expr::Placeholder(Placeholder { id, .. }) => { + let value = placeholders[&id].clone(); + // Replace the placeholder with the value + Ok(Transformed::Yes(Expr::Literal(value))) + } + _ => Ok(Transformed::No(expr)), + }) + .unwrap() +} + +async fn try_construct_early_filter( + join_predicate: Expr, + table_snapshot: &DeltaTableState, + session_state: &SessionState, + source: &LogicalPlan, + source_name: &TableReference<'_>, + target_name: &TableReference<'_>, +) -> DeltaResult> { + let table_metadata = table_snapshot.metadata(); + + if table_metadata.is_none() { + return Ok(None); + } + + let table_metadata = table_metadata.unwrap(); + + let partition_columns = &table_metadata.partition_columns; + + if partition_columns.is_empty() { + return Ok(None); + } + + let mut placeholders = HashMap::default(); + + match generalize_filter( + join_predicate, + partition_columns, + source_name, + target_name, + &mut placeholders, + ) { + None => Ok(None), + Some(filter) => { + if placeholders.is_empty() { + // if we haven't recognised any partition-based predicates in the join predicate, return our reduced filter + Ok(Some(filter)) + } else { + // if we have some recognised partitions, then discover the distinct set of partitions in the source data and + // make a new filter, which expands out the placeholders for each distinct partition (and then OR these together) + let distinct_partitions = LogicalPlan::Distinct(Distinct { + input: LogicalPlan::Projection(Projection::try_new( + placeholders + .into_iter() + .map(|(alias, expr)| expr.alias(alias)) + .collect_vec(), + source.clone().into(), + )?) + .into(), + }); + + let execution_plan = session_state + .create_physical_plan(&distinct_partitions) + .await?; + + let items = execute_plan_to_batch(session_state, execution_plan).await?; + + let placeholder_names = items + .schema() + .fields() + .iter() + .map(|f| f.name().to_owned()) + .collect_vec(); + + let expr = (0..items.num_rows()) + .map(|i| { + let replacements = placeholder_names + .iter() + .map(|placeholder| { + let col = items.column_by_name(placeholder).unwrap(); + let value = ScalarValue::try_from_array(col, i)?; + DeltaResult::Ok((placeholder.to_owned(), value)) + }) + .try_collect()?; + Ok(replace_placeholders(filter.clone(), &replacements)) + }) + .collect::>>()? + .into_iter() + .reduce(Expr::or); + + Ok(expr) + } + } + } +} + #[allow(clippy::too_many_arguments)] async fn execute( predicate: Expression, @@ -709,9 +950,12 @@ async fn execute( }; // This is only done to provide the source columns with a correct table reference. Just renaming the columns does not work - let source = - LogicalPlanBuilder::scan(source_name, provider_as_source(source.into_view()), None)? - .build()?; + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + )? + .build()?; let source = LogicalPlan::Extension(Extension { node: Arc::new(MetricObserver { @@ -721,9 +965,6 @@ async fn execute( }), }); - let source = DataFrame::new(state.clone(), source); - let source = source.with_column(SOURCE_COLUMN, lit(true))?; - let scan_config = DeltaScanConfigBuilder::default() .with_file_column(true) .build(snapshot)?; @@ -735,9 +976,48 @@ async fn execute( log_store.clone(), scan_config, )?); + let target_provider = provider_as_source(target_provider); - let target = LogicalPlanBuilder::scan(target_name, target_provider, None)?.build()?; + let target = LogicalPlanBuilder::scan(target_name.clone(), target_provider, None)?.build()?; + + let source_schema = source.schema(); + let target_schema = target.schema(); + let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; + let predicate = match predicate { + Expression::DataFusion(expr) => expr, + Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, + }; + + let state = state.with_query_planner(Arc::new(MergePlanner {})); + + let target = { + // Attempt to construct an early filter that we can apply to the Add action list and the delta scan. + // In the case where there are partition columns in the join predicate, we can scan the source table + // to get the distinct list of partitions affected and constrain the search to those. + + if !not_match_source_operations.is_empty() { + // It's only worth trying to create an early filter where there are no `when_not_matched_source` operators, since + // that implies a full scan + target + } else if let Some(filter) = try_construct_early_filter( + predicate.clone(), + snapshot, + &state, + &source, + &source_name, + &target_name, + ) + .await? + { + LogicalPlan::Filter(Filter::try_new(filter, target.into())?) + } else { + target + } + }; + + let source = DataFrame::new(state.clone(), source); + let source = source.with_column(SOURCE_COLUMN, lit(true))?; // Not match operations imply a full scan of the target table is required let enable_pushdown = @@ -752,14 +1032,6 @@ async fn execute( let target = DataFrame::new(state.clone(), target); let target = target.with_column(TARGET_COLUMN, lit(true))?; - let source_schema = source.schema(); - let target_schema = target.schema(); - let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; - let predicate = match predicate { - Expression::DataFusion(expr) => expr, - Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, - }; - let join = source.join(target, JoinType::Full, &[], &[], Some(predicate.clone()))?; let join_schema_df = join.schema().to_owned(); @@ -1037,7 +1309,6 @@ async fn execute( let project = filtered.select(write_projection)?; let merge_final = &project.into_unoptimized_plan(); - let state = state.with_query_planner(Arc::new(MergePlanner {})); let write = state.create_physical_plan(merge_final).await?; let err = || DeltaTableError::Generic("Unable to locate expected metric node".into()); @@ -1212,6 +1483,8 @@ mod tests { use crate::kernel::DataType; use crate::kernel::PrimitiveType; use crate::kernel::StructField; + use crate::operations::merge::generalize_filter; + use crate::operations::merge::try_construct_early_filter; use crate::operations::DeltaOps; use crate::protocol::*; use crate::writer::test_utils::datafusion::get_data; @@ -1225,11 +1498,21 @@ mod tests { use arrow_schema::DataType as ArrowDataType; use arrow_schema::Field; use datafusion::assert_batches_sorted_eq; + use datafusion::datasource::provider_as_source; use datafusion::prelude::DataFrame; use datafusion::prelude::SessionContext; + use datafusion_common::Column; + use datafusion_common::ScalarValue; + use datafusion_common::TableReference; use datafusion_expr::col; + use datafusion_expr::expr::Placeholder; use datafusion_expr::lit; + use datafusion_expr::Expr; + use datafusion_expr::LogicalPlanBuilder; + use datafusion_expr::Operator; use serde_json::json; + use std::collections::HashMap; + use std::ops::Neg; use std::sync::Arc; use super::MergeMetrics; @@ -1560,7 +1843,7 @@ mod tests { #[tokio::test] async fn test_merge_partitions() { - /* Validate the join predicate works with partition columns */ + /* Validate the join predicate works with table partitions */ let schema = get_arrow_schema(&None); let table = setup_table(Some(vec!["modified"])).await; @@ -1648,6 +1931,78 @@ mod tests { assert_batches_sorted_eq!(&expected, &actual); } + #[tokio::test] + async fn test_merge_partitions_skipping() { + /* Validate the join predicate can be used for skipping partitions */ + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["id"])).await; + + let table = write_data(table, &schema).await; + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 4); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![999, 999, 999])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("value", col("source.value")) + .update("modified", col("source.modified")) + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + }) + .unwrap() + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert!(table.get_file_uris().count() >= 3); + assert_eq!(metrics.num_target_files_added, 3); + assert_eq!(metrics.num_target_files_removed, 2); + assert_eq!(metrics.num_target_rows_copied, 0); + assert_eq!(metrics.num_target_rows_updated, 2); + assert_eq!(metrics.num_target_rows_inserted, 1); + assert_eq!(metrics.num_target_rows_deleted, 0); + assert_eq!(metrics.num_output_rows, 3); + assert_eq!(metrics.num_source_rows, 3); + + let expected = vec![ + "+-------+------------+----+", + "| value | modified | id |", + "+-------+------------+----+", + "| 1 | 2021-02-01 | A |", + "| 100 | 2021-02-02 | D |", + "| 999 | 2023-07-04 | B |", + "| 999 | 2023-07-04 | C |", + "| 999 | 2023-07-04 | X |", + "+-------+------------+----+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_merge_delete_matched() { // Validate behaviours of match delete @@ -2067,4 +2422,228 @@ mod tests { let actual = get_data(&table).await; assert_batches_sorted_eq!(&expected, &actual); } + + #[tokio::test] + async fn test_generalize_filter_with_partitions() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_generalize_filter_with_partitions_captures_expression() { + // Check that when generalizing the filter, the placeholder map captures the expression needed to make the statement the same + // when the distinct values are substitiuted in + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .neg() + .eq(col(Column::new(target.clone().into(), "id"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + + assert_eq!(placeholders.len(), 1); + + let placeholder_expr = &placeholders["id_0"]; + + let expected_placeholder = col(Column::new(source.clone().into(), "id")).neg(); + + assert_eq!(placeholder_expr, &expected_placeholder); + } + + #[tokio::test] + async fn test_generalize_filter_keeps_static_target_references() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_generalize_filter_removes_source_references() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(source.clone().into(), "id")).eq(lit("C"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_partitions_expands() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["id"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_file_uris().count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })); + + let pred = try_construct_early_filter( + join_predicate, + &table.state, + &ctx.state(), + &source, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let split_pred = { + fn split(expr: Expr, parts: &mut Vec<(String, String)>) { + match expr { + Expr::BinaryExpr(ex) if ex.op == Operator::Or => { + split(*ex.left, parts); + split(*ex.right, parts); + } + Expr::BinaryExpr(ex) if ex.op == Operator::Eq => { + let col = match *ex.right { + Expr::Column(col) => col.name, + ex => panic!("expected column in pred, got {ex}!"), + }; + + let value = match *ex.left { + Expr::Literal(ScalarValue::Utf8(Some(value))) => value, + ex => panic!("expected value in predicate, got {ex}!"), + }; + + parts.push((col, value)) + } + + expr => panic!("expected either = or OR, got {expr}"), + } + } + + let mut parts = vec![]; + split(pred.unwrap(), &mut parts); + parts.sort(); + parts + }; + + let expected_pred_parts = [ + ("id".to_owned(), "B".to_owned()), + ("id".to_owned(), "C".to_owned()), + ("id".to_owned(), "X".to_owned()), + ]; + + assert_eq!(split_pred, expected_pred_parts); + } } diff --git a/crates/deltalake-core/src/operations/optimize.rs b/crates/deltalake-core/src/operations/optimize.rs index 7eb046341f..24ecb8e853 100644 --- a/crates/deltalake-core/src/operations/optimize.rs +++ b/crates/deltalake-core/src/operations/optimize.rs @@ -533,6 +533,7 @@ impl MergePlan { context: Arc, ) -> Result>, DeltaTableError> { use datafusion::prelude::{col, ParquetReadOptions}; + use datafusion_common::Column; use datafusion_expr::expr::ScalarUDF; use datafusion_expr::Expr; @@ -549,12 +550,16 @@ impl MergePlan { .schema() .fields() .iter() - .map(|f| col(f.name())) + .map(|f| Expr::Column(Column::from_qualified_name_ignore_case(f.name()))) .collect_vec(); // Add a temporary z-order column we will sort by, and then drop. const ZORDER_KEY_COLUMN: &str = "__zorder_key"; - let cols = context.columns.iter().map(col).collect_vec(); + let cols = context + .columns + .iter() + .map(|col| Expr::Column(Column::from_qualified_name_ignore_case(col))) + .collect_vec(); let expr = Expr::ScalarUDF(ScalarUDF::new( Arc::new(zorder::datafusion::zorder_key_udf()), cols, @@ -1208,6 +1213,7 @@ pub(super) mod zorder { use ::datafusion::assert_batches_eq; use arrow_array::{Int32Array, StringArray}; use arrow_ord::sort::sort_to_indices; + use arrow_schema::Field; use arrow_select::take::take; use rand::Rng; #[test] @@ -1300,6 +1306,42 @@ pub(super) mod zorder { } array } + + #[tokio::test] + async fn test_zorder_mixed_case() { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("moDified", DataType::Utf8, true), + Field::new("ID", DataType::Utf8, true), + Field::new("vaLue", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-01", + "2021-02-01", + "2021-02-02", + "2021-02-02", + ])), + Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C", "D"])), + Arc::new(arrow::array::Int32Array::from(vec![1, 10, 20, 100])), + ], + ) + .unwrap(); + // write some data + let table = crate::DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(crate::protocol::SaveMode::Append) + .await + .unwrap(); + + let res = crate::DeltaOps(table) + .optimize() + .with_type(OptimizeType::ZOrder(vec!["moDified".into()])) + .await; + assert!(res.is_ok()); + } } } diff --git a/crates/deltalake-core/src/protocol/checkpoints.rs b/crates/deltalake-core/src/protocol/checkpoints.rs index 55b36a64e1..af695b662e 100644 --- a/crates/deltalake-core/src/protocol/checkpoints.rs +++ b/crates/deltalake-core/src/protocol/checkpoints.rs @@ -70,6 +70,13 @@ impl From for ProtocolError { } } +use core::str::Utf8Error; +impl From for ProtocolError { + fn from(value: Utf8Error) -> Self { + Self::Generic(value.to_string()) + } +} + /// The record batch size for checkpoint parquet file pub const CHECKPOINT_RECORD_BATCH_SIZE: usize = 5000; diff --git a/crates/deltalake-core/src/protocol/parquet_read/mod.rs b/crates/deltalake-core/src/protocol/parquet_read/mod.rs index 21ad2bdff8..a546e4b0b0 100644 --- a/crates/deltalake-core/src/protocol/parquet_read/mod.rs +++ b/crates/deltalake-core/src/protocol/parquet_read/mod.rs @@ -6,6 +6,7 @@ use num_traits::cast::ToPrimitive; use parquet::record::{Field, ListAccessor, MapAccessor, RowAccessor}; use serde_json::json; +use crate::kernel::serde_path::decode_path; use crate::kernel::{ Action, Add, AddCDCFile, DeletionVectorDescriptor, Metadata, Protocol, Remove, StorageType, Txn, }; @@ -119,10 +120,13 @@ impl Add { for (i, (name, _)) in record.get_column_iter().enumerate() { match name.as_str() { "path" => { - re.path = record - .get_string(i) - .map_err(|_| gen_action_type_error("add", "path", "string"))? - .clone(); + re.path = decode_path( + record + .get_string(i) + .map_err(|_| gen_action_type_error("add", "path", "string"))? + .clone() + .as_str(), + )?; } "size" => { re.size = record @@ -515,10 +519,13 @@ impl Remove { for (i, (name, _)) in record.get_column_iter().enumerate() { match name.as_str() { "path" => { - re.path = record - .get_string(i) - .map_err(|_| gen_action_type_error("remove", "path", "string"))? - .clone(); + re.path = decode_path( + record + .get_string(i) + .map_err(|_| gen_action_type_error("remove", "path", "string"))? + .clone() + .as_str(), + )?; } "dataChange" => { re.data_change = record diff --git a/crates/deltalake-core/src/protocol/serde_path.rs b/crates/deltalake-core/src/protocol/serde_path.rs deleted file mode 100644 index 9868523e81..0000000000 --- a/crates/deltalake-core/src/protocol/serde_path.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::str::Utf8Error; - -use percent_encoding::{percent_decode_str, percent_encode, AsciiSet, CONTROLS}; -use serde::{self, Deserialize, Deserializer, Serialize, Serializer}; - -pub fn deserialize<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - decode_path(&s).map_err(serde::de::Error::custom) -} - -pub fn serialize(value: &str, serializer: S) -> Result -where - S: Serializer, -{ - let encoded = encode_path(value); - String::serialize(&encoded, serializer) -} - -pub const _DELIMITER: &str = "/"; -/// The path delimiter as a single byte -pub const _DELIMITER_BYTE: u8 = _DELIMITER.as_bytes()[0]; - -/// Characters we want to encode. -const INVALID: &AsciiSet = &CONTROLS - // The delimiter we are reserving for internal hierarchy - // .add(DELIMITER_BYTE) - // Characters AWS recommends avoiding for object keys - // https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingMetadata.html - .add(b'\\') - .add(b'{') - .add(b'^') - .add(b'}') - .add(b'%') - .add(b'`') - .add(b']') - .add(b'"') - .add(b'>') - .add(b'[') - // .add(b'~') - .add(b'<') - .add(b'#') - .add(b'|') - // Characters Google Cloud Storage recommends avoiding for object names - // https://cloud.google.com/storage/docs/naming-objects - .add(b'\r') - .add(b'\n') - .add(b'*') - .add(b'?'); - -fn encode_path(path: &str) -> String { - percent_encode(path.as_bytes(), INVALID).to_string() -} - -fn decode_path(path: &str) -> Result { - Ok(percent_decode_str(path).decode_utf8()?.to_string()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_encode_path() { - let cases = [ - ( - "string=$%25&%2F()%3D%5E%22%5B%5D%23%2A%3F.%3A/part-00023-4b06bc90-0678-4a63-94a2-f09af1adb945.c000.snappy.parquet", - "string=$%2525&%252F()%253D%255E%2522%255B%255D%2523%252A%253F.%253A/part-00023-4b06bc90-0678-4a63-94a2-f09af1adb945.c000.snappy.parquet", - ), - ( - "string=$%25&%2F()%3D%5E%22<>~%5B%5D%7B}`%23|%2A%3F%2F%5Cr%5Cn.%3A/part-00023-e0a68495-8098-40a6-be5f-b502b111b789.c000.snappy.parquet", - "string=$%2525&%252F()%253D%255E%2522%3C%3E~%255B%255D%257B%7D%60%2523%7C%252A%253F%252F%255Cr%255Cn.%253A/part-00023-e0a68495-8098-40a6-be5f-b502b111b789.c000.snappy.parquet" - ), - ( - "string=$%25&%2F()%3D%5E%22<>~%5B%5D%7B}`%23|%2A%3F%2F%5Cr%5Cn.%3A_-/part-00023-346b6795-dafa-4948-bda5-ecdf4baa4445.c000.snappy.parquet", - "string=$%2525&%252F()%253D%255E%2522%3C%3E~%255B%255D%257B%7D%60%2523%7C%252A%253F%252F%255Cr%255Cn.%253A_-/part-00023-346b6795-dafa-4948-bda5-ecdf4baa4445.c000.snappy.parquet" - ) - ]; - - for (raw, expected) in cases { - let encoded = encode_path(raw); - assert_eq!(encoded, expected); - let decoded = decode_path(expected).unwrap(); - assert_eq!(decoded, raw); - } - } -} diff --git a/crates/deltalake-core/src/table/mod.rs b/crates/deltalake-core/src/table/mod.rs index 83374d1657..94fef6ae1b 100644 --- a/crates/deltalake-core/src/table/mod.rs +++ b/crates/deltalake-core/src/table/mod.rs @@ -644,7 +644,7 @@ impl DeltaTable { .object_store() .head(&commit_uri_from_version(version)) .await?; - let ts = meta.last_modified.timestamp(); + let ts = meta.last_modified.timestamp_millis(); // also cache timestamp for version self.version_timestamp.insert(version, ts); @@ -875,14 +875,13 @@ impl DeltaTable { let mut min_version = 0; let mut max_version = self.get_latest_version().await?; let mut version = min_version; - let target_ts = datetime.timestamp(); + let target_ts = datetime.timestamp_millis(); // binary search while min_version <= max_version { let pivot = (max_version + min_version) / 2; version = pivot; let pts = self.get_version_timestamp(pivot).await?; - match pts.cmp(&target_ts) { Ordering::Equal => { break; diff --git a/crates/deltalake-core/src/test_utils.rs b/crates/deltalake-core/src/test_utils.rs index c594a80a63..1d68f43420 100644 --- a/crates/deltalake-core/src/test_utils.rs +++ b/crates/deltalake-core/src/test_utils.rs @@ -67,7 +67,7 @@ impl IntegrationContext { if let StorageIntegration::Google = integration { gs_cli::prepare_env(); let base_url = std::env::var("GOOGLE_BASE_URL")?; - let token = json!({"gcs_base_url": base_url, "disable_oauth": true, "client_email": "", "private_key": ""}); + let token = json!({"gcs_base_url": base_url, "disable_oauth": true, "client_email": "", "private_key": "", "private_key_id": ""}); let account_path = tmp_dir.path().join("gcs.json"); std::fs::write(&account_path, serde_json::to_vec(&token)?)?; set_env_if_not_set( diff --git a/crates/deltalake-core/tests/command_restore.rs b/crates/deltalake-core/tests/command_restore.rs index 80c2083261..2c1c06cbb6 100644 --- a/crates/deltalake-core/tests/command_restore.rs +++ b/crates/deltalake-core/tests/command_restore.rs @@ -11,6 +11,8 @@ use rand::Rng; use std::error::Error; use std::fs; use std::sync::Arc; +use std::thread; +use std::time::Duration; use tempdir::TempDir; #[derive(Debug)] @@ -42,19 +44,21 @@ async fn setup_test() -> Result> { .await?; let batch = get_record_batch(); - + thread::sleep(Duration::from_secs(1)); let table = DeltaOps(table) .write(vec![batch.clone()]) .with_save_mode(SaveMode::Append) .await .unwrap(); + thread::sleep(Duration::from_secs(1)); let table = DeltaOps(table) .write(vec![batch.clone()]) .with_save_mode(SaveMode::Overwrite) .await .unwrap(); + thread::sleep(Duration::from_secs(1)); let table = DeltaOps(table) .write(vec![batch.clone()]) .with_save_mode(SaveMode::Append) diff --git a/python/deltalake/__init__.py b/python/deltalake/__init__.py index b10a708309..99089ae922 100644 --- a/python/deltalake/__init__.py +++ b/python/deltalake/__init__.py @@ -6,5 +6,6 @@ from .schema import Schema as Schema from .table import DeltaTable as DeltaTable from .table import Metadata as Metadata +from .table import WriterProperties as WriterProperties from .writer import convert_to_deltalake as convert_to_deltalake from .writer import write_deltalake as write_deltalake diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 228488d91a..b4d0ca8c3d 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -64,6 +64,7 @@ class RawDeltaTable: target_size: Optional[int], max_concurrent_tasks: Optional[int], min_commit_interval: Optional[int], + writer_properties: Optional[Dict[str, Optional[str]]], ) -> str: ... def z_order_optimize( self, @@ -73,7 +74,12 @@ class RawDeltaTable: max_concurrent_tasks: Optional[int], max_spill_size: Optional[int], min_commit_interval: Optional[int], + writer_properties: Optional[Dict[str, Optional[str]]], ) -> str: ... + def add_constraints( + self, + constraints: Dict[str, str], + ) -> None: ... def restore( self, target: Optional[Any], @@ -87,13 +93,17 @@ class RawDeltaTable: ) -> List[Any]: ... def create_checkpoint(self) -> None: ... def get_add_actions(self, flatten: bool) -> pyarrow.RecordBatch: ... - def delete(self, predicate: Optional[str]) -> str: ... + def delete( + self, + predicate: Optional[str], + writer_properties: Optional[Dict[str, Optional[str]]], + ) -> str: ... def repair(self, dry_run: bool) -> str: ... def update( self, updates: Dict[str, str], predicate: Optional[str], - writer_properties: Optional[Dict[str, int]], + writer_properties: Optional[Dict[str, Optional[str]]], safe_cast: bool = False, ) -> str: ... def merge_execute( @@ -102,7 +112,7 @@ class RawDeltaTable: predicate: str, source_alias: Optional[str], target_alias: Optional[str], - writer_properties: Optional[Dict[str, int | None]], + writer_properties: Optional[Dict[str, Optional[str]]], safe_cast: bool, matched_update_updates: Optional[List[Dict[str, str]]], matched_update_predicate: Optional[List[Optional[str]]], @@ -152,6 +162,7 @@ def write_to_deltalake( description: Optional[str], configuration: Optional[Mapping[str, Optional[str]]], storage_options: Optional[Dict[str, str]], + writer_properties: Optional[Dict[str, Optional[str]]], ) -> None: ... def convert_to_deltalake( uri: str, diff --git a/python/deltalake/table.py b/python/deltalake/table.py index a2d6189fb6..5adeaaa9dc 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -51,6 +51,65 @@ MAX_SUPPORTED_WRITER_VERSION = 2 +@dataclass(init=True) +class WriterProperties: + """A Writer Properties instance for the Rust parquet writer.""" + + def __init__( + self, + data_page_size_limit: Optional[int] = None, + dictionary_page_size_limit: Optional[int] = None, + data_page_row_count_limit: Optional[int] = None, + write_batch_size: Optional[int] = None, + max_row_group_size: Optional[int] = None, + compression: Optional[str] = None, + compression_level: Optional[int] = None, + ): + """Create a Writer Properties instance for the Rust parquet writer, + see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html: + + Args: + data_page_size_limit: Limit DataPage size to this in bytes. + dictionary_page_size_limit: Limit the size of each DataPage to store dicts to this amount in bytes. + data_page_row_count_limit: Limit the number of rows in each DataPage. + write_batch_size: Splits internally to smaller batch size. + max_row_group_size: Max number of rows in row group. + compression: compression type + compression_level: level of compression, only relevant for subset of compression types + """ + self.data_page_size_limit = data_page_size_limit + self.dictionary_page_size_limit = dictionary_page_size_limit + self.data_page_row_count_limit = data_page_row_count_limit + self.write_batch_size = write_batch_size + self.max_row_group_size = max_row_group_size + + if compression_level is not None and compression is None: + raise ValueError( + """Providing a compression level without the compression type is not possible, + please provide the compression as well.""" + ) + + if compression in ["gzip", "brotli", "zstd"]: + if compression_level is not None: + compression = compression = f"{compression}({compression_level})" + else: + raise ValueError("""Gzip, brotli, ztsd require a compression level""") + self.compression = compression + + def __str__(self) -> str: + return ( + f"WriterProperties(data_page_size_limit: {self.data_page_size_limit}, dictionary_page_size_limit: {self.dictionary_page_size_limit}, " + f"data_page_row_count_limit: {self.data_page_row_count_limit}, write_batch_size: {self.write_batch_size}, " + f"max_row_group_size: {self.max_row_group_size}, compression: {self.compression})" + ) + + def _to_dict(self) -> Dict[str, Optional[str]]: + values = {} + for key, value in self.__dict__.items(): + values[key] = str(value) if isinstance(value, int) else value + return values + + @dataclass(init=False) class Metadata: """Create a Metadata instance.""" @@ -264,7 +323,6 @@ def __init__( without_files=without_files, log_buffer_size=log_buffer_size, ) - self._metadata = Metadata(self._table) @classmethod def from_data_catalog( @@ -453,13 +511,59 @@ def file_uris( file_uris.__doc__ = "" + def load_as_version(self, version: Union[int, str, datetime]) -> None: + """ + Load/time travel a DeltaTable to a specified version number, or a timestamp version of the table. If a + string is passed then the argument should be an RFC 3339 and ISO 8601 date and time string format. + + Args: + version: the identifier of the version of the DeltaTable to load + + Example: + **Use a version number** + ``` + dt = DeltaTable("test_table") + dt.load_as_version(1) + ``` + + **Use a datetime object** + ``` + dt.load_as_version(datetime(2023,1,1)) + ``` + + **Use a datetime in string format** + ``` + dt.load_as_version("2018-01-26T18:30:09Z") + dt.load_as_version("2018-12-19T16:39:57-08:00") + dt.load_as_version("2018-01-26T18:30:09.453+00:00") + ``` + """ + if isinstance(version, int): + self._table.load_version(version) + elif isinstance(version, datetime): + self._table.load_with_datetime(version.isoformat()) + elif isinstance(version, str): + self._table.load_with_datetime(version) + else: + raise TypeError( + "Invalid datatype provided for version, only int, str or datetime are accepted." + ) + def load_version(self, version: int) -> None: """ Load a DeltaTable with a specified version. + !!! warning "Deprecated" + Load_version and load_with_datetime have been combined into `DeltaTable.load_as_version`. + Args: version: the identifier of the version of the DeltaTable to load """ + warnings.warn( + "Call to deprecated method DeltaTable.load_version. Use DeltaTable.load_as_version() instead.", + category=DeprecationWarning, + stacklevel=2, + ) self._table.load_version(version) def load_with_datetime(self, datetime_string: str) -> None: @@ -467,6 +571,9 @@ def load_with_datetime(self, datetime_string: str) -> None: Time travel Delta table to the latest version that's created at or before provided `datetime_string` argument. The `datetime_string` argument should be an RFC 3339 and ISO 8601 date and time string. + !!! warning "Deprecated" + Load_version and load_with_datetime have been combined into `DeltaTable.load_as_version`. + Args: datetime_string: the identifier of the datetime point of the DeltaTable to load @@ -477,6 +584,11 @@ def load_with_datetime(self, datetime_string: str) -> None: "2018-01-26T18:30:09.453+00:00" ``` """ + warnings.warn( + "Call to deprecated method DeltaTable.load_with_datetime. Use DeltaTable.load_as_version() instead.", + category=DeprecationWarning, + stacklevel=2, + ) self._table.load_with_datetime(datetime_string) @property @@ -499,7 +611,7 @@ def metadata(self) -> Metadata: Returns: the current Metadata registered in the transaction log """ - return self._metadata + return Metadata(self._table) def protocol(self) -> ProtocolVersions: """ @@ -575,7 +687,7 @@ def update( Dict[str, Union[int, float, str, datetime, bool, List[Any]]] ] = None, predicate: Optional[str] = None, - writer_properties: Optional[Dict[str, int]] = None, + writer_properties: Optional[WriterProperties] = None, error_on_type_mismatch: bool = True, ) -> Dict[str, Any]: """`UPDATE` records in the Delta Table that matches an optional predicate. Either updates or new_values needs @@ -585,9 +697,7 @@ def update( updates: a mapping of column name to update SQL expression. new_values: a mapping of column name to python datatype. predicate: a logical expression. - writer_properties: Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html, - only the following fields are supported: `data_page_size_limit`, `dictionary_page_size_limit`, - `data_page_row_count_limit`, `write_batch_size`, `max_row_group_size`. + writer_properties: Pass writer properties to the Rust parquet writer. error_on_type_mismatch: specify if update will return error if data types are mismatching :default = True Returns: @@ -666,7 +776,7 @@ def update( metrics = self._table.update( updates, predicate, - writer_properties, + writer_properties._to_dict() if writer_properties else None, safe_cast=not error_on_type_mismatch, ) return json.loads(metrics) @@ -677,6 +787,13 @@ def optimize( ) -> "TableOptimizer": return TableOptimizer(self) + @property + def alter( + self, + ) -> "TableAlterer": + """Namespace for all table alter related methods""" + return TableAlterer(self) + def merge( self, source: Union[ @@ -690,6 +807,7 @@ def merge( source_alias: Optional[str] = None, target_alias: Optional[str] = None, error_on_type_mismatch: bool = True, + writer_properties: Optional[WriterProperties] = None, ) -> "TableMerger": """Pass the source data which you want to merge on the target delta table, providing a predicate in SQL query like format. You can also specify on what to do when the underlying data types do not @@ -701,6 +819,7 @@ def merge( source_alias: Alias for the source table target_alias: Alias for the target table error_on_type_mismatch: specify if merge will return error if data types are mismatching :default = True + writer_properties: Pass writer properties to the Rust parquet writer Returns: TableMerger: TableMerger Object @@ -747,6 +866,7 @@ def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: source_alias=source_alias, target_alias=target_alias, safe_cast=not error_on_type_mismatch, + writer_properties=writer_properties, ) def restore( @@ -964,7 +1084,11 @@ def get_add_actions(self, flatten: bool = False) -> pyarrow.RecordBatch: """ return self._table.get_add_actions(flatten) - def delete(self, predicate: Optional[str] = None) -> Dict[str, Any]: + def delete( + self, + predicate: Optional[str] = None, + writer_properties: Optional[WriterProperties] = None, + ) -> Dict[str, Any]: """Delete records from a Delta Table that statisfy a predicate. When a predicate is not provided then all records are deleted from the Delta @@ -978,7 +1102,9 @@ def delete(self, predicate: Optional[str] = None) -> Dict[str, Any]: Returns: the metrics from delete. """ - metrics = self._table.delete(predicate) + metrics = self._table.delete( + predicate, writer_properties._to_dict() if writer_properties else None + ) return json.loads(metrics) def repair(self, dry_run: bool = False) -> Dict[str, Any]: @@ -1020,6 +1146,7 @@ def __init__( source_alias: Optional[str] = None, target_alias: Optional[str] = None, safe_cast: bool = True, + writer_properties: Optional[WriterProperties] = None, ): self.table = table self.source = source @@ -1027,7 +1154,7 @@ def __init__( self.source_alias = source_alias self.target_alias = target_alias self.safe_cast = safe_cast - self.writer_properties: Optional[Dict[str, Optional[int]]] = None + self.writer_properties = writer_properties self.matched_update_updates: Optional[List[Dict[str, str]]] = None self.matched_update_predicate: Optional[List[Optional[str]]] = None self.matched_delete_predicate: Optional[List[str]] = None @@ -1061,14 +1188,20 @@ def with_writer_properties( Returns: TableMerger: TableMerger Object """ - writer_properties = { + warnings.warn( + "Call to deprecated method TableMerger.with_writer_properties. Use DeltaTable.merge(writer_properties=WriterProperties()) instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + writer_properties: Dict[str, Any] = { "data_page_size_limit": data_page_size_limit, "dictionary_page_size_limit": dictionary_page_size_limit, "data_page_row_count_limit": data_page_row_count_limit, "write_batch_size": write_batch_size, "max_row_group_size": max_row_group_size, } - self.writer_properties = writer_properties + self.writer_properties = WriterProperties(**writer_properties) return self def when_matched_update( @@ -1465,7 +1598,9 @@ def execute(self) -> Dict[str, Any]: source_alias=self.source_alias, target_alias=self.target_alias, safe_cast=self.safe_cast, - writer_properties=self.writer_properties, + writer_properties=self.writer_properties._to_dict() + if self.writer_properties + else None, matched_update_updates=self.matched_update_updates, matched_update_predicate=self.matched_update_predicate, matched_delete_predicate=self.matched_delete_predicate, @@ -1481,6 +1616,43 @@ def execute(self) -> Dict[str, Any]: return json.loads(metrics) +class TableAlterer: + """API for various table alteration commands.""" + + def __init__(self, table: DeltaTable) -> None: + self.table = table + + def add_constraint(self, constraints: Dict[str, str]) -> None: + """ + Add constraints to the table. Limited to `single constraint` at once. + + Args: + constraints: mapping of constraint name to SQL-expression to evaluate on write + + Example: + ```python + from deltalake import DeltaTable + dt = DeltaTable("test_table_constraints") + dt.alter.add_constraint({ + "value_gt_5": "value > 5", + }) + ``` + + **Check configuration** + ``` + dt.metadata().configuration + {'delta.constraints.value_gt_5': 'value > 5'} + ``` + """ + if len(constraints.keys()) > 1: + raise ValueError( + """add_constraints is limited to a single constraint addition at once for now. + Please execute add_constraints multiple times with each time a different constraint.""" + ) + + self.table._table.add_constraints(constraints) + + class TableOptimizer: """API for various table optimization commands.""" @@ -1512,6 +1684,7 @@ def compact( target_size: Optional[int] = None, max_concurrent_tasks: Optional[int] = None, min_commit_interval: Optional[Union[int, timedelta]] = None, + writer_properties: Optional[WriterProperties] = None, ) -> Dict[str, Any]: """ Compacts small files to reduce the total number of files in the table. @@ -1533,6 +1706,7 @@ def compact( min_commit_interval: minimum interval in seconds or as timedeltas before a new commit is created. Interval is useful for long running executions. Set to 0 or timedelta(0), if you want a commit per partition. + writer_properties: Pass writer properties to the Rust parquet writer. Returns: the metrics from optimize @@ -1557,7 +1731,11 @@ def compact( min_commit_interval = int(min_commit_interval.total_seconds()) metrics = self.table._table.compact_optimize( - partition_filters, target_size, max_concurrent_tasks, min_commit_interval + partition_filters, + target_size, + max_concurrent_tasks, + min_commit_interval, + writer_properties._to_dict() if writer_properties else None, ) self.table.update_incremental() return json.loads(metrics) @@ -1570,6 +1748,7 @@ def z_order( max_concurrent_tasks: Optional[int] = None, max_spill_size: int = 20 * 1024 * 1024 * 1024, min_commit_interval: Optional[Union[int, timedelta]] = None, + writer_properties: Optional[WriterProperties] = None, ) -> Dict[str, Any]: """ Reorders the data using a Z-order curve to improve data skipping. @@ -1589,6 +1768,7 @@ def z_order( min_commit_interval: minimum interval in seconds or as timedeltas before a new commit is created. Interval is useful for long running executions. Set to 0 or timedelta(0), if you want a commit per partition. + writer_properties: Pass writer properties to the Rust parquet writer. Returns: the metrics from optimize @@ -1619,6 +1799,7 @@ def z_order( max_concurrent_tasks, max_spill_size, min_commit_interval, + writer_properties._to_dict() if writer_properties else None, ) self.table.update_incremental() return json.loads(metrics) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index bb69fee457..609a6487c6 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -48,7 +48,7 @@ convert_pyarrow_recordbatchreader, convert_pyarrow_table, ) -from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable +from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable, WriterProperties try: import pandas as pd # noqa: F811 @@ -119,7 +119,6 @@ def write_deltalake( schema: Optional[Union[pa.Schema, DeltaSchema]] = ..., partition_by: Optional[Union[List[str], str]] = ..., mode: Literal["error", "append", "overwrite", "ignore"] = ..., - max_rows_per_group: int = ..., name: Optional[str] = ..., description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., @@ -128,6 +127,7 @@ def write_deltalake( predicate: Optional[str] = ..., large_dtypes: bool = ..., engine: Literal["rust"], + writer_properties: WriterProperties = ..., ) -> None: ... @@ -162,6 +162,7 @@ def write_deltalake( predicate: Optional[str] = None, large_dtypes: bool = False, engine: Literal["pyarrow", "rust"] = "pyarrow", + writer_properties: Optional[WriterProperties] = None, ) -> None: """Write to a Delta Lake table @@ -234,6 +235,7 @@ def write_deltalake( large_dtypes: If True, the data schema is kept in large_dtypes, has no effect on pandas dataframe input. engine: writer engine to write the delta table. `Rust` engine is still experimental but you may see up to 4x performance improvements over pyarrow. + writer_properties: Pass writer properties to the Rust parquet writer. """ table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) if table is not None: @@ -295,6 +297,9 @@ def write_deltalake( description=description, configuration=configuration, storage_options=storage_options, + writer_properties=writer_properties._to_dict() + if writer_properties + else None, ) if table: table.update_incremental() diff --git a/python/src/error.rs b/python/src/error.rs index f72c6361d2..a69160e3ec 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -20,7 +20,7 @@ fn inner_to_py_err(err: DeltaTableError) -> PyErr { DeltaTableError::InvalidJsonLog { .. } => DeltaProtocolError::new_err(err.to_string()), DeltaTableError::InvalidStatsJson { .. } => DeltaProtocolError::new_err(err.to_string()), DeltaTableError::InvalidData { violations } => { - DeltaProtocolError::new_err(format!("Inaviant violations: {:?}", violations)) + DeltaProtocolError::new_err(format!("Invariant violations: {:?}", violations)) } // commit errors diff --git a/python/src/lib.rs b/python/src/lib.rs index 645a2f0b72..55a7442281 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -27,6 +27,7 @@ use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; use deltalake::kernel::{Action, Add, Invariant, Remove, StructType}; +use deltalake::operations::constraints::ConstraintBuilder; use deltalake::operations::convert_to_delta::{ConvertToDeltaBuilder, PartitionStrategy}; use deltalake::operations::delete::DeleteBuilder; use deltalake::operations::filesystem_check::FileSystemCheckBuilder; @@ -36,11 +37,13 @@ use deltalake::operations::restore::RestoreBuilder; use deltalake::operations::transaction::commit; use deltalake::operations::update::UpdateBuilder; use deltalake::operations::vacuum::VacuumBuilder; +use deltalake::parquet::basic::Compression; +use deltalake::parquet::errors::ParquetError; use deltalake::parquet::file::properties::WriterProperties; use deltalake::partitions::PartitionFilter; use deltalake::protocol::{ColumnCountStat, ColumnValueStat, DeltaOperation, SaveMode, Stats}; -use deltalake::DeltaOps; use deltalake::DeltaTableBuilder; +use deltalake::{DeltaOps, DeltaResult}; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyFrozenSet; @@ -270,36 +273,16 @@ impl RawDeltaTable { &mut self, updates: HashMap, predicate: Option, - writer_properties: Option>, + writer_properties: Option>>, safe_cast: bool, ) -> PyResult { let mut cmd = UpdateBuilder::new(self._table.log_store(), self._table.state.clone()) .with_safe_cast(safe_cast); if let Some(writer_props) = writer_properties { - let mut properties = WriterProperties::builder(); - let data_page_size_limit = writer_props.get("data_page_size_limit"); - let dictionary_page_size_limit = writer_props.get("dictionary_page_size_limit"); - let data_page_row_count_limit = writer_props.get("data_page_row_count_limit"); - let write_batch_size = writer_props.get("write_batch_size"); - let max_row_group_size = writer_props.get("max_row_group_size"); - - if let Some(data_page_size) = data_page_size_limit { - properties = properties.set_data_page_size_limit(*data_page_size); - } - if let Some(dictionary_page_size) = dictionary_page_size_limit { - properties = properties.set_dictionary_page_size_limit(*dictionary_page_size); - } - if let Some(data_page_row_count) = data_page_row_count_limit { - properties = properties.set_data_page_row_count_limit(*data_page_row_count); - } - if let Some(batch_size) = write_batch_size { - properties = properties.set_write_batch_size(*batch_size); - } - if let Some(row_group_size) = max_row_group_size { - properties = properties.set_max_row_group_size(*row_group_size); - } - cmd = cmd.with_writer_properties(properties.build()); + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); } for (col_name, expression) in updates { @@ -318,13 +301,20 @@ impl RawDeltaTable { } /// Run the optimize command on the Delta Table: merge small files into a large file by bin-packing. - #[pyo3(signature = (partition_filters = None, target_size = None, max_concurrent_tasks = None, min_commit_interval = None))] + #[pyo3(signature = ( + partition_filters = None, + target_size = None, + max_concurrent_tasks = None, + min_commit_interval = None, + writer_properties=None, + ))] pub fn compact_optimize( &mut self, partition_filters: Option>, target_size: Option, max_concurrent_tasks: Option, min_commit_interval: Option, + writer_properties: Option>>, ) -> PyResult { let mut cmd = OptimizeBuilder::new(self._table.log_store(), self._table.state.clone()) .with_max_concurrent_tasks(max_concurrent_tasks.unwrap_or_else(num_cpus::get)); @@ -334,6 +324,13 @@ impl RawDeltaTable { if let Some(commit_interval) = min_commit_interval { cmd = cmd.with_min_commit_interval(time::Duration::from_secs(commit_interval)); } + + if let Some(writer_props) = writer_properties { + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); + } + let converted_filters = convert_partition_filters(partition_filters.unwrap_or_default()) .map_err(PythonError::from)?; cmd = cmd.with_filters(&converted_filters); @@ -346,7 +343,14 @@ impl RawDeltaTable { } /// Run z-order variation of optimize - #[pyo3(signature = (z_order_columns, partition_filters = None, target_size = None, max_concurrent_tasks = None, max_spill_size = 20 * 1024 * 1024 * 1024, min_commit_interval = None))] + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (z_order_columns, + partition_filters = None, + target_size = None, + max_concurrent_tasks = None, + max_spill_size = 20 * 1024 * 1024 * 1024, + min_commit_interval = None, + writer_properties=None))] pub fn z_order_optimize( &mut self, z_order_columns: Vec, @@ -355,6 +359,7 @@ impl RawDeltaTable { max_concurrent_tasks: Option, max_spill_size: usize, min_commit_interval: Option, + writer_properties: Option>>, ) -> PyResult { let mut cmd = OptimizeBuilder::new(self._table.log_store(), self._table.state.clone()) .with_max_concurrent_tasks(max_concurrent_tasks.unwrap_or_else(num_cpus::get)) @@ -367,6 +372,12 @@ impl RawDeltaTable { cmd = cmd.with_min_commit_interval(time::Duration::from_secs(commit_interval)); } + if let Some(writer_props) = writer_properties { + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); + } + let converted_filters = convert_partition_filters(partition_filters.unwrap_or_default()) .map_err(PythonError::from)?; cmd = cmd.with_filters(&converted_filters); @@ -378,6 +389,22 @@ impl RawDeltaTable { Ok(serde_json::to_string(&metrics).unwrap()) } + #[pyo3(signature = (constraints))] + pub fn add_constraints(&mut self, constraints: HashMap) -> PyResult<()> { + let mut cmd = + ConstraintBuilder::new(self._table.log_store(), self._table.get_state().clone()); + + for (col_name, expression) in constraints { + cmd = cmd.with_constraint(col_name.clone(), expression.clone()); + } + + let table = rt()? + .block_on(cmd.into_future()) + .map_err(PythonError::from)?; + self._table.state = table.state; + Ok(()) + } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (source, predicate, @@ -403,7 +430,7 @@ impl RawDeltaTable { source_alias: Option, target_alias: Option, safe_cast: bool, - writer_properties: Option>, + writer_properties: Option>>, matched_update_updates: Option>>, matched_update_predicate: Option>>, matched_delete_predicate: Option>, @@ -439,29 +466,9 @@ impl RawDeltaTable { } if let Some(writer_props) = writer_properties { - let mut properties = WriterProperties::builder(); - let data_page_size_limit = writer_props.get("data_page_size_limit"); - let dictionary_page_size_limit = writer_props.get("dictionary_page_size_limit"); - let data_page_row_count_limit = writer_props.get("data_page_row_count_limit"); - let write_batch_size = writer_props.get("write_batch_size"); - let max_row_group_size = writer_props.get("max_row_group_size"); - - if let Some(data_page_size) = data_page_size_limit { - properties = properties.set_data_page_size_limit(*data_page_size); - } - if let Some(dictionary_page_size) = dictionary_page_size_limit { - properties = properties.set_dictionary_page_size_limit(*dictionary_page_size); - } - if let Some(data_page_row_count) = data_page_row_count_limit { - properties = properties.set_data_page_row_count_limit(*data_page_row_count); - } - if let Some(batch_size) = write_batch_size { - properties = properties.set_write_batch_size(*batch_size); - } - if let Some(row_group_size) = max_row_group_size { - properties = properties.set_max_row_group_size(*row_group_size); - } - cmd = cmd.with_writer_properties(properties.build()); + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); } if let Some(mu_updates) = matched_update_updates { @@ -846,12 +853,23 @@ impl RawDeltaTable { } /// Run the delete command on the delta table: delete records following a predicate and return the delete metrics. - #[pyo3(signature = (predicate = None))] - pub fn delete(&mut self, predicate: Option) -> PyResult { + #[pyo3(signature = (predicate = None, writer_properties=None))] + pub fn delete( + &mut self, + predicate: Option, + writer_properties: Option>>, + ) -> PyResult { let mut cmd = DeleteBuilder::new(self._table.log_store(), self._table.state.clone()); if let Some(predicate) = predicate { cmd = cmd.with_predicate(predicate); } + + if let Some(writer_props) = writer_properties { + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); + } + let (table, metrics) = rt()? .block_on(cmd.into_future()) .map_err(PythonError::from)?; @@ -874,6 +892,46 @@ impl RawDeltaTable { } } +fn set_writer_properties( + writer_properties: HashMap>, +) -> DeltaResult { + let mut properties = WriterProperties::builder(); + let data_page_size_limit = writer_properties.get("data_page_size_limit"); + let dictionary_page_size_limit = writer_properties.get("dictionary_page_size_limit"); + let data_page_row_count_limit = writer_properties.get("data_page_row_count_limit"); + let write_batch_size = writer_properties.get("write_batch_size"); + let max_row_group_size = writer_properties.get("max_row_group_size"); + let compression = writer_properties.get("compression"); + + if let Some(Some(data_page_size)) = data_page_size_limit { + dbg!(data_page_size.clone()); + properties = properties.set_data_page_size_limit(data_page_size.parse::().unwrap()); + } + if let Some(Some(dictionary_page_size)) = dictionary_page_size_limit { + properties = properties + .set_dictionary_page_size_limit(dictionary_page_size.parse::().unwrap()); + } + if let Some(Some(data_page_row_count)) = data_page_row_count_limit { + properties = + properties.set_data_page_row_count_limit(data_page_row_count.parse::().unwrap()); + } + if let Some(Some(batch_size)) = write_batch_size { + properties = properties.set_write_batch_size(batch_size.parse::().unwrap()); + } + if let Some(Some(row_group_size)) = max_row_group_size { + properties = properties.set_max_row_group_size(row_group_size.parse::().unwrap()); + } + + if let Some(Some(compression)) = compression { + let compress: Compression = compression + .parse() + .map_err(|err: ParquetError| DeltaTableError::Generic(err.to_string()))?; + + properties = properties.set_compression(compress); + } + Ok(properties.build()) +} + fn convert_partition_filters<'a>( partitions_filters: Vec<(&'a str, &'a str, PartitionFilterValue)>, ) -> Result, DeltaTableError> { @@ -1114,6 +1172,7 @@ fn write_to_deltalake( description: Option, configuration: Option>>, storage_options: Option>, + writer_properties: Option>>, ) -> PyResult<()> { let batches = data.0.map(|batch| batch.unwrap()).collect::>(); let save_mode = mode.parse().map_err(PythonError::from)?; @@ -1135,6 +1194,12 @@ fn write_to_deltalake( builder = builder.with_partition_columns(partition_columns); } + if let Some(writer_props) = writer_properties { + builder = builder.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); + } + if let Some(name) = &name { builder = builder.with_table_name(name); }; diff --git a/python/tests/test_alter.py b/python/tests/test_alter.py new file mode 100644 index 0000000000..edc6d3eda1 --- /dev/null +++ b/python/tests/test_alter.py @@ -0,0 +1,48 @@ +import pathlib + +import pyarrow as pa +import pytest + +from deltalake import DeltaTable, write_deltalake +from deltalake.exceptions import DeltaError, DeltaProtocolError + + +def test_add_constraint(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table) + + dt = DeltaTable(tmp_path) + + dt.alter.add_constraint({"check_price": "price >= 0"}) + + last_action = dt.history(1)[0] + assert last_action["operation"] == "ADD CONSTRAINT" + assert dt.version() == 1 + assert dt.metadata().configuration == { + "delta.constraints.check_price": "price >= 0" + } + + with pytest.raises(DeltaError): + # Invalid constraint + dt.alter.add_constraint({"check_price": "price < 0"}) + + with pytest.raises(DeltaProtocolError): + data = pa.table( + { + "id": pa.array(["1"]), + "price": pa.array([-1], pa.int64()), + "sold": pa.array(list(range(1)), pa.int32()), + "deleted": pa.array([False] * 1), + } + ) + write_deltalake(tmp_path, data, engine="rust", mode="append") + + +def test_add_multiple_constraints(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table) + + dt = DeltaTable(tmp_path) + + with pytest.raises(ValueError): + dt.alter.add_constraint( + {"check_price": "price >= 0", "check_price2": "price >= 0"} + ) diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index a49374e710..74c7a1b339 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -63,7 +63,15 @@ def test_read_simple_table_using_options_to_dict(): assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"value": [1, 2, 3]} -def test_load_with_datetime(): +@pytest.mark.parametrize( + ["date_value", "expected_version"], + [ + ("2020-05-01T00:47:31-07:00", 0), + ("2020-05-02T22:47:31-07:00", 1), + ("2020-05-25T22:47:31-07:00", 4), + ], +) +def test_load_as_version_datetime(date_value: str, expected_version): log_dir = "../crates/deltalake-core/tests/data/simple_table/_delta_log" log_mtime_pair = [ ("00000000000000000000.json", 1588398451.0), @@ -78,15 +86,14 @@ def test_load_with_datetime(): table_path = "../crates/deltalake-core/tests/data/simple_table" dt = DeltaTable(table_path) - dt.load_with_datetime("2020-05-01T00:47:31-07:00") - assert dt.version() == 0 - dt.load_with_datetime("2020-05-02T22:47:31-07:00") - assert dt.version() == 1 - dt.load_with_datetime("2020-05-25T22:47:31-07:00") - assert dt.version() == 4 + dt.load_as_version(date_value) + assert dt.version() == expected_version + dt = DeltaTable(table_path) + dt.load_as_version(datetime.fromisoformat(date_value)) + assert dt.version() == expected_version -def test_load_with_datetime_bad_format(): +def test_load_as_version_datetime_bad_format(): table_path = "../crates/deltalake-core/tests/data/simple_table" dt = DeltaTable(table_path) @@ -96,7 +103,7 @@ def test_load_with_datetime_bad_format(): "2020-05-01T00:47:31+08", ]: with pytest.raises(Exception, match="Failed to parse datetime string:"): - dt.load_with_datetime(bad_format) + dt.load_as_version(bad_format) def test_read_simple_table_update_incremental():