diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index 155d1c740b1a..ee58206a776e 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -947,6 +947,8 @@ macro_rules! equal_rows_elem { } /// Left and right row have equal values +/// If more data types are supported here, please also add the data types in can_hash function +/// to generate hash join logical plan. fn equal_rows( left: usize, right: usize, diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index df853b7c9f04..6b3b8c33964d 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1204,3 +1204,141 @@ async fn join_partitioned() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn join_with_hash_unsupported_data_type() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Int64, true), + Field::new("c4", DataType::Date32, true), + ]); + let data = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from_slice(&[1, 2, 3])), + Arc::new(StringArray::from_slice(&["aaa", "bbb", "ccc"])), + Arc::new(Int64Array::from_slice(&[100, 200, 300])), + Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])), + ], + )?; + let table = MemTable::try_new(data.schema(), vec![vec![data]])?; + ctx.register_table("foo", Arc::new(table))?; + + // join on hash unsupported data type (Date32), use cross join instead hash join + let sql = "select * from foo t1 join foo t2 on t1.c4 = t2.c4"; + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " CrossJoin: [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+----+-----+-----+------------+----+-----+-----+------------+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+----+-----+-----+------------+----+-----+-----+------------+", + "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |", + "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |", + "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |", + "+----+-----+-----+------------+----+-----+-----+------------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + // join on hash supported data type (Int32), use hash join + let sql = "select * from foo t1 join foo t2 on t1.c1 = t2.c1"; + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Inner Join: #t1.c1 = #t2.c1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+----+-----+-----+------------+----+-----+-----+------------+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+----+-----+-----+------------+----+-----+-----+------------+", + "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |", + "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |", + "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |", + "+----+-----+-----+------------+----+-----+-----+------------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + // join on two columns, hash supported data type(Int64) and hash unsupported data type (Date32), + // use hash join on Int64 column, and filter on Date32 column. + let sql = "select * from foo t1, foo t2 where t1.c3 = t2.c3 and t1.c4 = t2.c4"; + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Inner Join: #t1.c3 = #t2.c3 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+----+-----+-----+------------+----+-----+-----+------------+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+----+-----+-----+------------+----+-----+-----+------------+", + "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |", + "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |", + "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |", + "+----+-----+-----+------------+----+-----+-----+------------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 9176061442e6..5e7fa031cb00 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -19,6 +19,7 @@ use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs}; use crate::utils::{columnize_expr, exprlist_to_fields, from_plan}; +use crate::{and, binary_expr, Operator}; use crate::{ logical_plan::{ Aggregate, Analyze, CrossJoin, EmptyRelation, Explain, Filter, Join, @@ -27,7 +28,7 @@ use crate::{ Union, Values, Window, }, utils::{ - expand_qualified_wildcard, expand_wildcard, expr_to_columns, + can_hash, expand_qualified_wildcard, expand_wildcard, expr_to_columns, group_window_expr_by_sort_keys, }, Expr, ExprSchemable, TableSource, @@ -605,17 +606,46 @@ impl LogicalPlanBuilder { let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - - Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan.clone()), - right: Arc::new(right.clone()), - on, - filter: None, - join_type, - join_constraint: JoinConstraint::Using, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }))) + let mut join_on: Vec<(Column, Column)> = vec![]; + let mut filters: Option = None; + for (l, r) in &on { + if self.plan.schema().field_from_column(l).is_ok() + && right.schema().field_from_column(r).is_ok() + && can_hash(self.plan.schema().field_from_column(l).unwrap().data_type()) + { + join_on.push((l.clone(), r.clone())); + } else if self.plan.schema().field_from_column(r).is_ok() + && right.schema().field_from_column(l).is_ok() + && can_hash(self.plan.schema().field_from_column(r).unwrap().data_type()) + { + join_on.push((r.clone(), l.clone())); + } else { + let expr = binary_expr( + Expr::Column(l.clone()), + Operator::Eq, + Expr::Column(r.clone()), + ); + match filters { + None => filters = Some(expr), + Some(filter_expr) => filters = Some(and(expr, filter_expr)), + } + } + } + if join_on.is_empty() { + let join = Self::from(self.plan.clone()).cross_join(&right.clone())?; + join.filter(filters.unwrap()) + } else { + Ok(Self::from(LogicalPlan::Join(Join { + left: Arc::new(self.plan.clone()), + right: Arc::new(right.clone()), + on: join_on, + filter: filters, + join_type, + join_constraint: JoinConstraint::Using, + schema: DFSchemaRef::new(join_schema), + null_equals_null: false, + }))) + } } /// Apply a cross join diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 3986eb3e64e3..8e0cdbc00af0 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -25,6 +25,7 @@ use crate::logical_plan::{ Values, Window, }; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; +use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; @@ -643,6 +644,35 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { } } +/// can this data type be used in hash join equal conditions?? +/// data types here come from function 'equal_rows', if more data types are supported +/// in equal_rows(hash join), add those data types here to generate join logical plan. +pub fn can_hash(data_type: &DataType) -> bool { + match data_type { + DataType::Null => true, + DataType::Boolean => true, + DataType::Int8 => true, + DataType::Int16 => true, + DataType::Int32 => true, + DataType::Int64 => true, + DataType::UInt8 => true, + DataType::UInt16 => true, + DataType::UInt32 => true, + DataType::UInt64 => true, + DataType::Float32 => true, + DataType::Float64 => true, + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => true, + TimeUnit::Millisecond => true, + TimeUnit::Microsecond => true, + TimeUnit::Nanosecond => true, + }, + DataType::Utf8 => true, + DataType::LargeUtf8 => true, + _ => false, + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 1e5daa472bab..92462221c7a6 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -29,8 +29,9 @@ use datafusion_expr::logical_plan::{ ToStringifiedPlan, }; use datafusion_expr::utils::{ - expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns, - find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION, + can_hash, expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, + expr_to_columns, find_aggregate_exprs, find_column_exprs, find_window_exprs, + COUNT_STAR_EXPANSION, }; use datafusion_expr::{ and, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF, @@ -600,7 +601,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut filter = vec![]; // extract join keys - extract_join_keys(expr, &mut keys, &mut filter); + extract_join_keys( + expr, + &mut keys, + &mut filter, + left.schema(), + right.schema(), + ); let (left_keys, right_keys): (Vec, Vec) = keys.into_iter().unzip(); @@ -819,10 +826,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { for (l, r) in &possible_join_keys { if left_schema.field_from_column(l).is_ok() && right_schema.field_from_column(r).is_ok() + && can_hash( + left_schema + .field_from_column(l) + .unwrap() + .data_type(), + ) { join_keys.push((l.clone(), r.clone())); } else if left_schema.field_from_column(r).is_ok() && right_schema.field_from_column(l).is_ok() + && can_hash( + left_schema + .field_from_column(r) + .unwrap() + .data_type(), + ) { join_keys.push((r.clone(), l.clone())); } @@ -2516,12 +2535,26 @@ fn extract_join_keys( expr: Expr, accum: &mut Vec<(Column, Column)>, accum_filter: &mut Vec, + left_schema: &Arc, + right_schema: &Arc, ) { match &expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { (Expr::Column(l), Expr::Column(r)) => { - accum.push((l.clone(), r.clone())); + if left_schema.field_from_column(l).is_ok() + && right_schema.field_from_column(r).is_ok() + && can_hash(left_schema.field_from_column(l).unwrap().data_type()) + { + accum.push((l.clone(), r.clone())); + } else if left_schema.field_from_column(r).is_ok() + && right_schema.field_from_column(l).is_ok() + && can_hash(left_schema.field_from_column(r).unwrap().data_type()) + { + accum.push((r.clone(), l.clone())); + } else { + accum_filter.push(expr); + } } _other => { accum_filter.push(expr); @@ -2529,8 +2562,20 @@ fn extract_join_keys( }, Operator::And => { if let Expr::BinaryExpr { left, op: _, right } = expr { - extract_join_keys(*left, accum, accum_filter); - extract_join_keys(*right, accum, accum_filter); + extract_join_keys( + *left, + accum, + accum_filter, + left_schema, + right_schema, + ); + extract_join_keys( + *right, + accum, + accum_filter, + left_schema, + right_schema, + ); } } _other => {