From df0462495e7817defaaae94bad712d9406902d19 Mon Sep 17 00:00:00 2001 From: David Blajda Date: Tue, 19 Dec 2023 11:05:18 -0500 Subject: [PATCH 1/5] fix: case sensitivity for z-order (#1982) # Description Enable usage of z-order optimization on columns that have capitalization. # Related Issue(s) - closes #1586 --- .../deltalake-core/src/operations/optimize.rs | 46 ++++++++++++++++++- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/crates/deltalake-core/src/operations/optimize.rs b/crates/deltalake-core/src/operations/optimize.rs index 7eb046341f..24ecb8e853 100644 --- a/crates/deltalake-core/src/operations/optimize.rs +++ b/crates/deltalake-core/src/operations/optimize.rs @@ -533,6 +533,7 @@ impl MergePlan { context: Arc, ) -> Result>, DeltaTableError> { use datafusion::prelude::{col, ParquetReadOptions}; + use datafusion_common::Column; use datafusion_expr::expr::ScalarUDF; use datafusion_expr::Expr; @@ -549,12 +550,16 @@ impl MergePlan { .schema() .fields() .iter() - .map(|f| col(f.name())) + .map(|f| Expr::Column(Column::from_qualified_name_ignore_case(f.name()))) .collect_vec(); // Add a temporary z-order column we will sort by, and then drop. const ZORDER_KEY_COLUMN: &str = "__zorder_key"; - let cols = context.columns.iter().map(col).collect_vec(); + let cols = context + .columns + .iter() + .map(|col| Expr::Column(Column::from_qualified_name_ignore_case(col))) + .collect_vec(); let expr = Expr::ScalarUDF(ScalarUDF::new( Arc::new(zorder::datafusion::zorder_key_udf()), cols, @@ -1208,6 +1213,7 @@ pub(super) mod zorder { use ::datafusion::assert_batches_eq; use arrow_array::{Int32Array, StringArray}; use arrow_ord::sort::sort_to_indices; + use arrow_schema::Field; use arrow_select::take::take; use rand::Rng; #[test] @@ -1300,6 +1306,42 @@ pub(super) mod zorder { } array } + + #[tokio::test] + async fn test_zorder_mixed_case() { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("moDified", DataType::Utf8, true), + Field::new("ID", DataType::Utf8, true), + Field::new("vaLue", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-01", + "2021-02-01", + "2021-02-02", + "2021-02-02", + ])), + Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C", "D"])), + Arc::new(arrow::array::Int32Array::from(vec![1, 10, 20, 100])), + ], + ) + .unwrap(); + // write some data + let table = crate::DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(crate::protocol::SaveMode::Append) + .await + .unwrap(); + + let res = crate::DeltaOps(table) + .optimize() + .with_type(OptimizeType::ZOrder(vec!["moDified".into()])) + .await; + assert!(res.is_ok()); + } } } From 4ece26d0da666b59df206515d24ab8797329b46f Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Tue, 19 Dec 2023 17:20:26 +0100 Subject: [PATCH 2/5] refactor: trigger metadata retrieval only during `DeltaTable.metadata` (#1979) # Description Triggers metadata retrieval only on metadata call, this is a better approach otherwise we need to add it after each method that does an alteration to the table config. Also now it will only be triggered if the user actually wants to retrieve the metadata. Co-authored-by: Robert Pack <42610831+roeap@users.noreply.github.com> --- python/deltalake/table.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index a2d6189fb6..0fe9c25bb7 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -264,7 +264,6 @@ def __init__( without_files=without_files, log_buffer_size=log_buffer_size, ) - self._metadata = Metadata(self._table) @classmethod def from_data_catalog( @@ -499,7 +498,7 @@ def metadata(self) -> Metadata: Returns: the current Metadata registered in the transaction log """ - return self._metadata + return Metadata(self._table) def protocol(self) -> ProtocolVersions: """ From f6d20611c7386a316669b26fdf9207dc17f27e88 Mon Sep 17 00:00:00 2001 From: David Blajda Date: Tue, 19 Dec 2023 14:49:42 -0500 Subject: [PATCH 3/5] fix: implement consistent formatting for constraint expressions (#1985) # Description Implements consistent formatting for constraint expressions so something like `value < 1000` is normalized to `value < 1000` Also includes drive by improvements. 1. Test & Fix that Datafusion expressions can actually be used when adding a constraint 2. Test & Fix that constraints can be added to column with capitalization # Related Issue(s) - closes #1971 --- .../src/delta_datafusion/expr.rs | 15 +- .../src/delta_datafusion/mod.rs | 12 +- .../src/operations/constraints.rs | 172 ++++++++++++++---- 3 files changed, 157 insertions(+), 42 deletions(-) diff --git a/crates/deltalake-core/src/delta_datafusion/expr.rs b/crates/deltalake-core/src/delta_datafusion/expr.rs index f9275832a1..49cdae4387 100644 --- a/crates/deltalake-core/src/delta_datafusion/expr.rs +++ b/crates/deltalake-core/src/delta_datafusion/expr.rs @@ -347,9 +347,10 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> { mod test { use arrow_schema::DataType as ArrowDataType; use datafusion::prelude::SessionContext; - use datafusion_common::{DFSchema, ScalarValue}; + use datafusion_common::{Column, DFSchema, ScalarValue}; use datafusion_expr::{col, decode, lit, substring, Cast, Expr, ExprSchemable}; + use crate::delta_datafusion::DeltaSessionContext; use crate::kernel::{DataType, PrimitiveType, StructField, StructType}; use crate::{DeltaOps, DeltaTable}; @@ -388,6 +389,11 @@ mod test { DataType::Primitive(PrimitiveType::Integer), true, ), + StructField::new( + "Value3".to_string(), + DataType::Primitive(PrimitiveType::Integer), + true, + ), StructField::new( "modified".to_string(), DataType::Primitive(PrimitiveType::String), @@ -442,7 +448,10 @@ mod test { }), "arrow_cast(1, 'Int32')".to_string() ), - simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()), + simple!( + Expr::Column(Column::from_qualified_name_ignore_case("Value3")).eq(lit(3_i64)), + "Value3 = 3".to_string() + ), simple!(col("active").is_true(), "active IS TRUE".to_string()), simple!(col("active"), "active".to_string()), simple!(col("active").eq(lit(true)), "active = true".to_string()), @@ -536,7 +545,7 @@ mod test { ), ]; - let session = SessionContext::new(); + let session: SessionContext = DeltaSessionContext::default().into(); for test in tests { let actual = fmt_expr_to_sql(&test.expr).unwrap(); diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 83db86e8e2..9f2818de93 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -1033,7 +1033,7 @@ impl DeltaDataChecker { Self { invariants, constraints: vec![], - ctx: SessionContext::new(), + ctx: DeltaSessionContext::default().into(), } } @@ -1042,10 +1042,16 @@ impl DeltaDataChecker { Self { constraints, invariants: vec![], - ctx: SessionContext::new(), + ctx: DeltaSessionContext::default().into(), } } + /// Specify the Datafusion context + pub fn with_session_context(mut self, context: SessionContext) -> Self { + self.ctx = context; + self + } + /// Create a new DeltaDataChecker pub fn new(snapshot: &DeltaTableState) -> Self { let metadata = snapshot.metadata(); @@ -1059,7 +1065,7 @@ impl DeltaDataChecker { Self { invariants, constraints, - ctx: SessionContext::new(), + ctx: DeltaSessionContext::default().into(), } } diff --git a/crates/deltalake-core/src/operations/constraints.rs b/crates/deltalake-core/src/operations/constraints.rs index 889e668b1a..ed5888bd13 100644 --- a/crates/deltalake-core/src/operations/constraints.rs +++ b/crates/deltalake-core/src/operations/constraints.rs @@ -8,11 +8,15 @@ use datafusion::execution::context::SessionState; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; +use datafusion_common::ToDFSchema; use futures::future::BoxFuture; use futures::StreamExt; use serde_json::json; -use crate::delta_datafusion::{register_store, DeltaDataChecker, DeltaScanBuilder}; +use crate::delta_datafusion::expr::fmt_expr_to_sql; +use crate::delta_datafusion::{ + register_store, DeltaDataChecker, DeltaScanBuilder, DeltaSessionContext, +}; use crate::kernel::{Action, CommitInfo, IsolationLevel, Metadata, Protocol}; use crate::logstore::LogStoreRef; use crate::operations::datafusion_utils::Expression; @@ -23,6 +27,8 @@ use crate::table::Constraint; use crate::DeltaTable; use crate::{DeltaResult, DeltaTableError}; +use super::datafusion_utils::into_expr; + /// Build a constraint to add to a table pub struct ConstraintBuilder { snapshot: DeltaTableState, @@ -47,10 +53,10 @@ impl ConstraintBuilder { /// Specify the constraint to be added pub fn with_constraint, E: Into>( mut self, - column: S, + name: S, expression: E, ) -> Self { - self.name = Some(column.into()); + self.name = Some(name.into()); self.expr = Some(expression.into()); self } @@ -75,15 +81,10 @@ impl std::future::IntoFuture for ConstraintBuilder { Some(v) => v, None => return Err(DeltaTableError::Generic("No name provided".to_string())), }; - let expr = match this.expr { - Some(Expression::String(s)) => s, - Some(Expression::DataFusion(e)) => e.to_string(), - None => { - return Err(DeltaTableError::Generic( - "No expression provided".to_string(), - )) - } - }; + + let expr = this + .expr + .ok_or_else(|| DeltaTableError::Generic("No Expresion provided".to_string()))?; let mut metadata = this .snapshot @@ -94,23 +95,29 @@ impl std::future::IntoFuture for ConstraintBuilder { if metadata.configuration.contains_key(&configuration_key) { return Err(DeltaTableError::Generic(format!( - "Constraint with name: {} already exists, expr: {}", - name, expr + "Constraint with name: {} already exists", + name ))); } let state = this.state.unwrap_or_else(|| { - let session = SessionContext::new(); + let session: SessionContext = DeltaSessionContext::default().into(); register_store(this.log_store.clone(), session.runtime_env()); session.state() }); - // Checker built here with the one time constraint to check. - let checker = DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr)]); let scan = DeltaScanBuilder::new(&this.snapshot, this.log_store.clone(), &state) .build() .await?; + let schema = scan.schema().to_dfschema()?; + let expr = into_expr(expr, &schema, &state)?; + let expr_str = fmt_expr_to_sql(&expr)?; + + // Checker built here with the one time constraint to check. + let checker = + DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr_str)]); + let plan: Arc = Arc::new(scan); let mut tasks = vec![]; for p in 0..plan.output_partitioning().partition_count() { @@ -140,9 +147,10 @@ impl std::future::IntoFuture for ConstraintBuilder { // We have validated the table passes it's constraints, now to add the constraint to // the table. - metadata - .configuration - .insert(format!("delta.constraints.{}", name), Some(expr.clone())); + metadata.configuration.insert( + format!("delta.constraints.{}", name), + Some(expr_str.clone()), + ); let old_protocol = this.snapshot.protocol(); let protocol = Protocol { @@ -162,12 +170,12 @@ impl std::future::IntoFuture for ConstraintBuilder { let operational_parameters = HashMap::from_iter([ ("name".to_string(), json!(&name)), - ("expr".to_string(), json!(&expr)), + ("expr".to_string(), json!(&expr_str)), ]); let operations = DeltaOperation::AddConstraint { name: name.clone(), - expr: expr.clone(), + expr: expr_str.clone(), }; let commit_info = CommitInfo { @@ -208,11 +216,37 @@ mod tests { use std::sync::Arc; use arrow_array::{Array, Int32Array, RecordBatch, StringArray}; + use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema}; + use datafusion_expr::{col, lit}; use crate::writer::test_utils::{create_bare_table, get_arrow_schema, get_record_batch}; - use crate::{DeltaOps, DeltaResult}; + use crate::{DeltaOps, DeltaResult, DeltaTable}; + + fn get_constraint(table: &DeltaTable, name: &str) -> String { + table + .metadata() + .unwrap() + .configuration + .get(name) + .unwrap() + .clone() + .unwrap() + } + + async fn get_constraint_op_params(table: &mut DeltaTable) -> String { + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[commit_info.len() - 1]; + last_commit + .operation_parameters + .as_ref() + .unwrap() + .get("expr") + .unwrap() + .as_str() + .unwrap() + .to_owned() + } - #[cfg(feature = "datafusion")] #[tokio::test] async fn add_constraint_with_invalid_data() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -225,12 +259,10 @@ mod tests { .add_constraint() .with_constraint("id", "value > 5") .await; - dbg!(&constraint); assert!(constraint.is_err()); Ok(()) } - #[cfg(feature = "datafusion")] #[tokio::test] async fn add_valid_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -239,18 +271,89 @@ mod tests { .await?; let table = DeltaOps(write); - let constraint = table + let mut table = table .add_constraint() - .with_constraint("id", "value < 1000") - .await; - dbg!(&constraint); - assert!(constraint.is_ok()); - let version = constraint?.version(); + .with_constraint("id", "value < 1000") + .await?; + let version = table.version(); + assert_eq!(version, 1); + + let expected_expr = "value < 1000"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + assert_eq!( + get_constraint(&table, "delta.constraints.id"), + expected_expr + ); + Ok(()) + } + + #[tokio::test] + async fn add_constraint_datafusion() -> DeltaResult<()> { + // Add constraint by providing a datafusion expression. + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let mut table = table + .add_constraint() + .with_constraint("valid_values", col("value").lt(lit(1000))) + .await?; + let version = table.version(); assert_eq!(version, 1); + + let expected_expr = "value < 1000"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + assert_eq!( + get_constraint(&table, "delta.constraints.valid_values"), + expected_expr + ); + + Ok(()) + } + + #[tokio::test] + async fn test_constraint_case_sensitive() -> DeltaResult<()> { + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("Id", ArrowDataType::Utf8, true), + Field::new("vAlue", ArrowDataType::Int32, true), + Field::new("mOdifieD", ArrowDataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&arrow_schema.clone()), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + + let table = DeltaOps::new_in_memory().write(vec![batch]).await.unwrap(); + + let mut table = DeltaOps(table) + .add_constraint() + .with_constraint("valid_values", "vAlue < 1000") + .await?; + let version = table.version(); + assert_eq!(version, 1); + + let expected_expr = "vAlue < 1000"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + assert_eq!( + get_constraint(&table, "delta.constraints.valid_values"), + expected_expr + ); + Ok(()) } - #[cfg(feature = "datafusion")] #[tokio::test] async fn add_conflicting_named_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -269,12 +372,10 @@ mod tests { .add_constraint() .with_constraint("id", "value < 10") .await; - dbg!(&second_constraint); assert!(second_constraint.is_err()); Ok(()) } - #[cfg(feature = "datafusion")] #[tokio::test] async fn write_data_that_violates_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); @@ -294,7 +395,6 @@ mod tests { ]; let batch = RecordBatch::try_new(get_arrow_schema(&None), invalid_values)?; let err = table.write(vec![batch]).await; - dbg!(&err); assert!(err.is_err()); Ok(()) } From a5a4e69d2f111923cd154ac4a201a35bf0ac758b Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Tue, 19 Dec 2023 21:04:48 +0100 Subject: [PATCH 4/5] feat(python): combine load_version/load_with_datetime into `load_as_version` (#1968) # Description Combines the two functions into one. # Related Issue(s) - closes https://github.com/delta-io/delta-rs/issues/1910 - closes https://github.com/delta-io/delta-rs/issues/1967 --------- Co-authored-by: Robert Pack <42610831+roeap@users.noreply.github.com> --- crates/deltalake-core/src/table/mod.rs | 5 +- .../deltalake-core/tests/command_restore.rs | 6 ++- python/deltalake/table.py | 54 +++++++++++++++++++ python/tests/test_table_read.py | 25 +++++---- 4 files changed, 77 insertions(+), 13 deletions(-) diff --git a/crates/deltalake-core/src/table/mod.rs b/crates/deltalake-core/src/table/mod.rs index 83374d1657..94fef6ae1b 100644 --- a/crates/deltalake-core/src/table/mod.rs +++ b/crates/deltalake-core/src/table/mod.rs @@ -644,7 +644,7 @@ impl DeltaTable { .object_store() .head(&commit_uri_from_version(version)) .await?; - let ts = meta.last_modified.timestamp(); + let ts = meta.last_modified.timestamp_millis(); // also cache timestamp for version self.version_timestamp.insert(version, ts); @@ -875,14 +875,13 @@ impl DeltaTable { let mut min_version = 0; let mut max_version = self.get_latest_version().await?; let mut version = min_version; - let target_ts = datetime.timestamp(); + let target_ts = datetime.timestamp_millis(); // binary search while min_version <= max_version { let pivot = (max_version + min_version) / 2; version = pivot; let pts = self.get_version_timestamp(pivot).await?; - match pts.cmp(&target_ts) { Ordering::Equal => { break; diff --git a/crates/deltalake-core/tests/command_restore.rs b/crates/deltalake-core/tests/command_restore.rs index 80c2083261..2c1c06cbb6 100644 --- a/crates/deltalake-core/tests/command_restore.rs +++ b/crates/deltalake-core/tests/command_restore.rs @@ -11,6 +11,8 @@ use rand::Rng; use std::error::Error; use std::fs; use std::sync::Arc; +use std::thread; +use std::time::Duration; use tempdir::TempDir; #[derive(Debug)] @@ -42,19 +44,21 @@ async fn setup_test() -> Result> { .await?; let batch = get_record_batch(); - + thread::sleep(Duration::from_secs(1)); let table = DeltaOps(table) .write(vec![batch.clone()]) .with_save_mode(SaveMode::Append) .await .unwrap(); + thread::sleep(Duration::from_secs(1)); let table = DeltaOps(table) .write(vec![batch.clone()]) .with_save_mode(SaveMode::Overwrite) .await .unwrap(); + thread::sleep(Duration::from_secs(1)); let table = DeltaOps(table) .write(vec![batch.clone()]) .with_save_mode(SaveMode::Append) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 0fe9c25bb7..6075f64fd2 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -452,13 +452,59 @@ def file_uris( file_uris.__doc__ = "" + def load_as_version(self, version: Union[int, str, datetime]) -> None: + """ + Load/time travel a DeltaTable to a specified version number, or a timestamp version of the table. If a + string is passed then the argument should be an RFC 3339 and ISO 8601 date and time string format. + + Args: + version: the identifier of the version of the DeltaTable to load + + Example: + **Use a version number** + ``` + dt = DeltaTable("test_table") + dt.load_as_version(1) + ``` + + **Use a datetime object** + ``` + dt.load_as_version(datetime(2023,1,1)) + ``` + + **Use a datetime in string format** + ``` + dt.load_as_version("2018-01-26T18:30:09Z") + dt.load_as_version("2018-12-19T16:39:57-08:00") + dt.load_as_version("2018-01-26T18:30:09.453+00:00") + ``` + """ + if isinstance(version, int): + self._table.load_version(version) + elif isinstance(version, datetime): + self._table.load_with_datetime(version.isoformat()) + elif isinstance(version, str): + self._table.load_with_datetime(version) + else: + raise TypeError( + "Invalid datatype provided for version, only int, str or datetime are accepted." + ) + def load_version(self, version: int) -> None: """ Load a DeltaTable with a specified version. + !!! warning "Deprecated" + Load_version and load_with_datetime have been combined into `DeltaTable.load_as_version`. + Args: version: the identifier of the version of the DeltaTable to load """ + warnings.warn( + "Call to deprecated method DeltaTable.load_version. Use DeltaTable.load_as_version() instead.", + category=DeprecationWarning, + stacklevel=2, + ) self._table.load_version(version) def load_with_datetime(self, datetime_string: str) -> None: @@ -466,6 +512,9 @@ def load_with_datetime(self, datetime_string: str) -> None: Time travel Delta table to the latest version that's created at or before provided `datetime_string` argument. The `datetime_string` argument should be an RFC 3339 and ISO 8601 date and time string. + !!! warning "Deprecated" + Load_version and load_with_datetime have been combined into `DeltaTable.load_as_version`. + Args: datetime_string: the identifier of the datetime point of the DeltaTable to load @@ -476,6 +525,11 @@ def load_with_datetime(self, datetime_string: str) -> None: "2018-01-26T18:30:09.453+00:00" ``` """ + warnings.warn( + "Call to deprecated method DeltaTable.load_with_datetime. Use DeltaTable.load_as_version() instead.", + category=DeprecationWarning, + stacklevel=2, + ) self._table.load_with_datetime(datetime_string) @property diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index a49374e710..74c7a1b339 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -63,7 +63,15 @@ def test_read_simple_table_using_options_to_dict(): assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"value": [1, 2, 3]} -def test_load_with_datetime(): +@pytest.mark.parametrize( + ["date_value", "expected_version"], + [ + ("2020-05-01T00:47:31-07:00", 0), + ("2020-05-02T22:47:31-07:00", 1), + ("2020-05-25T22:47:31-07:00", 4), + ], +) +def test_load_as_version_datetime(date_value: str, expected_version): log_dir = "../crates/deltalake-core/tests/data/simple_table/_delta_log" log_mtime_pair = [ ("00000000000000000000.json", 1588398451.0), @@ -78,15 +86,14 @@ def test_load_with_datetime(): table_path = "../crates/deltalake-core/tests/data/simple_table" dt = DeltaTable(table_path) - dt.load_with_datetime("2020-05-01T00:47:31-07:00") - assert dt.version() == 0 - dt.load_with_datetime("2020-05-02T22:47:31-07:00") - assert dt.version() == 1 - dt.load_with_datetime("2020-05-25T22:47:31-07:00") - assert dt.version() == 4 + dt.load_as_version(date_value) + assert dt.version() == expected_version + dt = DeltaTable(table_path) + dt.load_as_version(datetime.fromisoformat(date_value)) + assert dt.version() == expected_version -def test_load_with_datetime_bad_format(): +def test_load_as_version_datetime_bad_format(): table_path = "../crates/deltalake-core/tests/data/simple_table" dt = DeltaTable(table_path) @@ -96,7 +103,7 @@ def test_load_with_datetime_bad_format(): "2020-05-01T00:47:31+08", ]: with pytest.raises(Exception, match="Failed to parse datetime string:"): - dt.load_with_datetime(bad_format) + dt.load_as_version(bad_format) def test_read_simple_table_update_incremental(): From 9eef52712ac0dc15d216fd1829d6d58c3b677563 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Tue, 19 Dec 2023 21:56:06 +0100 Subject: [PATCH 5/5] feat(python): add writer_properties to all operations (#1980) # Description I've changed the API to consolidate that how we use writer properties. You now need to instantiate a WriterProperties class and then pass it to the writer, merge, delete, update, optimize operations. ```python wp = WriterProperties(compression='gzip', compression_level=1) dt.optimize.z_order(['foo'], writer_properties=wp) ``` A potential idea I had is to allow users to set the write properties in the DeltaTable class once, so the properties can be grabbed from the tableclass so you don't have to provide them to each method. --------- Co-authored-by: Robert Pack <42610831+roeap@users.noreply.github.com> --- python/deltalake/__init__.py | 1 + python/deltalake/_internal.pyi | 13 ++- python/deltalake/table.py | 108 ++++++++++++++++++++--- python/deltalake/writer.py | 9 +- python/src/lib.rs | 154 +++++++++++++++++++++------------ 5 files changed, 215 insertions(+), 70 deletions(-) diff --git a/python/deltalake/__init__.py b/python/deltalake/__init__.py index b10a708309..99089ae922 100644 --- a/python/deltalake/__init__.py +++ b/python/deltalake/__init__.py @@ -6,5 +6,6 @@ from .schema import Schema as Schema from .table import DeltaTable as DeltaTable from .table import Metadata as Metadata +from .table import WriterProperties as WriterProperties from .writer import convert_to_deltalake as convert_to_deltalake from .writer import write_deltalake as write_deltalake diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 228488d91a..633bca9737 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -64,6 +64,7 @@ class RawDeltaTable: target_size: Optional[int], max_concurrent_tasks: Optional[int], min_commit_interval: Optional[int], + writer_properties: Optional[Dict[str, Optional[str]]], ) -> str: ... def z_order_optimize( self, @@ -73,6 +74,7 @@ class RawDeltaTable: max_concurrent_tasks: Optional[int], max_spill_size: Optional[int], min_commit_interval: Optional[int], + writer_properties: Optional[Dict[str, Optional[str]]], ) -> str: ... def restore( self, @@ -87,13 +89,17 @@ class RawDeltaTable: ) -> List[Any]: ... def create_checkpoint(self) -> None: ... def get_add_actions(self, flatten: bool) -> pyarrow.RecordBatch: ... - def delete(self, predicate: Optional[str]) -> str: ... + def delete( + self, + predicate: Optional[str], + writer_properties: Optional[Dict[str, Optional[str]]], + ) -> str: ... def repair(self, dry_run: bool) -> str: ... def update( self, updates: Dict[str, str], predicate: Optional[str], - writer_properties: Optional[Dict[str, int]], + writer_properties: Optional[Dict[str, Optional[str]]], safe_cast: bool = False, ) -> str: ... def merge_execute( @@ -102,7 +108,7 @@ class RawDeltaTable: predicate: str, source_alias: Optional[str], target_alias: Optional[str], - writer_properties: Optional[Dict[str, int | None]], + writer_properties: Optional[Dict[str, Optional[str]]], safe_cast: bool, matched_update_updates: Optional[List[Dict[str, str]]], matched_update_predicate: Optional[List[Optional[str]]], @@ -152,6 +158,7 @@ def write_to_deltalake( description: Optional[str], configuration: Optional[Mapping[str, Optional[str]]], storage_options: Optional[Dict[str, str]], + writer_properties: Optional[Dict[str, Optional[str]]], ) -> None: ... def convert_to_deltalake( uri: str, diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 6075f64fd2..c862418ee2 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -51,6 +51,65 @@ MAX_SUPPORTED_WRITER_VERSION = 2 +@dataclass(init=True) +class WriterProperties: + """A Writer Properties instance for the Rust parquet writer.""" + + def __init__( + self, + data_page_size_limit: Optional[int] = None, + dictionary_page_size_limit: Optional[int] = None, + data_page_row_count_limit: Optional[int] = None, + write_batch_size: Optional[int] = None, + max_row_group_size: Optional[int] = None, + compression: Optional[str] = None, + compression_level: Optional[int] = None, + ): + """Create a Writer Properties instance for the Rust parquet writer, + see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html: + + Args: + data_page_size_limit: Limit DataPage size to this in bytes. + dictionary_page_size_limit: Limit the size of each DataPage to store dicts to this amount in bytes. + data_page_row_count_limit: Limit the number of rows in each DataPage. + write_batch_size: Splits internally to smaller batch size. + max_row_group_size: Max number of rows in row group. + compression: compression type + compression_level: level of compression, only relevant for subset of compression types + """ + self.data_page_size_limit = data_page_size_limit + self.dictionary_page_size_limit = dictionary_page_size_limit + self.data_page_row_count_limit = data_page_row_count_limit + self.write_batch_size = write_batch_size + self.max_row_group_size = max_row_group_size + + if compression_level is not None and compression is None: + raise ValueError( + """Providing a compression level without the compression type is not possible, + please provide the compression as well.""" + ) + + if compression in ["gzip", "brotli", "zstd"]: + if compression_level is not None: + compression = compression = f"{compression}({compression_level})" + else: + raise ValueError("""Gzip, brotli, ztsd require a compression level""") + self.compression = compression + + def __str__(self) -> str: + return ( + f"WriterProperties(data_page_size_limit: {self.data_page_size_limit}, dictionary_page_size_limit: {self.dictionary_page_size_limit}, " + f"data_page_row_count_limit: {self.data_page_row_count_limit}, write_batch_size: {self.write_batch_size}, " + f"max_row_group_size: {self.max_row_group_size}, compression: {self.compression})" + ) + + def _to_dict(self) -> Dict[str, Optional[str]]: + values = {} + for key, value in self.__dict__.items(): + values[key] = str(value) if isinstance(value, int) else value + return values + + @dataclass(init=False) class Metadata: """Create a Metadata instance.""" @@ -628,7 +687,7 @@ def update( Dict[str, Union[int, float, str, datetime, bool, List[Any]]] ] = None, predicate: Optional[str] = None, - writer_properties: Optional[Dict[str, int]] = None, + writer_properties: Optional[WriterProperties] = None, error_on_type_mismatch: bool = True, ) -> Dict[str, Any]: """`UPDATE` records in the Delta Table that matches an optional predicate. Either updates or new_values needs @@ -638,9 +697,7 @@ def update( updates: a mapping of column name to update SQL expression. new_values: a mapping of column name to python datatype. predicate: a logical expression. - writer_properties: Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html, - only the following fields are supported: `data_page_size_limit`, `dictionary_page_size_limit`, - `data_page_row_count_limit`, `write_batch_size`, `max_row_group_size`. + writer_properties: Pass writer properties to the Rust parquet writer. error_on_type_mismatch: specify if update will return error if data types are mismatching :default = True Returns: @@ -719,7 +776,7 @@ def update( metrics = self._table.update( updates, predicate, - writer_properties, + writer_properties._to_dict() if writer_properties else None, safe_cast=not error_on_type_mismatch, ) return json.loads(metrics) @@ -743,6 +800,7 @@ def merge( source_alias: Optional[str] = None, target_alias: Optional[str] = None, error_on_type_mismatch: bool = True, + writer_properties: Optional[WriterProperties] = None, ) -> "TableMerger": """Pass the source data which you want to merge on the target delta table, providing a predicate in SQL query like format. You can also specify on what to do when the underlying data types do not @@ -754,6 +812,7 @@ def merge( source_alias: Alias for the source table target_alias: Alias for the target table error_on_type_mismatch: specify if merge will return error if data types are mismatching :default = True + writer_properties: Pass writer properties to the Rust parquet writer Returns: TableMerger: TableMerger Object @@ -800,6 +859,7 @@ def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: source_alias=source_alias, target_alias=target_alias, safe_cast=not error_on_type_mismatch, + writer_properties=writer_properties, ) def restore( @@ -1017,7 +1077,11 @@ def get_add_actions(self, flatten: bool = False) -> pyarrow.RecordBatch: """ return self._table.get_add_actions(flatten) - def delete(self, predicate: Optional[str] = None) -> Dict[str, Any]: + def delete( + self, + predicate: Optional[str] = None, + writer_properties: Optional[WriterProperties] = None, + ) -> Dict[str, Any]: """Delete records from a Delta Table that statisfy a predicate. When a predicate is not provided then all records are deleted from the Delta @@ -1031,7 +1095,9 @@ def delete(self, predicate: Optional[str] = None) -> Dict[str, Any]: Returns: the metrics from delete. """ - metrics = self._table.delete(predicate) + metrics = self._table.delete( + predicate, writer_properties._to_dict() if writer_properties else None + ) return json.loads(metrics) def repair(self, dry_run: bool = False) -> Dict[str, Any]: @@ -1073,6 +1139,7 @@ def __init__( source_alias: Optional[str] = None, target_alias: Optional[str] = None, safe_cast: bool = True, + writer_properties: Optional[WriterProperties] = None, ): self.table = table self.source = source @@ -1080,7 +1147,7 @@ def __init__( self.source_alias = source_alias self.target_alias = target_alias self.safe_cast = safe_cast - self.writer_properties: Optional[Dict[str, Optional[int]]] = None + self.writer_properties = writer_properties self.matched_update_updates: Optional[List[Dict[str, str]]] = None self.matched_update_predicate: Optional[List[Optional[str]]] = None self.matched_delete_predicate: Optional[List[str]] = None @@ -1114,14 +1181,20 @@ def with_writer_properties( Returns: TableMerger: TableMerger Object """ - writer_properties = { + warnings.warn( + "Call to deprecated method TableMerger.with_writer_properties. Use DeltaTable.merge(writer_properties=WriterProperties()) instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + writer_properties: Dict[str, Any] = { "data_page_size_limit": data_page_size_limit, "dictionary_page_size_limit": dictionary_page_size_limit, "data_page_row_count_limit": data_page_row_count_limit, "write_batch_size": write_batch_size, "max_row_group_size": max_row_group_size, } - self.writer_properties = writer_properties + self.writer_properties = WriterProperties(**writer_properties) return self def when_matched_update( @@ -1518,7 +1591,9 @@ def execute(self) -> Dict[str, Any]: source_alias=self.source_alias, target_alias=self.target_alias, safe_cast=self.safe_cast, - writer_properties=self.writer_properties, + writer_properties=self.writer_properties._to_dict() + if self.writer_properties + else None, matched_update_updates=self.matched_update_updates, matched_update_predicate=self.matched_update_predicate, matched_delete_predicate=self.matched_delete_predicate, @@ -1565,6 +1640,7 @@ def compact( target_size: Optional[int] = None, max_concurrent_tasks: Optional[int] = None, min_commit_interval: Optional[Union[int, timedelta]] = None, + writer_properties: Optional[WriterProperties] = None, ) -> Dict[str, Any]: """ Compacts small files to reduce the total number of files in the table. @@ -1586,6 +1662,7 @@ def compact( min_commit_interval: minimum interval in seconds or as timedeltas before a new commit is created. Interval is useful for long running executions. Set to 0 or timedelta(0), if you want a commit per partition. + writer_properties: Pass writer properties to the Rust parquet writer. Returns: the metrics from optimize @@ -1610,7 +1687,11 @@ def compact( min_commit_interval = int(min_commit_interval.total_seconds()) metrics = self.table._table.compact_optimize( - partition_filters, target_size, max_concurrent_tasks, min_commit_interval + partition_filters, + target_size, + max_concurrent_tasks, + min_commit_interval, + writer_properties._to_dict() if writer_properties else None, ) self.table.update_incremental() return json.loads(metrics) @@ -1623,6 +1704,7 @@ def z_order( max_concurrent_tasks: Optional[int] = None, max_spill_size: int = 20 * 1024 * 1024 * 1024, min_commit_interval: Optional[Union[int, timedelta]] = None, + writer_properties: Optional[WriterProperties] = None, ) -> Dict[str, Any]: """ Reorders the data using a Z-order curve to improve data skipping. @@ -1642,6 +1724,7 @@ def z_order( min_commit_interval: minimum interval in seconds or as timedeltas before a new commit is created. Interval is useful for long running executions. Set to 0 or timedelta(0), if you want a commit per partition. + writer_properties: Pass writer properties to the Rust parquet writer. Returns: the metrics from optimize @@ -1672,6 +1755,7 @@ def z_order( max_concurrent_tasks, max_spill_size, min_commit_interval, + writer_properties._to_dict() if writer_properties else None, ) self.table.update_incremental() return json.loads(metrics) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index bb69fee457..609a6487c6 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -48,7 +48,7 @@ convert_pyarrow_recordbatchreader, convert_pyarrow_table, ) -from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable +from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable, WriterProperties try: import pandas as pd # noqa: F811 @@ -119,7 +119,6 @@ def write_deltalake( schema: Optional[Union[pa.Schema, DeltaSchema]] = ..., partition_by: Optional[Union[List[str], str]] = ..., mode: Literal["error", "append", "overwrite", "ignore"] = ..., - max_rows_per_group: int = ..., name: Optional[str] = ..., description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., @@ -128,6 +127,7 @@ def write_deltalake( predicate: Optional[str] = ..., large_dtypes: bool = ..., engine: Literal["rust"], + writer_properties: WriterProperties = ..., ) -> None: ... @@ -162,6 +162,7 @@ def write_deltalake( predicate: Optional[str] = None, large_dtypes: bool = False, engine: Literal["pyarrow", "rust"] = "pyarrow", + writer_properties: Optional[WriterProperties] = None, ) -> None: """Write to a Delta Lake table @@ -234,6 +235,7 @@ def write_deltalake( large_dtypes: If True, the data schema is kept in large_dtypes, has no effect on pandas dataframe input. engine: writer engine to write the delta table. `Rust` engine is still experimental but you may see up to 4x performance improvements over pyarrow. + writer_properties: Pass writer properties to the Rust parquet writer. """ table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) if table is not None: @@ -295,6 +297,9 @@ def write_deltalake( description=description, configuration=configuration, storage_options=storage_options, + writer_properties=writer_properties._to_dict() + if writer_properties + else None, ) if table: table.update_incremental() diff --git a/python/src/lib.rs b/python/src/lib.rs index 645a2f0b72..61f28189c9 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -36,11 +36,13 @@ use deltalake::operations::restore::RestoreBuilder; use deltalake::operations::transaction::commit; use deltalake::operations::update::UpdateBuilder; use deltalake::operations::vacuum::VacuumBuilder; +use deltalake::parquet::basic::Compression; +use deltalake::parquet::errors::ParquetError; use deltalake::parquet::file::properties::WriterProperties; use deltalake::partitions::PartitionFilter; use deltalake::protocol::{ColumnCountStat, ColumnValueStat, DeltaOperation, SaveMode, Stats}; -use deltalake::DeltaOps; use deltalake::DeltaTableBuilder; +use deltalake::{DeltaOps, DeltaResult}; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyFrozenSet; @@ -270,36 +272,16 @@ impl RawDeltaTable { &mut self, updates: HashMap, predicate: Option, - writer_properties: Option>, + writer_properties: Option>>, safe_cast: bool, ) -> PyResult { let mut cmd = UpdateBuilder::new(self._table.log_store(), self._table.state.clone()) .with_safe_cast(safe_cast); if let Some(writer_props) = writer_properties { - let mut properties = WriterProperties::builder(); - let data_page_size_limit = writer_props.get("data_page_size_limit"); - let dictionary_page_size_limit = writer_props.get("dictionary_page_size_limit"); - let data_page_row_count_limit = writer_props.get("data_page_row_count_limit"); - let write_batch_size = writer_props.get("write_batch_size"); - let max_row_group_size = writer_props.get("max_row_group_size"); - - if let Some(data_page_size) = data_page_size_limit { - properties = properties.set_data_page_size_limit(*data_page_size); - } - if let Some(dictionary_page_size) = dictionary_page_size_limit { - properties = properties.set_dictionary_page_size_limit(*dictionary_page_size); - } - if let Some(data_page_row_count) = data_page_row_count_limit { - properties = properties.set_data_page_row_count_limit(*data_page_row_count); - } - if let Some(batch_size) = write_batch_size { - properties = properties.set_write_batch_size(*batch_size); - } - if let Some(row_group_size) = max_row_group_size { - properties = properties.set_max_row_group_size(*row_group_size); - } - cmd = cmd.with_writer_properties(properties.build()); + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); } for (col_name, expression) in updates { @@ -318,13 +300,20 @@ impl RawDeltaTable { } /// Run the optimize command on the Delta Table: merge small files into a large file by bin-packing. - #[pyo3(signature = (partition_filters = None, target_size = None, max_concurrent_tasks = None, min_commit_interval = None))] + #[pyo3(signature = ( + partition_filters = None, + target_size = None, + max_concurrent_tasks = None, + min_commit_interval = None, + writer_properties=None, + ))] pub fn compact_optimize( &mut self, partition_filters: Option>, target_size: Option, max_concurrent_tasks: Option, min_commit_interval: Option, + writer_properties: Option>>, ) -> PyResult { let mut cmd = OptimizeBuilder::new(self._table.log_store(), self._table.state.clone()) .with_max_concurrent_tasks(max_concurrent_tasks.unwrap_or_else(num_cpus::get)); @@ -334,6 +323,13 @@ impl RawDeltaTable { if let Some(commit_interval) = min_commit_interval { cmd = cmd.with_min_commit_interval(time::Duration::from_secs(commit_interval)); } + + if let Some(writer_props) = writer_properties { + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); + } + let converted_filters = convert_partition_filters(partition_filters.unwrap_or_default()) .map_err(PythonError::from)?; cmd = cmd.with_filters(&converted_filters); @@ -346,7 +342,14 @@ impl RawDeltaTable { } /// Run z-order variation of optimize - #[pyo3(signature = (z_order_columns, partition_filters = None, target_size = None, max_concurrent_tasks = None, max_spill_size = 20 * 1024 * 1024 * 1024, min_commit_interval = None))] + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (z_order_columns, + partition_filters = None, + target_size = None, + max_concurrent_tasks = None, + max_spill_size = 20 * 1024 * 1024 * 1024, + min_commit_interval = None, + writer_properties=None))] pub fn z_order_optimize( &mut self, z_order_columns: Vec, @@ -355,6 +358,7 @@ impl RawDeltaTable { max_concurrent_tasks: Option, max_spill_size: usize, min_commit_interval: Option, + writer_properties: Option>>, ) -> PyResult { let mut cmd = OptimizeBuilder::new(self._table.log_store(), self._table.state.clone()) .with_max_concurrent_tasks(max_concurrent_tasks.unwrap_or_else(num_cpus::get)) @@ -367,6 +371,12 @@ impl RawDeltaTable { cmd = cmd.with_min_commit_interval(time::Duration::from_secs(commit_interval)); } + if let Some(writer_props) = writer_properties { + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); + } + let converted_filters = convert_partition_filters(partition_filters.unwrap_or_default()) .map_err(PythonError::from)?; cmd = cmd.with_filters(&converted_filters); @@ -403,7 +413,7 @@ impl RawDeltaTable { source_alias: Option, target_alias: Option, safe_cast: bool, - writer_properties: Option>, + writer_properties: Option>>, matched_update_updates: Option>>, matched_update_predicate: Option>>, matched_delete_predicate: Option>, @@ -439,29 +449,9 @@ impl RawDeltaTable { } if let Some(writer_props) = writer_properties { - let mut properties = WriterProperties::builder(); - let data_page_size_limit = writer_props.get("data_page_size_limit"); - let dictionary_page_size_limit = writer_props.get("dictionary_page_size_limit"); - let data_page_row_count_limit = writer_props.get("data_page_row_count_limit"); - let write_batch_size = writer_props.get("write_batch_size"); - let max_row_group_size = writer_props.get("max_row_group_size"); - - if let Some(data_page_size) = data_page_size_limit { - properties = properties.set_data_page_size_limit(*data_page_size); - } - if let Some(dictionary_page_size) = dictionary_page_size_limit { - properties = properties.set_dictionary_page_size_limit(*dictionary_page_size); - } - if let Some(data_page_row_count) = data_page_row_count_limit { - properties = properties.set_data_page_row_count_limit(*data_page_row_count); - } - if let Some(batch_size) = write_batch_size { - properties = properties.set_write_batch_size(*batch_size); - } - if let Some(row_group_size) = max_row_group_size { - properties = properties.set_max_row_group_size(*row_group_size); - } - cmd = cmd.with_writer_properties(properties.build()); + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); } if let Some(mu_updates) = matched_update_updates { @@ -846,12 +836,23 @@ impl RawDeltaTable { } /// Run the delete command on the delta table: delete records following a predicate and return the delete metrics. - #[pyo3(signature = (predicate = None))] - pub fn delete(&mut self, predicate: Option) -> PyResult { + #[pyo3(signature = (predicate = None, writer_properties=None))] + pub fn delete( + &mut self, + predicate: Option, + writer_properties: Option>>, + ) -> PyResult { let mut cmd = DeleteBuilder::new(self._table.log_store(), self._table.state.clone()); if let Some(predicate) = predicate { cmd = cmd.with_predicate(predicate); } + + if let Some(writer_props) = writer_properties { + cmd = cmd.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); + } + let (table, metrics) = rt()? .block_on(cmd.into_future()) .map_err(PythonError::from)?; @@ -874,6 +875,46 @@ impl RawDeltaTable { } } +fn set_writer_properties( + writer_properties: HashMap>, +) -> DeltaResult { + let mut properties = WriterProperties::builder(); + let data_page_size_limit = writer_properties.get("data_page_size_limit"); + let dictionary_page_size_limit = writer_properties.get("dictionary_page_size_limit"); + let data_page_row_count_limit = writer_properties.get("data_page_row_count_limit"); + let write_batch_size = writer_properties.get("write_batch_size"); + let max_row_group_size = writer_properties.get("max_row_group_size"); + let compression = writer_properties.get("compression"); + + if let Some(Some(data_page_size)) = data_page_size_limit { + dbg!(data_page_size.clone()); + properties = properties.set_data_page_size_limit(data_page_size.parse::().unwrap()); + } + if let Some(Some(dictionary_page_size)) = dictionary_page_size_limit { + properties = properties + .set_dictionary_page_size_limit(dictionary_page_size.parse::().unwrap()); + } + if let Some(Some(data_page_row_count)) = data_page_row_count_limit { + properties = + properties.set_data_page_row_count_limit(data_page_row_count.parse::().unwrap()); + } + if let Some(Some(batch_size)) = write_batch_size { + properties = properties.set_write_batch_size(batch_size.parse::().unwrap()); + } + if let Some(Some(row_group_size)) = max_row_group_size { + properties = properties.set_max_row_group_size(row_group_size.parse::().unwrap()); + } + + if let Some(Some(compression)) = compression { + let compress: Compression = compression + .parse() + .map_err(|err: ParquetError| DeltaTableError::Generic(err.to_string()))?; + + properties = properties.set_compression(compress); + } + Ok(properties.build()) +} + fn convert_partition_filters<'a>( partitions_filters: Vec<(&'a str, &'a str, PartitionFilterValue)>, ) -> Result, DeltaTableError> { @@ -1114,6 +1155,7 @@ fn write_to_deltalake( description: Option, configuration: Option>>, storage_options: Option>, + writer_properties: Option>>, ) -> PyResult<()> { let batches = data.0.map(|batch| batch.unwrap()).collect::>(); let save_mode = mode.parse().map_err(PythonError::from)?; @@ -1135,6 +1177,12 @@ fn write_to_deltalake( builder = builder.with_partition_columns(partition_columns); } + if let Some(writer_props) = writer_properties { + builder = builder.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); + } + if let Some(name) = &name { builder = builder.with_table_name(name); };