diff --git a/Cargo.toml b/Cargo.toml index a3d312c857c9..38cdf539360e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -120,7 +120,7 @@ futures = "0.3" half = { version = "2.2.1", default-features = false } hashbrown = { version = "0.14.5", features = ["raw"] } indexmap = "2.0.0" -itertools = "0.12" +itertools = "0.13" log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.11.0", default-features = false } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9ca04f702241..352f2243aef6 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1208,7 +1208,7 @@ dependencies = [ "half", "hashbrown", "indexmap", - "itertools 0.12.1", + "itertools", "log", "num-traits", "num_cpus", @@ -1359,7 +1359,7 @@ dependencies = [ "datafusion-expr", "hashbrown", "hex", - "itertools 0.12.1", + "itertools", "log", "md-5", "rand", @@ -1414,7 +1414,7 @@ dependencies = [ "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "itertools 0.12.1", + "itertools", "log", "paste", "rand", @@ -1442,7 +1442,7 @@ dependencies = [ "datafusion-physical-expr", "hashbrown", "indexmap", - "itertools 0.12.1", + "itertools", "log", "paste", "regex-syntax", @@ -1471,7 +1471,7 @@ dependencies = [ "hashbrown", "hex", "indexmap", - "itertools 0.12.1", + "itertools", "log", "paste", "petgraph", @@ -1512,7 +1512,7 @@ dependencies = [ "datafusion-execution", "datafusion-physical-expr", "datafusion-physical-plan", - "itertools 0.12.1", + "itertools", ] [[package]] @@ -1540,7 +1540,7 @@ dependencies = [ "half", "hashbrown", "indexmap", - "itertools 0.12.1", + "itertools", "log", "once_cell", "parking_lot", @@ -2237,15 +2237,6 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -2646,7 +2637,7 @@ dependencies = [ "futures", "humantime", "hyper 1.4.1", - "itertools 0.13.0", + "itertools", "md-5", "parking_lot", "percent-encoding", diff --git a/datafusion/core/src/datasource/physical_plan/file_groups.rs b/datafusion/core/src/datasource/physical_plan/file_groups.rs index 6456bd5c7276..fb2cd4ad06ec 100644 --- a/datafusion/core/src/datasource/physical_plan/file_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/file_groups.rs @@ -256,7 +256,7 @@ impl FileGroupPartitioner { }, ) .flatten() - .group_by(|(partition_idx, _)| *partition_idx) + .chunk_by(|(partition_idx, _)| *partition_idx) .into_iter() .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) .collect_vec(); diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 6d2fb660f669..401762ad4d36 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -481,16 +481,22 @@ fn type_union_resolution_coercion( } } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation -/// Unlike `coerced_from`, usually the coerced type is for comparison only. -/// For example, compare with Dictionary and Dictionary, only value type is what we care about +/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a +/// comparison operation +/// +/// Example comparison operations are `lhs = rhs` and `lhs > rhs` +/// +/// Binary comparison kernels require the two arguments to be the (exact) same +/// data type. However, users can write queries where the two arguments are +/// different data types. In such cases, the data types are automatically cast +/// (coerced) to a single data type to pass to the kernels. pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { if lhs_type == rhs_type { // same type => equality is possible return Some(lhs_type.clone()); } binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_coercion(lhs_type, rhs_type, true)) + .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, true)) .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| list_coercion(lhs_type, rhs_type)) @@ -501,7 +507,11 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { if lhs_type == rhs_type { // same type => equality is possible @@ -883,7 +893,7 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> /// /// Not all operators support dictionaries, if `preserve_dictionaries` is true /// dictionaries will be preserved if possible -fn dictionary_coercion( +fn dictionary_comparison_coercion( lhs_type: &DataType, rhs_type: &DataType, preserve_dictionaries: bool, @@ -912,26 +922,22 @@ fn dictionary_coercion( /// Coercion rules for string concat. /// This is a union of string coercion rules and specified rules: -/// 1. At lease one side of lhs and rhs should be string type (Utf8 / LargeUtf8) +/// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8) /// 2. Data type of the other side should be able to cast to string type fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - // If Utf8View is in any side, we coerce to Utf8. - // Ref: https://github.com/apache/datafusion/pull/11796 - (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => { - Some(Utf8) + string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { + (Utf8View, from_type) | (from_type, Utf8View) => { + string_concat_internal_coercion(from_type, &Utf8View) } - _ => string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { - (Utf8, from_type) | (from_type, Utf8) => { - string_concat_internal_coercion(from_type, &Utf8) - } - (LargeUtf8, from_type) | (from_type, LargeUtf8) => { - string_concat_internal_coercion(from_type, &LargeUtf8) - } - _ => None, - }), - } + (Utf8, from_type) | (from_type, Utf8) => { + string_concat_internal_coercion(from_type, &Utf8) + } + (LargeUtf8, from_type) | (from_type, LargeUtf8) => { + string_concat_internal_coercion(from_type, &LargeUtf8) + } + _ => None, + }) } fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -942,6 +948,8 @@ fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } +/// If `from_type` can be casted to `to_type`, return `to_type`, otherwise +/// return `None`. fn string_concat_internal_coercion( from_type: &DataType, to_type: &DataType, @@ -967,6 +975,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8. (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8), + // Utf8 coerces to Utf8 (Utf8, Utf8) => Some(Utf8), _ => None, } @@ -1044,7 +1053,7 @@ pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option Option { string_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_coercion(lhs_type, rhs_type, false)) + .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) .or_else(|| regex_null_coercion(lhs_type, rhs_type)) } @@ -1324,38 +1333,50 @@ mod tests { let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Int32)); assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type, false), + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + Some(Int32) + ); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, false), Some(Int32) ); // Since we can coerce values of Int16 to Utf8 can support this let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Utf8)); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + Some(Utf8) + ); // Since we can coerce values of Utf8 to Binary can support this let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary)); assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type, true), + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), Some(Binary) ); let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Utf8; - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8)); assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type, true), + dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + Some(Utf8) + ); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), Some(lhs_type.clone()) ); let lhs_type = Utf8; let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8)); assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type, true), + dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + Some(Utf8) + ); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), Some(rhs_type.clone()) ); } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index a96caa03d611..559908bcfdfa 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -60,6 +60,7 @@ pub const UNNAMED_TABLE: &str = "?table?"; /// Builder for logical plans /// +/// # Example building a simple plan /// ``` /// # use datafusion_expr::{lit, col, LogicalPlanBuilder, logical_plan::table_scan}; /// # use datafusion_common::Result; @@ -88,17 +89,27 @@ pub const UNNAMED_TABLE: &str = "?table?"; /// .project(vec![col("last_name")])? /// .build()?; /// +/// // Convert from plan back to builder +/// let builder = LogicalPlanBuilder::from(plan); +/// /// # Ok(()) /// # } /// ``` #[derive(Debug, Clone)] pub struct LogicalPlanBuilder { - plan: LogicalPlan, + plan: Arc, } impl LogicalPlanBuilder { /// Create a builder from an existing plan - pub fn from(plan: LogicalPlan) -> Self { + pub fn new(plan: LogicalPlan) -> Self { + Self { + plan: Arc::new(plan), + } + } + + /// Create a builder from an existing plan + pub fn new_from_arc(plan: Arc) -> Self { Self { plan } } @@ -116,7 +127,7 @@ impl LogicalPlanBuilder { /// /// `produce_one_row` set to true means this empty node needs to produce a placeholder row. pub fn empty(produce_one_row: bool) -> Self { - Self::from(LogicalPlan::EmptyRelation(EmptyRelation { + Self::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row, schema: DFSchemaRef::new(DFSchema::empty()), })) @@ -125,7 +136,7 @@ impl LogicalPlanBuilder { /// Convert a regular plan into a recursive query. /// `is_distinct` indicates whether the recursive term should be de-duplicated (`UNION`) after each iteration or not (`UNION ALL`). pub fn to_recursive_query( - &self, + self, name: String, recursive_term: LogicalPlan, is_distinct: bool, @@ -150,7 +161,7 @@ impl LogicalPlanBuilder { coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { name, - static_term: Arc::new(self.plan.clone()), + static_term: self.plan, recursive_term: Arc::new(coerced_recursive_term), is_distinct, }))) @@ -228,7 +239,7 @@ impl LogicalPlanBuilder { .collect::>(); let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?; let schema = DFSchemaRef::new(dfschema); - Ok(Self::from(LogicalPlan::Values(Values { schema, values }))) + Ok(Self::new(LogicalPlan::Values(Values { schema, values }))) } /// Convert a table provider into a builder with a TableScan @@ -279,7 +290,7 @@ impl LogicalPlanBuilder { options: HashMap, partition_by: Vec, ) -> Result { - Ok(Self::from(LogicalPlan::Copy(CopyTo { + Ok(Self::new(LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url, partition_by, @@ -303,7 +314,7 @@ impl LogicalPlanBuilder { WriteOp::InsertInto }; - Ok(Self::from(LogicalPlan::Dml(DmlStatement::new( + Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( table_name.into(), table_schema, op, @@ -320,7 +331,7 @@ impl LogicalPlanBuilder { ) -> Result { TableScan::try_new(table_name, table_source, projection, filters, None) .map(LogicalPlan::TableScan) - .map(Self::from) + .map(Self::new) } /// Wrap a plan in a window @@ -365,7 +376,7 @@ impl LogicalPlanBuilder { self, expr: impl IntoIterator>, ) -> Result { - project(self.plan, expr).map(Self::from) + project(unwrap_arc(self.plan), expr).map(Self::new) } /// Select the given column indices @@ -380,17 +391,25 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Filter::try_new(expr, Arc::new(self.plan)) + Filter::try_new(expr, self.plan) + .map(LogicalPlan::Filter) + .map(Self::new) + } + + /// Apply a filter which is used for a having clause + pub fn having(self, expr: impl Into) -> Result { + let expr = normalize_col(expr.into(), &self.plan)?; + Filter::try_new_with_having(expr, self.plan) .map(LogicalPlan::Filter) .map(Self::from) } /// Make a builder for a prepare logical plan from the builder's plan pub fn prepare(self, name: String, data_types: Vec) -> Result { - Ok(Self::from(LogicalPlan::Prepare(Prepare { + Ok(Self::new(LogicalPlan::Prepare(Prepare { name, data_types, - input: Arc::new(self.plan), + input: self.plan, }))) } @@ -401,16 +420,16 @@ impl LogicalPlanBuilder { /// `fetch` - Maximum number of rows to fetch, after skipping `skip` rows, /// if specified. pub fn limit(self, skip: usize, fetch: Option) -> Result { - Ok(Self::from(LogicalPlan::Limit(Limit { + Ok(Self::new(LogicalPlan::Limit(Limit { skip, fetch, - input: Arc::new(self.plan), + input: self.plan, }))) } /// Apply an alias pub fn alias(self, alias: impl Into) -> Result { - subquery_alias(self.plan, alias).map(Self::from) + subquery_alias(unwrap_arc(self.plan), alias).map(Self::new) } /// Add missing sort columns to all downstream projection @@ -465,7 +484,7 @@ impl LogicalPlanBuilder { Self::ambiguous_distinct_check(&missing_exprs, missing_cols, &expr)?; } expr.extend(missing_exprs); - project((*input).clone(), expr) + project(unwrap_arc(input), expr) } _ => { let is_distinct = @@ -550,9 +569,9 @@ impl LogicalPlanBuilder { })?; if missing_cols.is_empty() { - return Ok(Self::from(LogicalPlan::Sort(Sort { + return Ok(Self::new(LogicalPlan::Sort(Sort { expr: normalize_cols(exprs, &self.plan)?, - input: Arc::new(self.plan), + input: self.plan, fetch: None, }))); } @@ -561,7 +580,8 @@ impl LogicalPlanBuilder { let new_expr = schema.columns().into_iter().map(Expr::Column).collect(); let is_distinct = false; - let plan = Self::add_missing_columns(self.plan, &missing_cols, is_distinct)?; + let plan = + Self::add_missing_columns(unwrap_arc(self.plan), &missing_cols, is_distinct)?; let sort_plan = LogicalPlan::Sort(Sort { expr: normalize_cols(exprs, &plan)?, input: Arc::new(plan), @@ -570,29 +590,27 @@ impl LogicalPlanBuilder { Projection::try_new(new_expr, Arc::new(sort_plan)) .map(LogicalPlan::Projection) - .map(Self::from) + .map(Self::new) } /// Apply a union, preserving duplicate rows pub fn union(self, plan: LogicalPlan) -> Result { - union(self.plan, plan).map(Self::from) + union(unwrap_arc(self.plan), plan).map(Self::new) } /// Apply a union, removing duplicate rows pub fn union_distinct(self, plan: LogicalPlan) -> Result { - let left_plan: LogicalPlan = self.plan; + let left_plan: LogicalPlan = unwrap_arc(self.plan); let right_plan: LogicalPlan = plan; - Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new( union(left_plan, right_plan)?, ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( - self.plan, - ))))) + Ok(Self::new(LogicalPlan::Distinct(Distinct::All(self.plan)))) } /// Project first values of the specified expression list according to the provided @@ -603,8 +621,8 @@ impl LogicalPlanBuilder { select_expr: Vec, sort_expr: Option>, ) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct::On( - DistinctOn::try_new(on_expr, select_expr, sort_expr, Arc::new(self.plan))?, + Ok(Self::new(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new(on_expr, select_expr, sort_expr, self.plan)?, )))) } @@ -819,8 +837,8 @@ impl LogicalPlanBuilder { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::Join(Join { + left: self.plan, right: Arc::new(right), on, filter, @@ -883,8 +901,8 @@ impl LogicalPlanBuilder { DataFusionError::Internal("filters should not be None here".to_string()) })?) } else { - Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::Join(Join { + left: self.plan, right: Arc::new(right), on: join_on, filter: filters, @@ -900,8 +918,8 @@ impl LogicalPlanBuilder { pub fn cross_join(self, right: LogicalPlan) -> Result { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::from(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::CrossJoin(CrossJoin { + left: self.plan, right: Arc::new(right), schema: DFSchemaRef::new(join_schema), }))) @@ -909,8 +927,8 @@ impl LogicalPlanBuilder { /// Repartition pub fn repartition(self, partitioning_scheme: Partitioning) -> Result { - Ok(Self::from(LogicalPlan::Repartition(Repartition { - input: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::Repartition(Repartition { + input: self.plan, partitioning_scheme, }))) } @@ -922,9 +940,9 @@ impl LogicalPlanBuilder { ) -> Result { let window_expr = normalize_cols(window_expr, &self.plan)?; validate_unique_names("Windows", &window_expr)?; - Ok(Self::from(LogicalPlan::Window(Window::try_new( + Ok(Self::new(LogicalPlan::Window(Window::try_new( window_expr, - Arc::new(self.plan), + self.plan, )?))) } @@ -941,9 +959,9 @@ impl LogicalPlanBuilder { let group_expr = add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; - Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) + Aggregate::try_new(self.plan, group_expr, aggr_expr) .map(LogicalPlan::Aggregate) - .map(Self::from) + .map(Self::new) } /// Create an expression to represent the explanation of the plan @@ -957,18 +975,18 @@ impl LogicalPlanBuilder { let schema = schema.to_dfschema_ref()?; if analyze { - Ok(Self::from(LogicalPlan::Analyze(Analyze { + Ok(Self::new(LogicalPlan::Analyze(Analyze { verbose, - input: Arc::new(self.plan), + input: self.plan, schema, }))) } else { let stringified_plans = vec![self.plan.to_stringified(PlanType::InitialLogicalPlan)]; - Ok(Self::from(LogicalPlan::Explain(Explain { + Ok(Self::new(LogicalPlan::Explain(Explain { verbose, - plan: Arc::new(self.plan), + plan: self.plan, stringified_plans, schema, logical_optimization_succeeded: false, @@ -1046,7 +1064,7 @@ impl LogicalPlanBuilder { /// Build the plan pub fn build(self) -> Result { - Ok(self.plan) + Ok(unwrap_arc(self.plan)) } /// Apply a join with the expression on constraint. @@ -1106,8 +1124,8 @@ impl LogicalPlanBuilder { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::Join(Join { + left: self.plan, right: Arc::new(right), on: join_key_pairs, filter, @@ -1120,7 +1138,7 @@ impl LogicalPlanBuilder { /// Unnest the given column. pub fn unnest_column(self, column: impl Into) -> Result { - Ok(Self::from(unnest(self.plan, vec![column.into()])?)) + unnest(unwrap_arc(self.plan), vec![column.into()]).map(Self::new) } /// Unnest the given column given [`UnnestOptions`] @@ -1129,11 +1147,8 @@ impl LogicalPlanBuilder { column: impl Into, options: UnnestOptions, ) -> Result { - Ok(Self::from(unnest_with_options( - self.plan, - vec![column.into()], - options, - )?)) + unnest_with_options(unwrap_arc(self.plan), vec![column.into()], options) + .map(Self::new) } /// Unnest the given columns with the given [`UnnestOptions`] @@ -1142,45 +1157,19 @@ impl LogicalPlanBuilder { columns: Vec, options: UnnestOptions, ) -> Result { - Ok(Self::from(unnest_with_options( - self.plan, columns, options, - )?)) + unnest_with_options(unwrap_arc(self.plan), columns, options).map(Self::new) } } -/// Converts a `Arc` into `LogicalPlanBuilder` -/// ``` -/// # use datafusion_expr::{Expr, expr, col, LogicalPlanBuilder, logical_plan::table_scan}; -/// # use datafusion_common::Result; -/// # use arrow::datatypes::{Schema, DataType, Field}; -/// # fn main() -> Result<()> { -/// # -/// # fn employee_schema() -> Schema { -/// # Schema::new(vec![ -/// # Field::new("id", DataType::Int32, false), -/// # Field::new("first_name", DataType::Utf8, false), -/// # Field::new("last_name", DataType::Utf8, false), -/// # Field::new("state", DataType::Utf8, false), -/// # Field::new("salary", DataType::Int32, false), -/// # ]) -/// # } -/// # -/// // Create the plan -/// let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? -/// .sort(vec![ -/// Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), -/// Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), -/// ])? -/// .build()?; -/// // Convert LogicalPlan into LogicalPlanBuilder -/// let plan_builder: LogicalPlanBuilder = std::sync::Arc::new(plan).into(); -/// # Ok(()) -/// # } -/// ``` +impl From for LogicalPlanBuilder { + fn from(plan: LogicalPlan) -> Self { + LogicalPlanBuilder::new(plan) + } +} impl From> for LogicalPlanBuilder { fn from(plan: Arc) -> Self { - LogicalPlanBuilder::from(unwrap_arc(plan)) + LogicalPlanBuilder::new_from_arc(plan) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index f93b7c0fedd0..ca7d04b9b03e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -643,9 +643,12 @@ impl LogicalPlan { // todo it isn't clear why the schema is not recomputed here Ok(LogicalPlan::Values(Values { schema, values })) } - LogicalPlan::Filter(Filter { predicate, input }) => { - Filter::try_new(predicate, input).map(LogicalPlan::Filter) - } + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) => Filter::try_new_internal(predicate, input, having) + .map(LogicalPlan::Filter), LogicalPlan::Repartition(_) => Ok(self), LogicalPlan::Window(Window { input, @@ -2080,6 +2083,8 @@ pub struct Filter { pub predicate: Expr, /// The incoming logical plan pub input: Arc, + /// The flag to indicate if the filter is a having clause + pub having: bool, } impl Filter { @@ -2088,6 +2093,20 @@ impl Filter { /// Notes: as Aliases have no effect on the output of a filter operator, /// they are removed from the predicate expression. pub fn try_new(predicate: Expr, input: Arc) -> Result { + Self::try_new_internal(predicate, input, false) + } + + /// Create a new filter operator for a having clause. + /// This is similar to a filter, but its having flag is set to true. + pub fn try_new_with_having(predicate: Expr, input: Arc) -> Result { + Self::try_new_internal(predicate, input, true) + } + + fn try_new_internal( + predicate: Expr, + input: Arc, + having: bool, + ) -> Result { // Filter predicates must return a boolean value so we try and validate that here. // Note that it is not always possible to resolve the predicate expression during plan // construction (such as with correlated subqueries) so we make a best effort here and @@ -2104,6 +2123,7 @@ impl Filter { Ok(Self { predicate: predicate.unalias_nested().data, input, + having, }) } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index dbe43128fd38..539cb1cf5fb2 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -87,8 +87,17 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Filter(Filter { predicate, input }) => rewrite_arc(input, f)? - .update_data(|input| LogicalPlan::Filter(Filter { predicate, input })), + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) + }), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -561,10 +570,17 @@ impl LogicalPlan { value.into_iter().map_until_stop_and_collect(&mut f) })? .update_data(|values| LogicalPlan::Values(Values { schema, values })), - LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? - .update_data(|predicate| { - LogicalPlan::Filter(Filter { predicate, input }) - }), + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) => f(predicate)?.update_data(|predicate| { + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) + }), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index cb278c767974..7b4b3bb95c46 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -337,6 +337,9 @@ where /// let expr = geometric_mean.call(vec![col("a")]); /// ``` pub trait AggregateUDFImpl: Debug + Send + Sync { + // Note: When adding any methods (with default implementations), remember to add them also + // into the AliasedAggregateUDFImpl below! + /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -635,6 +638,60 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { &self.aliases } + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.inner.state_fields(args) + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + self.inner.groups_accumulator_supported(args) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.inner.create_groups_accumulator(args) + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.inner.accumulator(args) + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + Arc::clone(&self.inner) + .with_beneficial_ordering(beneficial_ordering) + .map(|udf| { + udf.map(|udf| { + Arc::new(AliasedAggregateUDFImpl { + inner: udf, + aliases: self.aliases.clone(), + }) as Arc + }) + }) + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + self.inner.order_sensitivity() + } + + fn simplify(&self) -> Option { + self.inner.simplify() + } + + fn reverse_expr(&self) -> ReversedUDAF { + self.inner.reverse_expr() + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) + } + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { if let Some(other) = other.as_any().downcast_ref::() { self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases @@ -649,6 +706,10 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.aliases.hash(hasher); hasher.finish() } + + fn is_descending(&self) -> Option { + self.inner.is_descending() + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index a4584038e48b..be3f811dbe51 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -346,6 +346,9 @@ where /// let expr = add_one.call(vec![col("a")]); /// ``` pub trait ScalarUDFImpl: Debug + Send + Sync { + // Note: When adding any methods (with default implementations), remember to add them also + // into the AliasedScalarUDFImpl below! + /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -632,6 +635,14 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.name() } + fn display_name(&self, args: &[Expr]) -> Result { + self.inner.display_name(args) + } + + fn schema_name(&self, args: &[Expr]) -> Result { + self.inner.schema_name(args) + } + fn signature(&self) -> &Signature { self.inner.signature() } @@ -640,12 +651,57 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type(arg_types) } + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn return_type_from_exprs( + &self, + args: &[Expr], + schema: &dyn ExprSchema, + arg_types: &[DataType], + ) -> Result { + self.inner.return_type_from_exprs(args, schema, arg_types) + } + fn invoke(&self, args: &[ColumnarValue]) -> Result { self.inner.invoke(args) } - fn aliases(&self) -> &[String] { - &self.aliases + fn invoke_no_args(&self, number_rows: usize) -> Result { + self.inner.invoke_no_args(number_rows) + } + + fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + self.inner.simplify(args, info) + } + + fn short_circuits(&self) -> bool { + self.inner.short_circuits() + } + + fn evaluate_bounds(&self, input: &[&Interval]) -> Result { + self.inner.evaluate_bounds(input) + } + + fn propagate_constraints( + &self, + interval: &Interval, + inputs: &[&Interval], + ) -> Result>> { + self.inner.propagate_constraints(interval, inputs) + } + + fn output_ordering(&self, inputs: &[ExprProperties]) -> Result { + self.inner.output_ordering(inputs) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) } fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 88b3d613cb43..e5fdaaceb439 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -266,6 +266,9 @@ where /// .unwrap(); /// ``` pub trait WindowUDFImpl: Debug + Send + Sync { + // Note: When adding any methods (with default implementations), remember to add them also + // into the AliasedWindowUDFImpl below! + /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -428,6 +431,10 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { &self.aliases } + fn simplify(&self) -> Option { + self.inner.simplify() + } + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { if let Some(other) = other.as_any().downcast_ref::() { self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases @@ -442,6 +449,18 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.aliases.hash(hasher); hasher.finish() } + + fn nullable(&self) -> bool { + self.inner.nullable() + } + + fn sort_options(&self) -> Option { + self.inner.sort_options() + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) + } } /// Implementation of [`WindowUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 5f5c468fa2f5..11a244a944f8 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -804,6 +804,15 @@ pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan { match input { LogicalPlan::Window(window) => find_base_plan(&window.input), LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input), + LogicalPlan::Filter(filter) => { + if filter.having { + // If a filter is used for a having clause, its input plan is an aggregation. + // We should expand the wildcard expression based on the aggregation's input plan. + find_base_plan(&filter.input) + } else { + input + } + } _ => input, } } diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 6a1973ecfed1..5e1a15233cb5 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -50,7 +50,7 @@ datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } -itertools = { version = "0.12", features = ["use_std"] } +itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "1.0.14" rand = "0.8.5" diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index aea3d4a59e02..43d2796ad7dc 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -52,13 +52,13 @@ pub fn row_number_udwf() -> std::sync::Arc { /// row_number expression #[derive(Debug)] -struct RowNumber { +pub struct RowNumber { signature: Signature, } impl RowNumber { /// Create a new `row_number` function - fn new() -> Self { + pub fn new() -> Self { Self { signature: Signature::any(0, Volatility::Immutable), } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 9ef020b772f0..337379a74670 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -161,3 +161,8 @@ required-features = ["string_expressions"] harness = false name = "random" required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "substr" +required-features = ["unicode_expressions"] diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs new file mode 100644 index 000000000000..14a3389da380 --- /dev/null +++ b/datafusion/functions/benches/substr.rs @@ -0,0 +1,202 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::unicode; +use std::sync::Arc; + +fn create_args_without_count( + size: usize, + str_len: usize, + start_half_way: bool, + use_string_view: bool, +) -> Vec { + let start_array = Arc::new(Int64Array::from( + (0..size) + .map(|_| { + if start_half_way { + (str_len / 2) as i64 + } else { + 1i64 + } + }) + .collect::>(), + )); + + if use_string_view { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ] + } +} + +fn create_args_with_count( + size: usize, + str_len: usize, + count_max: usize, + use_string_view: bool, +) -> Vec { + let start_array = + Arc::new(Int64Array::from((0..size).map(|_| 1).collect::>())); + let count = count_max.min(str_len) as i64; + let count_array = Arc::new(Int64Array::from( + (0..size).map(|_| count).collect::>(), + )); + + if use_string_view { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ColumnarValue::Array(count_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ColumnarValue::Array(Arc::clone(&count_array) as ArrayRef), + ] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let substr = unicode::substr(); + for size in [1024, 4096] { + // string_len = 12, substring_len=6 (see `create_args_without_count`) + let len = 12; + let mut group = c.benchmark_group("SHORTER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_without_count::(size, len, true, true); + group.bench_function( + &format!("substr_string_view [size={}, strlen={}]", size, len), + |b| b.iter(|| black_box(substr.invoke(&args))), + ); + + let args = create_args_without_count::(size, len, false, false); + group.bench_function( + &format!("substr_string [size={}, strlen={}]", size, len), + |b| b.iter(|| black_box(substr.invoke(&args))), + ); + + let args = create_args_without_count::(size, len, true, false); + group.bench_function( + &format!("substr_large_string [size={}, strlen={}]", size, len), + |b| b.iter(|| black_box(substr.invoke(&args))), + ); + + group.finish(); + + // string_len = 128, start=1, count=64, substring_len=64 + let len = 128; + let count = 64; + let mut group = c.benchmark_group("LONGER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + &format!( + "substr_string_view [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| b.iter(|| black_box(substr.invoke(&args))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + &format!( + "substr_string [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| b.iter(|| black_box(substr.invoke(&args))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + &format!( + "substr_large_string [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| b.iter(|| black_box(substr.invoke(&args))), + ); + + group.finish(); + + // string_len = 128, start=1, count=6, substring_len=6 + let len = 128; + let count = 6; + let mut group = c.benchmark_group("SRC_LEN > 12, SUB_LEN < 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + &format!( + "substr_string_view [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| b.iter(|| black_box(substr.invoke(&args))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + &format!( + "substr_string [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| b.iter(|| black_box(substr.invoke(&args))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + &format!( + "substr_large_string [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| b.iter(|| black_box(substr.invoke(&args))), + ); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 54aebb039046..6f23a5ddd236 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! Common utilities for implementing string functions + use std::fmt::{Display, Formatter}; use std::sync::Arc; @@ -252,7 +254,69 @@ impl<'a> ColumnarValueRef<'a> { } } +/// Abstracts iteration over different types of string arrays. +/// +/// The [`StringArrayType`] trait helps write generic code for string functions that can work with +/// different types of string arrays. +/// +/// Currently three types are supported: +/// - [`StringArray`] +/// - [`LargeStringArray`] +/// - [`StringViewArray`] +/// +/// It is inspired / copied from [arrow-rs]. +/// +/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/bf0ea9129e617e4a3cf915a900b747cc5485315f/arrow-string/src/like.rs#L151-L157 +/// +/// # Examples +/// Generic function that works for [`StringArray`], [`LargeStringArray`] +/// and [`StringViewArray`]: +/// ``` +/// # use arrow::array::{StringArray, LargeStringArray, StringViewArray}; +/// # use datafusion_functions::string::common::StringArrayType; +/// +/// /// Combines string values for any StringArrayType type. It can be invoked on +/// /// and combination of `StringArray`, `LargeStringArray` or `StringViewArray` +/// fn combine_values<'a, S1, S2>(array1: S1, array2: S2) -> Vec +/// where S1: StringArrayType<'a>, S2: StringArrayType<'a> +/// { +/// // iterate over the elements of the 2 arrays in parallel +/// array1 +/// .iter() +/// .zip(array2.iter()) +/// .map(|(s1, s2)| { +/// // if both values are non null, combine them +/// if let (Some(s1), Some(s2)) = (s1, s2) { +/// format!("{s1}{s2}") +/// } else { +/// "None".to_string() +/// } +/// }) +/// .collect() +/// } +/// +/// let string_array = StringArray::from(vec!["foo", "bar"]); +/// let large_string_array = LargeStringArray::from(vec!["foo2", "bar2"]); +/// let string_view_array = StringViewArray::from(vec!["foo3", "bar3"]); +/// +/// // can invoke this function a string array and large string array +/// assert_eq!( +/// combine_values(&string_array, &large_string_array), +/// vec![String::from("foofoo2"), String::from("barbar2")] +/// ); +/// +/// // Can call the same function with string array and string view array +/// assert_eq!( +/// combine_values(&string_array, &string_view_array), +/// vec![String::from("foofoo3"), String::from("barbar3")] +/// ); +/// ``` +/// +/// [`LargeStringArray`]: arrow::array::LargeStringArray pub trait StringArrayType<'a>: ArrayAccessor + Sized { + /// Return an [`ArrayIter`] over the values of the array. + /// + /// This iterator iterates returns `Option<&str>` for each item in the array. fn iter(&self) -> ArrayIter; } diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 19721f0fad28..8d292315a35a 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -15,21 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; -use arrow::datatypes::DataType; - -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, +use arrow::array::{ + ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringViewArray, }; -use datafusion_common::{exec_err, Result}; +use arrow::array::{AsArray, GenericStringBuilder}; +use arrow::datatypes::DataType; +use datafusion_common::cast::as_int64_array; +use datafusion_common::ScalarValue; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use std::any::Any; +use std::sync::Arc; + +use crate::utils::utf8_to_str_type; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use super::common::StringArrayType; #[derive(Debug)] pub struct SplitPartFunc { @@ -82,36 +84,121 @@ impl ScalarUDFImpl for SplitPartFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match (args[0].data_type(), args[1].data_type()) { - ( - DataType::Utf8 | DataType::Utf8View, - DataType::Utf8 | DataType::Utf8View, - ) => make_scalar_function(split_part::, vec![])(args), + // First, determine if any of the arguments is an Array + let len = args.iter().find_map(|arg| match arg { + ColumnarValue::Array(a) => Some(a.len()), + _ => None, + }); + + let inferred_length = len.unwrap_or(1); + let is_scalar = len.is_none(); + + // Convert all ColumnarValues to ArrayRefs + let args = args + .iter() + .map(|arg| match arg { + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length), + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + }) + .collect::>>()?; + + // Unpack the ArrayRefs from the arguments + let n_array = as_int64_array(&args[2])?; + let result = match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8View, DataType::Utf8View) => { + split_part_impl::<&StringViewArray, &StringViewArray, i32>( + args[0].as_string_view(), + args[1].as_string_view(), + n_array, + ) + } + (DataType::Utf8View, DataType::Utf8) => { + split_part_impl::<&StringViewArray, &GenericStringArray, i32>( + args[0].as_string_view(), + args[1].as_string::(), + n_array, + ) + } + (DataType::Utf8View, DataType::LargeUtf8) => { + split_part_impl::<&StringViewArray, &GenericStringArray, i32>( + args[0].as_string_view(), + args[1].as_string::(), + n_array, + ) + } + (DataType::Utf8, DataType::Utf8View) => { + split_part_impl::<&GenericStringArray, &StringViewArray, i32>( + args[0].as_string::(), + args[1].as_string_view(), + n_array, + ) + } + (DataType::LargeUtf8, DataType::Utf8View) => { + split_part_impl::<&GenericStringArray, &StringViewArray, i64>( + args[0].as_string::(), + args[1].as_string_view(), + n_array, + ) + } + (DataType::Utf8, DataType::Utf8) => { + split_part_impl::<&GenericStringArray, &GenericStringArray, i32>( + args[0].as_string::(), + args[1].as_string::(), + n_array, + ) + } (DataType::LargeUtf8, DataType::LargeUtf8) => { - make_scalar_function(split_part::, vec![])(args) + split_part_impl::<&GenericStringArray, &GenericStringArray, i64>( + args[0].as_string::(), + args[1].as_string::(), + n_array, + ) } - (_, DataType::LargeUtf8) => { - make_scalar_function(split_part::, vec![])(args) + (DataType::Utf8, DataType::LargeUtf8) => { + split_part_impl::<&GenericStringArray, &GenericStringArray, i32>( + args[0].as_string::(), + args[1].as_string::(), + n_array, + ) } - (DataType::LargeUtf8, _) => { - make_scalar_function(split_part::, vec![])(args) + (DataType::LargeUtf8, DataType::Utf8) => { + split_part_impl::<&GenericStringArray, &GenericStringArray, i64>( + args[0].as_string::(), + args[1].as_string::(), + n_array, + ) } - (first_type, second_type) => exec_err!( - "unsupported first type {} and second type {} for split_part function", - first_type, - second_type - ), + _ => exec_err!("Unsupported combination of argument types for split_part"), + }; + if is_scalar { + // If all inputs are scalar, keep the output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) } } } -macro_rules! process_split_part { - ($string_array: expr, $delimiter_array: expr, $n_array: expr) => {{ - let result = $string_array - .iter() - .zip($delimiter_array.iter()) - .zip($n_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { +/// impl +pub fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>( + string_array: StringArrType, + delimiter_array: DelimiterArrType, + n_array: &Int64Array, +) -> Result +where + StringArrType: StringArrayType<'a>, + DelimiterArrType: StringArrayType<'a>, + StringArrayLen: OffsetSizeTrait, +{ + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + + string_array + .iter() + .zip(delimiter_array.iter()) + .zip(n_array.iter()) + .try_for_each(|((string, delimiter), n)| -> Result<(), DataFusionError> { + match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { let split_string: Vec<&str> = string.split(delimiter).collect(); let len = split_string.len(); @@ -125,58 +212,17 @@ macro_rules! process_split_part { } as usize; if index < len { - Ok(Some(split_string[index])) + builder.append_value(split_string[index]); } else { - Ok(Some("")) + builder.append_value(""); } } - _ => Ok(None), - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - }}; -} - -/// Splits string at occurrences of delimiter and returns the n'th field (counting from one). -/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' -fn split_part( - args: &[ArrayRef], -) -> Result { - let n_array = as_int64_array(&args[2])?; - match (args[0].data_type(), args[1].data_type()) { - (DataType::Utf8View, _) => { - let string_array = as_string_view_array(&args[0])?; - match args[1].data_type() { - DataType::Utf8View => { - let delimiter_array = as_string_view_array(&args[1])?; - process_split_part!(string_array, delimiter_array, n_array) - } - _ => { - let delimiter_array = - as_generic_string_array::(&args[1])?; - process_split_part!(string_array, delimiter_array, n_array) - } - } - } - (_, DataType::Utf8View) => { - let delimiter_array = as_string_view_array(&args[1])?; - match args[0].data_type() { - DataType::Utf8View => { - let string_array = as_string_view_array(&args[0])?; - process_split_part!(string_array, delimiter_array, n_array) - } - _ => { - let string_array = as_generic_string_array::(&args[0])?; - process_split_part!(string_array, delimiter_array, n_array) - } + _ => builder.append_null(), } - } - (_, _) => { - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - process_split_part!(string_array, delimiter_array, n_array) - } - } + Ok(()) + })?; + + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 702baf6e8fa7..cf10b18ae338 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -19,11 +19,10 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use datafusion_common::cast::as_generic_string_array; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -52,6 +51,9 @@ impl StrposFunc { Exact(vec![Utf8, LargeUtf8]), Exact(vec![LargeUtf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8]), + Exact(vec![Utf8View, Utf8View]), + Exact(vec![Utf8View, Utf8]), + Exact(vec![Utf8View, LargeUtf8]), ], Volatility::Immutable, ), @@ -78,21 +80,7 @@ impl ScalarUDFImpl for StrposFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match (args[0].data_type(), args[1].data_type()) { - (DataType::Utf8, DataType::Utf8) => { - make_scalar_function(strpos::, vec![])(args) - } - (DataType::Utf8, DataType::LargeUtf8) => { - make_scalar_function(strpos::, vec![])(args) - } - (DataType::LargeUtf8, DataType::Utf8) => { - make_scalar_function(strpos::, vec![])(args) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - make_scalar_function(strpos::, vec![])(args) - } - other => exec_err!("Unsupported data type {other:?} for function strpos"), - } + make_scalar_function(strpos, vec![])(args) } fn aliases(&self) -> &[String] { @@ -100,30 +88,71 @@ impl ScalarUDFImpl for StrposFunc { } } +fn strpos(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8, DataType::Utf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8, DataType::LargeUtf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::LargeUtf8, DataType::Utf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int64Type>(string_array, substring_array) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int64Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::Utf8View) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string_view(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::Utf8) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::LargeUtf8) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + + other => { + exec_err!("Unsupported data type combination {other:?} for function strpos") + } + } +} + /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 /// The implementation uses UTF-8 code points as characters -fn strpos( - args: &[ArrayRef], +fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>( + string_array: V1, + substring_array: V2, ) -> Result where - T0::Native: OffsetSizeTrait, - T1::Native: OffsetSizeTrait, + V1: ArrayAccessor, + V2: ArrayAccessor, { - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let substring_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; + let string_iter = ArrayIter::new(string_array); + let substring_iter = ArrayIter::new(substring_array); - let result = string_array - .iter() - .zip(substring_array.iter()) + let result = string_iter + .zip(substring_iter) .map(|(string, substring)| match (string, substring) { (Some(string), Some(substring)) => { - // the find method returns the byte index of the substring - // Next, we count the number of the chars until that byte - T0::Native::from_usize( + // The `find` method returns the byte index of the substring. + // We count the number of chars up to that byte index. + T::Native::from_usize( string .find(substring) .map(|x| string[..x].chars().count() + 1) @@ -132,20 +161,21 @@ where } _ => None, }) - .collect::>(); + .collect::>(); Ok(Arc::new(result) as ArrayRef) } #[cfg(test)] -mod test { - use super::*; +mod tests { + use arrow::array::{Array, Int32Array, Int64Array}; + use arrow::datatypes::DataType::{Int32, Int64}; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::strpos::StrposFunc; use crate::utils::test::test_function; - use arrow::{ - array::{Array as _, Int32Array, Int64Array}, - datatypes::DataType::{Int32, Int64}, - }; - use datafusion_common::ScalarValue; macro_rules! test_strpos { ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => { @@ -164,21 +194,54 @@ mod test { } #[test] - fn strpos() { - test_strpos!("foo", "bar" -> 0; Utf8 Utf8 i32 Int32 Int32Array); - test_strpos!("foobar", "foo" -> 1; Utf8 Utf8 i32 Int32 Int32Array); - test_strpos!("foobar", "bar" -> 4; Utf8 Utf8 i32 Int32 Int32Array); - - test_strpos!("foo", "bar" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("foobar", "foo" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("foobar", "bar" -> 4; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - - test_strpos!("foo", "bar" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("foobar", "foo" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("foobar", "bar" -> 4; Utf8 LargeUtf8 i32 Int32 Int32Array); - - test_strpos!("foo", "bar" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("foobar", "foo" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("foobar", "bar" -> 4; LargeUtf8 Utf8 i64 Int64 Int64Array); + fn test_strpos_functions() { + // Utf8 and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array); + + // LargeUtf8 and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + + // Utf8 and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); + + // LargeUtf8 and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); + + // Utf8View and Utf8View combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array); + + // Utf8View and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array); + + // Utf8View and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); } } diff --git a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs index 53ba3042f522..dd422f7aab95 100644 --- a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs @@ -160,14 +160,13 @@ fn replace_columns( mod tests { use arrow::datatypes::{DataType, Field, Schema}; + use crate::test::{assert_analyzed_plan_eq_display_indent, test_table_scan}; + use crate::Analyzer; use datafusion_common::{JoinType, TableReference}; use datafusion_expr::{ col, in_subquery, qualified_wildcard, table_scan, wildcard, LogicalPlanBuilder, }; - use crate::test::{assert_analyzed_plan_eq_display_indent, test_table_scan}; - use crate::Analyzer; - use super::*; fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 26885ae1350c..b663d8614275 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -41,6 +41,7 @@ use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested}; +use crate::expressions::binary::kernels::concat_elements_utf8view; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, @@ -131,34 +132,6 @@ impl std::fmt::Display for BinaryExpr { } } -/// Invoke a compute kernel on a pair of binary data arrays -macro_rules! compute_utf8_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast right side array"); - Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?)) - }}; -} - -macro_rules! binary_string_array_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - match $LEFT.data_type() { - DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray), - DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, LargeStringArray), - other => internal_err!( - "Data type {:?} not supported for binary operation '{}' on string arrays", - other, stringify!($OP) - ), - } - }}; -} - /// Invoke a boolean kernel on a pair of arrays macro_rules! boolean_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ @@ -662,7 +635,7 @@ impl BinaryExpr { BitwiseXor => bitwise_xor_dyn(left, right), BitwiseShiftRight => bitwise_shift_right_dyn(left, right), BitwiseShiftLeft => bitwise_shift_left_dyn(left, right), - StringConcat => binary_string_array_op!(left, right, concat_elements), + StringConcat => concat_elements(left, right), AtArrow | ArrowAt => { unreachable!("ArrowAt and AtArrow should be rewritten to function") } @@ -670,6 +643,28 @@ impl BinaryExpr { } } +fn concat_elements(left: Arc, right: Arc) -> Result { + Ok(match left.data_type() { + DataType::Utf8 => Arc::new(concat_elements_utf8( + left.as_string::(), + right.as_string::(), + )?), + DataType::LargeUtf8 => Arc::new(concat_elements_utf8( + left.as_string::(), + right.as_string::(), + )?), + DataType::Utf8View => Arc::new(concat_elements_utf8view( + left.as_string_view(), + right.as_string_view(), + )?), + other => { + return internal_err!( + "Data type {other:?} not supported for binary operation 'concat_elements' on string arrays" + ); + } + }) +} + /// Create a binary expression whose arguments are correctly coerced. /// This function errors if it is not possible to coerce the arguments /// to computational types supported by the operator. diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index b0736e140fec..1f9cfed1a44f 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -27,6 +27,7 @@ use arrow::datatypes::DataType; use datafusion_common::internal_err; use datafusion_common::{Result, ScalarValue}; +use arrow_schema::ArrowError; use std::sync::Arc; /// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT, $RIGHT) @@ -131,3 +132,35 @@ create_dyn_scalar_kernel!(bitwise_or_dyn_scalar, bitwise_or_scalar); create_dyn_scalar_kernel!(bitwise_xor_dyn_scalar, bitwise_xor_scalar); create_dyn_scalar_kernel!(bitwise_shift_right_dyn_scalar, bitwise_shift_right_scalar); create_dyn_scalar_kernel!(bitwise_shift_left_dyn_scalar, bitwise_shift_left_scalar); + +pub fn concat_elements_utf8view( + left: &StringViewArray, + right: &StringViewArray, +) -> std::result::Result { + let capacity = left + .data_buffers() + .iter() + .zip(right.data_buffers().iter()) + .map(|(b1, b2)| b1.len() + b2.len()) + .sum(); + let mut result = StringViewBuilder::with_capacity(capacity); + + // Avoid reallocations by writing to a reused buffer (note we + // could be even more efficient r by creating the view directly + // here and avoid the buffer but that would be more complex) + let mut buffer = String::new(); + + for (left, right) in left.iter().zip(right.iter()) { + if let (Some(left), Some(right)) = (left, right) { + use std::fmt::Write; + buffer.clear(); + write!(&mut buffer, "{left}{right}") + .expect("writing into string buffer failed"); + result.append_value(&buffer); + } else { + // at least one of the values is null, so the output is also null + result.append_null() + } + } + Ok(result.finish()) +} diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index e7e1c5481f80..a81b09948cca 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -499,6 +499,12 @@ impl ExternalSorter { metrics: BaselineMetrics, ) -> Result { assert_ne!(self.in_mem_batches.len(), 0); + + // The elapsed compute timer is updated when the value is dropped. + // There is no need for an explicit call to drop. + let elapsed_compute = metrics.elapsed_compute().clone(); + let _timer = elapsed_compute.timer(); + if self.in_mem_batches.len() == 1 { let batch = self.in_mem_batches.remove(0); let reservation = self.reservation.take(); @@ -552,7 +558,9 @@ impl ExternalSorter { let fetch = self.fetch; let expressions = Arc::clone(&self.expr); let stream = futures::stream::once(futures::future::lazy(move |_| { + let timer = metrics.elapsed_compute().timer(); let sorted = sort_batch(&batch, &expressions, fetch)?; + timer.done(); metrics.record_output(sorted.num_rows()); drop(batch); drop(reservation); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index acf540d44465..826992e132ba 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1140,6 +1140,7 @@ message NestedLoopJoinExecNode { message CoalesceBatchesExecNode { PhysicalPlanNode input = 1; uint32 target_batch_size = 2; + optional uint32 fetch = 3; } message CoalescePartitionsExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e59f1c121cd8..e78ffe1004a9 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2002,6 +2002,9 @@ impl serde::Serialize for CoalesceBatchesExecNode { if self.target_batch_size != 0 { len += 1; } + if self.fetch.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CoalesceBatchesExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -2009,6 +2012,9 @@ impl serde::Serialize for CoalesceBatchesExecNode { if self.target_batch_size != 0 { struct_ser.serialize_field("targetBatchSize", &self.target_batch_size)?; } + if let Some(v) = self.fetch.as_ref() { + struct_ser.serialize_field("fetch", v)?; + } struct_ser.end() } } @@ -2022,12 +2028,14 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { "input", "target_batch_size", "targetBatchSize", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, TargetBatchSize, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2051,6 +2059,7 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { match value { "input" => Ok(GeneratedField::Input), "targetBatchSize" | "target_batch_size" => Ok(GeneratedField::TargetBatchSize), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2072,6 +2081,7 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { { let mut input__ = None; let mut target_batch_size__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -2088,11 +2098,20 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(CoalesceBatchesExecNode { input: input__, target_batch_size: target_batch_size__.unwrap_or_default(), + fetch: fetch__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 41d469fe5ee4..7b1f5d7bc38e 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1813,6 +1813,8 @@ pub struct CoalesceBatchesExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(uint32, tag = "2")] pub target_batch_size: u32, + #[prost(uint32, optional, tag = "3")] + pub fetch: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 0f6722dd375b..96fb45eafe62 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -259,10 +259,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, )?; - Ok(Arc::new(CoalesceBatchesExec::new( - input, - coalesce_batches.target_batch_size as usize, - ))) + Ok(Arc::new( + CoalesceBatchesExec::new( + input, + coalesce_batches.target_batch_size as usize, + ) + .with_fetch(coalesce_batches.fetch.map(|f| f as usize)), + )) } PhysicalPlanType::Merge(merge) => { let input: Arc = @@ -1536,6 +1539,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { protobuf::CoalesceBatchesExecNode { input: Some(Box::new(input)), target_batch_size: coalesce_batches.target_batch_size() as u32, + fetch: coalesce_batches.fetch().map(|n| n as u32), }, ))), }); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 6766468ef443..0ffc494321fb 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -25,6 +25,7 @@ use std::vec; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; @@ -629,6 +630,23 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { )) } +#[test] +fn roundtrip_coalesce_with_fetch() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + roundtrip_test(Arc::new(CoalesceBatchesExec::new( + Arc::new(EmptyExec::new(schema.clone())), + 8096, + )))?; + + roundtrip_test(Arc::new( + CoalesceBatchesExec::new(Arc::new(EmptyExec::new(schema.clone())), 8096) + .with_fetch(Some(10)), + )) +} + #[test] fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { let scan_config = FileScanConfig { diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 4e0ce33f1334..45fda094557b 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -215,7 +215,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr { LogicalPlanBuilder::from(plan) - .filter(having_expr_post_aggr)? + .having(having_expr_post_aggr)? .build()? } else { plan diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index af161bba45c1..c32acecaae5f 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -17,8 +17,6 @@ //! SQL Utility Functions -use std::collections::HashMap; - use arrow_schema::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; @@ -33,6 +31,7 @@ use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{expr_vec_fmt, Expr, ExprSchemable, LogicalPlan}; use sqlparser::ast::{Ident, Value}; +use std::collections::HashMap; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index d381b95310a8..46b59c84f171 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5655,6 +5655,97 @@ select count(null), min(null), max(null), bit_and(NULL), bit_or(NULL), bit_xor(N ---- 0 NULL NULL NULL NULL NULL NULL NULL +statement ok +create table having_test(v1 int, v2 int) + +statement ok +create table join_table(v1 int, v2 int) + +statement ok +insert into having_test values (1, 2), (2, 3), (3, 4) + +statement ok +insert into join_table values (1, 2), (2, 3), (3, 4) + + +query II +select * from having_test group by v1, v2 having max(v1) = 3 +---- +3 4 + +query TT +EXPLAIN select * from having_test group by v1, v2 having max(v1) = 3 +---- +logical_plan +01)Projection: having_test.v1, having_test.v2 +02)--Filter: max(having_test.v1) = Int32(3) +03)----Aggregate: groupBy=[[having_test.v1, having_test.v2]], aggr=[[max(having_test.v1)]] +04)------TableScan: having_test projection=[v1, v2] +physical_plan +01)ProjectionExec: expr=[v1@0 as v1, v2@1 as v2] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: max(having_test.v1)@2 = 3 +04)------AggregateExec: mode=FinalPartitioned, gby=[v1@0 as v1, v2@1 as v2], aggr=[max(having_test.v1)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([v1@0, v2@1], 4), input_partitions=4 +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------AggregateExec: mode=Partial, gby=[v1@0 as v1, v2@1 as v2], aggr=[max(having_test.v1)] +09)----------------MemoryExec: partitions=1, partition_sizes=[1] + + +query error +select * from having_test having max(v1) = 3 + +query I +select max(v1) from having_test having max(v1) = 3 +---- +3 + +query I +select max(v1), * exclude (v1, v2) from having_test having max(v1) = 3 +---- +3 + +# because v1, v2 is not in the group by clause, the sql is invalid +query III +select max(v1), * replace ('v1' as v3) from having_test group by v1, v2 having max(v1) = 3 +---- +3 3 4 + +query III +select max(v1), t.* from having_test t group by v1, v2 having max(v1) = 3 +---- +3 3 4 + +# j.* should also be included in the group-by clause +query error +select max(t.v1), j.* from having_test t join join_table j on t.v1 = j.v1 group by t.v1, t.v2 having max(t.v1) = 3 + +query III +select max(t.v1), j.* from having_test t join join_table j on t.v1 = j.v1 group by j.v1, j.v2 having max(t.v1) = 3 +---- +3 3 4 + +# If the select items only contain scalar expressions, the having clause is valid. +query P +select now() from having_test having max(v1) = 4 +---- + +# If the select items only contain scalar expressions, the having clause is valid. +query I +select 0 from having_test having max(v1) = 4 +---- + +# v2 should also be included in group-by clause +query error +select * from having_test group by v1 having max(v1) = 3 + +statement ok +drop table having_test + +statement ok +drop table join_table + # test min/max Float16 without group expression query RRTT WITH data AS ( diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 0b441bcbeb8f..3b3d7b88a4a1 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -1066,9 +1066,8 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: strpos(__common_expr_1, Utf8("f")) AS c, strpos(__common_expr_1, CAST(test.column2_utf8view AS Utf8)) AS c2 -02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view -03)----TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: strpos(test.column1_utf8view, Utf8("f")) AS c, strpos(test.column1_utf8view, test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for SUBSTR ## TODO file ticket @@ -1145,6 +1144,63 @@ FROM test; 0 NULL +# || mixed types +# expect all results to be the same for each row as they all have the same values +query TTTTTTTT +SELECT + column1_utf8view || column2_utf8view, + column1_utf8 || column2_utf8view, + column1_large_utf8 || column2_utf8view, + column1_dict || column2_utf8view, + -- reverse argument order + column2_utf8view || column1_utf8view, + column2_utf8view || column1_utf8, + column2_utf8view || column1_large_utf8, + column2_utf8view || column1_dict +FROM test; +---- +AndrewX AndrewX AndrewX AndrewX XAndrew XAndrew XAndrew XAndrew +XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng +RaphaelR RaphaelR RaphaelR RaphaelR RRaphael RRaphael RRaphael RRaphael +NULL NULL NULL NULL NULL NULL NULL NULL + +# || constants +# expect all results to be the same for each row as they all have the same values +query TTTTTTTT +SELECT + column1_utf8view || 'foo', + column1_utf8 || 'foo', + column1_large_utf8 || 'foo', + column1_dict || 'foo', + -- reverse argument order + 'foo' || column1_utf8view, + 'foo' || column1_utf8, + 'foo' || column1_large_utf8, + 'foo' || column1_dict +FROM test; +---- +Andrewfoo Andrewfoo Andrewfoo Andrewfoo fooAndrew fooAndrew fooAndrew fooAndrew +Xiangpengfoo Xiangpengfoo Xiangpengfoo Xiangpengfoo fooXiangpeng fooXiangpeng fooXiangpeng fooXiangpeng +Raphaelfoo Raphaelfoo Raphaelfoo Raphaelfoo fooRaphael fooRaphael fooRaphael fooRaphael +NULL NULL NULL NULL NULL NULL NULL NULL + +# || same type (column1 has null, so also tests NULL || NULL) +# expect all results to be the same for each row as they all have the same values +query TTT +SELECT + column1_utf8view || column1_utf8view, + column1_utf8 || column1_utf8, + column1_large_utf8 || column1_large_utf8 + -- Dictionary/Dictionary coercion doesn't work + -- https://github.com/apache/datafusion/issues/12101 + --column1_dict || column1_dict +FROM test; +---- +AndrewAndrew AndrewAndrew AndrewAndrew +XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng +RaphaelRaphael RaphaelRaphael RaphaelRaphael +NULL NULL NULL + statement ok drop table test; @@ -1168,18 +1224,25 @@ select t.dt from dates t where arrow_cast('2024-01-01', 'Utf8View') < t.dt; statement ok drop table dates; +### Tests for `||` with Utf8View specifically + statement ok create table temp as values ('value1', arrow_cast('rust', 'Utf8View'), arrow_cast('fast', 'Utf8View')), ('value2', arrow_cast('datafusion', 'Utf8View'), arrow_cast('cool', 'Utf8View')); +query TTT +select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from temp; +---- +Utf8 Utf8View Utf8View +Utf8 Utf8View Utf8View + query T select column2||' is fast' from temp; ---- rust is fast datafusion is fast - query T select column2 || ' is ' || column3 from temp; ---- @@ -1190,15 +1253,15 @@ query TT explain select column2 || 'is' || column3 from temp; ---- logical_plan -01)Projection: CAST(temp.column2 AS Utf8) || Utf8("is") || CAST(temp.column3 AS Utf8) +01)Projection: temp.column2 || Utf8View("is") || temp.column3 AS temp.column2 || Utf8("is") || temp.column3 02)--TableScan: temp projection=[column2, column3] - +# should not cast the column2 to utf8 query TT explain select column2||' is fast' from temp; ---- logical_plan -01)Projection: CAST(temp.column2 AS Utf8) || Utf8(" is fast") +01)Projection: temp.column2 || Utf8View(" is fast") AS temp.column2 || Utf8(" is fast") 02)--TableScan: temp projection=[column2] @@ -1212,7 +1275,7 @@ query TT explain select column2||column3 from temp; ---- logical_plan -01)Projection: CAST(temp.column2 AS Utf8) || CAST(temp.column3 AS Utf8) +01)Projection: temp.column2 || temp.column3 02)--TableScan: temp projection=[column2, column3] query T diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index f618f844fea7..ff02ef8c7ef6 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -40,7 +40,7 @@ itertools = { workspace = true } object_store = { workspace = true } pbjson-types = "0.7" prost = "0.13" -substrait = { version = "0.37", features = ["serde"] } +substrait = { version = "0.41", features = ["serde"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f2756bb06d1e..b1b510f1792d 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -42,14 +42,14 @@ use crate::variation_const::{ DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, - TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; #[allow(deprecated)] use crate::variation_const::{ INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_REF, + INTERVAL_YEAR_MONTH_TYPE_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, + TIMESTAMP_SECOND_TYPE_VARIATION_REF, }; use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::common::scalar::ScalarStructBuilder; @@ -69,6 +69,7 @@ use datafusion::{ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::user_defined::Val; use substrait::proto::expression::literal::{ IntervalDayToSecond, IntervalYearToMonth, UserDefined, @@ -95,6 +96,13 @@ use substrait::proto::{ }; use substrait::proto::{FunctionArgument, SortField}; +// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which +// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone +// results in correct points on the timeline, and we pick UTC as a reasonable default. +// However, DF uses the timezone also for some arithmetic and display purposes (see e.g. +// https://github.com/apache/arrow-rs/blob/ee5694078c86c8201549654246900a4232d531a9/arrow-cast/src/cast/mod.rs#L1749). +const DEFAULT_TIMEZONE: &str = "UTC"; + pub fn name_to_op(name: &str) -> Option { match name { "equal" => Some(Operator::Eq), @@ -877,8 +885,8 @@ fn from_substrait_jointype(join_type: i32) -> Result { join_rel::JoinType::Left => Ok(JoinType::Left), join_rel::JoinType::Right => Ok(JoinType::Right), join_rel::JoinType::Outer => Ok(JoinType::Full), - join_rel::JoinType::Anti => Ok(JoinType::LeftAnti), - join_rel::JoinType::Semi => Ok(JoinType::LeftSemi), + join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), + join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), _ => plan_err!("unsupported join type {substrait_join_type:?}"), } } else { @@ -1369,23 +1377,51 @@ fn from_substrait_type( }, r#type::Kind::Fp32(_) => Ok(DataType::Float32), r#type::Kind::Fp64(_) => Ok(DataType::Float64), - r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + r#type::Kind::Timestamp(ts) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Second, None)) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, + } + r#type::Kind::PrecisionTimestamp(pts) => { + let unit = match pts.precision { + 0 => Ok(TimeUnit::Second), + 3 => Ok(TimeUnit::Millisecond), + 6 => Ok(TimeUnit::Microsecond), + 9 => Ok(TimeUnit::Nanosecond), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ), + }?; + Ok(DataType::Timestamp(unit, None)) + } + r#type::Kind::PrecisionTimestampTz(pts) => { + let unit = match pts.precision { + 0 => Ok(TimeUnit::Second), + 3 => Ok(TimeUnit::Millisecond), + 6 => Ok(TimeUnit::Microsecond), + 9 => Ok(TimeUnit::Nanosecond), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestampTz" + ), + }?; + Ok(DataType::Timestamp(unit, Some(DEFAULT_TIMEZONE.into()))) + } r#type::Kind::Date(date) => match date.type_variation_reference { DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), @@ -1465,22 +1501,10 @@ fn from_substrait_type( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, - r#type::Kind::IntervalYear(i) => match i.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::IntervalDay(i) => match i.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, + r#type::Kind::IntervalYear(_) => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), r#type::Kind::UserDefined(u) => { if let Some(name) = extensions.types.get(&u.type_reference) { match name.as_ref() { @@ -1676,21 +1700,59 @@ fn from_substrait_literal( }, Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), - Some(LiteralType::Timestamp(t)) => match lit.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - ScalarValue::TimestampSecond(Some(*t), None) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - ScalarValue::TimestampMillisecond(Some(*t), None) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - ScalarValue::TimestampMicrosecond(Some(*t), None) + Some(LiteralType::Timestamp(t)) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match lit.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + ScalarValue::TimestampSecond(Some(*t), None) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + ScalarValue::TimestampMillisecond(Some(*t), None) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + ScalarValue::TimestampMicrosecond(Some(*t), None) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + ScalarValue::TimestampNanosecond(Some(*t), None) + } + others => { + return substrait_err!("Unknown type variation reference {others}"); + } } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - ScalarValue::TimestampNanosecond(Some(*t), None) + } + Some(LiteralType::PrecisionTimestamp(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond(Some(pt.value), None), + 3 => ScalarValue::TimestampMillisecond(Some(pt.value), None), + 6 => ScalarValue::TimestampMicrosecond(Some(pt.value), None), + 9 => ScalarValue::TimestampNanosecond(Some(pt.value), None), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); } - others => { - return substrait_err!("Unknown type variation reference {others}"); + }, + Some(LiteralType::PrecisionTimestampTz(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 3 => ScalarValue::TimestampMillisecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 6 => ScalarValue::TimestampMicrosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 9 => ScalarValue::TimestampNanosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); } }, Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), @@ -1881,10 +1943,24 @@ fn from_substrait_literal( Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { days, seconds, - microseconds, + subseconds, + precision_mode, })) => { - // DF only supports millisecond precision, so we lose the micros here - ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000)) + // DF only supports millisecond precision, so for any more granular type we lose precision + let milliseconds = match precision_mode { + Some(PrecisionMode::Microseconds(ms)) => ms / 1000, + Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000, + Some(PrecisionMode::Precision(3)) => *subseconds as i32, + Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32, + Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32, + _ => { + return not_impl_err!( + "Unsupported Substrait interval day to second precision mode" + ) + } + }; + + ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds) } Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { ScalarValue::new_interval_ym(*years, *months) @@ -2026,21 +2102,55 @@ fn from_substrait_null( }, r#type::Kind::Fp32(_) => Ok(ScalarValue::Float32(None)), r#type::Kind::Fp64(_) => Ok(ScalarValue::Float64(None)), - r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - Ok(ScalarValue::TimestampSecond(None, None)) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - Ok(ScalarValue::TimestampMillisecond(None, None)) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - Ok(ScalarValue::TimestampMicrosecond(None, None)) - } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - Ok(ScalarValue::TimestampNanosecond(None, None)) + r#type::Kind::Timestamp(ts) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + Ok(ScalarValue::TimestampSecond(None, None)) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + Ok(ScalarValue::TimestampMillisecond(None, None)) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + Ok(ScalarValue::TimestampMicrosecond(None, None)) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + Ok(ScalarValue::TimestampNanosecond(None, None)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ), } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" + } + r#type::Kind::PrecisionTimestamp(pts) => match pts.precision { + 0 => Ok(ScalarValue::TimestampSecond(None, None)), + 3 => Ok(ScalarValue::TimestampMillisecond(None, None)), + 6 => Ok(ScalarValue::TimestampMicrosecond(None, None)), + 9 => Ok(ScalarValue::TimestampNanosecond(None, None)), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ), + }, + r#type::Kind::PrecisionTimestampTz(pts) => match pts.precision { + 0 => Ok(ScalarValue::TimestampSecond( + None, + Some(DEFAULT_TIMEZONE.into()), + )), + 3 => Ok(ScalarValue::TimestampMillisecond( + None, + Some(DEFAULT_TIMEZONE.into()), + )), + 6 => Ok(ScalarValue::TimestampMicrosecond( + None, + Some(DEFAULT_TIMEZONE.into()), + )), + 9 => Ok(ScalarValue::TimestampNanosecond( + None, + Some(DEFAULT_TIMEZONE.into()), + )), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" ), }, r#type::Kind::Date(date) => match date.type_variation_reference { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index ee04749f5e6b..72b6760be29c 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -38,8 +38,6 @@ use crate::variation_const::{ DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, - TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; @@ -55,10 +53,11 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use pbjson_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::map::KeyValue; use substrait::proto::expression::literal::{ - user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, Struct, - UserDefined, + user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, + PrecisionTimestamp, Struct, UserDefined, }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; @@ -658,8 +657,8 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { JoinType::Left => join_rel::JoinType::Left, JoinType::Right => join_rel::JoinType::Right, JoinType::Full => join_rel::JoinType::Outer, - JoinType::LeftAnti => join_rel::JoinType::Anti, - JoinType::LeftSemi => join_rel::JoinType::Semi, + JoinType::LeftAnti => join_rel::JoinType::LeftAnti, + JoinType::LeftSemi => join_rel::JoinType::LeftSemi, JoinType::RightAnti | JoinType::RightSemi => unimplemented!(), } } @@ -1376,20 +1375,31 @@ fn to_substrait_type( nullability, })), }), - // Timezone is ignored. - DataType::Timestamp(unit, _) => { - let type_variation_reference = match unit { - TimeUnit::Second => TIMESTAMP_SECOND_TYPE_VARIATION_REF, - TimeUnit::Millisecond => TIMESTAMP_MILLI_TYPE_VARIATION_REF, - TimeUnit::Microsecond => TIMESTAMP_MICRO_TYPE_VARIATION_REF, - TimeUnit::Nanosecond => TIMESTAMP_NANO_TYPE_VARIATION_REF, + DataType::Timestamp(unit, tz) => { + let precision = match unit { + TimeUnit::Second => 0, + TimeUnit::Millisecond => 3, + TimeUnit::Microsecond => 6, + TimeUnit::Nanosecond => 9, }; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference, + let kind = match tz { + None => r#type::Kind::PrecisionTimestamp(r#type::PrecisionTimestamp { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, - })), - }) + precision, + }), + Some(_) => { + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + r#type::Kind::PrecisionTimestampTz(r#type::PrecisionTimestampTz { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision, + }) + } + }; + Ok(substrait::proto::Type { kind: Some(kind) }) } DataType::Date32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { @@ -1415,6 +1425,7 @@ fn to_substrait_type( kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, + precision: Some(3), // DayTime precision is always milliseconds })), }), IntervalUnit::MonthDayNano => { @@ -1798,21 +1809,64 @@ fn to_substrait_literal( ScalarValue::Float64(Some(f)) => { (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::TimestampSecond(Some(t), _) => ( - LiteralType::Timestamp(*t), - TIMESTAMP_SECOND_TYPE_VARIATION_REF, + ScalarValue::TimestampSecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, ), - ScalarValue::TimestampMillisecond(Some(t), _) => ( - LiteralType::Timestamp(*t), - TIMESTAMP_MILLI_TYPE_VARIATION_REF, + ScalarValue::TimestampMillisecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, ), - ScalarValue::TimestampMicrosecond(Some(t), _) => ( - LiteralType::Timestamp(*t), - TIMESTAMP_MICRO_TYPE_VARIATION_REF, + ScalarValue::TimestampMicrosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, ), - ScalarValue::TimestampNanosecond(Some(t), _) => ( - LiteralType::Timestamp(*t), - TIMESTAMP_NANO_TYPE_VARIATION_REF, + ScalarValue::TimestampNanosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + ScalarValue::TimestampSecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, ), ScalarValue::Date32(Some(d)) => { (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) @@ -1847,7 +1901,8 @@ fn to_substrait_literal( LiteralType::IntervalDayToSecond(IntervalDayToSecond { days: i.days, seconds: i.milliseconds / 1000, - microseconds: (i.milliseconds % 1000) * 1000, + subseconds: (i.milliseconds % 1000) as i64, + precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds }), DEFAULT_TYPE_VARIATION_REF, ), @@ -2142,6 +2197,18 @@ mod test { round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; + for (ts, tz) in [ + (Some(12345), None), + (None, None), + (Some(12345), Some("UTC".into())), + (None, Some("UTC".into())), + ] { + round_trip_literal(ScalarValue::TimestampSecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMillisecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMicrosecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampNanosecond(ts, tz))?; + } + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( &[ScalarValue::Float32(Some(1.0))], &DataType::Float32, @@ -2271,10 +2338,14 @@ mod test { round_trip_type(DataType::UInt64)?; round_trip_type(DataType::Float32)?; round_trip_type(DataType::Float64)?; - round_trip_type(DataType::Timestamp(TimeUnit::Second, None))?; - round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, None))?; - round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, None))?; - round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, None))?; + + for tz in [None, Some("UTC".into())] { + round_trip_type(DataType::Timestamp(TimeUnit::Second, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, tz))?; + } + round_trip_type(DataType::Date32)?; round_trip_type(DataType::Date64)?; round_trip_type(DataType::Binary)?; diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index c94ad2d669fd..1525da764509 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -38,10 +38,16 @@ /// The "system-preferred" variation (i.e., no variation). pub const DEFAULT_TYPE_VARIATION_REF: u32 = 0; pub const UNSIGNED_INTEGER_TYPE_VARIATION_REF: u32 = 1; + +#[deprecated(since = "42.0.0", note = "Use `PrecisionTimestamp(Tz)` type instead")] pub const TIMESTAMP_SECOND_TYPE_VARIATION_REF: u32 = 0; +#[deprecated(since = "42.0.0", note = "Use `PrecisionTimestamp(Tz)` type instead")] pub const TIMESTAMP_MILLI_TYPE_VARIATION_REF: u32 = 1; +#[deprecated(since = "42.0.0", note = "Use `PrecisionTimestamp(Tz)` type instead")] pub const TIMESTAMP_MICRO_TYPE_VARIATION_REF: u32 = 2; +#[deprecated(since = "42.0.0", note = "Use `PrecisionTimestamp(Tz)` type instead")] pub const TIMESTAMP_NANO_TYPE_VARIATION_REF: u32 = 3; + pub const DATE_32_TYPE_VARIATION_REF: u32 = 0; pub const DATE_64_TYPE_VARIATION_REF: u32 = 1; pub const DEFAULT_CONTAINER_TYPE_VARIATION_REF: u32 = 0;