diff --git a/Cargo.toml b/Cargo.toml index 8f007e02f2..9d3679059b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,17 +57,17 @@ lance-test-macros = { version = "=0.10.15", path = "./rust/lance-test-macros" } lance-testing = { version = "=0.10.15", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow -arrow = { version = "50.0.0", optional = false, features = ["prettyprint"] } -arrow-arith = "50.0" -arrow-array = "50.0" -arrow-buffer = "50.0" -arrow-cast = "50.0" -arrow-data = "50.0" -arrow-ipc = { version = "50.0", features = ["zstd"] } -arrow-ord = "50.0" -arrow-row = "50.0" -arrow-schema = "50.0" -arrow-select = "50.0" +arrow = { version = "51.0.0", optional = false, features = ["prettyprint"] } +arrow-arith = "51.0" +arrow-array = "51.0" +arrow-buffer = "51.0" +arrow-cast = "51.0" +arrow-data = "51.0" +arrow-ipc = { version = "51.0", features = ["zstd"] } +arrow-ord = "51.0" +arrow-row = "51.0" +arrow-schema = "51.0" +arrow-select = "51.0" async-recursion = "1.0" async-trait = "0.1" aws-config = "0.56" @@ -85,14 +85,18 @@ chrono = { version = "0.4.25", default-features = false, features = [ "now", ] } criterion = { version = "0.5", features = ["async", "async_tokio"] } -datafusion = { version = "36.0.0", default-features = false, features = [ +datafusion = { version = "37.1", default-features = false, features = [ + "array_expressions", + "regex_expressions", +] } +datafusion-common = "37.1" +datafusion-functions = { version = "37.1", features = ["regex_expressions"] } +datafusion-sql = "37.1" +datafusion-expr = "37.1" +datafusion-execution = "37.1" +datafusion-physical-expr = { version = "37.1", features = [ "regex_expressions", ] } -datafusion-common = "36.0" -datafusion-sql = "36.0" -datafusion-expr = "36.0" -datafusion-execution = "36.0" -datafusion-physical-expr = "36.0" either = "1.0" futures = "0.3" http = "0.2.9" diff --git a/python/Cargo.toml b/python/Cargo.toml index 6bf318b766..85e3691d4f 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -12,10 +12,10 @@ name = "lance" crate-type = ["cdylib"] [dependencies] -arrow = { version = "50.0.0", features = ["pyarrow"] } -arrow-array = "50.0" -arrow-data = "50.0" -arrow-schema = "50.0" +arrow = { version = "51.0.0", features = ["pyarrow"] } +arrow-array = "51.0" +arrow-data = "51.0" +arrow-schema = "51.0" object_store = "0.9.0" async-trait = "0.1" chrono = "0.4.31" diff --git a/rust/lance-datafusion/Cargo.toml b/rust/lance-datafusion/Cargo.toml index 6a8f7bd05b..bb2e0bcf7a 100644 --- a/rust/lance-datafusion/Cargo.toml +++ b/rust/lance-datafusion/Cargo.toml @@ -17,8 +17,9 @@ arrow-ord.workspace = true async-trait.workspace = true datafusion.workspace = true datafusion-common.workspace = true +datafusion-functions.workspace = true datafusion-physical-expr.workspace = true -datafusion-substrait = { version = "36.0", optional = true } +datafusion-substrait = { version = "37.1", optional = true } futures.workspace = true lance-arrow.workspace = true lance-core = { workspace = true, features = ["datafusion"] } diff --git a/rust/lance-datafusion/src/exec.rs b/rust/lance-datafusion/src/exec.rs index dd6dd8c791..dbed261761 100644 --- a/rust/lance-datafusion/src/exec.rs +++ b/rust/lance-datafusion/src/exec.rs @@ -17,12 +17,12 @@ use datafusion::{ TaskContext, }, physical_plan::{ - streaming::PartitionStream, DisplayAs, DisplayFormatType, ExecutionPlan, + streaming::PartitionStream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, }, }; -use datafusion_common::DataFusionError; -use datafusion_physical_expr::Partitioning; +use datafusion_common::{DataFusionError, Statistics}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use lance_arrow::SchemaExt; use lance_core::Result; @@ -32,11 +32,15 @@ use log::{info, warn}; /// /// It can only be used once, and will return the stream. After that the node /// is exhuasted. +/// +/// Note: the stream should be finite, otherwise we will report datafusion properties +/// incorrectly. pub struct OneShotExec { stream: Mutex>, // We save off a copy of the schema to speed up formatting and so ExecutionPlan::schema & display_as // can still function after exhuasted schema: Arc, + properties: PlanProperties, } impl OneShotExec { @@ -45,7 +49,12 @@ impl OneShotExec { let schema = stream.schema().clone(); Self { stream: Mutex::new(Some(stream)), - schema, + schema: schema.clone(), + properties: PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::RoundRobinBatch(1), + datafusion::physical_plan::ExecutionMode::Bounded, + ), } } } @@ -96,14 +105,6 @@ impl ExecutionPlan for OneShotExec { self.schema.clone() } - fn output_partitioning(&self) -> datafusion_physical_expr::Partitioning { - Partitioning::RoundRobinBatch(1) - } - - fn output_ordering(&self) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { - None - } - fn children(&self) -> Vec> { vec![] } @@ -135,7 +136,11 @@ impl ExecutionPlan for OneShotExec { } fn statistics(&self) -> datafusion_common::Result { - todo!() + Ok(Statistics::new_unknown(&self.schema)) + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + &self.properties } } @@ -194,7 +199,7 @@ pub fn execute_plan( let session_state = SessionState::new_with_config_rt(session_config, runtime_env); // NOTE: we are only executing the first partition here. Therefore, if // the plan has more than one partition, we will be missing data. - assert_eq!(plan.output_partitioning().partition_count(), 1); + assert_eq!(plan.properties().partitioning.partition_count(), 1); Ok(plan.execute(0, session_state.task_ctx())?) } diff --git a/rust/lance-datafusion/src/expr.rs b/rust/lance-datafusion/src/expr.rs index 4cb059438e..1a07bfb64e 100644 --- a/rust/lance-datafusion/src/expr.rs +++ b/rust/lance-datafusion/src/expr.rs @@ -508,7 +508,7 @@ pub async fn parse_substrait(expr: &[u8], input_schema: Arc) -> Result { if table == "dummy" { - Ok(Transformed::Yes(Expr::Column(Column { + Ok(Transformed::yes(Expr::Column(Column { relation: None, name: column.name, }))) @@ -524,12 +524,12 @@ pub async fn parse_substrait(expr: &[u8], input_schema: Arc) -> Result Err(DataFusionError::Substrait("Unexpected partially or fully qualified table reference encountered when parsing filter".into())) } } else { - Ok(Transformed::No(Expr::Column(column))) + Ok(Transformed::no(Expr::Column(column))) } } - _ => Ok(Transformed::No(node)), + _ => Ok(Transformed::no(node)), })?; - Ok(expr) + Ok(expr.data) } #[cfg(test)] diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index b72e46772c..f825f4487c 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -494,6 +494,8 @@ impl Ord for OrderableScalarValue { (Dictionary(_k1, _v1), Dictionary(_k2, _v2)) => todo!(), (Dictionary(_, v1), Null) => Self(*v1.clone()).cmp(&Self(ScalarValue::Null)), (Dictionary(_, _), _) => panic!("Attempt to compare Dictionary with non-Dictionary"), + // What would a btree of unions even look like? May not be possible. + (Union(_, _, _), _) => todo!("Support for union scalars"), (Null, Null) => Ordering::Equal, (Null, _) => todo!(), } diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index be2577a124..ed13a39c23 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -549,6 +549,18 @@ mod tests { fn options(&self) -> &ConfigOptions { todo!() } + + fn udfs_names(&self) -> Vec { + todo!() + } + + fn udafs_names(&self) -> Vec { + todo!() + } + + fn udwfs_names(&self) -> Vec { + todo!() + } } fn check( diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index 63511cd738..a4629de596 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -57,6 +57,7 @@ arrow.workspace = true num_cpus.workspace = true # TODO: use datafusion sub-modules to reduce build size? datafusion.workspace = true +datafusion-functions.workspace = true datafusion-physical-expr.workspace = true lapack = { version = "0.19.0", optional = true } lru_time_cache = "0.11" diff --git a/rust/lance/src/datafusion/logical_expr.rs b/rust/lance/src/datafusion/logical_expr.rs index 74cb79bebc..237f6db540 100644 --- a/rust/lance/src/datafusion/logical_expr.rs +++ b/rust/lance/src/datafusion/logical_expr.rs @@ -6,12 +6,13 @@ use arrow_schema::DataType; use datafusion::logical_expr::ScalarFunctionDefinition; +use datafusion::logical_expr::ScalarUDFImpl; use datafusion::logical_expr::{ - expr::ScalarFunction, BinaryExpr, BuiltinScalarFunction, GetFieldAccess, GetIndexedField, - Operator, + expr::ScalarFunction, BinaryExpr, GetFieldAccess, GetIndexedField, Operator, }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; +use datafusion_functions::core::getfield::GetFieldFunc; use lance_arrow::DataTypeExt; use lance_datafusion::expr::safe_coerce_scalar; @@ -34,6 +35,15 @@ fn resolve_value(expr: &Expr, data_type: &DataType) -> Result { } } +/// A simple helper function that interprets an Expr as a string scalar +/// or returns None if it is not. +pub fn get_as_string_scalar_opt(expr: &Expr) -> Option<&str> { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(s))) => Some(s), + _ => None, + } +} + /// Given a Expr::Column or Expr::GetIndexedField, get the data type of referenced /// field in the schema. /// @@ -49,6 +59,15 @@ pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option { field_path.push(c.name.as_str()); break; } + Expr::ScalarFunction(udf) => { + if udf.name() == GetFieldFunc::default().name() { + let name = get_as_string_scalar_opt(&udf.args[1])?; + field_path.push(name); + current_expr = &udf.args[0]; + } else { + return None; + } + } Expr::GetIndexedField(GetIndexedField { expr, field }) => { if let GetFieldAccess::NamedStructField { name: ScalarValue::Utf8(Some(name)), @@ -87,52 +106,41 @@ pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => { if matches!(op, Operator::And | Operator::Or) { - return Ok(Expr::BinaryExpr(BinaryExpr { + Ok(Expr::BinaryExpr(BinaryExpr { left: Box::new(resolve_expr(left.as_ref(), schema)?), op: *op, right: Box::new(resolve_expr(right.as_ref(), schema)?), - })); - } - match (left.as_ref(), right.as_ref()) { - (Expr::Column(_) | Expr::GetIndexedField(_), Expr::Literal(_)) => { - if let Some(resolved_type) = resolve_column_type(left.as_ref(), schema) { - Ok(Expr::BinaryExpr(BinaryExpr { - left: left.clone(), - op: *op, - right: Box::new(resolve_value(right.as_ref(), &resolved_type)?), - })) - } else { - Ok(expr.clone()) - } - } - (Expr::Literal(_), Expr::Column(_) | Expr::GetIndexedField(_)) => { - if let Some(resolved_type) = resolve_column_type(right.as_ref(), schema) { - Ok(Expr::BinaryExpr(BinaryExpr { - left: Box::new(resolve_value(left.as_ref(), &resolved_type)?), - op: *op, - right: right.clone(), - })) - } else { - Ok(expr.clone()) - } + })) + } else if let Some(left_type) = resolve_column_type(left.as_ref(), schema) { + match right.as_ref() { + Expr::Literal(_) => Ok(Expr::BinaryExpr(BinaryExpr { + left: left.clone(), + op: *op, + right: Box::new(resolve_value(right.as_ref(), &left_type)?), + })), + // For cases complex expressions (not just literals) on right hand side like x = 1 + 1 + -2*2 + Expr::BinaryExpr(r) => Ok(Expr::BinaryExpr(BinaryExpr { + left: left.clone(), + op: *op, + right: Box::new(Expr::BinaryExpr(BinaryExpr { + left: coerce_expr(&r.left, &left_type).map(Box::new)?, + op: r.op, + right: coerce_expr(&r.right, &left_type).map(Box::new)?, + })), + })), + _ => Ok(expr.clone()), } - // For cases complex expressions (not just literals) on right hand side like x = 1 + 1 + -2*2 - (Expr::Column(_) | Expr::GetIndexedField(_), Expr::BinaryExpr(r)) => { - if let Some(resolved_type) = resolve_column_type(left.as_ref(), schema) { - Ok(Expr::BinaryExpr(BinaryExpr { - left: left.clone(), - op: *op, - right: Box::new(Expr::BinaryExpr(BinaryExpr { - left: coerce_expr(&r.left, &resolved_type).map(Box::new)?, - op: r.op, - right: coerce_expr(&r.right, &resolved_type).map(Box::new)?, - })), - })) - } else { - Ok(expr.clone()) - } + } else if let Some(right_type) = resolve_column_type(right.as_ref(), schema) { + match left.as_ref() { + Expr::Literal(_) => Ok(Expr::BinaryExpr(BinaryExpr { + left: Box::new(resolve_value(left.as_ref(), &right_type)?), + op: *op, + right: right.clone(), + })), + _ => Ok(expr.clone()), } - _ => Ok(expr.clone()), + } else { + Ok(expr.clone()) } } Expr::InList(in_list) => { @@ -189,25 +197,61 @@ pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result { /// /// - *expr*: a datafusion logical expression pub fn coerce_filter_type_to_boolean(expr: Expr) -> Result { - match expr { + match &expr { // TODO: consider making this dispatch more generic, i.e. fun.output_type -> coerce // instead of hardcoding coerce method for each function Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::RegexpMatch), + func_def: ScalarFunctionDefinition::UDF(udf), .. - }) => Ok(Expr::IsNotNull(Box::new(expr))), - + }) => { + if udf.name() == "regexp_match" { + Ok(Expr::IsNotNull(Box::new(expr))) + } else { + Ok(expr) + } + } _ => Ok(expr), } } #[cfg(test)] -mod tests { +pub mod tests { use std::sync::Arc; use super::*; use arrow_schema::{Field, Schema as ArrowSchema}; + use datafusion::logical_expr::ScalarUDF; + + // As part of the DF 37 release there are now two different ways to + // represent a nested field access in `Expr`. The old way is to use + // `Expr::field` which returns a `GetStructField` and the new way is + // to use `Expr::ScalarFunction` with a `GetFieldFunc` UDF. + // + // Currently, the old path leads to bugs in DF. This is probably a + // bug and will probably be fixed in a future version. In the meantime + // we need to make sure we are always using the new way to avoid this + // bug. This trait adds field_newstyle which lets us easily create + // logical `Expr` that use the new style. + pub trait ExprExt { + // Helper function to replace Expr::field in DF 37 since DF + // confuses itself with the GetStructField returned by Expr::field + fn field_newstyle(&self, name: &str) -> Expr; + } + + impl ExprExt for Expr { + fn field_newstyle(&self, name: &str) -> Expr { + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + GetFieldFunc::default(), + ))), + args: vec![ + self.clone(), + Expr::Literal(ScalarValue::Utf8(Some(name.to_string()))), + ], + }) + } + } #[test] fn test_resolve_large_utf8() { diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 9f11090413..37fa6ff7a5 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -702,6 +702,7 @@ impl Scanner { &[], &plan.schema(), "", + false, )?; let plan_schema = plan.schema().clone(); let count_plan = Arc::new(AggregateExec::try_new( diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index db4521af0d..1b95e56a9a 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -431,6 +431,7 @@ impl MergeInsertJob { vec![(Arc::new(target_key), Arc::new(source_key))], None, &JoinType::Full, + None, PartitionMode::CollectLeft, true, ) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 85a8883b8a..2afefdc677 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -8,7 +8,7 @@ use std::task::{Context, Poll}; use arrow_array::cast::AsArray; use arrow_array::{RecordBatch, UInt64Array}; -use arrow_schema::{DataType, Field, Schema}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::common::stats::Precision; use datafusion::error::{DataFusionError, Result as DataFusionResult}; @@ -16,6 +16,8 @@ use datafusion::physical_plan::{ stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream as DFRecordBatchStream, SendableRecordBatchStream, Statistics, }; +use datafusion::physical_plan::{ExecutionMode, PlanProperties}; +use datafusion_physical_expr::EquivalenceProperties; use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use lance_core::utils::mask::{RowIdMask, RowIdTreeMap}; use lance_core::{ROW_ID, ROW_ID_FIELD}; @@ -133,6 +135,8 @@ pub struct KNNFlatExec { /// The vector query to execute. pub query: Query, + output_schema: SchemaRef, + properties: PlanProperties, } impl DisplayAs for KNNFlatExec { @@ -177,7 +181,26 @@ impl KNNFlatExec { } } - Ok(Self { input, query }) + let mut fields = schema.fields().to_vec(); + if schema.field_with_name(DIST_COL).is_err() { + fields.push(Arc::new(Field::new(DIST_COL, DataType::Float32, true))); + } + + let output_schema = Arc::new(Schema::new_with_metadata(fields, schema.metadata().clone())); + + // This node has the same partitioning & boundedness as the input node + // but it destroys any ordering. + let properties = input + .properties() + .clone() + .with_eq_properties(EquivalenceProperties::new(output_schema.clone())); + + Ok(Self { + input, + query, + output_schema, + properties, + }) } } @@ -188,24 +211,7 @@ impl ExecutionPlan for KNNFlatExec { /// Flat KNN inherits the schema from input node, and add one distance column. fn schema(&self) -> arrow_schema::SchemaRef { - let input_schema = self.input.schema(); - let mut fields = input_schema.fields().to_vec(); - if input_schema.field_with_name(DIST_COL).is_err() { - fields.push(Arc::new(Field::new(DIST_COL, DataType::Float32, true))); - } - - Arc::new(Schema::new_with_metadata( - fields, - input_schema.metadata().clone(), - )) - } - - fn output_partitioning(&self) -> Partitioning { - self.input.output_partitioning() - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - self.input.output_ordering() + self.output_schema.clone() } fn children(&self) -> Vec> { @@ -236,6 +242,10 @@ impl ExecutionPlan for KNNFlatExec { ..Statistics::new_unknown(self.schema().as_ref()) }) } + + fn properties(&self) -> &PlanProperties { + &self.properties + } } // Utility to convert an input (containing row ids) into a prefilter @@ -334,6 +344,13 @@ pub enum PreFilterSource { None, } +lazy_static::lazy_static! { + pub static ref KNN_INDEX_SCHEMA: SchemaRef = Arc::new(Schema::new(vec![ + Field::new(DIST_COL, DataType::Float32, true), + ROW_ID_FIELD.clone(), + ])); +} + /// [ExecutionPlan] for KNNIndex node. #[derive(Debug)] pub struct KNNIndexExec { @@ -345,6 +362,8 @@ pub struct KNNIndexExec { indices: Vec, /// The vector query to execute. query: Query, + /// The datafusion plan properties + properties: PlanProperties, } impl DisplayAs for KNNIndexExec { @@ -388,11 +407,18 @@ impl KNNIndexExec { }); }; + let properties = PlanProperties::new( + EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + ExecutionMode::Bounded, + ); + Ok(Self { dataset, indices: indices.to_vec(), query: query.clone(), prefilter_source, + properties, }) } } @@ -403,18 +429,7 @@ impl ExecutionPlan for KNNIndexExec { } fn schema(&self) -> arrow_schema::SchemaRef { - Arc::new(Schema::new(vec![ - Field::new(DIST_COL, DataType::Float32, true), - ROW_ID_FIELD.clone(), - ])) - } - - fn output_partitioning(&self) -> Partitioning { - Partitioning::RoundRobinBatch(1) - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - None + KNN_INDEX_SCHEMA.clone() } fn children(&self) -> Vec> { @@ -465,6 +480,10 @@ impl ExecutionPlan for KNNIndexExec { ..Statistics::new_unknown(self.schema().as_ref()) }) } + + fn properties(&self) -> &PlanProperties { + &self.properties + } } #[cfg(test)] diff --git a/rust/lance/src/io/exec/optimizer.rs b/rust/lance/src/io/exec/optimizer.rs index 77fca1888a..fb350b1a99 100644 --- a/rust/lance/src/io/exec/optimizer.rs +++ b/rust/lance/src/io/exec/optimizer.rs @@ -25,16 +25,18 @@ impl PhysicalOptimizerRule for CoalesceTake { plan: Arc, _config: &ConfigOptions, ) -> DFResult> { - plan.transform_down(&|plan| { - if let Some(take) = plan.as_any().downcast_ref::() { - let child = &take.children()[0]; - if let Some(exec_child) = child.as_any().downcast_ref::() { - let upstream_plan = exec_child.children(); - return Ok(Transformed::Yes(plan.with_new_children(upstream_plan)?)); + Ok(plan + .transform_down(&|plan| { + if let Some(take) = plan.as_any().downcast_ref::() { + let child = &take.children()[0]; + if let Some(exec_child) = child.as_any().downcast_ref::() { + let upstream_plan = exec_child.children(); + return Ok(Transformed::yes(plan.with_new_children(upstream_plan)?)); + } } - } - Ok(Transformed::No(plan)) - }) + Ok(Transformed::no(plan)) + })? + .data) } fn name(&self) -> &str { @@ -56,36 +58,38 @@ impl PhysicalOptimizerRule for SimplifyProjection { plan: Arc, _config: &ConfigOptions, ) -> DFResult> { - plan.transform_down(&|plan| { - if let Some(proj) = plan.as_any().downcast_ref::() { - let children = &proj.children(); - if children.len() != 1 { - return Ok(Transformed::No(plan)); - } + Ok(plan + .transform_down(&|plan| { + if let Some(proj) = plan.as_any().downcast_ref::() { + let children = &proj.children(); + if children.len() != 1 { + return Ok(Transformed::no(plan)); + } - let input = &children[0]; + let input = &children[0]; - // TODO: we could try to coalesce consecutive projections, something for later - // For now, we just keep things simple and only remove NoOp projections + // TODO: we could try to coalesce consecutive projections, something for later + // For now, we just keep things simple and only remove NoOp projections - // output has differnet schema, projection needed - if input.schema() != proj.schema() { - return Ok(Transformed::No(plan)); - } + // output has differnet schema, projection needed + if input.schema() != proj.schema() { + return Ok(Transformed::no(plan)); + } - if proj.expr().iter().enumerate().all(|(index, (expr, name))| { - if let Some(expr) = expr.as_any().downcast_ref::() { - // no renaming, no reordering - expr.index() == index && expr.name() == name - } else { - false + if proj.expr().iter().enumerate().all(|(index, (expr, name))| { + if let Some(expr) = expr.as_any().downcast_ref::() { + // no renaming, no reordering + expr.index() == index && expr.name() == name + } else { + false + } + }) { + return Ok(Transformed::yes(input.clone())); } - }) { - return Ok(Transformed::Yes(input.clone())); } - } - Ok(Transformed::No(plan)) - }) + Ok(Transformed::no(plan)) + })? + .data) } fn name(&self) -> &str { diff --git a/rust/lance/src/io/exec/planner.rs b/rust/lance/src/io/exec/planner.rs index a89bb5834c..c4a909b77c 100644 --- a/rust/lance/src/io/exec/planner.rs +++ b/rust/lance/src/io/exec/planner.rs @@ -11,11 +11,17 @@ use arrow_array::ListArray; use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit}; use arrow_select::concat::concat; -use datafusion::common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion::common::DFSchema; +use datafusion::config::ConfigOptions; use datafusion::error::Result as DFResult; +use datafusion::execution::config::SessionConfig; +use datafusion::execution::context::SessionState; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ - ColumnarValue, GetFieldAccess, GetIndexedField, ScalarUDF, ScalarUDFImpl, Signature, Volatility, + AggregateUDF, ColumnarValue, GetFieldAccess, GetIndexedField, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, WindowUDF, }; use datafusion::optimizer::simplify_expressions::SimplifyContext; use datafusion::physical_optimizer::optimizer::PhysicalOptimizer; @@ -32,6 +38,7 @@ use datafusion::{ prelude::Expr, scalar::ScalarValue, }; +use datafusion_functions::core::getfield::GetFieldFunc; use lance_arrow::cast::cast_with_options; use lance_datafusion::expr::safe_coerce_scalar; use lance_index::scalar::expression::{ @@ -39,7 +46,7 @@ use lance_index::scalar::expression::{ }; use snafu::{location, Location}; -use crate::datafusion::logical_expr::coerce_filter_type_to_boolean; +use crate::datafusion::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt}; use crate::utils::sql::parse_sql_expr; use crate::{ datafusion::logical_expr::resolve_expr, datatypes::Schema, utils::sql::parse_sql_filter, Error, @@ -167,9 +174,22 @@ impl ScalarUDFImpl for CastListF16Udf { } // Adapter that instructs datafusion how lance expects expressions to be interpreted -#[derive(Default)] struct LanceContextProvider { options: datafusion::config::ConfigOptions, + state: SessionState, +} + +impl Default for LanceContextProvider { + fn default() -> Self { + let config = SessionConfig::new(); + let runtime_config = RuntimeConfig::new(); + let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); + let state = SessionState::new_with_config_rt(config, runtime); + Self { + options: ConfigOptions::default(), + state, + } + } } impl ContextProvider for LanceContextProvider { @@ -183,25 +203,23 @@ impl ContextProvider for LanceContextProvider { ))) } + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + fn get_function_meta(&self, f: &str) -> Option> { match f { // TODO: cast should go thru CAST syntax instead of UDF // Going thru UDF makes it hard for the optimizer to find no-ops "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))), - _ => None, + _ => self.state.scalar_functions().get(f).cloned(), } } - fn get_aggregate_meta(&self, _: &str) -> Option> { - // UDFs not supported yet - None - } - - fn get_window_meta(&self, _: &str) -> Option> { - // UDFs not supported yet - None - } - fn get_variable_type(&self, _: &[String]) -> Option { // Variables (things like @@LANGUAGE) not supported None @@ -210,6 +228,18 @@ impl ContextProvider for LanceContextProvider { fn options(&self) -> &datafusion::config::ConfigOptions { &self.options } + + fn udfs_names(&self) -> Vec { + self.state.scalar_functions().keys().cloned().collect() + } + + fn udafs_names(&self) -> Vec { + self.state.aggregate_functions().keys().cloned().collect() + } + + fn udwfs_names(&self) -> Vec { + self.state.window_functions().keys().cloned().collect() + } } pub struct Planner { @@ -224,7 +254,15 @@ impl Planner { fn column(idents: &[Ident]) -> Expr { let mut column = col(&idents[0].value); for ident in &idents[1..] { - column = column.field(&ident.value); + column = Expr::ScalarFunction(ScalarFunction { + args: vec![ + column, + Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone()))), + ], + func_def: datafusion::logical_expr::ScalarFunctionDefinition::UDF(Arc::new( + ScalarUDF::new_from_impl(GetFieldFunc::default()), + )), + }); } column } @@ -717,9 +755,9 @@ struct ColumnCapturingVisitor { } impl TreeNodeVisitor for ColumnCapturingVisitor { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, node: &Self::N) -> DFResult { + fn f_down(&mut self, node: &Self::Node) -> DFResult { match node { Expr::Column(Column { name, .. }) => { let mut path = name.clone(); @@ -730,6 +768,17 @@ impl TreeNodeVisitor for ColumnCapturingVisitor { self.columns.insert(path); self.current_path.clear(); } + Expr::ScalarFunction(udf) => { + if udf.name() == GetFieldFunc::default().name() { + if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) { + self.current_path.push_front(name.to_string()) + } else { + self.current_path.clear(); + } + } else { + self.current_path.clear(); + } + } Expr::GetIndexedField(GetIndexedField { expr: _, field: GetFieldAccess::NamedStructField { name }, @@ -741,12 +790,15 @@ impl TreeNodeVisitor for ColumnCapturingVisitor { } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } #[cfg(test)] mod tests { + + use crate::datafusion::logical_expr::tests::ExprExt; + use super::*; use arrow_array::{ @@ -755,7 +807,7 @@ mod tests { TimestampNanosecondArray, TimestampSecondArray, }; use arrow_schema::{DataType, Fields, Schema}; - use datafusion::logical_expr::{lit, Cast}; + use datafusion::logical_expr::{lit, Cast, ScalarFunctionDefinition}; #[test] fn test_parse_filter_simple() { @@ -776,7 +828,7 @@ mod tests { let expected = col("i") .gt(lit(3_i32)) - .and(col("st").field("x").lt_eq(lit(5.0_f32))) + .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32))) .and( col("s") .eq(lit("str-4")) @@ -875,33 +927,43 @@ mod tests { assert_column_eq(&planner, "s0", &expected); assert_column_eq(&planner, "`s0`", &expected); - let expected = Expr::GetIndexedField(GetIndexedField { - expr: Box::new(Expr::Column(Column { - relation: None, - name: "st".to_string(), - })), - field: GetFieldAccess::NamedStructField { - name: ScalarValue::from("s1"), - }, + let expected = Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + GetFieldFunc::default(), + ))), + args: vec![ + Expr::Column(Column { + relation: None, + name: "st".to_string(), + }), + Expr::Literal(ScalarValue::Utf8(Some("s1".to_string()))), + ], }); assert_column_eq(&planner, "st.s1", &expected); assert_column_eq(&planner, "`st`.`s1`", &expected); assert_column_eq(&planner, "st.`s1`", &expected); - let expected = Expr::GetIndexedField(GetIndexedField { - expr: Box::new(Expr::GetIndexedField(GetIndexedField { - expr: Box::new(Expr::Column(Column { - relation: None, - name: "st".to_string(), - })), - field: GetFieldAccess::NamedStructField { - name: ScalarValue::from("st"), - }, - })), - field: GetFieldAccess::NamedStructField { - name: ScalarValue::from("s2"), - }, + let expected = Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + GetFieldFunc::default(), + ))), + args: vec![ + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + GetFieldFunc::default(), + ))), + args: vec![ + Expr::Column(Column { + relation: None, + name: "st".to_string(), + }), + Expr::Literal(ScalarValue::Utf8(Some("st".to_string()))), + ], + }), + Expr::Literal(ScalarValue::Utf8(Some("s2".to_string()))), + ], }); + assert_column_eq(&planner, "st.st.s2", &expected); assert_column_eq(&planner, "`st`.`st`.`s2`", &expected); assert_column_eq(&planner, "st.st.`s2`", &expected); diff --git a/rust/lance/src/io/exec/projection.rs b/rust/lance/src/io/exec/projection.rs index 9e8b6e7bcd..d0a8f29bab 100644 --- a/rust/lance/src/io/exec/projection.rs +++ b/rust/lance/src/io/exec/projection.rs @@ -10,10 +10,13 @@ use std::task::{Context, Poll}; use arrow_array::RecordBatch; use arrow_schema::{Schema as ArrowSchema, SchemaRef}; +use datafusion::common::Statistics; use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, }; +use datafusion_physical_expr::EquivalenceProperties; use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use tokio::sync::mpsc::Receiver; use tokio::task::JoinHandle; @@ -116,6 +119,7 @@ impl RecordBatchStream for ProjectionStream { pub struct ProjectionExec { input: Arc, project: Arc, + properties: PlanProperties, } impl DisplayAs for ProjectionExec { @@ -137,7 +141,18 @@ impl DisplayAs for ProjectionExec { impl ProjectionExec { pub fn try_new(input: Arc, project: Arc) -> Result { - Ok(Self { input, project }) + let arrow_schema = ArrowSchema::from(project.as_ref()); + // TODO: we reset the EquivalenceProperties here but we could probably just project + // them, that way ordering is maintained (or just use DF project?) + let properties = input + .properties() + .clone() + .with_eq_properties(EquivalenceProperties::new(Arc::new(arrow_schema))); + Ok(Self { + input, + project, + properties, + }) } } @@ -151,14 +166,6 @@ impl ExecutionPlan for ProjectionExec { arrow_schema.into() } - fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { - self.input.output_partitioning() - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - self.input.output_ordering() - } - fn children(&self) -> Vec> { vec![self.input.clone()] } @@ -183,6 +190,14 @@ impl ExecutionPlan for ProjectionExec { } fn statistics(&self) -> datafusion::error::Result { - self.input.statistics() + let num_rows = self.input.statistics()?.num_rows; + Ok(Statistics { + num_rows, + ..datafusion::physical_plan::Statistics::new_unknown(self.schema().as_ref()) + }) + } + + fn properties(&self) -> &PlanProperties { + &self.properties } } diff --git a/rust/lance/src/io/exec/pushdown_scan.rs b/rust/lance/src/io/exec/pushdown_scan.rs index 098895b7d4..81089b73c7 100644 --- a/rust/lance/src/io/exec/pushdown_scan.rs +++ b/rust/lance/src/io/exec/pushdown_scan.rs @@ -9,12 +9,13 @@ use arrow_array::types::{Int64Type, UInt64Type}; use arrow_array::{Array, BooleanArray, Int64Array, PrimitiveArray, RecordBatch, UInt32Array}; use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef}; use arrow_select::filter::filter_record_batch; +use datafusion::common::Statistics; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::col; use datafusion::logical_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; use datafusion::physical_expr::execution_props::ExecutionProps; -use datafusion::physical_plan::ColumnarValue; +use datafusion::physical_plan::{ColumnarValue, ExecutionMode, PlanProperties}; use datafusion::scalar::ScalarValue; use datafusion::{ physical_plan::{ @@ -23,6 +24,7 @@ use datafusion::{ }, prelude::Expr, }; +use datafusion_physical_expr::EquivalenceProperties; use futures::{FutureExt, Stream, StreamExt, TryStreamExt}; use lance_arrow::RecordBatchExt; use lance_core::ROW_ID_FIELD; @@ -82,6 +84,8 @@ pub struct LancePushdownScanExec { predicate_projection: Arc, predicate: Expr, config: ScanConfig, + output_schema: Arc, + properties: PlanProperties, } impl LancePushdownScanExec { @@ -109,6 +113,22 @@ impl LancePushdownScanExec { )); } + let output_schema: ArrowSchema = projection.as_ref().into(); + let output_schema = if config.with_row_id { + let mut fields: Vec> = Vec::with_capacity(output_schema.fields.len() + 1); + fields.push(Arc::new(ROW_ID_FIELD.clone())); + fields.extend(output_schema.fields.iter().cloned()); + Arc::new(ArrowSchema::new(fields)) + } else { + Arc::new(output_schema) + }; + + let properties = PlanProperties::new( + EquivalenceProperties::new(output_schema.clone()), + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Ok(Self { dataset, fragments, @@ -116,6 +136,8 @@ impl LancePushdownScanExec { predicate, predicate_projection, config, + output_schema, + properties, }) } } @@ -126,23 +148,7 @@ impl ExecutionPlan for LancePushdownScanExec { } fn schema(&self) -> SchemaRef { - let schema: ArrowSchema = self.projection.as_ref().into(); - if self.config.with_row_id { - let mut fields: Vec> = Vec::with_capacity(schema.fields.len() + 1); - fields.push(Arc::new(ROW_ID_FIELD.clone())); - fields.extend(schema.fields.iter().cloned()); - Arc::new(ArrowSchema::new(fields)) - } else { - Arc::new(schema) - } - } - - fn output_partitioning(&self) -> Partitioning { - Partitioning::UnknownPartitioning(1) - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - None + self.output_schema.clone() } fn children(&self) -> Vec> { @@ -157,7 +163,7 @@ impl ExecutionPlan for LancePushdownScanExec { } fn statistics(&self) -> datafusion::error::Result { - todo!() + Ok(Statistics::new_unknown(self.output_schema.as_ref())) } fn execute( @@ -203,6 +209,10 @@ impl ExecutionPlan for LancePushdownScanExec { batch_stream, ))) } + + fn properties(&self) -> &PlanProperties { + &self.properties + } } impl DisplayAs for LancePushdownScanExec { @@ -655,7 +665,7 @@ mod test { use lance_arrow::{FixedSizeListArrayExt, SchemaExt}; use tempfile::tempdir; - use crate::dataset::WriteParams; + use crate::{datafusion::logical_expr::tests::ExprExt, dataset::WriteParams}; use super::*; @@ -800,9 +810,9 @@ mod test { let projection = Arc::new(dataset.schema().clone().project_by_ids(&[2, 4])); let predicate = col("x") - .field("a") + .field_newstyle("a") .lt(lit(8)) - .and(col("y").field("b").gt(lit(3))); + .and(col("y").field_newstyle("b").gt(lit(3))); let exec = LancePushdownScanExec::try_new( dataset.clone(), @@ -1038,7 +1048,7 @@ mod test { let dataset = Arc::new(test_dataset().await); let predicate = col("struct") - .field("int") + .field_newstyle("int") .gt(lit(4)) .and(col(Column::from_name("str")).is_not_null()); diff --git a/rust/lance/src/io/exec/scalar_index.rs b/rust/lance/src/io/exec/scalar_index.rs index b1e4f26d1b..f77975667f 100644 --- a/rust/lance/src/io/exec/scalar_index.rs +++ b/rust/lance/src/io/exec/scalar_index.rs @@ -7,11 +7,14 @@ use arrow_array::{RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::{ + common::{stats::Precision, Statistics}, physical_plan::{ - stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionMode, + ExecutionPlan, Partitioning, PlanProperties, }, scalar::ScalarValue, }; +use datafusion_physical_expr::EquivalenceProperties; use futures::{stream::BoxStream, Stream, StreamExt, TryFutureExt, TryStreamExt}; use lance_core::{ utils::{address::RowAddress, mask::RowIdTreeMap}, @@ -63,6 +66,7 @@ impl ScalarIndexLoader for Dataset { pub struct ScalarIndexExec { dataset: Arc, expr: ScalarIndexExpr, + properties: PlanProperties, } impl DisplayAs for ScalarIndexExec { @@ -77,7 +81,16 @@ impl DisplayAs for ScalarIndexExec { impl ScalarIndexExec { pub fn new(dataset: Arc, expr: ScalarIndexExpr) -> Self { - Self { dataset, expr } + let properties = PlanProperties::new( + EquivalenceProperties::new(SCALAR_INDEX_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + ExecutionMode::Bounded, + ); + Self { + dataset, + expr, + properties, + } } async fn do_execute(expr: ScalarIndexExpr, dataset: Arc) -> Result { @@ -99,15 +112,6 @@ impl ExecutionPlan for ScalarIndexExec { SCALAR_INDEX_SCHEMA.clone() } - fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { - Partitioning::RoundRobinBatch(1) - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - // No guarantee a scalar index scan will return row ids in any meaningful order - None - } - fn children(&self) -> Vec> { vec![] } @@ -136,7 +140,14 @@ impl ExecutionPlan for ScalarIndexExec { } fn statistics(&self) -> datafusion::error::Result { - todo!() + Ok(Statistics { + num_rows: Precision::Exact(2), + ..Statistics::new_unknown(&SCALAR_INDEX_SCHEMA) + }) + } + + fn properties(&self) -> &PlanProperties { + &self.properties } } @@ -152,6 +163,7 @@ pub struct MapIndexExec { dataset: Arc, column_name: String, input: Arc, + properties: PlanProperties, } impl DisplayAs for MapIndexExec { @@ -166,10 +178,16 @@ impl DisplayAs for MapIndexExec { impl MapIndexExec { pub fn new(dataset: Arc, column_name: String, input: Arc) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(INDEX_LOOKUP_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + ExecutionMode::Bounded, + ); Self { dataset, column_name, input, + properties, } } @@ -246,15 +264,6 @@ impl ExecutionPlan for MapIndexExec { INDEX_LOOKUP_SCHEMA.clone() } - fn output_partitioning(&self) -> Partitioning { - self.input.output_partitioning() - } - - fn output_ordering(&self) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { - // The output does have an implicit ordering but nothing we can express with PhysicalSortExpr - None - } - fn children(&self) -> Vec> { vec![self.input.clone()] } @@ -283,6 +292,10 @@ impl ExecutionPlan for MapIndexExec { stream, ))) } + + fn properties(&self) -> &PlanProperties { + &self.properties + } } lazy_static::lazy_static! { @@ -299,6 +312,7 @@ pub struct MaterializeIndexExec { dataset: Arc, expr: ScalarIndexExpr, fragments: Arc>, + properties: PlanProperties, } impl DisplayAs for MaterializeIndexExec { @@ -356,10 +370,16 @@ impl MaterializeIndexExec { expr: ScalarIndexExpr, fragments: Arc>, ) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(MATERIALIZE_INDEX_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + ExecutionMode::Bounded, + ); Self { dataset, expr, fragments, + properties, } } @@ -441,15 +461,6 @@ impl ExecutionPlan for MaterializeIndexExec { MATERIALIZE_INDEX_SCHEMA.clone() } - fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { - Partitioning::RoundRobinBatch(1) - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - // No guarantee a scalar index scan will return row ids in any meaningful order - None - } - fn children(&self) -> Vec> { vec![] } @@ -482,6 +493,10 @@ impl ExecutionPlan for MaterializeIndexExec { } fn statistics(&self) -> datafusion::error::Result { - todo!() + Ok(Statistics::new_unknown(&MATERIALIZE_INDEX_SCHEMA)) + } + + fn properties(&self) -> &PlanProperties { + &self.properties } } diff --git a/rust/lance/src/io/exec/scan.rs b/rust/lance/src/io/exec/scan.rs index 2188af9484..c46cb9c1fd 100644 --- a/rust/lance/src/io/exec/scan.rs +++ b/rust/lance/src/io/exec/scan.rs @@ -11,9 +11,10 @@ use arrow_schema::{Field, Schema as ArrowSchema, SchemaRef}; use datafusion::common::stats::Precision; use datafusion::error::{DataFusionError, Result}; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use datafusion_physical_expr::EquivalenceProperties; use futures::stream; use futures::stream::Stream; use futures::{StreamExt, TryStreamExt}; @@ -191,6 +192,8 @@ pub struct LanceScanExec { with_row_id: bool, with_make_deletions_null: bool, ordered_output: bool, + output_schema: Arc, + properties: PlanProperties, } impl DisplayAs for LanceScanExec { @@ -230,6 +233,19 @@ impl LanceScanExec { with_make_deletions_null: bool, ordered_ouput: bool, ) -> Self { + let output_schema: ArrowSchema = projection.as_ref().into(); + let output_schema = if with_row_id { + let mut fields: Vec> = output_schema.fields.to_vec(); + fields.push(Arc::new(ROW_ID_FIELD.clone())); + Arc::new(ArrowSchema::new(fields)) + } else { + Arc::new(output_schema) + }; + let properties = PlanProperties::new( + EquivalenceProperties::new(output_schema.clone()), + Partitioning::RoundRobinBatch(1), + datafusion::physical_plan::ExecutionMode::Bounded, + ); Self { dataset, fragments, @@ -240,6 +256,8 @@ impl LanceScanExec { with_row_id, with_make_deletions_null, ordered_output: ordered_ouput, + output_schema, + properties, } } } @@ -250,22 +268,7 @@ impl ExecutionPlan for LanceScanExec { } fn schema(&self) -> SchemaRef { - let schema: ArrowSchema = self.projection.as_ref().into(); - if self.with_row_id { - let mut fields: Vec> = schema.fields.to_vec(); - fields.push(Arc::new(ROW_ID_FIELD.clone())); - Arc::new(ArrowSchema::new(fields)) - } else { - Arc::new(schema) - } - } - - fn output_partitioning(&self) -> Partitioning { - Partitioning::RoundRobinBatch(1) - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - None + self.output_schema.clone() } /// Scan is the leaf node, so returns an empty vector. @@ -326,4 +329,8 @@ impl ExecutionPlan for LanceScanExec { ..datafusion::physical_plan::Statistics::new_unknown(self.schema().as_ref()) }) } + + fn properties(&self) -> &PlanProperties { + &self.properties + } } diff --git a/rust/lance/src/io/exec/take.rs b/rust/lance/src/io/exec/take.rs index 0f3069e326..ea1a0194b1 100644 --- a/rust/lance/src/io/exec/take.rs +++ b/rust/lance/src/io/exec/take.rs @@ -7,10 +7,13 @@ use std::task::{Context, Poll}; use arrow_array::{cast::as_primitive_array, RecordBatch, UInt64Array}; use arrow_schema::{Schema as ArrowSchema, SchemaRef}; +use datafusion::common::Statistics; use datafusion::error::{DataFusionError, Result}; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, }; +use datafusion_physical_expr::EquivalenceProperties; use futures::stream::{self, Stream, StreamExt, TryStreamExt}; use futures::{Future, FutureExt}; use tokio::sync::mpsc::{self, Receiver}; @@ -177,6 +180,8 @@ pub struct TakeExec { output_schema: Schema, batch_readahead: usize, + + properties: PlanProperties, } impl DisplayAs for TakeExec { @@ -219,12 +224,19 @@ impl TakeExec { let remaining_schema = extra_schema.exclude(&input_schema)?; + let output_arrow = Arc::new(ArrowSchema::from(&output_schema)); + let properties = input + .properties() + .clone() + .with_eq_properties(EquivalenceProperties::new(output_arrow)); + Ok(Self { dataset, extra_schema: Arc::new(remaining_schema), input, output_schema, batch_readahead, + properties, }) } } @@ -238,14 +250,6 @@ impl ExecutionPlan for TakeExec { ArrowSchema::from(&self.output_schema).into() } - fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { - self.input.output_partitioning() - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - self.input.output_ordering() - } - fn children(&self) -> Vec> { vec![self.input.clone()] } @@ -291,7 +295,14 @@ impl ExecutionPlan for TakeExec { } fn statistics(&self) -> Result { - self.input.statistics() + Ok(Statistics { + num_rows: self.input.statistics()?.num_rows, + ..Statistics::new_unknown(self.schema().as_ref()) + }) + } + + fn properties(&self) -> &PlanProperties { + &self.properties } } diff --git a/rust/lance/src/io/exec/testing.rs b/rust/lance/src/io/exec/testing.rs index ab10a0a639..9ab2794590 100644 --- a/rust/lance/src/io/exec/testing.rs +++ b/rust/lance/src/io/exec/testing.rs @@ -9,18 +9,32 @@ use std::sync::Arc; use arrow_array::RecordBatch; use datafusion::{ + common::Statistics, execution::context::TaskContext, - physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream}, + physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, PlanProperties, + SendableRecordBatchStream, + }, }; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; #[derive(Debug)] pub struct TestingExec { pub(crate) batches: Vec, + properties: PlanProperties, } impl TestingExec { pub(crate) fn new(batches: Vec) -> Self { - Self { batches } + let properties = PlanProperties::new( + EquivalenceProperties::new(batches[0].schema().clone()), + Partitioning::RoundRobinBatch(1), + ExecutionMode::Bounded, + ); + Self { + batches, + properties, + } } } @@ -41,14 +55,6 @@ impl ExecutionPlan for TestingExec { self.batches[0].schema() } - fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { - todo!() - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - todo!() - } - fn children(&self) -> Vec> { vec![] } @@ -69,6 +75,10 @@ impl ExecutionPlan for TestingExec { } fn statistics(&self) -> datafusion::error::Result { - todo!() + Ok(Statistics::new_unknown(self.schema().as_ref())) + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + &self.properties } } diff --git a/rust/lance/src/io/exec/utils.rs b/rust/lance/src/io/exec/utils.rs index 5935f4a5d2..06d2eceb41 100644 --- a/rust/lance/src/io/exec/utils.rs +++ b/rust/lance/src/io/exec/utils.rs @@ -131,14 +131,6 @@ impl ExecutionPlan for ReplayExec { self.input.schema() } - fn output_partitioning(&self) -> datafusion_physical_expr::Partitioning { - self.input.output_partitioning() - } - - fn output_ordering(&self) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { - self.input.output_ordering() - } - fn children(&self) -> Vec> { vec![self.input.clone()] } @@ -177,6 +169,10 @@ impl ExecutionPlan for ReplayExec { })) } } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + self.input.properties() + } } #[cfg(test)]