Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
Blajda committed Dec 20, 2023
2 parents 51a4d70 + bc9253c commit 9884480
Show file tree
Hide file tree
Showing 23 changed files with 1,348 additions and 251 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ of features outlined in the Delta [protocol][protocol] is also [tracked](#protoc
| Version 2 | Column Invariants | ![done] |
| Version 3 | Enforce `delta.checkpoint.writeStatsAsJson` | [![open]][writer-rs] |
| Version 3 | Enforce `delta.checkpoint.writeStatsAsStruct` | [![open]][writer-rs] |
| Version 3 | CHECK constraints | [![open]][writer-rs] |
| Version 3 | CHECK constraints | [![semi-done]][check-constraints] |
| Version 4 | Change Data Feed | |
| Version 4 | Generated Columns | |
| Version 5 | Column Mapping | |
Expand All @@ -185,5 +185,6 @@ of features outlined in the Delta [protocol][protocol] is also [tracked](#protoc
[merge-py]: https://github.com/delta-io/delta-rs/issues/1357
[merge-rs]: https://github.com/delta-io/delta-rs/issues/850
[writer-rs]: https://github.com/delta-io/delta-rs/issues/851
[check-constraints]: https://github.com/delta-io/delta-rs/issues/1881
[onelake-rs]: https://github.com/delta-io/delta-rs/issues/1418
[protocol]: https://github.com/delta-io/delta/blob/master/PROTOCOL.md
15 changes: 12 additions & 3 deletions crates/deltalake-core/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,10 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> {
mod test {
use arrow_schema::DataType as ArrowDataType;
use datafusion::prelude::SessionContext;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_common::{Column, DFSchema, ScalarValue};
use datafusion_expr::{col, decode, lit, substring, Cast, Expr, ExprSchemable};

use crate::delta_datafusion::DeltaSessionContext;
use crate::kernel::{DataType, PrimitiveType, StructField, StructType};
use crate::{DeltaOps, DeltaTable};

Expand Down Expand Up @@ -388,6 +389,11 @@ mod test {
DataType::Primitive(PrimitiveType::Integer),
true,
),
StructField::new(
"Value3".to_string(),
DataType::Primitive(PrimitiveType::Integer),
true,
),
StructField::new(
"modified".to_string(),
DataType::Primitive(PrimitiveType::String),
Expand Down Expand Up @@ -442,7 +448,10 @@ mod test {
}),
"arrow_cast(1, 'Int32')".to_string()
),
simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()),
simple!(
Expr::Column(Column::from_qualified_name_ignore_case("Value3")).eq(lit(3_i64)),
"Value3 = 3".to_string()
),
simple!(col("active").is_true(), "active IS TRUE".to_string()),
simple!(col("active"), "active".to_string()),
simple!(col("active").eq(lit(true)), "active = true".to_string()),
Expand Down Expand Up @@ -536,7 +545,7 @@ mod test {
),
];

let session = SessionContext::new();
let session: SessionContext = DeltaSessionContext::default().into();

for test in tests {
let actual = fmt_expr_to_sql(&test.expr).unwrap();
Expand Down
38 changes: 35 additions & 3 deletions crates/deltalake-core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
use datafusion_proto::logical_plan::LogicalExtensionCodec;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_sql::planner::ParserOptions;
use futures::TryStreamExt;

use itertools::Itertools;
use log::error;
Expand Down Expand Up @@ -1034,6 +1035,31 @@ pub(crate) fn logical_expr_to_physical_expr(
create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap()
}

pub(crate) async fn execute_plan_to_batch(
state: &SessionState,
plan: Arc<dyn ExecutionPlan>,
) -> DeltaResult<arrow::record_batch::RecordBatch> {
let data =
futures::future::try_join_all((0..plan.output_partitioning().partition_count()).map(|p| {
let plan_copy = plan.clone();
let task_context = state.task_ctx().clone();
async move {
let batch_stream = plan_copy.execute(p, task_context)?;

let schema = batch_stream.schema();

let batches = batch_stream.try_collect::<Vec<_>>().await?;

DataFusionResult::<_>::Ok(arrow::compute::concat_batches(&schema, batches.iter())?)
}
}))
.await?;

let batch = arrow::compute::concat_batches(&plan.schema(), data.iter())?;

Ok(batch)
}

/// Responsible for checking batches of data conform to table's invariants.
#[derive(Clone)]
pub struct DeltaDataChecker {
Expand All @@ -1048,7 +1074,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints: vec![],
ctx: SessionContext::new(),
ctx: DeltaSessionContext::default().into(),
}
}

Expand All @@ -1057,10 +1083,16 @@ impl DeltaDataChecker {
Self {
constraints,
invariants: vec![],
ctx: SessionContext::new(),
ctx: DeltaSessionContext::default().into(),
}
}

/// Specify the Datafusion context
pub fn with_session_context(mut self, context: SessionContext) -> Self {
self.ctx = context;
self
}

/// Create a new DeltaDataChecker
pub fn new(snapshot: &DeltaTableState) -> Self {
let metadata = snapshot.metadata();
Expand All @@ -1074,7 +1106,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints,
ctx: SessionContext::new(),
ctx: DeltaSessionContext::default().into(),
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/deltalake-core/src/kernel/actions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};

pub(crate) mod schemas;
mod serde_path;
pub(crate) mod serde_path;
pub(crate) mod types;

pub use types::*;
Expand Down
2 changes: 1 addition & 1 deletion crates/deltalake-core/src/kernel/actions/serde_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn encode_path(path: &str) -> String {
percent_encode(path.as_bytes(), INVALID).to_string()
}

fn decode_path(path: &str) -> Result<String, Utf8Error> {
pub fn decode_path(path: &str) -> Result<String, Utf8Error> {
Ok(percent_decode_str(path).decode_utf8()?.to_string())
}

Expand Down
90 changes: 89 additions & 1 deletion crates/deltalake-core/src/operations/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn cast_record_batch_columns(
.iter()
.map(|f| {
let col = batch.column_by_name(f.name()).unwrap();

if let (DataType::Struct(_), DataType::Struct(child_fields)) =
(col.data_type(), f.data_type())
{
Expand All @@ -28,7 +29,7 @@ fn cast_record_batch_columns(
child_columns.clone(),
None,
)) as ArrayRef)
} else if !col.data_type().equals_datatype(f.data_type()) {
} else if is_cast_required(col.data_type(), f.data_type()) {
cast_with_options(col, f.data_type(), cast_options)
} else {
Ok(col.clone())
Expand All @@ -37,6 +38,16 @@ fn cast_record_batch_columns(
.collect::<Result<Vec<_>, _>>()
}

fn is_cast_required(a: &DataType, b: &DataType) -> bool {
match (a, b) {
(DataType::List(a_item), DataType::List(b_item)) => {
// If list item name is not the default('item') the list must be casted
!a.equals_datatype(b) || a_item.name() != b_item.name()
}
(_, _) => !a.equals_datatype(b),
}
}

/// Cast recordbatch to a new target_schema, by casting each column array
pub fn cast_record_batch(
batch: &RecordBatch,
Expand All @@ -51,3 +62,80 @@ pub fn cast_record_batch(
let columns = cast_record_batch_columns(batch, target_schema.fields(), &cast_options)?;
Ok(RecordBatch::try_new(target_schema, columns)?)
}

#[cfg(test)]
mod tests {
use crate::operations::cast::{cast_record_batch, is_cast_required};
use arrow::array::ArrayData;
use arrow_array::{Array, ArrayRef, ListArray, RecordBatch};
use arrow_buffer::Buffer;
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
use std::sync::Arc;

#[test]
fn test_cast_record_batch_with_list_non_default_item() {
let array = Arc::new(make_list_array()) as ArrayRef;
let source_schema = Schema::new(vec![Field::new(
"list_column",
array.data_type().clone(),
false,
)]);
let record_batch = RecordBatch::try_new(Arc::new(source_schema), vec![array]).unwrap();

let fields = Fields::from(vec![Field::new_list(
"list_column",
Field::new("item", DataType::Int8, false),
false,
)]);
let target_schema = Arc::new(Schema::new(fields)) as SchemaRef;

let result = cast_record_batch(&record_batch, target_schema, false);

let schema = result.unwrap().schema();
let field = schema.column_with_name("list_column").unwrap().1;
if let DataType::List(list_item) = field.data_type() {
assert_eq!(list_item.name(), "item");
} else {
panic!("Not a list");
}
}

fn make_list_array() -> ListArray {
let value_data = ArrayData::builder(DataType::Int32)
.len(8)
.add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
.build()
.unwrap();

let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);

let list_data_type = DataType::List(Arc::new(Field::new("element", DataType::Int32, true)));
let list_data = ArrayData::builder(list_data_type)
.len(3)
.add_buffer(value_offsets)
.add_child_data(value_data)
.build()
.unwrap();
ListArray::from(list_data)
}

#[test]
fn test_is_cast_required_with_list() {
let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));

assert!(!is_cast_required(&field1, &field2));
}

#[test]
fn test_is_cast_required_with_list_non_default_item() {
let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new(
"element",
DataType::Int32,
false,
)));

assert!(is_cast_required(&field1, &field2));
}
}
Loading

0 comments on commit 9884480

Please sign in to comment.