From f0d615d281364016f2a2253632289d5cf6fef841 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Wed, 28 Feb 2024 16:39:10 +0100 Subject: [PATCH] feat: merge schema support for the write operation and Python (with Rust engine) This replaces the old "overwrite_schema" parameter with a schema_write_mode parameter that basically allows to distinguish between overwrite/merge/none Fixes #1386 --- crates/core/src/operations/cast.rs | 115 +++++-- crates/core/src/operations/delete.rs | 11 +- crates/core/src/operations/merge/mod.rs | 4 +- crates/core/src/operations/optimize.rs | 8 +- crates/core/src/operations/update.rs | 4 +- crates/core/src/operations/write.rs | 426 +++++++++++++++++++++--- crates/core/src/writer/record_batch.rs | 5 +- docs/integrations/delta-lake-pandas.md | 6 +- docs/usage/writing/index.md | 4 +- python/deltalake/_internal.pyi | 7 +- python/deltalake/exceptions.py | 1 + python/deltalake/writer.py | 24 +- python/docs/source/usage.rst | 2 +- python/src/error.rs | 2 + python/src/lib.rs | 10 +- python/tests/test_writer.py | 199 +++++++++-- 16 files changed, 702 insertions(+), 126 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 6e77552286..33155dedd8 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -1,36 +1,109 @@ //! Provide common cast functionality for callers //! -use arrow_array::{Array, ArrayRef, RecordBatch, StructArray}; +use arrow::datatypes::DataType::Dictionary; +use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; -use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef}; - +use arrow_schema::{ + ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, + SchemaRef as ArrowSchemaRef, +}; use std::sync::Arc; use crate::DeltaResult; +pub(crate) fn merge_field(left: &ArrowField, right: &ArrowField) -> Result { + if let Dictionary(_, value_type) = right.data_type() { + if value_type.equals_datatype(left.data_type()) { + return Ok(left.clone()); + } + } + if let Dictionary(_, value_type) = left.data_type() { + if value_type.equals_datatype(right.data_type()) { + return Ok(right.clone()); + } + } + let mut new_field = left.clone(); + new_field.try_merge(right)?; + Ok(new_field) +} + +pub(crate) fn merge_schema( + left: ArrowSchema, + right: ArrowSchema, +) -> Result { + let mut errors = Vec::with_capacity(left.fields().len()); + let merged_fields: Result, ArrowError> = left + .fields() + .iter() + .map(|field| { + let right_field = right.field_with_name(field.name()); + if let Ok(right_field) = right_field { + let field_or_not = merge_field(field.as_ref(), right_field); + match field_or_not { + Err(e) => { + errors.push(e.to_string()); + Err(e) + } + Ok(f) => Ok(f), + } + } else { + Ok(field.as_ref().clone()) + } + }) + .collect(); + match merged_fields { + Ok(mut fields) => { + for field in right.fields() { + if !left.field_with_name(field.name()).is_ok() { + fields.push(field.as_ref().clone()); + } + } + + Ok(ArrowSchema::new(fields)) + } + Err(e) => { + errors.push(e.to_string()); + Err(ArrowError::SchemaError(errors.join("\n"))) + } + } +} + fn cast_struct( struct_array: &StructArray, fields: &Fields, cast_options: &CastOptions, + add_missing: bool, ) -> Result>, arrow_schema::ArrowError> { fields .iter() .map(|field| { - let col = struct_array.column_by_name(field.name()).unwrap(); - if let (DataType::Struct(_), DataType::Struct(child_fields)) = - (col.data_type(), field.data_type()) - { - let child_struct = StructArray::from(col.into_data()); - let s = cast_struct(&child_struct, child_fields, cast_options)?; - Ok(Arc::new(StructArray::new( - child_fields.clone(), - s, - child_struct.nulls().map(ToOwned::to_owned), - )) as ArrayRef) - } else if is_cast_required(col.data_type(), field.data_type()) { - cast_with_options(col, field.data_type(), cast_options) - } else { - Ok(col.clone()) + let col_or_not = struct_array.column_by_name(field.name()); + match col_or_not { + None => match add_missing { + true => Ok(new_null_array(field.data_type(), struct_array.len())), + false => Err(arrow_schema::ArrowError::SchemaError(format!( + "Could not find column {0}", + field.name() + ))), + }, + Some(col) => { + if let (DataType::Struct(_), DataType::Struct(child_fields)) = + (col.data_type(), field.data_type()) + { + let child_struct = StructArray::from(col.into_data()); + let s = + cast_struct(&child_struct, child_fields, cast_options, add_missing)?; + Ok(Arc::new(StructArray::new( + child_fields.clone(), + s, + child_struct.nulls().map(ToOwned::to_owned), + )) as ArrayRef) + } else if is_cast_required(col.data_type(), field.data_type()) { + cast_with_options(col, field.data_type(), cast_options) + } else { + Ok(col.clone()) + } + } } }) .collect::, _>>() @@ -51,6 +124,7 @@ pub fn cast_record_batch( batch: &RecordBatch, target_schema: ArrowSchemaRef, safe: bool, + add_missing: bool, ) -> DeltaResult { let cast_options = CastOptions { safe, @@ -62,8 +136,7 @@ pub fn cast_record_batch( batch.columns().to_owned(), None, ); - - let columns = cast_struct(&s, target_schema.fields(), &cast_options)?; + let columns = cast_struct(&s, target_schema.fields(), &cast_options, add_missing)?; Ok(RecordBatch::try_new(target_schema, columns)?) } @@ -93,7 +166,7 @@ mod tests { )]); let target_schema = Arc::new(Schema::new(fields)) as SchemaRef; - let result = cast_record_batch(&record_batch, target_schema, false); + let result = cast_record_batch(&record_batch, target_schema, false, false); let schema = result.unwrap().schema(); let field = schema.column_with_name("list_column").unwrap().1; diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index 2e3e99bde2..1ab55310ea 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -17,6 +17,7 @@ //! .await?; //! ```` +use core::panic; use std::collections::HashMap; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; @@ -167,9 +168,15 @@ async fn excute_non_empty_expr( None, writer_properties, false, - false, + None, ) - .await?; + .await? + .into_iter() + .map(|a| match a { + Action::Add(a) => a, + _ => panic!("Expected Add action"), + }) + .collect::>(); let read_records = scan.parquet_scan.metrics().and_then(|m| m.output_rows()); let filter_records = filter.metrics().and_then(|m| m.output_rows()); diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index a495afb9e7..a112fa6114 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -1379,13 +1379,13 @@ async fn execute( None, writer_properties, safe_cast, - false, + None, ) .await?; metrics.rewrite_time_ms = Instant::now().duration_since(rewrite_start).as_millis() as u64; - let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); + let mut actions: Vec = add_actions.clone(); metrics.num_target_files_added = actions.len(); let survivors = barrier diff --git a/crates/core/src/operations/optimize.rs b/crates/core/src/operations/optimize.rs index 990997399e..90334e6de1 100644 --- a/crates/core/src/operations/optimize.rs +++ b/crates/core/src/operations/optimize.rs @@ -457,8 +457,12 @@ impl MergePlan { while let Some(maybe_batch) = read_stream.next().await { let mut batch = maybe_batch?; - batch = - super::cast::cast_record_batch(&batch, task_parameters.file_schema.clone(), false)?; + batch = super::cast::cast_record_batch( + &batch, + task_parameters.file_schema.clone(), + false, + false, + )?; partial_metrics.num_batches += 1; writer.write(&batch).await.map_err(DeltaTableError::from)?; } diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index d07f3f9fc0..803b1d0312 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -357,7 +357,7 @@ async fn execute( None, writer_properties, safe_cast, - false, + None, ) .await?; @@ -377,7 +377,7 @@ async fn execute( .duration_since(UNIX_EPOCH) .unwrap() .as_millis() as i64; - let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); + let mut actions: Vec = add_actions.clone(); metrics.num_added_files = actions.len(); metrics.num_removed_files = candidates.candidates.len(); diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 73c1599a7e..2d08082ea9 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -25,12 +25,14 @@ //! ```` use std::collections::HashMap; +use std::str::FromStr; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; +use std::vec; use arrow_array::RecordBatch; use arrow_cast::can_cast_types; -use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::physical_expr::create_physical_expr; use datafusion::physical_plan::filter::FilterExec; @@ -50,8 +52,9 @@ use crate::delta_datafusion::expr::parse_predicate_expression; use crate::delta_datafusion::DeltaDataChecker; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Action, Add, PartitionsExt, Remove, StructType}; +use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; +use crate::operations::cast::{cast_record_batch, merge_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; @@ -87,6 +90,30 @@ impl From for DeltaTableError { } } +///Specifies how to handle schema drifts +#[derive(PartialEq, Clone, Copy)] +pub enum SchemaMode { + /// Overwrite the schema with the new schema + Overwrite, + /// Append the new schema to the existing schema + Merge, +} + +impl FromStr for SchemaMode { + type Err = DeltaTableError; + + fn from_str(s: &str) -> DeltaResult { + match s.to_ascii_lowercase().as_str() { + "overwrite" => Ok(SchemaMode::Overwrite), + "merge" => Ok(SchemaMode::Merge), + _ => Err(DeltaTableError::Generic(format!( + "Invalid schema write mode provided: {}, only these are supported: ['overwrite', 'merge']", + s + ))), + } + } +} + /// Write data into a DeltaTable pub struct WriteBuilder { /// A snapshot of the to-be-loaded table's state @@ -109,8 +136,8 @@ pub struct WriteBuilder { write_batch_size: Option, /// RecordBatches to be written into the table batches: Option>, - /// whether to overwrite the schema - overwrite_schema: bool, + /// whether to overwrite the schema or to merge it. None means to fail on schmema drift + schema_mode: Option, /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) safe_cast: bool, /// Parquet writer properties @@ -140,7 +167,7 @@ impl WriteBuilder { write_batch_size: None, batches: None, safe_cast: false, - overwrite_schema: false, + schema_mode: None, writer_properties: None, app_metadata: None, name: None, @@ -155,9 +182,9 @@ impl WriteBuilder { self } - /// Add overwrite_schema - pub fn with_overwrite_schema(mut self, overwrite_schema: bool) -> Self { - self.overwrite_schema = overwrite_schema; + /// Add Schema Write Mode + pub fn with_schema_mode(mut self, schema_mode: SchemaMode) -> Self { + self.schema_mode = Some(schema_mode); self } @@ -311,10 +338,9 @@ async fn write_execution_plan_with_predicate( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - overwrite_schema: bool, -) -> DeltaResult> { - // Use input schema to prevent wrapping partitions columns into a dictionary. - let schema: ArrowSchemaRef = if overwrite_schema { + schema_mode: Option, +) -> DeltaResult> { + let schema: ArrowSchemaRef = if schema_mode.is_some() { plan.schema() } else { snapshot @@ -352,23 +378,29 @@ async fn write_execution_plan_with_predicate( let mut writer = DeltaWriter::new(object_store.clone(), config); let checker_stream = checker.clone(); let mut stream = inner_plan.execute(i, task_ctx)?; - let handle: tokio::task::JoinHandle>> = + let handle: tokio::task::JoinHandle>> = tokio::task::spawn(async move { while let Some(maybe_batch) = stream.next().await { let batch = maybe_batch?; checker_stream.check_batch(&batch).await?; - let arr = - super::cast::cast_record_batch(&batch, inner_schema.clone(), safe_cast)?; + let arr = super::cast::cast_record_batch( + &batch, + inner_schema.clone(), + safe_cast, + schema_mode == Some(SchemaMode::Merge), + )?; writer.write(&arr).await?; } - writer.close().await + let add_actions = writer.close().await; + match add_actions { + Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::>()), + Err(err) => Err(err), + } }); tasks.push(handle); } - - // Collect add actions to add to commit - Ok(futures::future::join_all(tasks) + let actions = futures::future::join_all(tasks) .await .into_iter() .collect::, _>>() @@ -377,7 +409,9 @@ async fn write_execution_plan_with_predicate( .collect::, _>>()? .concat() .into_iter() - .collect::>()) + .collect::>(); + // Collect add actions to add to commit + Ok(actions) } #[allow(clippy::too_many_arguments)] @@ -391,8 +425,8 @@ pub(crate) async fn write_execution_plan( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - overwrite_schema: bool, -) -> DeltaResult> { + schema_mode: Option, +) -> DeltaResult> { write_execution_plan_with_predicate( None, snapshot, @@ -404,7 +438,7 @@ pub(crate) async fn write_execution_plan( write_batch_size, writer_properties, safe_cast, - overwrite_schema, + schema_mode, ) .await } @@ -417,7 +451,7 @@ async fn execute_non_empty_expr( expression: &Expr, rewrite: &[Add], writer_properties: Option, -) -> DeltaResult> { +) -> DeltaResult> { // For each identified file perform a parquet scan + filter + limit (1) + count. // If returned count is not zero then append the file to be rewritten and removed from the log. Otherwise do nothing to the file. @@ -452,7 +486,7 @@ async fn execute_non_empty_expr( None, writer_properties, false, - false, + None, ) .await?; @@ -488,7 +522,7 @@ async fn prepare_predicate_actions( }; let remove = candidates.candidates; - let mut actions: Vec = add.into_iter().map(Action::Add).collect(); + let mut actions: Vec = add.into_iter().collect(); for action in remove { actions.push(Action::Remove(Remove { @@ -520,6 +554,11 @@ impl std::future::IntoFuture for WriteBuilder { PROTOCOL.check_append_only(snapshot)?; } } + if this.schema_mode == Some(SchemaMode::Overwrite) && this.mode != SaveMode::Overwrite { + return Err(DeltaTableError::Generic( + "Schema overwrite not supported for Append".to_string(), + )); + } // Create table actions to initialize table in case it does not yet exist and should be created let mut actions = this.check_preconditions().await?; @@ -546,8 +585,13 @@ impl std::future::IntoFuture for WriteBuilder { } else { Ok(this.partition_columns.unwrap_or_default()) }?; - + let mut schema_drift = false; let plan = if let Some(plan) = this.input { + if this.schema_mode == Some(SchemaMode::Merge) { + return Err(DeltaTableError::Generic( + "Schema merge not supported yet for Datafusion".to_string(), + )); + } Ok(plan) } else if let Some(batches) = this.batches { if batches.is_empty() { @@ -555,6 +599,7 @@ impl std::future::IntoFuture for WriteBuilder { } else { let schema = batches[0].schema(); + let mut new_schema = None; if let Some(snapshot) = &this.snapshot { let table_schema = snapshot .physical_arrow_schema(this.log_store.object_store().clone()) @@ -562,23 +607,38 @@ impl std::future::IntoFuture for WriteBuilder { .or_else(|_| snapshot.arrow_schema()) .unwrap_or(schema.clone()); - if !can_cast_batch(schema.fields(), table_schema.fields()) - && !(this.overwrite_schema && matches!(this.mode, SaveMode::Overwrite)) + if let Err(schema_err) = + try_cast_batch(schema.fields(), table_schema.fields()) { - return Err(DeltaTableError::Generic( - "Schema of data does not match table schema".to_string(), - )); - }; + schema_drift = true; + if this.mode == SaveMode::Overwrite && this.schema_mode.is_some() { + new_schema = None // we overwrite anyway, so no need to cast + } else if this.schema_mode == Some(SchemaMode::Merge) { + new_schema = Some(Arc::new(merge_schema( + table_schema.as_ref().clone(), + schema.as_ref().clone(), + )?)); + } else { + return Err(schema_err.into()); + } + } } let data = if !partition_columns.is_empty() { // TODO partitioning should probably happen in its own plan ... let mut partitions: HashMap> = HashMap::new(); for batch in batches { + let real_batch = match new_schema.clone() { + Some(new_schema) => { + cast_record_batch(&batch, new_schema, false, true)? + } + None => batch, + }; + let divided = divide_by_partition_values( - schema.clone(), + new_schema.clone().unwrap_or(schema.clone()), partition_columns.clone(), - &batch, + &real_batch, )?; for part in divided { let key = part.partition_values.hive_partition_path(); @@ -594,17 +654,44 @@ impl std::future::IntoFuture for WriteBuilder { } partitions.into_values().collect::>() } else { - vec![batches] + match new_schema { + Some(ref new_schema) => { + let mut new_batches = vec![]; + for batch in batches { + new_batches.push(cast_record_batch( + &batch, + new_schema.clone(), + false, + true, + )?); + } + vec![new_batches] + } + None => vec![batches], + } }; - Ok(Arc::new(MemoryExec::try_new(&data, schema.clone(), None)?) - as Arc) + Ok(Arc::new(MemoryExec::try_new( + &data, + new_schema.unwrap_or(schema).clone(), + None, + )?) as Arc) } } else { Err(WriteError::MissingData) }?; let schema = plan.schema(); - + if this.schema_mode == Some(SchemaMode::Merge) && schema_drift { + if let Some(snapshot) = &this.snapshot { + let schema_struct: StructType = schema.clone().try_into()?; + let schema_action = Action::Metadata(Metadata::try_new( + schema_struct, + partition_columns.clone(), + snapshot.metadata().configuration.clone(), + )?); + actions.push(schema_action); + } + } let state = match this.state { Some(state) => state, None => { @@ -641,10 +728,10 @@ impl std::future::IntoFuture for WriteBuilder { this.write_batch_size, this.writer_properties.clone(), this.safe_cast, - this.overwrite_schema, + this.schema_mode, ) .await?; - actions.extend(add_actions.into_iter().map(Action::Add)); + actions.extend(add_actions); // Collect remove actions if we are overwriting the table if let Some(snapshot) = &this.snapshot { @@ -729,24 +816,42 @@ impl std::future::IntoFuture for WriteBuilder { } } -fn can_cast_batch(from_fields: &Fields, to_fields: &Fields) -> bool { +fn try_cast_batch(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowError> { if from_fields.len() != to_fields.len() { - return false; + return Err(ArrowError::SchemaError(format!( + "Cannot cast schema, number of fields does not match: {} vs {}", + from_fields.len(), + to_fields.len() + ))); } - from_fields.iter().all(|f| { - if let Some((_, target_field)) = to_fields.find(f.name()) { - if let (DataType::Struct(fields0), DataType::Struct(fields1)) = - (f.data_type(), target_field.data_type()) - { - can_cast_batch(fields0, fields1) + from_fields + .iter() + .map(|f| { + if let Some((_, target_field)) = to_fields.find(f.name()) { + if let (DataType::Struct(fields0), DataType::Struct(fields1)) = + (f.data_type(), target_field.data_type()) + { + try_cast_batch(fields0, fields1) + } else if !can_cast_types(f.data_type(), target_field.data_type()) { + Err(ArrowError::SchemaError(format!( + "Cannot cast field {} from {} to {}", + f.name(), + f.data_type(), + target_field.data_type() + ))) + } else { + Ok(()) + } } else { - can_cast_types(f.data_type(), target_field.data_type()) + Err(ArrowError::SchemaError(format!( + "Field {} not found in schema", + f.name() + ))) } - } else { - false - } - }) + }) + .collect::, _>>()?; + Ok(()) } #[cfg(test)] @@ -997,6 +1102,219 @@ mod tests { assert_eq!(table.get_files_count(), 4) } + #[tokio::test] + async fn test_merge_schema() { + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + for field in batch.schema().fields() { + if field.name() != "modified" { + new_schema_builder.push(field.clone()); + } + } + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["id", "value", "inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(batch.column_by_name("id").unwrap().clone()), + Arc::new(batch.column_by_name("value").unwrap().clone()), + Arc::new(inserted_by), + ], + ) + .unwrap(); + + let mut table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .with_schema_mode(SchemaMode::Merge) + .await + .unwrap(); + table.load().await.unwrap(); + assert_eq!(table.version(), 1); + let new_schema = table.metadata().unwrap().schema().unwrap(); + let fields = new_schema.fields(); + let names = fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(names, vec!["id", "value", "modified", "inserted_by"]); + } + + #[tokio::test] + async fn test_merge_schema_with_partitions() { + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_partition_columns(vec!["id", "value"]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + for field in batch.schema().fields() { + if field.name() != "modified" { + new_schema_builder.push(field.clone()); + } + } + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["id", "value", "inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(batch.column_by_name("id").unwrap().clone()), + Arc::new(batch.column_by_name("value").unwrap().clone()), + Arc::new(inserted_by), + ], + ) + .unwrap(); + println!("new_batch: {:?}", new_batch.schema()); + let table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .with_schema_mode(SchemaMode::Merge) + .await + .unwrap(); + + assert_eq!(table.version(), 1); + let new_schema = table.metadata().unwrap().schema().unwrap(); + let fields = new_schema.fields(); + let mut names = fields.iter().map(|f| f.name()).collect::>(); + names.sort(); + assert_eq!(names, vec!["id", "inserted_by", "modified", "value"]); + let part_cols = table.metadata().unwrap().partition_columns.clone(); + assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions + } + + #[tokio::test] + async fn test_overwrite_schema() { + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + for field in batch.schema().fields() { + if field.name() != "modified" { + new_schema_builder.push(field.clone()); + } + } + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["id", "value", "inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(batch.column_by_name("id").unwrap().clone()), + Arc::new(batch.column_by_name("value").unwrap().clone()), + Arc::new(inserted_by), + ], + ) + .unwrap(); + + let table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .with_schema_mode(SchemaMode::Overwrite) + .await; + assert!(table.is_err()); + } + + #[tokio::test] + async fn test_overwrite_check() { + // If you do not pass a schema mode, we want to check the schema + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = + RecordBatch::try_new(Arc::new(new_schema), vec![Arc::new(inserted_by)]).unwrap(); + + let table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .await; + assert!(table.is_err()); + } + #[tokio::test] async fn test_check_invariants() { let batch = get_record_batch(None, false); diff --git a/crates/core/src/writer/record_batch.rs b/crates/core/src/writer/record_batch.rs index 5a4066b5b7..5c8fb57509 100644 --- a/crates/core/src/writer/record_batch.rs +++ b/crates/core/src/writer/record_batch.rs @@ -29,6 +29,7 @@ use super::utils::{ use super::{DeltaWriter, DeltaWriterError, WriteMode}; use crate::errors::DeltaTableError; use crate::kernel::{Action, Add, PartitionsExt, Scalar, StructType}; +use crate::operations::cast::merge_schema; use crate::storage::ObjectStoreRetryExt; use crate::table::builder::DeltaTableBuilder; use crate::DeltaTable; @@ -305,10 +306,10 @@ impl PartitionWriter { WriteMode::MergeSchema => { debug!("The writer and record batch schemas do not match, merging"); - let merged = ArrowSchema::try_merge(vec![ + let merged = merge_schema( self.arrow_schema.as_ref().clone(), record_batch.schema().as_ref().clone(), - ])?; + )?; self.arrow_schema = Arc::new(merged); let mut cols = vec![]; diff --git a/docs/integrations/delta-lake-pandas.md b/docs/integrations/delta-lake-pandas.md index b14c1bd45b..ca60362838 100644 --- a/docs/integrations/delta-lake-pandas.md +++ b/docs/integrations/delta-lake-pandas.md @@ -250,10 +250,10 @@ Schema enforcement protects your table from getting corrupted by appending data ## Overwriting schema of table -You can overwrite the table contents and schema by setting the `overwrite_schema` option. Here's how to overwrite the table contents: +You can overwrite the table contents and schema by setting the `schema_mode` option. Here's how to overwrite the table contents: ```python -write_deltalake("tmp/some-table", df, mode="overwrite", overwrite_schema=True) +write_deltalake("tmp/some-table", df, mode="overwrite", schema_mode="overwrite") ``` Here are the contents of the table after the values and schema have been overwritten: @@ -267,6 +267,8 @@ Here are the contents of the table after the values and schema have been overwri +-------+----------+ ``` +If you want the schema to be merged instead, specify schema_mode="merge". + ## In-memory vs. in-storage data changes It's important to distinguish between data stored in-memory and data stored on disk when understanding the functionality offered by Delta Lake. diff --git a/docs/usage/writing/index.md b/docs/usage/writing/index.md index dc8bb62389..9e9e1bcbec 100644 --- a/docs/usage/writing/index.md +++ b/docs/usage/writing/index.md @@ -23,7 +23,9 @@ of Spark's `pyspark.sql.DataFrameWriter.saveAsTable` DataFrame method. To overwr `write_deltalake` will raise `ValueError` if the schema of the data passed to it differs from the existing table's schema. If you wish to -alter the schema as part of an overwrite pass in `overwrite_schema=True`. +alter the schema as part of an overwrite pass in `schema_mode="overwrite"` or `schema_mode="merge"`. +`schema_mode="overwrite"` will completely overwrite the schema, even if columns are dropped; merge will append the new columns +and fill missing columns with `null`. `schema_mode="merge"` is also supported on append operations. ## Overwriting a partition diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index e8994983f1..695d7e3322 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -174,7 +174,7 @@ def write_to_deltalake( partition_by: Optional[List[str]], mode: str, max_rows_per_group: int, - overwrite_schema: bool, + schema_mode: Optional[str], predicate: Optional[str], name: Optional[str], description: Optional[str], @@ -795,6 +795,11 @@ class DeltaProtocolError(DeltaError): pass +class SchemaMismatchError(DeltaError): + """Raised when a schema mismatch is detected.""" + + pass + FilterLiteralType = Tuple[str, str, Any] FilterConjunctionType = List[FilterLiteralType] FilterDNFType = List[FilterConjunctionType] diff --git a/python/deltalake/exceptions.py b/python/deltalake/exceptions.py index bacd0af9f8..a2e5b1ba1e 100644 --- a/python/deltalake/exceptions.py +++ b/python/deltalake/exceptions.py @@ -1,4 +1,5 @@ from ._internal import CommitFailedError as CommitFailedError from ._internal import DeltaError as DeltaError from ._internal import DeltaProtocolError as DeltaProtocolError +from ._internal import SchemaMismatchError as SchemaMismatchError from ._internal import TableNotFoundError as TableNotFoundError diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 6ebc496436..89a12e2d6e 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -31,6 +31,8 @@ else: from typing_extensions import Literal +import warnings + import pyarrow as pa import pyarrow.dataset as ds import pyarrow.fs as pa_fs @@ -95,6 +97,7 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., + schema_mode: Optional[Literal["overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., partition_filters: Optional[List[Tuple[str, str, Any]]] = ..., large_dtypes: bool = ..., @@ -122,6 +125,7 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., + schema_mode: Optional[Literal["merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., large_dtypes: bool = ..., engine: Literal["rust"], @@ -149,6 +153,7 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., + schema_mode: Optional[Literal["merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., predicate: Optional[str] = ..., large_dtypes: bool = ..., @@ -182,6 +187,7 @@ def write_deltalake( description: Optional[str] = None, configuration: Optional[Mapping[str, Optional[str]]] = None, overwrite_schema: bool = False, + schema_mode: Optional[Literal["merge", "overwrite"]] = None, storage_options: Optional[Dict[str, str]] = None, partition_filters: Optional[List[Tuple[str, str, Any]]] = None, predicate: Optional[str] = None, @@ -235,7 +241,8 @@ def write_deltalake( name: User-provided identifier for this table. description: User-provided description for this table. configuration: A map containing configuration options for the metadata action. - overwrite_schema: If True, allows updating the schema of the table. + overwrite_schema: Deprecated, use schema_mode instead. + schema_mode: If set to "overwrite", allows replacing the schema of the table. Set to "merge" to merge with existing schema. storage_options: options passed to the native delta filesystem. predicate: When using `Overwrite` mode, replace data that matches a predicate. Only used in rust engine. partition_filters: the partition filters that will be used for partition overwrite. Only used in pyarrow engine. @@ -253,7 +260,14 @@ def write_deltalake( table.update_incremental() __enforce_append_only(table=table, configuration=configuration, mode=mode) + if overwrite_schema: + schema_mode = "overwrite" + warnings.warn( + "overwrite_schema is deprecated, use schema_mode instead. ", + category=DeprecationWarning, + stacklevel=2, + ) if isinstance(partition_by, str): partition_by = [partition_by] @@ -299,7 +313,7 @@ def write_deltalake( partition_by=partition_by, mode=mode, max_rows_per_group=max_rows_per_group, - overwrite_schema=overwrite_schema, + schema_mode=schema_mode, predicate=predicate, name=name, description=description, @@ -314,6 +328,10 @@ def write_deltalake( table.update_incremental() elif engine == "pyarrow": + if schema_mode == "merge": + raise ValueError( + "schema_mode 'merge' is not supported in pyarrow engine. Use engine=rust" + ) # We need to write against the latest table version filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) @@ -324,7 +342,7 @@ def sort_arrow_schema(schema: pa.schema) -> pa.schema: if table: # already exists if sort_arrow_schema(schema) != sort_arrow_schema( table.schema().to_pyarrow(as_large_types=large_dtypes) - ) and not (mode == "overwrite" and overwrite_schema): + ) and not (mode == "overwrite" and schema_mode == "overwrite"): raise ValueError( "Schema of data does not match table schema\n" f"Data schema:\n{schema}\nTable Schema:\n{table.schema().to_pyarrow(as_large_types=large_dtypes)}" diff --git a/python/docs/source/usage.rst b/python/docs/source/usage.rst index d0349a450c..baa26f275c 100644 --- a/python/docs/source/usage.rst +++ b/python/docs/source/usage.rst @@ -481,7 +481,7 @@ to append pass in ``mode='append'``: :py:meth:`write_deltalake` will raise :py:exc:`ValueError` if the schema of the data passed to it differs from the existing table's schema. If you wish to -alter the schema as part of an overwrite pass in ``overwrite_schema=True``. +alter the schema as part of an overwrite pass in ``schema_mode="overwrite"``. Writing to s3 ~~~~~~~~~~~~~ diff --git a/python/src/error.rs b/python/src/error.rs index a69160e3ec..a54b1e60b4 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -10,6 +10,7 @@ create_exception!(_internal, DeltaError, PyException); create_exception!(_internal, TableNotFoundError, DeltaError); create_exception!(_internal, DeltaProtocolError, DeltaError); create_exception!(_internal, CommitFailedError, DeltaError); +create_exception!(_internal, SchemaMismatchError, DeltaError); fn inner_to_py_err(err: DeltaTableError) -> PyErr { match err { @@ -55,6 +56,7 @@ fn arrow_to_py(err: ArrowError) -> PyErr { ArrowError::DivideByZero => PyValueError::new_err("division by zero"), ArrowError::InvalidArgumentError(msg) => PyValueError::new_err(msg), ArrowError::NotYetImplemented(msg) => PyNotImplementedError::new_err(msg), + ArrowError::SchemaError(msg) => SchemaMismatchError::new_err(msg), other => PyException::new_err(other.to_string()), } } diff --git a/python/src/lib.rs b/python/src/lib.rs index 1992bae642..0800d42927 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1367,7 +1367,7 @@ fn write_to_deltalake( data: PyArrowType, mode: String, max_rows_per_group: i64, - overwrite_schema: bool, + schema_mode: Option, partition_by: Option>, predicate: Option, name: Option, @@ -1390,9 +1390,10 @@ fn write_to_deltalake( let mut builder = table .write(batches) .with_save_mode(save_mode) - .with_overwrite_schema(overwrite_schema) .with_write_batch_size(max_rows_per_group as usize); - + if let Some(schema_mode) = schema_mode { + builder = builder.with_schema_mode(schema_mode.parse().map_err(PythonError::from)?); + } if let Some(partition_columns) = partition_by { builder = builder.with_partition_columns(partition_columns); } @@ -1623,7 +1624,7 @@ impl PyDeltaDataChecker { #[pymodule] // module name need to match project name fn _internal(py: Python, m: &PyModule) -> PyResult<()> { - use crate::error::{CommitFailedError, DeltaError, TableNotFoundError}; + use crate::error::{CommitFailedError, DeltaError, SchemaMismatchError, TableNotFoundError}; deltalake::aws::register_handlers(None); deltalake::azure::register_handlers(None); @@ -1633,6 +1634,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add("CommitFailedError", py.get_type::())?; m.add("DeltaProtocolError", py.get_type::())?; m.add("TableNotFoundError", py.get_type::())?; + m.add("SchemaMismatchError", py.get_type::())?; env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("warn")).init(); m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 550fec71ee..0ee751c2d7 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -17,7 +17,12 @@ from pyarrow.lib import RecordBatchReader from deltalake import DeltaTable, Schema, write_deltalake -from deltalake.exceptions import CommitFailedError, DeltaError, DeltaProtocolError +from deltalake.exceptions import ( + CommitFailedError, + DeltaError, + DeltaProtocolError, + SchemaMismatchError, +) from deltalake.table import ProtocolVersions from deltalake.writer import try_get_table_and_table_uri @@ -124,11 +129,17 @@ def test_enforce_schema(existing_table: DeltaTable, mode: str): def test_enforce_schema_rust_writer(existing_table: DeltaTable, mode: str): bad_data = pa.table({"x": pa.array([1, 2, 3])}) - with pytest.raises(DeltaError): + with pytest.raises( + SchemaMismatchError, + match=".*Cannot cast schema, number of fields does not match.*", + ): write_deltalake(existing_table, bad_data, mode=mode, engine="rust") table_uri = existing_table._table.table_uri() - with pytest.raises(DeltaError): + with pytest.raises( + SchemaMismatchError, + match=".*Cannot cast schema, number of fields does not match.*", + ): write_deltalake(table_uri, bad_data, mode=mode, engine="rust") @@ -136,47 +147,154 @@ def test_update_schema(existing_table: DeltaTable): new_data = pa.table({"x": pa.array([1, 2, 3])}) with pytest.raises(ValueError): - write_deltalake(existing_table, new_data, mode="append", overwrite_schema=True) + write_deltalake( + existing_table, new_data, mode="append", schema_mode="overwrite" + ) - write_deltalake(existing_table, new_data, mode="overwrite", overwrite_schema=True) + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") read_data = existing_table.to_pyarrow_table() assert new_data == read_data assert existing_table.schema().to_pyarrow() == new_data.schema -def test_update_schema_rust_writer(existing_table: DeltaTable): - new_data = pa.table({"x": pa.array([1, 2, 3])}) +def test_merge_schema(existing_table: DeltaTable): + print(existing_table._table.table_uri()) + old_table_data = existing_table.to_pyarrow_table() + new_data = pa.table( + { + "new_x": pa.array([1, 2, 3], pa.int32()), + "new_y": pa.array([1, 2, 3], pa.int32()), + } + ) + + write_deltalake( + existing_table, new_data, mode="append", schema_mode="merge", engine="rust" + ) + # adjust schema of old_table_data and new_data to match each other + + for i in range(old_table_data.num_columns): + col = old_table_data.schema.field(i) + new_data = new_data.add_column(i, col, pa.nulls(new_data.num_rows, col.type)) + + old_table_data = old_table_data.append_column( + pa.field("new_x", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) + ) + old_table_data = old_table_data.append_column( + pa.field("new_y", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) + ) + + # define sort order + read_data = existing_table.to_pyarrow_table().sort_by( + [("utf8", "ascending"), ("new_x", "ascending")] + ) + print(repr(read_data.to_pylist())) + concated = pa.concat_tables([old_table_data, new_data]) + print(repr(concated.to_pylist())) + assert read_data == concated + + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") + + assert existing_table.schema().to_pyarrow() == new_data.schema + + +def test_overwrite_schema(existing_table: DeltaTable): + new_data_invalid = pa.table( + { + "utf8": pa.array([1235, 546, 5645]), + "new_x": pa.array([1, 2, 3], pa.int32()), + "new_y": pa.array([1, 2, 3], pa.int32()), + } + ) with pytest.raises(DeltaError): write_deltalake( existing_table, - new_data, + new_data_invalid, mode="append", - overwrite_schema=True, + schema_mode="overwrite", engine="rust", ) + + new_data = pa.table( + { + "utf8": pa.array(["bla", "bli", "blubb"]), + "new_x": pa.array([1, 2, 3], pa.int32()), + "new_y": pa.array([1, 2, 3], pa.int32()), + } + ) with pytest.raises(DeltaError): write_deltalake( existing_table, new_data, - mode="overwrite", - overwrite_schema=False, + mode="append", + schema_mode="overwrite", + engine="rust", + ) + + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") + + assert existing_table.schema().to_pyarrow() == new_data.schema + + +def test_update_schema_rust_writer_append(existing_table: DeltaTable): + with pytest.raises( + SchemaMismatchError, match="Cannot cast schema, number of fields does not match" + ): + # It's illegal to do schema drift without correct schema_mode + write_deltalake( + existing_table, + pa.table({"x4": pa.array([1, 2, 3])}), + mode="append", + schema_mode=None, engine="rust", ) with pytest.raises(DeltaError): + write_deltalake( # schema_mode overwrite is illegal with append + existing_table, + pa.table({"x1": pa.array([1, 2, 3])}), + mode="append", + schema_mode="overwrite", + engine="rust", + ) + with pytest.raises( + SchemaMismatchError, + match="Schema error: Fail to merge schema field 'utf8' because the from data_type = Int64 does not equal Utf8", + ): write_deltalake( existing_table, - new_data, + pa.table({"utf8": pa.array([1, 2, 3])}), mode="append", - overwrite_schema=False, + schema_mode="merge", engine="rust", ) + write_deltalake( + existing_table, + pa.table({"x2": pa.array([1, 2, 3])}), + mode="append", + schema_mode="merge", + engine="rust", + ) + + +def test_update_schema_rust_writer_invalid(existing_table: DeltaTable): + new_data = pa.table({"x5": pa.array([1, 2, 3])}) + with pytest.raises( + SchemaMismatchError, match="Cannot cast schema, number of fields does not match" + ): + write_deltalake( + existing_table, + new_data, + mode="overwrite", + schema_mode=None, + engine="rust", + ) + write_deltalake( existing_table, new_data, mode="overwrite", - overwrite_schema=True, + schema_mode="overwrite", engine="rust", ) @@ -660,35 +778,58 @@ def test_writer_with_options(tmp_path: pathlib.Path): def test_try_get_table_and_table_uri(tmp_path: pathlib.Path): + def _normalize_path(t): # who does not love Windows? ;) + return t[0], t[1].replace("\\", "/") if t[1] else t[1] + data = pa.table({"vals": pa.array(["1", "2", "3"])}) table_or_uri = tmp_path / "delta_table" write_deltalake(table_or_uri, data) delta_table = DeltaTable(table_or_uri) # table_or_uri as DeltaTable - assert try_get_table_and_table_uri(delta_table, None) == ( - delta_table, - str(tmp_path / "delta_table") + "/", + assert _normalize_path( + try_get_table_and_table_uri(delta_table, None) + ) == _normalize_path( + ( + delta_table, + str(tmp_path / "delta_table") + "/", + ) ) # table_or_uri as str - assert try_get_table_and_table_uri(str(tmp_path / "delta_table"), None) == ( - delta_table, - str(tmp_path / "delta_table"), + assert _normalize_path( + try_get_table_and_table_uri(str(tmp_path / "delta_table"), None) + ) == _normalize_path( + ( + delta_table, + str(tmp_path / "delta_table"), + ) ) - assert try_get_table_and_table_uri(str(tmp_path / "str"), None) == ( - None, - str(tmp_path / "str"), + assert _normalize_path( + try_get_table_and_table_uri(str(tmp_path / "str"), None) + ) == _normalize_path( + ( + None, + str(tmp_path / "str"), + ) ) # table_or_uri as Path - assert try_get_table_and_table_uri(tmp_path / "delta_table", None) == ( - delta_table, - str(tmp_path / "delta_table"), + assert _normalize_path( + try_get_table_and_table_uri(tmp_path / "delta_table", None) + ) == _normalize_path( + ( + delta_table, + str(tmp_path / "delta_table"), + ) ) - assert try_get_table_and_table_uri(tmp_path / "Path", None) == ( - None, - str(tmp_path / "Path"), + assert _normalize_path( + try_get_table_and_table_uri(tmp_path / "Path", None) + ) == _normalize_path( + ( + None, + str(tmp_path / "Path"), + ) ) # table_or_uri with invalid parameter type