From 9730d59786886c28a9d98fb8f240602cd65fcc74 Mon Sep 17 00:00:00 2001 From: David Blajda Date: Tue, 13 Jun 2023 22:43:19 -0400 Subject: [PATCH] feat: implement update operation (#1390) # Description Users can now update data that matches a predicate. This operation should be encouraged over the replace write operation since update determines which values require rewriting based on the supplied predicate. # Related Issue(s) - closes #1126 --------- Co-authored-by: Will Jones Co-authored-by: Robert Pack <42610831+roeap@users.noreply.github.com> --- rust/src/action/mod.rs | 10 +- rust/src/delta_datafusion.rs | 382 ++++++++++++- rust/src/operations/delete.rs | 438 ++------------- rust/src/operations/mod.rs | 11 +- rust/src/operations/update.rs | 989 ++++++++++++++++++++++++++++++++++ 5 files changed, 1429 insertions(+), 401 deletions(-) create mode 100644 rust/src/operations/update.rs diff --git a/rust/src/action/mod.rs b/rust/src/action/mod.rs index d2f0004316..19c9b27a97 100644 --- a/rust/src/action/mod.rs +++ b/rust/src/action/mod.rs @@ -594,6 +594,11 @@ pub enum DeltaOperation { /// The condition the to be deleted data must match predicate: Option, }, + /// Update data matching predicate from delta table + Update { + /// The update predicate + predicate: Option, + }, /// Represents a Delta `StreamingUpdate` operation. #[serde(rename_all = "camelCase")] @@ -632,6 +637,7 @@ impl DeltaOperation { DeltaOperation::Create { .. } => "CREATE TABLE", DeltaOperation::Write { .. } => "WRITE", DeltaOperation::Delete { .. } => "DELETE", + DeltaOperation::Update { .. } => "UPDATE", DeltaOperation::StreamingUpdate { .. } => "STREAMING UPDATE", DeltaOperation::Optimize { .. } => "OPTIMIZE", DeltaOperation::FileSystemCheck { .. } => "FSCK", @@ -675,7 +681,8 @@ impl DeltaOperation { | Self::FileSystemCheck {} | Self::StreamingUpdate { .. } | Self::Write { .. } - | Self::Delete { .. } => true, + | Self::Delete { .. } + | Self::Update { .. } => true, } } @@ -695,6 +702,7 @@ impl DeltaOperation { // TODO add more operations Self::Write { predicate, .. } => predicate.clone(), Self::Delete { predicate, .. } => predicate.clone(), + Self::Update { predicate, .. } => predicate.clone(), _ => None, } } diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index b52a330d7d..479269c2cc 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -28,9 +28,12 @@ use std::sync::Arc; use arrow::array::ArrayRef; use arrow::compute::{cast_with_options, CastOptions}; +use arrow::datatypes::DataType; use arrow::datatypes::{DataType as ArrowDataType, Schema as ArrowSchema, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; +use arrow_array::StringArray; +use arrow_schema::Field; use async_trait::async_trait; use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion::datasource::datasource::TableProviderFactory; @@ -43,13 +46,19 @@ use datafusion::optimizer::utils::conjunction; use datafusion::physical_expr::PhysicalSortExpr; use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use datafusion::physical_plan::file_format::FileScanConfig; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::limit::LocalLimitExec; use datafusion::physical_plan::{ ColumnStatistics, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use datafusion_common::scalar::ScalarValue; -use datafusion_common::{Column, DataFusionError, Result as DataFusionResult, ToDFSchema}; +use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; +use datafusion_common::{ + Column, DFSchema, DataFusionError, Result as DataFusionResult, ToDFSchema, +}; +use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; use datafusion_expr::logical_plan::CreateExternalTable; -use datafusion_expr::{Expr, Extension, LogicalPlan, TableProviderFilterPushDown}; +use datafusion_expr::{col, Expr, Extension, LogicalPlan, TableProviderFilterPushDown, Volatility}; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datafusion_proto::logical_plan::LogicalExtensionCodec; @@ -64,6 +73,8 @@ use crate::storage::ObjectStoreRef; use crate::table_state::DeltaTableState; use crate::{open_table, open_table_with_storage_options, DeltaTable, Invariant, SchemaDataType}; +const PATH_COLUMN: &str = "__delta_rs_path"; + impl From for DataFusionError { fn from(err: DeltaTableError) -> Self { match err { @@ -994,6 +1005,373 @@ impl TableProviderFactory for DeltaTableFactory { } } +pub(crate) struct FindFilesExprProperties { + pub partition_columns: Vec, + + pub partition_only: bool, + pub result: DeltaResult<()>, +} + +/// Ensure only expressions that make sense are accepted, check for +/// non-deterministic functions, and determine if the expression only contains +/// partition columns +impl TreeNodeVisitor for FindFilesExprProperties { + type N = Expr; + + fn pre_visit(&mut self, expr: &Self::N) -> datafusion_common::Result { + // TODO: We can likely relax the volatility to STABLE. Would require further + // research to confirm the same value is generated during the scan and + // rewrite phases. + + match expr { + Expr::Column(c) => { + if !self.partition_columns.contains(&c.name) { + self.partition_only = false; + } + } + Expr::ScalarVariable(_, _) + | Expr::Literal(_) + | Expr::Alias(_, _) + | Expr::BinaryExpr(_) + | Expr::Like(_) + | Expr::ILike(_) + | Expr::SimilarTo(_) + | Expr::Not(_) + | Expr::IsNotNull(_) + | Expr::IsNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + | Expr::Negative(_) + | Expr::InList { .. } + | Expr::GetIndexedField(_) + | Expr::Between(_) + | Expr::Case(_) + | Expr::Cast(_) + | Expr::TryCast(_) => (), + Expr::ScalarFunction(ScalarFunction { fun, .. }) => { + let v = fun.volatility(); + if v > Volatility::Immutable { + self.result = Err(DeltaTableError::Generic(format!( + "Find files predicate contains nondeterministic function {}", + fun + ))); + return Ok(VisitRecursion::Stop); + } + } + Expr::ScalarUDF(ScalarUDF { fun, .. }) => { + let v = fun.signature.volatility; + if v > Volatility::Immutable { + self.result = Err(DeltaTableError::Generic(format!( + "Find files predicate contains nondeterministic function {}", + fun.name + ))); + return Ok(VisitRecursion::Stop); + } + } + _ => { + self.result = Err(DeltaTableError::Generic(format!( + "Find files predicate contains unsupported expression {}", + expr + ))); + return Ok(VisitRecursion::Stop); + } + } + + Ok(VisitRecursion::Continue) + } +} + +pub(crate) struct FindFiles { + pub candidates: Vec, + /// Was a physical read to the datastore required to determine the candidates + pub partition_scan: bool, +} + +fn join_batches_with_add_actions( + batches: Vec, + mut actions: HashMap, +) -> DeltaResult> { + // Given RecordBatches that contains `__delta_rs_path` perform a hash join + // with actions to obtain original add actions + + let mut files = Vec::with_capacity(batches.iter().map(|batch| batch.num_rows()).sum()); + for batch in batches { + let array = batch + .column_by_name(PATH_COLUMN) + .ok_or_else(|| { + DeltaTableError::Generic(format!("Unable to find column {}", PATH_COLUMN)) + })? + .as_any() + .downcast_ref::() + .ok_or(DeltaTableError::Generic(format!( + "Unable to downcast column {}", + PATH_COLUMN + )))?; + for path in array { + let path = path.ok_or(DeltaTableError::Generic(format!( + "{} cannot be null", + PATH_COLUMN + )))?; + + match actions.remove(path) { + Some(action) => files.push(action), + None => { + return Err(DeltaTableError::Generic( + "Unable to map __delta_rs_path to action.".to_owned(), + )) + } + } + } + } + Ok(files) +} + +/// Determine which files contain a record that statisfies the predicate +pub(crate) async fn find_files_scan<'a>( + snapshot: &DeltaTableState, + store: ObjectStoreRef, + schema: Arc, + file_schema: Arc, + candidates: Vec<&'a Add>, + state: &SessionState, + expression: &Expr, +) -> DeltaResult> { + let mut candidate_map: HashMap = HashMap::new(); + + let table_partition_cols = snapshot + .current_metadata() + .ok_or(DeltaTableError::NoMetadata)? + .partition_columns + .clone(); + + let mut file_groups: HashMap, Vec> = HashMap::new(); + for action in candidates { + let mut part = partitioned_file_from_action(action, &schema); + part.partition_values + .push(ScalarValue::Utf8(Some(action.path.clone()))); + + file_groups + .entry(part.partition_values.clone()) + .or_default() + .push(part); + + candidate_map.insert(action.path.to_owned(), action.to_owned()); + } + + let mut table_partition_cols = table_partition_cols + .iter() + .map(|c| Ok((c.to_owned(), schema.field_with_name(c)?.data_type().clone()))) + .collect::, ArrowError>>()?; + // Append a column called __delta_rs_path to track the file path + table_partition_cols.push((PATH_COLUMN.to_owned(), DataType::Utf8)); + + let input_schema = snapshot.input_schema()?; + + let mut fields = Vec::new(); + for field in input_schema.fields.iter() { + fields.push(field.to_owned()); + } + fields.push(Arc::new(Field::new( + PATH_COLUMN, + arrow_schema::DataType::Boolean, + true, + ))); + let input_schema = Arc::new(ArrowSchema::new(fields)); + + // Identify which columns we need to project + let mut used_columns = expression + .to_columns()? + .into_iter() + .map(|column| input_schema.index_of(&column.name)) + .collect::, ArrowError>>() + .unwrap(); + // Add path column + used_columns.push(input_schema.index_of(PATH_COLUMN)?); + + // Project the logical schema so column indicies align between the parquet scan and the expression + let mut fields = vec![]; + for idx in &used_columns { + fields.push(input_schema.field(*idx).to_owned()); + } + let input_schema = Arc::new(ArrowSchema::new(fields)); + let input_dfschema = input_schema.as_ref().clone().try_into()?; + + let parquet_scan = ParquetFormat::new() + .create_physical_plan( + state, + FileScanConfig { + object_store_url: store.object_store_url(), + file_schema, + file_groups: file_groups.into_values().collect(), + statistics: snapshot.datafusion_table_statistics(), + projection: Some(used_columns), + limit: None, + table_partition_cols, + infinite_source: false, + output_ordering: vec![], + }, + None, + ) + .await?; + + let predicate_expr = create_physical_expr( + &Expr::IsTrue(Box::new(expression.clone())), + &input_dfschema, + &input_schema, + state.execution_props(), + )?; + + let filter: Arc = + Arc::new(FilterExec::try_new(predicate_expr, parquet_scan.clone())?); + let limit: Arc = Arc::new(LocalLimitExec::new(filter, 1)); + + let task_ctx = Arc::new(TaskContext::from(state)); + let path_batches = datafusion::physical_plan::collect(limit, task_ctx).await?; + + join_batches_with_add_actions(path_batches, candidate_map) +} + +pub(crate) async fn scan_memory_table( + snapshot: &DeltaTableState, + predicate: &Expr, +) -> DeltaResult> { + let actions = snapshot.files().to_owned(); + + let batch = snapshot.add_actions_table(true)?; + let mut arrays = Vec::new(); + let mut fields = Vec::new(); + + let schema = batch.schema(); + + arrays.push( + batch + .column_by_name("path") + .ok_or(DeltaTableError::Generic( + "Column with name `path` does not exist".to_owned(), + ))? + .to_owned(), + ); + fields.push(Field::new(PATH_COLUMN, DataType::Utf8, false)); + + for field in schema.fields() { + if field.name().starts_with("partition.") { + let name = field.name().strip_prefix("partition.").unwrap(); + + arrays.push(batch.column_by_name(field.name()).unwrap().to_owned()); + fields.push(Field::new( + name, + field.data_type().to_owned(), + field.is_nullable(), + )); + } + } + + let schema = Arc::new(ArrowSchema::new(fields)); + let batch = RecordBatch::try_new(schema, arrays)?; + let mem_table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + + let ctx = SessionContext::new(); + let mut df = ctx.read_table(Arc::new(mem_table))?; + df = df + .filter(predicate.to_owned())? + .select(vec![col(PATH_COLUMN)])?; + let batches = df.collect().await?; + + let map = actions + .into_iter() + .map(|action| (action.path.clone(), action)) + .collect::>(); + + join_batches_with_add_actions(batches, map) +} + +pub(crate) async fn find_files<'a>( + snapshot: &DeltaTableState, + object_store: ObjectStoreRef, + schema: Arc, + state: &SessionState, + predicate: Option, +) -> DeltaResult { + let current_metadata = snapshot + .current_metadata() + .ok_or(DeltaTableError::NoMetadata)?; + let table_partition_cols = current_metadata.partition_columns.clone(); + + match &predicate { + Some(predicate) => { + let file_schema = Arc::new(ArrowSchema::new( + schema + .fields() + .iter() + .filter(|f| !table_partition_cols.contains(f.name())) + .cloned() + .collect::>(), + )); + + let input_schema = snapshot.input_schema()?; + let input_dfschema: DFSchema = input_schema.clone().as_ref().clone().try_into()?; + let expr = create_physical_expr( + predicate, + &input_dfschema, + &input_schema, + state.execution_props(), + )?; + + // Validate the Predicate and determine if it only contains partition columns + let mut expr_properties = FindFilesExprProperties { + partition_only: true, + partition_columns: current_metadata.partition_columns.clone(), + result: Ok(()), + }; + + TreeNode::visit(predicate, &mut expr_properties)?; + expr_properties.result?; + + if expr_properties.partition_only { + let candidates = scan_memory_table(snapshot, predicate).await?; + Ok(FindFiles { + candidates, + partition_scan: true, + }) + } else { + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone())?; + let files_to_prune = pruning_predicate.prune(snapshot)?; + let files: Vec<&Add> = snapshot + .files() + .iter() + .zip(files_to_prune.into_iter()) + .filter_map(|(action, keep)| if keep { Some(action) } else { None }) + .collect(); + + // Create a new delta scan plan with only files that have a record + let candidates = find_files_scan( + snapshot, + object_store.clone(), + schema.clone(), + file_schema.clone(), + files, + state, + predicate, + ) + .await?; + + Ok(FindFiles { + candidates, + partition_scan: false, + }) + } + } + None => Ok(FindFiles { + candidates: snapshot.files().to_owned(), + partition_scan: true, + }), + } +} + #[cfg(test)] mod tests { use arrow::array::StructArray; diff --git a/rust/src/operations/delete.rs b/rust/src/operations/delete.rs index 92814ecac5..80be16e088 100644 --- a/rust/src/operations/delete.rs +++ b/rust/src/operations/delete.rs @@ -17,45 +17,24 @@ //! .await?; //! ```` -use std::collections::HashMap; -use std::pin::Pin; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; use crate::action::{Action, Add, Remove}; -use arrow::array::StringArray; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use arrow::datatypes::Schema as ArrowSchema; -use arrow::error::ArrowError; -use arrow::record_batch::RecordBatch; -use datafusion::datasource::file_format::{parquet::ParquetFormat, FileFormat}; -use datafusion::datasource::listing::PartitionedFile; -use datafusion::datasource::MemTable; -use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; +use datafusion::execution::context::{SessionContext, SessionState}; use datafusion::physical_expr::create_physical_expr; -use datafusion::physical_optimizer::pruning::PruningPredicate; -use datafusion::physical_plan::file_format::FileScanConfig; use datafusion::physical_plan::filter::FilterExec; -use datafusion::physical_plan::limit::LocalLimitExec; use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::RecordBatchStream; use datafusion::prelude::Expr; use datafusion_common::scalar::ScalarValue; -use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; use datafusion_common::DFSchema; -use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; -use datafusion_expr::{col, Volatility}; use futures::future::BoxFuture; -use futures::stream::StreamExt; use parquet::file::properties::WriterProperties; use serde_json::Map; use serde_json::Value; use crate::action::DeltaOperation; -use crate::delta_datafusion::{ - parquet_scan_from_actions, partitioned_file_from_action, register_store, -}; +use crate::delta_datafusion::{find_files, parquet_scan_from_actions, register_store}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::operations::transaction::commit; use crate::operations::write::write_execution_plan; @@ -63,8 +42,6 @@ use crate::storage::{DeltaObjectStore, ObjectStoreRef}; use crate::table_state::DeltaTableState; use crate::DeltaTable; -const PATH_COLUMN: &str = "__delta_rs_path"; - /// Delete Records from the Delta Table. /// See this module's documentaiton for more information pub struct DeleteBuilder { @@ -101,218 +78,6 @@ pub struct DeleteMetrics { pub rewrite_time_ms: u128, } -/// Determine which files contain a record that statisfies the predicate -async fn find_files<'a>( - snapshot: &DeltaTableState, - store: ObjectStoreRef, - schema: Arc, - file_schema: Arc, - candidates: Vec<&'a Add>, - state: &SessionState, - expression: &Expr, -) -> DeltaResult> { - let mut files = Vec::new(); - let mut candidate_map: HashMap = HashMap::new(); - - let table_partition_cols = snapshot - .current_metadata() - .ok_or(DeltaTableError::NoMetadata)? - .partition_columns - .clone(); - - let mut file_groups: HashMap, Vec> = HashMap::new(); - for action in candidates { - let mut part = partitioned_file_from_action(action, &schema); - part.partition_values - .push(ScalarValue::Utf8(Some(action.path.clone()))); - - file_groups - .entry(part.partition_values.clone()) - .or_default() - .push(part); - - candidate_map.insert(action.path.to_owned(), action); - } - - let mut table_partition_cols = table_partition_cols - .iter() - .map(|c| Ok((c.to_owned(), schema.field_with_name(c)?.data_type().clone()))) - .collect::, ArrowError>>()?; - // Append a column called __delta_rs_path to track the file path - table_partition_cols.push((PATH_COLUMN.to_owned(), DataType::Utf8)); - - let input_schema = snapshot.input_schema()?; - let input_dfschema: DFSchema = input_schema.clone().as_ref().clone().try_into()?; - - let predicate_expr = create_physical_expr( - &Expr::IsTrue(Box::new(expression.clone())), - &input_dfschema, - &input_schema, - state.execution_props(), - )?; - - let parquet_scan = ParquetFormat::new() - .create_physical_plan( - state, - FileScanConfig { - object_store_url: store.object_store_url(), - file_schema, - file_groups: file_groups.into_values().collect(), - statistics: snapshot.datafusion_table_statistics(), - projection: None, - limit: None, - table_partition_cols, - infinite_source: false, - output_ordering: vec![], - }, - None, - ) - .await?; - - let filter: Arc = - Arc::new(FilterExec::try_new(predicate_expr, parquet_scan.clone())?); - let limit: Arc = Arc::new(LocalLimitExec::new(filter, 1)); - - let task_ctx = Arc::new(TaskContext::from(state)); - let partitions = limit.output_partitioning().partition_count(); - let mut tasks = Vec::with_capacity(partitions); - - for i in 0..partitions { - let stream = limit.execute(i, task_ctx.clone())?; - tasks.push(handle_stream(stream)); - } - - for res in futures::future::join_all(tasks).await.into_iter() { - let path = res?; - if let Some(path) = path { - match candidate_map.remove(&path) { - Some(action) => files.push(action), - None => { - return Err(DeltaTableError::Generic( - "Unable to map __delta_rs_path to action.".to_owned(), - )) - } - } - } - } - - Ok(files) -} - -async fn handle_stream( - mut stream: Pin>, -) -> Result, DeltaTableError> { - if let Some(maybe_batch) = stream.next().await { - let batch: RecordBatch = maybe_batch?; - if batch.num_rows() > 1 { - return Err(DeltaTableError::Generic( - "Find files returned multiple records for batch".to_owned(), - )); - } - let array = batch - .column_by_name(PATH_COLUMN) - .unwrap() - .as_any() - .downcast_ref::() - .ok_or(DeltaTableError::Generic(format!( - "Unable to downcast column {}", - PATH_COLUMN - )))?; - - let path = array - .into_iter() - .next() - .unwrap() - .ok_or(DeltaTableError::Generic(format!( - "{} cannot be null", - PATH_COLUMN - )))?; - return Ok(Some(path.to_string())); - } - - Ok(None) -} - -struct ExprProperties { - partition_columns: Vec, - - partition_only: bool, - result: DeltaResult<()>, -} - -/// Ensure only expressions that make sense are accepted, check for -/// non-deterministic functions, and determine if the expression only contains -/// partition columns -impl TreeNodeVisitor for ExprProperties { - type N = Expr; - - fn pre_visit(&mut self, expr: &Self::N) -> datafusion_common::Result { - // TODO: We can likely relax the volatility to STABLE. Would require further - // research to confirm the same value is generated during the scan and - // rewrite phases. - - match expr { - Expr::Column(c) => { - if !self.partition_columns.contains(&c.name) { - self.partition_only = false; - } - } - Expr::ScalarVariable(_, _) - | Expr::Literal(_) - | Expr::Alias(_, _) - | Expr::BinaryExpr(_) - | Expr::Like(_) - | Expr::ILike(_) - | Expr::SimilarTo(_) - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) - | Expr::InList { .. } - | Expr::GetIndexedField(_) - | Expr::Between(_) - | Expr::Case(_) - | Expr::Cast(_) - | Expr::TryCast(_) => (), - Expr::ScalarFunction(ScalarFunction { fun, .. }) => { - let v = fun.volatility(); - if v > Volatility::Immutable { - self.result = Err(DeltaTableError::Generic(format!( - "Delete predicate contains nondeterministic function {}", - fun - ))); - return Ok(VisitRecursion::Stop); - } - } - Expr::ScalarUDF(ScalarUDF { fun, .. }) => { - let v = fun.signature.volatility; - if v > Volatility::Immutable { - self.result = Err(DeltaTableError::Generic(format!( - "Delete predicate contains nondeterministic function {}", - fun.name - ))); - return Ok(VisitRecursion::Stop); - } - } - _ => { - self.result = Err(DeltaTableError::Generic(format!( - "Delete predicate contains unsupported expression {}", - expr - ))); - return Ok(VisitRecursion::Stop); - } - } - - Ok(VisitRecursion::Continue) - } -} - impl DeleteBuilder { /// Create a new [`DeleteBuilder`] pub fn new(object_store: ObjectStoreRef, snapshot: DeltaTableState) -> Self { @@ -371,13 +136,12 @@ async fn excute_non_empty_expr( state: &SessionState, expression: &Expr, metrics: &mut DeleteMetrics, + rewrite: &[Add], writer_properties: Option, -) -> DeltaResult<(Vec, Vec)> { +) -> DeltaResult> { // For each identified file perform a parquet scan + filter + limit (1) + count. // If returned count is not zero then append the file to be rewritten and removed from the log. Otherwise do nothing to the file. - let scan_start = Instant::now(); - let schema = snapshot.arrow_schema()?; let input_schema = snapshot.input_schema()?; let input_dfschema: DFSchema = input_schema.clone().as_ref().clone().try_into()?; @@ -387,50 +151,11 @@ async fn excute_non_empty_expr( .ok_or(DeltaTableError::NoMetadata)? .partition_columns .clone(); - let file_schema = Arc::new(ArrowSchema::new( - schema - .fields() - .iter() - .filter(|f| !table_partition_cols.contains(f.name())) - .cloned() - .collect::>(), - )); - let expr = create_physical_expr( - expression, - &input_dfschema, - &input_schema, - state.execution_props(), - )?; - let pruning_predicate = PruningPredicate::try_new(expr, schema.clone())?; - let files_to_prune = pruning_predicate.prune(snapshot)?; - let files: Vec<&Add> = snapshot - .files() - .iter() - .zip(files_to_prune.into_iter()) - .filter_map(|(action, keep)| if keep { Some(action) } else { None }) - .collect(); - - // Create a new delta scan plan with only files that have a record - let rewrite = find_files( - snapshot, - object_store.clone(), - schema.clone(), - file_schema.clone(), - files, - state, - expression, - ) - .await?; - - metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_millis(); - let write_start = Instant::now(); - - let rewrite: Vec = rewrite.into_iter().map(|s| s.to_owned()).collect(); let parquet_scan = parquet_scan_from_actions( snapshot, object_store.clone(), - &rewrite, + rewrite, &schema, None, state, @@ -463,7 +188,6 @@ async fn excute_non_empty_expr( false, ) .await?; - metrics.rewrite_time_ms = Instant::now().duration_since(write_start).as_millis(); let read_records = parquet_scan.metrics().and_then(|m| m.output_rows()); let filter_records = filter.metrics().and_then(|m| m.output_rows()); @@ -472,7 +196,7 @@ async fn excute_non_empty_expr( .zip(filter_records) .map(|(read, filter)| read - filter); - Ok((add_actions, rewrite)) + Ok(add_actions) } async fn execute( @@ -483,60 +207,53 @@ async fn execute( writer_properties: Option, app_metadata: Option>, ) -> DeltaResult<((Vec, i64), DeleteMetrics)> { - let mut metrics = DeleteMetrics::default(); let exec_start = Instant::now(); + let mut metrics = DeleteMetrics::default(); + let schema = snapshot.arrow_schema()?; - let (add_actions, to_delete) = match &predicate { - Some(expr) => { - let current_metadata = snapshot - .current_metadata() - .ok_or(DeltaTableError::NoMetadata)?; - - let mut expr_properties = ExprProperties { - partition_only: true, - partition_columns: current_metadata.partition_columns.clone(), - result: Ok(()), - }; - - TreeNode::visit(expr, &mut expr_properties)?; - expr_properties.result?; - - if expr_properties.partition_only { - // If the expression only refers to partition columns, we can perform - // the deletion just by removing entire files, so there is no need to - // do an scan. - let scan_start = Instant::now(); - let remove = scan_memory_table(snapshot, expr).await?; - metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_micros(); - (Vec::new(), remove) - } else { - excute_non_empty_expr( - snapshot, - object_store.clone(), - &state, - expr, - &mut metrics, - writer_properties, - ) - .await? - } - } - None => (Vec::::new(), snapshot.files().to_owned()), - }; + let scan_start = Instant::now(); + let candidates = find_files( + snapshot, + object_store.clone(), + schema.clone(), + &state, + predicate.clone(), + ) + .await?; + metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_micros(); let predicate = predicate.unwrap_or(Expr::Literal(ScalarValue::Boolean(Some(true)))); + let add = if candidates.partition_scan { + Vec::new() + } else { + let write_start = Instant::now(); + let add = excute_non_empty_expr( + snapshot, + object_store.clone(), + &state, + &predicate, + &mut metrics, + &candidates.candidates, + writer_properties, + ) + .await?; + metrics.rewrite_time_ms = Instant::now().duration_since(write_start).as_millis(); + add + }; + let remove = candidates.candidates; + let deletion_timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_millis() as i64; - let mut actions: Vec = add_actions.into_iter().map(Action::add).collect(); + let mut actions: Vec = add.into_iter().map(Action::add).collect(); let mut version = snapshot.version(); - metrics.num_removed_files = to_delete.len(); + metrics.num_removed_files = remove.len(); metrics.num_added_files = actions.len(); - for action in to_delete { + for action in remove { actions.push(Action::remove(Remove { path: action.path, deletion_timestamp: Some(deletion_timestamp), @@ -568,78 +285,6 @@ async fn execute( Ok(((actions, version), metrics)) } -async fn scan_memory_table(snapshot: &DeltaTableState, predicate: &Expr) -> DeltaResult> { - let actions = snapshot.files().to_owned(); - - let batch = snapshot.add_actions_table(true)?; - let mut arrays = Vec::new(); - let mut fields = Vec::new(); - - let schema = batch.schema(); - - arrays.push( - batch - .column_by_name("path") - .ok_or(DeltaTableError::Generic( - "Column with name `path` does not exist".to_owned(), - ))? - .to_owned(), - ); - fields.push(Field::new(PATH_COLUMN, DataType::Utf8, false)); - - for field in schema.fields() { - if field.name().starts_with("partition.") { - let name = field.name().strip_prefix("partition.").unwrap(); - - arrays.push(batch.column_by_name(field.name()).unwrap().to_owned()); - fields.push(Field::new( - name, - field.data_type().to_owned(), - field.is_nullable(), - )); - } - } - - let schema = Arc::new(ArrowSchema::new(fields)); - let batch = RecordBatch::try_new(schema, arrays)?; - let mem_table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - - let ctx = SessionContext::new(); - let mut df = ctx.read_table(Arc::new(mem_table))?; - df = df - .filter(predicate.to_owned())? - .select(vec![col(PATH_COLUMN)])?; - let batches = df.collect().await?; - - let mut map = HashMap::new(); - for action in actions { - map.insert(action.path.clone(), action); - } - let mut files = Vec::new(); - - for batch in batches { - let array = batch - .column_by_name(PATH_COLUMN) - .unwrap() - .as_any() - .downcast_ref::() - .ok_or(DeltaTableError::Generic(format!( - "Unable to downcast column {}", - PATH_COLUMN - )))?; - for path in array { - let path = path.ok_or(DeltaTableError::Generic(format!( - "{} cannot be null", - PATH_COLUMN - )))?; - let value = map.remove(path).unwrap(); - files.push(value); - } - } - - Ok(files) -} - impl std::future::IntoFuture for DeleteBuilder { type Output = DeltaResult<(DeltaTable, DeleteMetrics)>; type IntoFuture = BoxFuture<'static, Self::Output>; @@ -742,8 +387,7 @@ mod tests { assert_eq!(metrics.num_deleted_rows, None); assert_eq!(metrics.num_copied_rows, None); - // Scan and rewrite is not required - assert_eq!(metrics.scan_time_ms, 0); + // rewrite is not required assert_eq!(metrics.rewrite_time_ms, 0); // Deletes with no changes to state must not commit diff --git a/rust/src/operations/mod.rs b/rust/src/operations/mod.rs index 786f455344..6bf1584708 100644 --- a/rust/src/operations/mod.rs +++ b/rust/src/operations/mod.rs @@ -22,7 +22,7 @@ pub mod transaction; pub mod vacuum; #[cfg(feature = "datafusion")] -use self::{delete::DeleteBuilder, load::LoadBuilder, write::WriteBuilder}; +use self::{delete::DeleteBuilder, load::LoadBuilder, update::UpdateBuilder, write::WriteBuilder}; #[cfg(feature = "datafusion")] use arrow::record_batch::RecordBatch; #[cfg(feature = "datafusion")] @@ -35,6 +35,8 @@ pub mod delete; #[cfg(feature = "datafusion")] mod load; #[cfg(feature = "datafusion")] +pub mod update; +#[cfg(feature = "datafusion")] pub mod write; #[cfg(all(feature = "arrow", feature = "parquet"))] pub mod writer; @@ -140,6 +142,13 @@ impl DeltaOps { pub fn delete(self) -> DeleteBuilder { DeleteBuilder::new(self.0.object_store(), self.0.state) } + + /// Update data from Delta table + #[cfg(feature = "datafusion")] + #[must_use] + pub fn update(self) -> UpdateBuilder { + UpdateBuilder::new(self.0.object_store(), self.0.state) + } } impl From for DeltaOps { diff --git a/rust/src/operations/update.rs b/rust/src/operations/update.rs new file mode 100644 index 0000000000..34dbf27113 --- /dev/null +++ b/rust/src/operations/update.rs @@ -0,0 +1,989 @@ +//! Update records from a Delta Table for records statisfy a predicate +//! +//! When a predicate is not provided then all records are updated from the Delta +//! Table. Otherwise a scan of the Delta table is performed to mark any files +//! that contain records that satisfy the predicate. Once they are determined +//! then column values are updated with new values provided by the user +//! +//! +//! Predicates MUST be deterministic otherwise undefined behaviour may occur during the +//! scanning and rewriting phase. +//! +//! # Example +//! ```rust ignore +//! let table = open_table("../path/to/table")?; +//! let (table, metrics) = UpdateBuilder::new(table.object_store(), table.state) +//! .with_predicate(col("col1").eq(lit(1))) +//! .with_update("value", col("value") + lit(20)) +//! .await?; +//! ```` + +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::{Instant, SystemTime, UNIX_EPOCH}, +}; + +use arrow::datatypes::Schema as ArrowSchema; +use arrow_array::RecordBatch; +use arrow_schema::{Field, SchemaRef}; +use datafusion::{ + execution::context::SessionState, + physical_plan::{ + metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + projection::ProjectionExec, + ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, + }, + prelude::SessionContext, +}; +use datafusion_common::Result as DataFusionResult; +use datafusion_common::{Column, DFSchema, ScalarValue}; +use datafusion_expr::{case, col, lit, when, Expr}; +use datafusion_physical_expr::{ + create_physical_expr, + expressions::{self}, + PhysicalExpr, +}; +use futures::{future::BoxFuture, Stream, StreamExt}; +use parquet::file::properties::WriterProperties; +use serde_json::{Map, Value}; + +use crate::{ + action::{Action, DeltaOperation, Remove}, + delta_datafusion::{find_files, parquet_scan_from_actions, register_store}, + storage::{DeltaObjectStore, ObjectStoreRef}, + table_state::DeltaTableState, + DeltaResult, DeltaTable, DeltaTableError, +}; + +use super::{transaction::commit, write::write_execution_plan}; + +/// Used to represent user input of either a Datafusion expression or string expression +pub enum Expression { + /// Datafusion Expression + DataFusion(Expr), + /// String Expression + String(String), +} + +impl From for Expression { + fn from(val: Expr) -> Self { + Expression::DataFusion(val) + } +} + +impl From<&str> for Expression { + fn from(val: &str) -> Self { + Expression::String(val.to_string()) + } +} +impl From for Expression { + fn from(val: String) -> Self { + Expression::String(val) + } +} + +/// Updates records in the Delta Table. +/// See this module's documentation for more information +pub struct UpdateBuilder { + /// Which records to update + predicate: Option, + /// How to update columns in a record that match the predicate + updates: HashMap, + /// A snapshot of the table's state + snapshot: DeltaTableState, + /// Delta object store for handling data files + object_store: Arc, + /// Datafusion session state relevant for executing the input plan + state: Option, + /// Properties passed to underlying parquet writer for when files are rewritten + writer_properties: Option, + /// Additional metadata to be added to commit + app_metadata: Option>, + /// safe_cast determines how data types that do not match the underlying table are handled + /// By default an error is returned + safe_cast: bool, +} + +#[derive(Default)] +/// Metrics collected during the Update operation +pub struct UpdateMetrics { + /// Number of files added. + pub num_added_files: usize, + /// Number of files removed. + pub num_removed_files: usize, + /// Number of rows updated. + pub num_updated_rows: usize, + /// Number of rows just copied over in the process of updating files. + pub num_copied_rows: usize, + /// Time taken to execute the entire operation. + pub execution_time_ms: u64, + /// Time taken to scan the files for matches. + pub scan_time_ms: u64, +} + +impl UpdateBuilder { + /// Create a new ['UpdateBuilder'] + pub fn new(object_store: ObjectStoreRef, snapshot: DeltaTableState) -> Self { + Self { + predicate: None, + updates: HashMap::new(), + snapshot, + object_store, + state: None, + writer_properties: None, + app_metadata: None, + safe_cast: false, + } + } + + /// Which records to update + pub fn with_predicate>(mut self, predicate: E) -> Self { + self.predicate = Some(predicate.into()); + self + } + + /// Perform an additonal update expression during the operaton + pub fn with_update, E: Into>( + mut self, + column: S, + expression: E, + ) -> Self { + self.updates.insert(column.into(), expression.into()); + self + } + + /// The Datafusion session state to use + pub fn with_session_state(mut self, state: SessionState) -> Self { + self.state = Some(state); + self + } + + /// Additional metadata to be added to commit info + pub fn with_metadata( + mut self, + metadata: impl IntoIterator, + ) -> Self { + self.app_metadata = Some(Map::from_iter(metadata)); + self + } + + /// Writer properties passed to parquet writer for when fiiles are rewritten + pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self { + self.writer_properties = Some(writer_properties); + self + } + + /// Specify the cast options to use when casting columns that do not match + /// the table's schema. When `cast_options.safe` is set true then any + /// failures to cast a datatype will use null instead of returning an error + /// to the user. + /// + /// Example (column's type is int): + /// Input Output + /// 123 -> 123 + /// Test123 -> null + pub fn with_safe_cast(mut self, safe_cast: bool) -> Self { + self.safe_cast = safe_cast; + self + } +} + +#[allow(clippy::too_many_arguments)] +async fn execute( + predicate: Option, + updates: HashMap, + object_store: ObjectStoreRef, + snapshot: &DeltaTableState, + state: SessionState, + writer_properties: Option, + app_metadata: Option>, + safe_cast: bool, +) -> DeltaResult<((Vec, i64), UpdateMetrics)> { + // Validate the predicate and update expressions. + // + // If the predicate is not set, then all files need to be updated. + // If it only contains partition columns then perform in memory-scan. + // Otherwise, scan files for records that satisfy the predicate. + // + // For files that were identified, scan for records that match the predicate, + // perform update operations, and then commit add and remove actions to + // the log. + + let exec_start = Instant::now(); + let mut metrics = UpdateMetrics::default(); + let mut version = snapshot.version(); + + if updates.is_empty() { + return Ok(((Vec::new(), version), metrics)); + } + + let predicate = match predicate { + Some(predicate) => match predicate { + Expression::DataFusion(expr) => Some(expr), + Expression::String(s) => Some(snapshot.parse_predicate_expression(s)?), + }, + None => None, + }; + + let updates: HashMap = updates + .into_iter() + .map(|(key, expr)| match expr { + Expression::DataFusion(e) => Ok((key, e)), + Expression::String(s) => snapshot.parse_predicate_expression(s).map(|e| (key, e)), + }) + .collect::, _>>()?; + + let current_metadata = snapshot + .current_metadata() + .ok_or(DeltaTableError::NoMetadata)?; + let table_partition_cols = current_metadata.partition_columns.clone(); + let schema = snapshot.arrow_schema()?; + + let scan_start = Instant::now(); + let candidates = find_files( + snapshot, + object_store.clone(), + schema.clone(), + &state, + predicate.clone(), + ) + .await?; + metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_millis() as u64; + + if candidates.candidates.is_empty() { + return Ok(((Vec::new(), version), metrics)); + } + + let predicate = predicate.unwrap_or(Expr::Literal(ScalarValue::Boolean(Some(true)))); + + let execution_props = state.execution_props(); + // For each rewrite evaluate the predicate and then modify each expression + // to either compute the new value or obtain the old one then write these batches + let parquet_scan = parquet_scan_from_actions( + snapshot, + object_store.clone(), + &candidates.candidates, + &schema, + None, + &state, + None, + None, + ) + .await?; + + // Create a projection for a new column with the predicate evaluated + let input_schema = snapshot.input_schema()?; + + let mut fields = Vec::new(); + for field in input_schema.fields.iter() { + fields.push(field.to_owned()); + } + fields.push(Arc::new(Field::new( + "__delta_rs_update_predicate", + arrow_schema::DataType::Boolean, + true, + ))); + // Recreate the schemas with the new column included + let input_schema = Arc::new(ArrowSchema::new(fields)); + let input_dfschema: DFSchema = input_schema.as_ref().clone().try_into()?; + + let mut expressions: Vec<(Arc, String)> = Vec::new(); + let scan_schema = parquet_scan.schema(); + for (i, field) in scan_schema.fields().into_iter().enumerate() { + expressions.push(( + Arc::new(expressions::Column::new(field.name(), i)), + field.name().to_owned(), + )); + } + + // Take advantage of how null counts are tracked in arrow arrays use the + // null count to track how many records do NOT statisfy the predicate. The + // count is then exposed through the metrics through the `UpdateCountExec` + // execution plan + + let predicate_null = + when(predicate.clone(), lit(true)).otherwise(lit(ScalarValue::Boolean(None)))?; + let predicate_expr = create_physical_expr( + &predicate_null, + &input_dfschema, + &input_schema, + execution_props, + )?; + expressions.push((predicate_expr, "__delta_rs_update_predicate".to_string())); + + let projection_predicate: Arc = + Arc::new(ProjectionExec::try_new(expressions, parquet_scan)?); + + let count_plan = Arc::new(UpdateCountExec::new(projection_predicate.clone())); + + // Perform another projection but instead calculate updated values based on + // the predicate value. If the predicate is true then evalute the user + // provided expression otherwise return the original column value + // + // For each update column a new column with a name of __delta_rs_ + `original name` is created + let mut expressions: Vec<(Arc, String)> = Vec::new(); + let scan_schema = count_plan.schema(); + for (i, field) in scan_schema.fields().into_iter().enumerate() { + expressions.push(( + Arc::new(expressions::Column::new(field.name(), i)), + field.name().to_owned(), + )); + } + + // Maintain a map from the original column name to its temporary column index + let mut map = HashMap::::new(); + let mut control_columns = HashSet::::new(); + control_columns.insert("__delta_rs_update_predicate".to_owned()); + + for (column, expr) in updates { + let expr = case(col("__delta_rs_update_predicate")) + .when(lit(true), expr.to_owned()) + .otherwise(col(column.to_owned()))?; + let predicate_expr = + create_physical_expr(&expr, &input_dfschema, &input_schema, execution_props)?; + map.insert(column.name.clone(), expressions.len()); + let c = "__delta_rs_".to_string() + &column.name; + expressions.push((predicate_expr, c.clone())); + control_columns.insert(c); + } + + let projection_update: Arc = + Arc::new(ProjectionExec::try_new(expressions, count_plan.clone())?); + + // Project again to remove __delta_rs columns and rename update columns to their original name + let mut expressions: Vec<(Arc, String)> = Vec::new(); + let scan_schema = projection_update.schema(); + for (i, field) in scan_schema.fields().into_iter().enumerate() { + if !control_columns.contains(field.name()) { + match map.get(field.name()) { + Some(value) => { + expressions.push(( + Arc::new(expressions::Column::new(field.name(), *value)), + field.name().to_owned(), + )); + } + None => { + expressions.push(( + Arc::new(expressions::Column::new(field.name(), i)), + field.name().to_owned(), + )); + } + } + } + } + + let projection: Arc = Arc::new(ProjectionExec::try_new( + expressions, + projection_update.clone(), + )?); + + let add_actions = write_execution_plan( + snapshot, + state.clone(), + projection.clone(), + table_partition_cols.clone(), + object_store.clone(), + Some(snapshot.table_config().target_file_size() as usize), + None, + writer_properties, + safe_cast, + ) + .await?; + + let count_metrics = count_plan.metrics().unwrap(); + + metrics.num_updated_rows = count_metrics + .sum_by_name("num_updated_rows") + .map(|m| m.as_usize()) + .unwrap_or(0); + + metrics.num_copied_rows = count_metrics + .sum_by_name("num_copied_rows") + .map(|m| m.as_usize()) + .unwrap_or(0); + + let deletion_timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64; + let mut actions: Vec = add_actions.into_iter().map(Action::add).collect(); + + metrics.num_added_files = actions.len(); + metrics.num_removed_files = candidates.candidates.len(); + + for action in candidates.candidates { + actions.push(Action::remove(Remove { + path: action.path, + deletion_timestamp: Some(deletion_timestamp), + data_change: true, + extended_file_metadata: Some(true), + partition_values: Some(action.partition_values), + size: Some(action.size), + tags: None, + })) + } + + metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_millis() as u64; + + let operation = DeltaOperation::Update { + predicate: Some(predicate.canonical_name()), + }; + version = commit( + object_store.as_ref(), + &actions, + operation, + snapshot, + app_metadata, + ) + .await?; + + Ok(((actions, version), metrics)) +} + +impl std::future::IntoFuture for UpdateBuilder { + type Output = DeltaResult<(DeltaTable, UpdateMetrics)>; + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + let mut this = self; + + Box::pin(async move { + let state = this.state.unwrap_or_else(|| { + let session = SessionContext::new(); + + // If a user provides their own their DF state then they must register the store themselves + register_store(this.object_store.clone(), session.runtime_env()); + + session.state() + }); + + let ((actions, version), metrics) = execute( + this.predicate, + this.updates, + this.object_store.clone(), + &this.snapshot, + state, + this.writer_properties, + this.app_metadata, + this.safe_cast, + ) + .await?; + + this.snapshot + .merge(DeltaTableState::from_actions(actions, version)?, true, true); + let table = DeltaTable::new_with_state(this.object_store, this.snapshot); + + Ok((table, metrics)) + }) + } +} + +#[derive(Debug)] +struct UpdateCountExec { + parent: Arc, + metrics: ExecutionPlanMetricsSet, +} + +impl UpdateCountExec { + pub fn new(parent: Arc) -> Self { + UpdateCountExec { + parent, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl ExecutionPlan for UpdateCountExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + self.parent.schema() + } + + fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { + self.parent.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { + self.parent.output_ordering() + } + + fn children(&self) -> Vec> { + vec![self.parent.clone()] + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> datafusion_common::Result { + let res = self.parent.execute(partition, context)?; + Ok(Box::pin(UpdateCountStream { + schema: self.schema(), + input: res, + metrics: self.metrics.clone(), + })) + } + + fn statistics(&self) -> datafusion_common::Statistics { + self.parent.statistics() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + ExecutionPlan::with_new_children(self.parent.clone(), children) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +struct UpdateCountStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + metrics: ExecutionPlanMetricsSet, +} + +impl Stream for UpdateCountStream { + type Item = DataFusionResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => { + let array = batch.column_by_name("__delta_rs_update_predicate").unwrap(); + let copied_rows = array.null_count(); + let num_updated = array.len() - copied_rows; + let c1 = MetricBuilder::new(&self.metrics).global_counter("num_updated_rows"); + c1.add(num_updated); + + let c2 = MetricBuilder::new(&self.metrics).global_counter("num_copied_rows"); + c2.add(copied_rows); + Some(Ok(batch)) + } + other => other, + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.input.size_hint() + } +} + +impl RecordBatchStream for UpdateCountStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests { + + use crate::operations::DeltaOps; + use crate::writer::test_utils::datafusion::get_data; + use crate::writer::test_utils::{get_arrow_schema, get_delta_schema}; + use crate::DeltaTable; + use crate::{action::*, DeltaResult}; + use arrow::datatypes::{Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow_array::Int32Array; + use datafusion::assert_batches_sorted_eq; + use datafusion::from_slice::FromSlice; + use datafusion::prelude::*; + use std::sync::Arc; + + async fn setup_table(partitions: Option>) -> DeltaTable { + let table_schema = get_delta_schema(); + + let table = DeltaOps::new_in_memory() + .create() + .with_columns(table_schema.get_fields().clone()) + .with_partition_columns(partitions.unwrap_or_default()) + .await + .unwrap(); + assert_eq!(table.version(), 0); + table + } + + async fn write_batch(table: DeltaTable, batch: RecordBatch) -> DeltaResult { + DeltaOps(table) + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::Append) + .await + } + + async fn prepare_values_table() -> DeltaTable { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + arrow::datatypes::DataType::Int32, + true, + )])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![ + Some(0), + None, + Some(2), + None, + Some(4), + ]))], + ) + .unwrap(); + + DeltaOps::new_in_memory().write(vec![batch]).await.unwrap() + } + + #[tokio::test] + async fn test_update_no_predicate() { + let schema = get_arrow_schema(&None); + let table = setup_table(None).await; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from_slice(["A", "B", "A", "A"])), + Arc::new(arrow::array::Int32Array::from_slice([1, 10, 10, 100])), + Arc::new(arrow::array::StringArray::from_slice([ + "2021-02-02", + "2021-02-02", + "2021-02-02", + "2021-02-02", + ])), + ], + ) + .unwrap(); + + let table = write_batch(table, batch).await.unwrap(); + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 1); + + let (table, metrics) = DeltaOps(table) + .update() + .with_update("modified", lit("2023-05-14")) + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert_eq!(table.get_file_uris().count(), 1); + assert_eq!(metrics.num_added_files, 1); + assert_eq!(metrics.num_removed_files, 1); + assert_eq!(metrics.num_updated_rows, 4); + assert_eq!(metrics.num_copied_rows, 0); + + let expected = vec![ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| A | 1 | 2023-05-14 |", + "| A | 10 | 2023-05-14 |", + "| A | 100 | 2023-05-14 |", + "| B | 10 | 2023-05-14 |", + "+----+-------+------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + + #[tokio::test] + async fn test_update_non_partition() { + let schema = get_arrow_schema(&None); + let table = setup_table(None).await; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from_slice(["A", "B", "A", "A"])), + Arc::new(arrow::array::Int32Array::from_slice([1, 10, 10, 100])), + Arc::new(arrow::array::StringArray::from_slice([ + "2021-02-02", + "2021-02-02", + "2021-02-03", + "2021-02-03", + ])), + ], + ) + .unwrap(); + + // Update a partitioned table where the predicate contains only partition column + // The expectation is that a physical scan of data is not required + + let table = write_batch(table, batch).await.unwrap(); + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 1); + + let (table, metrics) = DeltaOps(table) + .update() + .with_predicate(col("modified").eq(lit("2021-02-03"))) + .with_update("modified", lit("2023-05-14")) + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert_eq!(table.get_file_uris().count(), 1); + assert_eq!(metrics.num_added_files, 1); + assert_eq!(metrics.num_removed_files, 1); + assert_eq!(metrics.num_updated_rows, 2); + assert_eq!(metrics.num_copied_rows, 2); + + let expected = vec![ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| A | 1 | 2021-02-02 |", + "| A | 10 | 2023-05-14 |", + "| A | 100 | 2023-05-14 |", + "| B | 10 | 2021-02-02 |", + "+----+-------+------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + + #[tokio::test] + async fn test_update_partitions() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from_slice(["A", "B", "A", "A"])), + Arc::new(arrow::array::Int32Array::from_slice([1, 10, 10, 100])), + Arc::new(arrow::array::StringArray::from_slice([ + "2021-02-02", + "2021-02-02", + "2021-02-03", + "2021-02-03", + ])), + ], + ) + .unwrap(); + + let table = write_batch(table, batch.clone()).await.unwrap(); + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 2); + + let (table, metrics) = DeltaOps(table) + .update() + .with_predicate(col("modified").eq(lit("2021-02-03"))) + .with_update("modified", lit("2023-05-14")) + .with_update("id", lit("C")) + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert_eq!(table.get_file_uris().count(), 2); + assert_eq!(metrics.num_added_files, 1); + assert_eq!(metrics.num_removed_files, 1); + assert_eq!(metrics.num_updated_rows, 2); + assert_eq!(metrics.num_copied_rows, 0); + + let expected = vec![ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| A | 1 | 2021-02-02 |", + "| C | 10 | 2023-05-14 |", + "| C | 100 | 2023-05-14 |", + "| B | 10 | 2021-02-02 |", + "+----+-------+------------+", + ]; + + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + + // Update a partitioned table where the predicate contains a partition column and non-partition column + let table = setup_table(Some(vec!["modified"])).await; + let table = write_batch(table, batch).await.unwrap(); + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 2); + + let (table, metrics) = DeltaOps(table) + .update() + .with_predicate( + col("modified") + .eq(lit("2021-02-03")) + .and(col("value").eq(lit(100))), + ) + .with_update("modified", lit("2023-05-14")) + .with_update("id", lit("C")) + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert_eq!(table.get_file_uris().count(), 3); + assert_eq!(metrics.num_added_files, 2); + assert_eq!(metrics.num_removed_files, 1); + assert_eq!(metrics.num_updated_rows, 1); + assert_eq!(metrics.num_copied_rows, 1); + + let expected = vec![ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| A | 1 | 2021-02-02 |", + "| A | 10 | 2021-02-03 |", + "| B | 10 | 2021-02-02 |", + "| C | 100 | 2023-05-14 |", + "+----+-------+------------+", + ]; + + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + + #[tokio::test] + async fn test_update_null() { + let table = prepare_values_table().await; + assert_eq!(table.version(), 0); + assert_eq!(table.get_file_uris().count(), 1); + + let (table, metrics) = DeltaOps(table) + .update() + .with_update("value", col("value") + lit(1)) + .await + .unwrap(); + + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 1); + assert_eq!(metrics.num_added_files, 1); + assert_eq!(metrics.num_removed_files, 1); + assert_eq!(metrics.num_updated_rows, 5); + assert_eq!(metrics.num_copied_rows, 0); + + let expected = [ + "+-------+", + "| value |", + "+-------+", + "| |", + "| |", + "| 1 |", + "| 3 |", + "| 5 |", + "+-------+", + ]; + + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + + // Validate order operators do not include nulls + let table = prepare_values_table().await; + let (table, metrics) = DeltaOps(table) + .update() + .with_predicate(col("value").gt(lit(2)).or(col("value").lt(lit(2)))) + .with_update("value", lit(10)) + .await + .unwrap(); + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 1); + assert_eq!(metrics.num_added_files, 1); + assert_eq!(metrics.num_removed_files, 1); + assert_eq!(metrics.num_updated_rows, 2); + assert_eq!(metrics.num_copied_rows, 3); + + let expected = [ + "+-------+", + "| value |", + "+-------+", + "| |", + "| |", + "| 2 |", + "| 10 |", + "| 10 |", + "+-------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + + let table = prepare_values_table().await; + let (table, metrics) = DeltaOps(table) + .update() + .with_predicate("value is null") + .with_update("value", "10") + .await + .unwrap(); + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 1); + assert_eq!(metrics.num_added_files, 1); + assert_eq!(metrics.num_removed_files, 1); + assert_eq!(metrics.num_updated_rows, 2); + assert_eq!(metrics.num_copied_rows, 3); + + let expected = [ + "+-------+", + "| value |", + "+-------+", + "| 10 |", + "| 10 |", + "| 0 |", + "| 2 |", + "| 4 |", + "+-------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + + #[tokio::test] + async fn test_no_updates() { + // No Update operations are provided + let table = prepare_values_table().await; + let (table, metrics) = DeltaOps(table).update().await.unwrap(); + + assert_eq!(table.version(), 0); + assert_eq!(metrics.num_added_files, 0); + assert_eq!(metrics.num_removed_files, 0); + assert_eq!(metrics.num_copied_rows, 0); + assert_eq!(metrics.num_removed_files, 0); + assert_eq!(metrics.scan_time_ms, 0); + assert_eq!(metrics.execution_time_ms, 0); + + // The predicate does not match any records + let (table, metrics) = DeltaOps(table) + .update() + .with_predicate(col("value").eq(lit(3))) + .with_update("value", lit(10)) + .await + .unwrap(); + + assert_eq!(table.version(), 0); + assert_eq!(metrics.num_added_files, 0); + assert_eq!(metrics.num_removed_files, 0); + assert_eq!(metrics.num_copied_rows, 0); + assert_eq!(metrics.num_removed_files, 0); + } + + #[tokio::test] + async fn test_expected_failures() { + // The predicate must be deterministic and expression must be valid + + let table = setup_table(None).await; + + let res = DeltaOps(table) + .update() + .with_predicate(col("value").eq(cast( + random() * lit(20.0), + arrow::datatypes::DataType::Int32, + ))) + .with_update("value", col("value") + lit(20)) + .await; + assert!(res.is_err()); + + // Expression result types must match the table's schema + let table = prepare_values_table().await; + let res = DeltaOps(table) + .update() + .with_update("value", lit("a string")) + .await; + assert!(res.is_err()); + } +}