From c97e55063cd0ca3ccd37e7c5a3873749531d61c0 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 19 Jun 2021 23:09:42 -0700 Subject: [PATCH 1/5] fix join column handling logic for `On` and `Using` constraints --- ballista/rust/core/proto/ballista.proto | 12 +- .../core/src/serde/logical_plan/from_proto.rs | 43 ++-- .../core/src/serde/logical_plan/to_proto.rs | 15 +- ballista/rust/core/src/serde/mod.rs | 46 +++- .../src/serde/physical_plan/from_proto.rs | 26 ++- .../rust/core/src/serde/physical_plan/mod.rs | 4 +- .../core/src/serde/physical_plan/to_proto.rs | 14 +- benchmarks/queries/q7.sql | 2 +- datafusion/src/execution/context.rs | 90 ++++++++ datafusion/src/execution/dataframe_impl.rs | 23 +- datafusion/src/logical_plan/builder.rs | 87 ++++---- datafusion/src/logical_plan/dfschema.rs | 174 ++++++++++------ datafusion/src/logical_plan/expr.rs | 8 +- datafusion/src/logical_plan/plan.rs | 15 +- datafusion/src/optimizer/filter_push_down.rs | 70 ++++--- datafusion/src/physical_plan/hash_join.rs | 197 +++++++++++++----- datafusion/src/physical_plan/hash_utils.rs | 74 ++----- datafusion/src/physical_plan/planner.rs | 18 +- datafusion/src/sql/planner.rs | 22 +- datafusion/src/test/mod.rs | 11 +- 20 files changed, 613 insertions(+), 338 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 365d8e9fd9a4..e42ddb364d70 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -378,12 +378,18 @@ enum JoinType { ANTI = 5; } +enum JoinConstraint { + ON = 0; + USING = 1; +} + message JoinNode { LogicalPlanNode left = 1; LogicalPlanNode right = 2; JoinType join_type = 3; - repeated Column left_join_column = 4; - repeated Column right_join_column = 5; + JoinConstraint join_constraint = 4; + repeated Column left_join_column = 5; + repeated Column right_join_column = 6; } message LimitNode { @@ -570,7 +576,7 @@ message HashJoinExecNode { PhysicalPlanNode right = 2; repeated JoinOn on = 3; JoinType join_type = 4; - + JoinConstraint join_constraint = 5; } message PhysicalColumn { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index a1136cf4a7d6..51683f3d437d 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -26,8 +26,8 @@ use datafusion::logical_plan::window_frames::{ }; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, - sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan, - LogicalPlanBuilder, Operator, + sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinConstraint, JoinType, + LogicalPlan, LogicalPlanBuilder, Operator, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; @@ -257,23 +257,34 @@ impl TryInto for &protobuf::LogicalPlanNode { join.join_type )) })?; - let join_type = match join_type { - protobuf::JoinType::Inner => JoinType::Inner, - protobuf::JoinType::Left => JoinType::Left, - protobuf::JoinType::Right => JoinType::Right, - protobuf::JoinType::Full => JoinType::Full, - protobuf::JoinType::Semi => JoinType::Semi, - protobuf::JoinType::Anti => JoinType::Anti, - }; - LogicalPlanBuilder::from(convert_box_required!(join.left)?) - .join( + let join_constraint = protobuf::JoinConstraint::from_i32( + join.join_constraint, + ) + .ok_or_else(|| { + proto_error(format!( + "Received a JoinNode message with unknown JoinConstraint {}", + join.join_constraint + )) + })?; + + let builder = + LogicalPlanBuilder::from(&convert_box_required!(join.left)?); + + let builder = match join_constraint.into() { + JoinConstraint::On => builder.join( &convert_box_required!(join.right)?, - join_type, + join_type.into(), left_keys, right_keys, - )? - .build() - .map_err(|e| e.into()) + )?, + JoinConstraint::Using => builder.join_using( + &convert_box_required!(join.right)?, + join_type.into(), + left_keys, + )?, + }; + + builder.build().map_err(|e| e.into()) } } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 4049622b83dc..07d7a59c114c 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUn use datafusion::datasource::CsvFile; use datafusion::logical_plan::{ window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits}, - Column, Expr, JoinType, LogicalPlan, + Column, Expr, JoinConstraint, JoinType, LogicalPlan, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::functions::BuiltinScalarFunction; @@ -804,26 +804,23 @@ impl TryInto for &LogicalPlan { right, on, join_type, + join_constraint, .. } => { let left: protobuf::LogicalPlanNode = left.as_ref().try_into()?; let right: protobuf::LogicalPlanNode = right.as_ref().try_into()?; - let join_type = match join_type { - JoinType::Inner => protobuf::JoinType::Inner, - JoinType::Left => protobuf::JoinType::Left, - JoinType::Right => protobuf::JoinType::Right, - JoinType::Full => protobuf::JoinType::Full, - JoinType::Semi => protobuf::JoinType::Semi, - JoinType::Anti => protobuf::JoinType::Anti, - }; let (left_join_column, right_join_column) = on.iter().map(|(l, r)| (l.into(), r.into())).unzip(); + let join_type: protobuf::JoinType = join_type.to_owned().into(); + let join_constraint: protobuf::JoinConstraint = + join_constraint.to_owned().into(); Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { left: Some(Box::new(left)), right: Some(Box::new(right)), join_type: join_type.into(), + join_constraint: join_constraint.into(), left_join_column, right_join_column, }, diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index af83660baab5..1df0675ecae5 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -20,7 +20,7 @@ use std::{convert::TryInto, io::Cursor}; -use datafusion::logical_plan::Operator; +use datafusion::logical_plan::{JoinConstraint, JoinType, Operator}; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::window_functions::BuiltInWindowFunction; @@ -291,3 +291,47 @@ impl Into for protobuf::PrimitiveScalarT } } } + +impl From for JoinType { + fn from(t: protobuf::JoinType) -> Self { + match t { + protobuf::JoinType::Inner => JoinType::Inner, + protobuf::JoinType::Left => JoinType::Left, + protobuf::JoinType::Right => JoinType::Right, + protobuf::JoinType::Full => JoinType::Full, + protobuf::JoinType::Semi => JoinType::Semi, + protobuf::JoinType::Anti => JoinType::Anti, + } + } +} + +impl From for protobuf::JoinType { + fn from(t: JoinType) -> Self { + match t { + JoinType::Inner => protobuf::JoinType::Inner, + JoinType::Left => protobuf::JoinType::Left, + JoinType::Right => protobuf::JoinType::Right, + JoinType::Full => protobuf::JoinType::Full, + JoinType::Semi => protobuf::JoinType::Semi, + JoinType::Anti => protobuf::JoinType::Anti, + } + } +} + +impl From for JoinConstraint { + fn from(t: protobuf::JoinConstraint) -> Self { + match t { + protobuf::JoinConstraint::On => JoinConstraint::On, + protobuf::JoinConstraint::Using => JoinConstraint::Using, + } + } +} + +impl From for protobuf::JoinConstraint { + fn from(t: JoinConstraint) -> Self { + match t { + JoinConstraint::On => protobuf::JoinConstraint::On, + JoinConstraint::Using => protobuf::JoinConstraint::Using, + } + } +} diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 4b87be4105be..e7ce75d7215c 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -35,7 +35,9 @@ use datafusion::catalog::catalog::{ use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; -use datafusion::logical_plan::{window_frames::WindowFrame, DFSchema, Expr}; +use datafusion::logical_plan::{ + window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType, +}; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction}; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::PartitionMode; @@ -57,7 +59,6 @@ use datafusion::physical_plan::{ filter::FilterExec, functions::{self, BuiltinScalarFunction, ScalarFunctionExpr}, hash_join::HashJoinExec, - hash_utils::JoinType, limit::{GlobalLimitExec, LocalLimitExec}, parquet::ParquetExec, projection::ProjectionExec, @@ -348,19 +349,22 @@ impl TryInto> for &protobuf::PhysicalPlanNode { hashjoin.join_type )) })?; - let join_type = match join_type { - protobuf::JoinType::Inner => JoinType::Inner, - protobuf::JoinType::Left => JoinType::Left, - protobuf::JoinType::Right => JoinType::Right, - protobuf::JoinType::Full => JoinType::Full, - protobuf::JoinType::Semi => JoinType::Semi, - protobuf::JoinType::Anti => JoinType::Anti, - }; + + let join_constraint = + protobuf::JoinConstraint::from_i32(hashjoin.join_constraint) + .ok_or_else(|| { + proto_error(format!( + "Received a HashJoinNode message with unknown JoinConstraint {}", + hashjoin.join_constraint, + )) + })?; + Ok(Arc::new(HashJoinExec::try_new( left, right, on, - &join_type, + &join_type.into(), + join_constraint.into(), PartitionMode::CollectLeft, )?)) } diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index c0fe81f0ffb9..b2ddd48397c2 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -27,7 +27,7 @@ mod roundtrip_tests { compute::kernels::sort::SortOptions, datatypes::{DataType, Field, Schema}, }, - logical_plan::Operator, + logical_plan::{JoinConstraint, JoinType, Operator}, physical_plan::{ empty::EmptyExec, expressions::{binary, col, lit, InListExpr, NotExpr}, @@ -35,7 +35,6 @@ mod roundtrip_tests { filter::FilterExec, hash_aggregate::{AggregateMode, HashAggregateExec}, hash_join::{HashJoinExec, PartitionMode}, - hash_utils::JoinType, limit::{GlobalLimitExec, LocalLimitExec}, sort::SortExec, AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning, @@ -93,6 +92,7 @@ mod roundtrip_tests { Arc::new(EmptyExec::new(false, Arc::new(schema_right))), on, &JoinType::Inner, + JoinConstraint::On, PartitionMode::CollectLeft, )?)) } diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index cf5401b65019..a4b823bc0717 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -26,6 +26,7 @@ use std::{ sync::Arc, }; +use datafusion::logical_plan::{JoinConstraint, JoinType}; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::csv::CsvExec; use datafusion::physical_plan::expressions::{ @@ -35,7 +36,6 @@ use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::AggregateMode; use datafusion::physical_plan::hash_join::HashJoinExec; -use datafusion::physical_plan::hash_utils::JoinType; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::parquet::ParquetExec; use datafusion::physical_plan::projection::ProjectionExec; @@ -135,14 +135,9 @@ impl TryInto for Arc { }), }) .collect(); - let join_type = match exec.join_type() { - JoinType::Inner => protobuf::JoinType::Inner, - JoinType::Left => protobuf::JoinType::Left, - JoinType::Right => protobuf::JoinType::Right, - JoinType::Full => protobuf::JoinType::Full, - JoinType::Semi => protobuf::JoinType::Semi, - JoinType::Anti => protobuf::JoinType::Anti, - }; + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let join_constraint: protobuf::JoinConstraint = exec.join_constraint().into(); + Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( protobuf::HashJoinExecNode { @@ -150,6 +145,7 @@ impl TryInto for Arc { right: Some(Box::new(right)), on, join_type: join_type.into(), + join_constraint: join_constraint.into(), }, ))), }) diff --git a/benchmarks/queries/q7.sql b/benchmarks/queries/q7.sql index d53877c8dde6..512e5be55a2d 100644 --- a/benchmarks/queries/q7.sql +++ b/benchmarks/queries/q7.sql @@ -36,4 +36,4 @@ group by order by supp_nation, cust_nation, - l_year; \ No newline at end of file + l_year; diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 8ce408de86a5..33a17cfab0ff 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1259,6 +1259,96 @@ mod tests { Ok(()) } + #[tokio::test] + async fn left_join_using() -> Result<()> { + let results = execute( + "SELECT t1.c1, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2", + 1, + ) + .await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 0 | 1 |", + "| 0 | 2 |", + "| 0 | 3 |", + "| 0 | 4 |", + "| 0 | 5 |", + "| 0 | 6 |", + "| 0 | 7 |", + "| 0 | 8 |", + "| 0 | 9 |", + "| 0 | 10 |", + "+----+----+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) + } + + #[tokio::test] + async fn left_join_using_join_key_projection() -> Result<()> { + let results = execute( + "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2", + 1, + ) + .await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+----+", + "| c1 | c2 | c2 |", + "+----+----+----+", + "| 0 | 1 | 1 |", + "| 0 | 2 | 2 |", + "| 0 | 3 | 3 |", + "| 0 | 4 | 4 |", + "| 0 | 5 | 5 |", + "| 0 | 6 | 6 |", + "| 0 | 7 | 7 |", + "| 0 | 8 | 8 |", + "| 0 | 9 | 9 |", + "| 0 | 10 | 10 |", + "+----+----+----+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) + } + + #[tokio::test] + async fn left_join() -> Result<()> { + let results = execute( + "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 ON t1.c2 = t2.c2 ORDER BY t1.c2", + 1, + ) + .await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+----+", + "| c1 | c2 | c2 |", + "+----+----+----+", + "| 0 | 1 | 1 |", + "| 0 | 2 | 2 |", + "| 0 | 3 | 3 |", + "| 0 | 4 | 4 |", + "| 0 | 5 | 5 |", + "| 0 | 6 | 6 |", + "| 0 | 7 | 7 |", + "| 0 | 8 | 8 |", + "| 0 | 9 | 9 |", + "| 0 | 10 | 10 |", + "+----+----+----+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) + } + #[tokio::test] async fn window() -> Result<()> { let results = execute( diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 7cf779740c47..4edd01c2c0a9 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -264,7 +264,7 @@ mod tests { #[tokio::test] async fn join() -> Result<()> { let left = test_table()?.select_columns(&["c1", "c2"])?; - let right = test_table()?.select_columns(&["c1", "c3"])?; + let right = test_table_with_name("c2")?.select_columns(&["c1", "c3"])?; let left_rows = left.collect().await?; let right_rows = right.collect().await?; let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?; @@ -315,7 +315,7 @@ mod tests { #[test] fn registry() -> Result<()> { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; + register_aggregate_csv(&mut ctx, "aggregate_test_100")?; // declare the udf let my_fn: ScalarFunctionImplementation = @@ -366,21 +366,28 @@ mod tests { /// Create a logical plan from a SQL query fn create_plan(sql: &str) -> Result { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; + register_aggregate_csv(&mut ctx, "aggregate_test_100")?; ctx.create_logical_plan(sql) } - fn test_table() -> Result> { + fn test_table_with_name(name: &str) -> Result> { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; - ctx.table("aggregate_test_100") + register_aggregate_csv(&mut ctx, name)?; + ctx.table(name) + } + + fn test_table() -> Result> { + test_table_with_name("aggregate_test_100") } - fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { + fn register_aggregate_csv( + ctx: &mut ExecutionContext, + table_name: &str, + ) -> Result<()> { let schema = test::aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); ctx.register_csv( - "aggregate_test_100", + table_name, &format!("{}/csv/aggregate_test_100.csv", testdata), CsvReadOptions::new().schema(schema.as_ref()), )?; diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 17fe6636439c..0bab3f7fe790 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -441,60 +441,53 @@ pub fn build_join_schema( join_constraint: &JoinConstraint, ) -> Result { let fields: Vec = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full => { - let duplicate_keys = match join_constraint { - JoinConstraint::On => on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.clone()) - .collect::>(), - // using join requires unique join columns in the output schema, so we mark all - // right join keys as duplicate - JoinConstraint::Using => { - on.iter().map(|on| on.1.clone()).collect::>() + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + match join_constraint { + JoinConstraint::On => { + let right_fields = right.fields().iter(); + let left_fields = left.fields().iter(); + // left then right + left_fields.chain(right_fields).cloned().collect() } - }; - - let left_fields = left.fields().iter(); - - // remove right-side join keys if they have the same names as the left-side - let right_fields = right - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(&f.qualified_column())); + JoinConstraint::Using => { + // using join requires unique join column in the output schema, so we mark all + // right join keys as duplicate + let duplicate_join_names = + on.iter().map(|on| &on.1.name).collect::>(); + + let right_fields = right + .fields() + .iter() + .filter(|f| !duplicate_join_names.contains(f.name())) + .cloned(); + + let left_fields = left.fields().iter().map(|f| { + for key in on.iter() { + // update qualifiers for shared fields + if duplicate_join_names.contains(f.name()) { + let mut hs = HashSet::new(); + if let Some(q) = &key.0.relation { + hs.insert(q.to_string()); + } + if let Some(q) = &key.1.relation { + hs.insert(q.to_string()); + } + return f.clone().set_shared_qualifiers(hs); + } + } + + f.clone() + }); - // left then right - left_fields.chain(right_fields).cloned().collect() + // left then right + left_fields.chain(right_fields).collect() + } + } } JoinType::Semi | JoinType::Anti => { // Only use the left side for the schema left.fields().clone() } - JoinType::Right => { - let duplicate_keys = match join_constraint { - JoinConstraint::On => on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.clone()) - .collect::>(), - // using join requires unique join columns in the output schema, so we mark all - // left join keys as duplicate - JoinConstraint::Using => { - on.iter().map(|on| on.0.clone()).collect::>() - } - }; - - // remove left-side join keys if they have the same names as the right-side - let left_fields = left - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(&f.qualified_column())); - - let right_fields = right.fields().iter(); - - // left then right - left_fields.chain(right_fields).cloned().collect() - } }; DFSchema::new(fields) diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index b46e067a268b..75b9b5e308e0 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -48,6 +48,7 @@ impl DFSchema { pub fn new(fields: Vec) -> Result { let mut qualified_names = HashSet::new(); let mut unqualified_names = HashSet::new(); + for field in &fields { if let Some(qualifier) = field.qualifier() { if !qualified_names.insert((qualifier, field.name())) { @@ -56,6 +57,15 @@ impl DFSchema { field.qualified_name() ))); } + } else if let Some(shared_qualifiers) = field.shared_qualifiers() { + for qualifier in shared_qualifiers { + if !qualified_names.insert((qualifier, field.name())) { + return Err(DataFusionError::Plan(format!( + "Schema contains duplicate qualified field name '{}'", + field.qualified_name() + ))); + } + } } else if !unqualified_names.insert(field.name()) { return Err(DataFusionError::Plan(format!( "Schema contains duplicate unqualified field name '{}'", @@ -94,10 +104,7 @@ impl DFSchema { schema .fields() .iter() - .map(|f| DFField { - field: f.clone(), - qualifier: Some(qualifier.to_owned()), - }) + .map(|f| DFField::from_qualified(qualifier, f.clone())) .collect(), ) } @@ -149,35 +156,74 @@ impl DFSchema { ))) } - /// Find the index of the column with the given qualifer and name - pub fn index_of_column(&self, col: &Column) -> Result { - for i in 0..self.fields.len() { - let field = &self.fields[i]; - if field.qualifier() == col.relation.as_ref() && field.name() == &col.name { - return Ok(i); - } + fn index_of_column_by_name( + &self, + qualifier: Option<&str>, + name: &str, + ) -> Result { + let matches: Vec = self + .fields + .iter() + .enumerate() + .filter(|(_, field)| match (qualifier, &field.qualifier) { + // field to lookup is qualified. + // current field is qualified and not shared between relations, compare both + // qualifer and name. + (Some(q), Some(field_q)) => q == field_q && field.name() == name, + // field to lookup is qualified. + // current field is either unqualified or qualified and shared between relations. + (Some(q), None) => { + if let Some(shared_q) = field.shared_qualifiers() { + // current field is a shared qualified field, check for all shared + // relation names. + shared_q.contains(q) && field.name() == name + } else { + // current field is unqualifiied + false + } + } + // field to lookup is unqualified, no need to compare qualifier + _ => field.name() == name, + }) + .map(|(idx, _)| idx) + .collect(); + + match matches.len() { + 0 => Err(DataFusionError::Plan(format!( + "No field named '{}.{}'. Valid fields are {}.", + qualifier.unwrap_or(""), + name, + self.get_field_names() + ))), + 1 => Ok(matches[0]), + _ => Err(DataFusionError::Internal(format!( + "Ambiguous reference to qualified field named '{}.{}'", + qualifier.unwrap_or(""), + name + ))), } - Err(DataFusionError::Plan(format!( - "No field matches column '{}'", - col, - ))) + } + + /// Find the index of the column with the given qualifier and name + pub fn index_of_column(&self, col: &Column) -> Result { + self.index_of_column_by_name(col.relation.as_deref(), &col.name) } /// Find the field with the given name pub fn field_with_name( &self, - relation_name: Option<&str>, + qualifier: Option<&str>, name: &str, - ) -> Result { - if let Some(relation_name) = relation_name { - self.field_with_qualified_name(relation_name, name) + ) -> Result<&DFField> { + if let Some(qualifier) = qualifier { + self.field_with_qualified_name(qualifier, name) } else { self.field_with_unqualified_name(name) } } /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result { + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { let matches: Vec<&DFField> = self .fields .iter() @@ -189,7 +235,7 @@ impl DFSchema { name, self.get_field_names() ))), - 1 => Ok(matches[0].to_owned()), + 1 => Ok(matches[0]), _ => Err(DataFusionError::Plan(format!( "Ambiguous reference to field named '{}'", name @@ -200,33 +246,15 @@ impl DFSchema { /// Find the field with the given qualified name pub fn field_with_qualified_name( &self, - relation_name: &str, + qualifier: &str, name: &str, - ) -> Result { - let matches: Vec<&DFField> = self - .fields - .iter() - .filter(|field| { - field.qualifier == Some(relation_name.to_string()) && field.name() == name - }) - .collect(); - match matches.len() { - 0 => Err(DataFusionError::Plan(format!( - "No field named '{}.{}'. Valid fields are {}.", - relation_name, - name, - self.get_field_names() - ))), - 1 => Ok(matches[0].to_owned()), - _ => Err(DataFusionError::Internal(format!( - "Ambiguous reference to qualified field named '{}.{}'", - relation_name, name - ))), - } + ) -> Result<&DFField> { + let idx = self.index_of_column_by_name(Some(qualifier), name)?; + Ok(self.field(idx)) } /// Find the field with the given qualified column - pub fn field_from_qualified_column(&self, column: &Column) -> Result { + pub fn field_from_column(&self, column: &Column) -> Result<&DFField> { match &column.relation { Some(r) => self.field_with_qualified_name(r, &column.name), None => self.field_with_unqualified_name(&column.name), @@ -247,31 +275,20 @@ impl DFSchema { fields: self .fields .into_iter() - .map(|f| { - if f.qualifier().is_some() { - DFField::new( - None, - f.name(), - f.data_type().to_owned(), - f.is_nullable(), - ) - } else { - f - } - }) + .map(|f| f.strip_qualifier()) .collect(), } } /// Replace all field qualifier with new value in schema - pub fn replace_qualifier(self, qualifer: &str) -> Self { + pub fn replace_qualifier(self, qualifier: &str) -> Self { DFSchema { fields: self .fields .into_iter() .map(|f| { DFField::new( - Some(qualifer), + Some(qualifier), f.name(), f.data_type().to_owned(), f.is_nullable(), @@ -328,10 +345,7 @@ impl TryFrom for DFSchema { schema .fields() .iter() - .map(|f| DFField { - field: f.clone(), - qualifier: None, - }) + .map(|f| DFField::from(f.clone())) .collect(), ) } @@ -403,6 +417,9 @@ impl Display for DFSchema { pub struct DFField { /// Optional qualifier (usually a table or relation name) qualifier: Option, + /// Optional set of qualifiers that all share this same field. This is used for `JOIN USING` + /// clause where the join keys are combined into a shared column. + shared_qualifiers: Option>, /// Arrow field definition field: Field, } @@ -417,6 +434,7 @@ impl DFField { ) -> Self { DFField { qualifier: qualifier.map(|s| s.to_owned()), + shared_qualifiers: None, field: Field::new(name, data_type, nullable), } } @@ -425,6 +443,7 @@ impl DFField { pub fn from(field: Field) -> Self { Self { qualifier: None, + shared_qualifiers: None, field, } } @@ -433,6 +452,7 @@ impl DFField { pub fn from_qualified(qualifier: &str, field: Field) -> Self { Self { qualifier: Some(qualifier.to_owned()), + shared_qualifiers: None, field, } } @@ -454,8 +474,8 @@ impl DFField { /// Returns a string to the `DFField`'s qualified name pub fn qualified_name(&self) -> String { - if let Some(relation_name) = &self.qualifier { - format!("{}.{}", relation_name, self.field.name()) + if let Some(qualifier) = &self.qualifier { + format!("{}.{}", qualifier, self.field.name()) } else { self.field.name().to_owned() } @@ -469,15 +489,41 @@ impl DFField { } } + /// Builds an unqualified column based on self + pub fn unqualified_column(&self) -> Column { + Column { + relation: None, + name: self.field.name().to_string(), + } + } + /// Get the optional qualifier pub fn qualifier(&self) -> Option<&String> { self.qualifier.as_ref() } + /// Get the optional qualifier + pub fn shared_qualifiers(&self) -> Option<&HashSet> { + self.shared_qualifiers.as_ref() + } + /// Get the arrow field pub fn field(&self) -> &Field { &self.field } + + /// Return field with qualifier stripped + pub fn strip_qualifier(mut self) -> Self { + self.qualifier = None; + self + } + + /// Return field with shared qualifiers set and qualifier stripped + pub fn set_shared_qualifiers(mut self, shared_qualifiers: HashSet) -> Self { + self.qualifier = None; + self.shared_qualifiers = Some(shared_qualifiers); + self + } } #[cfg(test)] diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index d20b1f698238..a27140f65384 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -316,9 +316,7 @@ impl Expr { pub fn get_type(&self, schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.get_type(schema), - Expr::Column(c) => { - Ok(schema.field_from_qualified_column(c)?.data_type().clone()) - } + Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()), Expr::ScalarVariable(_) => Ok(DataType::Utf8), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), @@ -390,9 +388,7 @@ impl Expr { pub fn nullable(&self, input_schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.nullable(input_schema), - Expr::Column(c) => { - Ok(input_schema.field_from_qualified_column(c)?.is_nullable()) - } + Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()), Expr::Literal(value) => Ok(value.is_null()), Expr::ScalarVariable(_) => Ok(true), Expr::Case { diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 99f0fa14a2d9..fa40904bd4e9 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -709,10 +709,21 @@ impl LogicalPlan { } Ok(()) } - LogicalPlan::Join { on: ref keys, .. } => { + LogicalPlan::Join { + on: ref keys, + join_constraint, + .. + } => { let join_expr: Vec = keys.iter().map(|(l, r)| format!("{} = {}", l, r)).collect(); - write!(f, "Join: {}", join_expr.join(", ")) + match join_constraint { + JoinConstraint::On => { + write!(f, "Join: {}", join_expr.join(", ")) + } + JoinConstraint::Using => { + write!(f, "Join: Using {}", join_expr.join(", ")) + } + } } LogicalPlan::CrossJoin { .. } => { write!(f, "CrossJoin:") diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index 7b1ff326c3c6..bdb7177cca66 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -96,12 +96,21 @@ fn get_join_predicates<'a>( let left_columns = &left .fields() .iter() - .map(|f| f.qualified_column()) + .map(|f| { + std::iter::once(f.qualified_column()) + // we need to push down filter using unqualified column as well + .chain(std::iter::once(f.unqualified_column())) + }) + .flatten() .collect::>(); let right_columns = &right .fields() .iter() - .map(|f| f.qualified_column()) + .map(|f| { + std::iter::once(f.qualified_column()) + .chain(std::iter::once(f.unqualified_column())) + }) + .flatten() .collect::>(); let filters = state @@ -882,16 +891,16 @@ mod tests { #[test] fn filter_join_on_common_independent() -> Result<()> { let table_scan = test_table_scan()?; - let left = LogicalPlanBuilder::from(table_scan.clone()).build()?; - let right = LogicalPlanBuilder::from(table_scan) + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) .project(vec![col("a")])? .build()?; let plan = LogicalPlanBuilder::from(left) - .join( + .join_using( &right, JoinType::Inner, vec![Column::from_name("a".to_string())], - vec![Column::from_name("a".to_string())], )? .filter(col("a").lt_eq(lit(1i64)))? .build()?; @@ -900,21 +909,21 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #test.a LtEq Int64(1)\ - \n Join: #test.a = #test.a\ + Filter: #a LtEq Int64(1)\ + \n Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ - \n Projection: #test.a\ - \n TableScan: test projection=None" + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" ); // filter sent to side before the join let expected = "\ - Join: #test.a = #test.a\ - \n Filter: #test.a LtEq Int64(1)\ + Join: Using #test.a = #test2.a\ + \n Filter: #a LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #test.a\ - \n Filter: #test.a LtEq Int64(1)\ - \n TableScan: test projection=None"; + \n Projection: #test2.a\ + \n Filter: #a LtEq Int64(1)\ + \n TableScan: test2 projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } @@ -923,10 +932,11 @@ mod tests { #[test] fn filter_join_on_common_dependent() -> Result<()> { let table_scan = test_table_scan()?; - let left = LogicalPlanBuilder::from(table_scan.clone()) + let left = LogicalPlanBuilder::from(table_scan) .project(vec![col("a"), col("c")])? .build()?; - let right = LogicalPlanBuilder::from(table_scan) + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) .project(vec![col("a"), col("b")])? .build()?; let plan = LogicalPlanBuilder::from(left) @@ -944,12 +954,12 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #test.c LtEq #test.b\ - \n Join: #test.a = #test.a\ + Filter: #test.c LtEq #test2.b\ + \n Join: #test.a = #test2.a\ \n Projection: #test.a, #test.c\ \n TableScan: test projection=None\ - \n Projection: #test.a, #test.b\ - \n TableScan: test projection=None" + \n Projection: #test2.a, #test2.b\ + \n TableScan: test2 projection=None" ); // expected is equal: no push-down @@ -962,12 +972,14 @@ mod tests { #[test] fn filter_join_on_one_side() -> Result<()> { let table_scan = test_table_scan()?; - let left = LogicalPlanBuilder::from(table_scan.clone()) + let left = LogicalPlanBuilder::from(table_scan) .project(vec![col("a"), col("b")])? .build()?; - let right = LogicalPlanBuilder::from(table_scan) + let table_scan_right = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(table_scan_right) .project(vec![col("a"), col("c")])? .build()?; + let plan = LogicalPlanBuilder::from(left) .join( &right, @@ -983,20 +995,20 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.b LtEq Int64(1)\ - \n Join: #test.a = #test.a\ + \n Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n TableScan: test projection=None\ - \n Projection: #test.a, #test.c\ - \n TableScan: test projection=None" + \n Projection: #test2.a, #test2.c\ + \n TableScan: test2 projection=None" ); let expected = "\ - Join: #test.a = #test.a\ + Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n Filter: #test.b LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #test.a, #test.c\ - \n TableScan: test projection=None"; + \n Projection: #test2.a, #test2.c\ + \n TableScan: test2 projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index ad356079387a..ef330c579c11 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -54,10 +54,11 @@ use arrow::array::{ use super::expressions::Column; use super::{ - hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType}, + hash_utils::{build_join_schema, check_join_is_valid, JoinOn}, merge::MergeExec, }; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::{JoinConstraint, JoinType}; use super::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -94,6 +95,8 @@ pub struct HashJoinExec { on: Vec<(Column, Column)>, /// How the join is performed join_type: JoinType, + /// Join constraint + join_constraint: JoinConstraint, /// The schema once the join is applied schema: SchemaRef, /// Build-side @@ -130,6 +133,7 @@ impl HashJoinExec { right: Arc, on: JoinOn, join_type: &JoinType, + join_constraint: JoinConstraint, partition_mode: PartitionMode, ) -> Result { let left_schema = left.schema(); @@ -141,6 +145,7 @@ impl HashJoinExec { &right_schema, &on, join_type, + join_constraint, )); let random_state = RandomState::with_seeds(0, 0, 0, 0); @@ -150,6 +155,7 @@ impl HashJoinExec { right, on, join_type: *join_type, + join_constraint, schema, build_side: Arc::new(Mutex::new(None)), random_state, @@ -177,6 +183,11 @@ impl HashJoinExec { &self.join_type } + /// Join constraint + pub fn join_constraint(&self) -> JoinConstraint { + self.join_constraint + } + /// Calculates column indices and left/right placement on input / output schemas and jointype fn column_indices_from_schema(&self) -> ArrowResult> { let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { @@ -234,6 +245,7 @@ impl ExecutionPlan for HashJoinExec { children[1].clone(), self.on.clone(), &self.join_type, + self.join_constraint, self.mode, )?)), _ => Err(DataFusionError::Internal( @@ -1313,8 +1325,16 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, + join_constraint: JoinConstraint, ) -> Result { - HashJoinExec::try_new(left, right, on, join_type, PartitionMode::CollectLeft) + HashJoinExec::try_new( + left, + right, + on, + join_type, + join_constraint, + PartitionMode::CollectLeft, + ) } async fn join_collect( @@ -1322,8 +1342,9 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, + join_constraint: JoinConstraint, ) -> Result<(Vec, Vec)> { - let join = join(left, right, on, join_type)?; + let join = join(left, right, on, join_type, join_constraint)?; let columns = columns(&join.schema()); let stream = join.execute(0).await?; @@ -1337,6 +1358,7 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, + join_constraint: JoinConstraint, ) -> Result<(Vec, Vec)> { let partition_count = 4; @@ -1361,6 +1383,7 @@ mod tests { )?), on, join_type, + join_constraint, PartitionMode::Partitioned, )?; @@ -1399,9 +1422,57 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let (columns, batches) = - join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner) - .await?; + let (columns, batches) = join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Inner, + JoinConstraint::On, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_using() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Inner, + JoinConstraint::Using, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); @@ -1441,6 +1512,7 @@ mod tests { right.clone(), on.clone(), &JoinType::Inner, + JoinConstraint::Using, ) .await?; @@ -1477,7 +1549,8 @@ mod tests { Column::new_with_schema("b2", &right.schema())?, )]; - let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, JoinConstraint::On).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1519,20 +1592,21 @@ mod tests { ), ]; - let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, JoinConstraint::On).await?; - assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+", - "| a1 | b2 | c1 | c2 |", - "+----+----+----+----+", - "| 1 | 1 | 7 | 70 |", - "| 2 | 2 | 8 | 80 |", - "| 2 | 2 | 9 | 80 |", - "+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1542,7 +1616,7 @@ mod tests { /// Test where the left has 2 parts, the right with 1 part => 1 part #[tokio::test] - async fn join_inner_one_two_parts_left() -> Result<()> { + async fn join_inner_using_one_two_parts_left() -> Result<()> { let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -1571,7 +1645,9 @@ mod tests { ), ]; - let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, JoinConstraint::Using) + .await?; assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); @@ -1594,7 +1670,7 @@ mod tests { /// Test where the left has 1 part, the right has 2 parts => 2 parts #[tokio::test] - async fn join_inner_one_two_parts_right() -> Result<()> { + async fn join_inner_using_one_two_parts_right() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1618,7 +1694,7 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let join = join(left, right, on, &JoinType::Inner)?; + let join = join(left, right, on, &JoinType::Inner, JoinConstraint::Using)?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); @@ -1684,7 +1760,7 @@ mod tests { Column::new_with_schema("b1", &right.schema()).unwrap(), )]; - let join = join(left, right, on, &JoinType::Left).unwrap(); + let join = join(left, right, on, &JoinType::Left, JoinConstraint::Using).unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); @@ -1725,7 +1801,7 @@ mod tests { Column::new_with_schema("b2", &right.schema()).unwrap(), )]; - let join = join(left, right, on, &JoinType::Full).unwrap(); + let join = join(left, right, on, &JoinType::Full, JoinConstraint::On).unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1764,7 +1840,7 @@ mod tests { )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); - let join = join(left, right, on, &JoinType::Left).unwrap(); + let join = join(left, right, on, &JoinType::Left, JoinConstraint::Using).unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); @@ -1799,7 +1875,7 @@ mod tests { )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); - let join = join(left, right, on, &JoinType::Full).unwrap(); + let join = join(left, right, on, &JoinType::Full, JoinConstraint::On).unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1837,19 +1913,24 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let (columns, batches) = - join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left) - .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + let (columns, batches) = join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Left, + JoinConstraint::On, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1878,6 +1959,7 @@ mod tests { right.clone(), on.clone(), &JoinType::Left, + JoinConstraint::Using, ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); @@ -1913,7 +1995,7 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let join = join(left, right, on, &JoinType::Semi)?; + let join = join(left, right, on, &JoinType::Semi, JoinConstraint::On)?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -1952,7 +2034,7 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let join = join(left, right, on, &JoinType::Anti)?; + let join = join(left, right, on, &JoinType::Anti, JoinConstraint::On)?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -1989,18 +2071,19 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; + let (columns, batches) = + join_collect(left, right, on, &JoinType::Right, JoinConstraint::On).await?; - assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+", - "| | | 30 | 6 | 90 |", - "| 1 | 7 | 10 | 4 | 70 |", - "| 2 | 8 | 20 | 5 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2025,18 +2108,24 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let (columns, batches) = - partitioned_join_collect(left, right, on, &JoinType::Right).await?; + let (columns, batches) = partitioned_join_collect( + left, + right, + on, + &JoinType::Right, + JoinConstraint::Using, + ) + .await?; - assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); let expected = vec![ "+----+----+----+----+----+", - "| a1 | c1 | a2 | b1 | c2 |", + "| a1 | b1 | c1 | a2 | c2 |", "+----+----+----+----+----+", - "| | | 30 | 6 | 90 |", - "| 1 | 7 | 10 | 4 | 70 |", - "| 2 | 8 | 20 | 5 | 80 |", + "| | 6 | | 30 | 90 |", + "| 1 | 4 | 7 | 10 | 70 |", + "| 2 | 5 | 8 | 20 | 80 |", "+----+----+----+----+----+", ]; @@ -2062,7 +2151,7 @@ mod tests { Column::new_with_schema("b2", &right.schema()).unwrap(), )]; - let join = join(left, right, on, &JoinType::Full)?; + let join = join(left, right, on, &JoinType::Full, JoinConstraint::On)?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 0cf0b9212cd2..e4fd17be91a0 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -21,25 +21,9 @@ use crate::error::{DataFusionError, Result}; use arrow::datatypes::{Field, Schema}; use std::collections::HashSet; +use crate::logical_plan::{JoinConstraint, JoinType}; use crate::physical_plan::expressions::Column; -/// All valid types of joins. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum JoinType { - /// Inner Join - Inner, - /// Left Join - Left, - /// Right Join - Right, - /// Full Join - Full, - /// Semi Join - Semi, - /// Anti Join - Anti, -} - /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; /// Reference for JoinOn. @@ -109,43 +93,29 @@ pub fn build_join_schema( right: &Schema, on: JoinOnRef, join_type: &JoinType, + join_constraint: JoinConstraint, ) -> Schema { let fields: Vec = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full => { - // remove right-side join keys if they have the same names as the left-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l.name() == r.name()) - .map(|on| on.1.name()) - .collect::>(); - - let left_fields = left.fields().iter(); - - let right_fields = right - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(f.name().as_str())); - - // left then right - left_fields.chain(right_fields).cloned().collect() - } - JoinType::Right => { - // remove left-side join keys if they have the same names as the right-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l.name() == r.name()) - .map(|on| on.1.name()) - .collect::>(); - - let left_fields = left - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(f.name().as_str())); - - let right_fields = right.fields().iter(); - - // left then right - left_fields.chain(right_fields).cloned().collect() + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + match join_constraint { + JoinConstraint::On => { + let left_fields = left.fields().iter(); + let right_fields = right.fields().iter(); + // left then right + left_fields.chain(right_fields).cloned().collect() + } + JoinConstraint::Using => { + // using join requires unique join columns in the output schema, so we mark all + // right join keys as duplicate + let duplicate_keys = + &on.iter().map(|on| on.1.name()).collect::>(); + let right_fields = right + .fields() + .iter() + .filter(|f| !duplicate_keys.contains(f.name().as_str())); + left.fields().iter().chain(right_fields).cloned().collect() + } + } } JoinType::Semi | JoinType::Anti => left.fields().clone(), }; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index d59004243533..1153771f5f49 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -42,7 +42,6 @@ use crate::physical_plan::{hash_utils, Partitioning}; use crate::physical_plan::{ AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner, WindowExpr, }; -use crate::prelude::JoinType; use crate::scalar::ScalarValue; use crate::sql::utils::generate_sort_key; use crate::variable::VarType; @@ -564,20 +563,13 @@ impl DefaultPhysicalPlanner { right, on: keys, join_type, + join_constraint, .. } => { let left_df_schema = left.schema(); let physical_left = self.create_initial_plan(left, ctx_state)?; let right_df_schema = right.schema(); let physical_right = self.create_initial_plan(right, ctx_state)?; - let physical_join_type = match join_type { - JoinType::Inner => hash_utils::JoinType::Inner, - JoinType::Left => hash_utils::JoinType::Left, - JoinType::Right => hash_utils::JoinType::Right, - JoinType::Full => hash_utils::JoinType::Full, - JoinType::Semi => hash_utils::JoinType::Semi, - JoinType::Anti => hash_utils::JoinType::Anti, - }; let join_on = keys .iter() .map(|(l, r)| { @@ -611,7 +603,8 @@ impl DefaultPhysicalPlanner { Partitioning::Hash(right_expr, ctx_state.config.concurrency), )?), join_on, - &physical_join_type, + join_type, + *join_constraint, PartitionMode::Partitioned, )?)) } else { @@ -619,7 +612,8 @@ impl DefaultPhysicalPlanner { physical_left, physical_right, join_on, - &physical_join_type, + join_type, + *join_constraint, PartitionMode::CollectLeft, )?)) } @@ -1395,7 +1389,7 @@ mod tests { let expected_error: &str = "Error during planning: \ Extension planner for NoOp created an ExecutionPlan with mismatched schema. \ LogicalPlan schema: DFSchema { fields: [\ - DFField { qualifier: None, field: Field { \ + DFField { qualifier: None, shared_qualifiers: None, field: Field { \ name: \"a\", \ data_type: Int32, \ nullable: false, \ diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index b86dc0f48c14..c855413dd13e 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -477,12 +477,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let right_schema = right.schema(); let mut join_keys = vec![]; for (l, r) in &possible_join_keys { - if left_schema.field_from_qualified_column(l).is_ok() - && right_schema.field_from_qualified_column(r).is_ok() + if left_schema.field_from_column(l).is_ok() + && right_schema.field_from_column(r).is_ok() { join_keys.push((l.clone(), r.clone())); - } else if left_schema.field_from_qualified_column(r).is_ok() - && right_schema.field_from_qualified_column(l).is_ok() + } else if left_schema.field_from_column(r).is_ok() + && right_schema.field_from_column(l).is_ok() { join_keys.push((r.clone(), l.clone())); } @@ -818,10 +818,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .try_for_each(|col| match col { Expr::Column(col) => { match &col.relation { - Some(r) => schema.field_with_qualified_name(r, &col.name), - None => schema.field_with_unqualified_name(&col.name), + Some(r) => { + Ok(schema.field_with_qualified_name(r, &col.name)?.to_owned()) + } + None => { + Ok(schema.field_with_unqualified_name(&col.name)?.to_owned()) + } } - .map_err(|_| { + .map_err(|_: DataFusionError| { DataFusionError::Plan(format!( "Invalid identifier '{}' for schema {}", col, @@ -2720,8 +2724,8 @@ mod tests { FROM person \ JOIN person as person2 \ USING (id)"; - let expected = "Projection: #person.first_name, #person.id\ - \n Join: #person.id = #person2.id\ + let expected = "Projection: #person.first_name, #id\ + \n Join: Using #person.id = #person2.id\ \n TableScan: person projection=None\ \n TableScan: person2 projection=None"; quick_test(sql, expected); diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index 7ca7cc12d9ef..c06feffd9f99 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -110,14 +110,19 @@ pub fn aggr_test_schema() -> SchemaRef { ])) } -/// some tests share a common table -pub fn test_table_scan() -> Result { +/// some tests share a common table with different names +pub fn test_table_scan_with_name(name: &str) -> Result { let schema = Schema::new(vec![ Field::new("a", DataType::UInt32, false), Field::new("b", DataType::UInt32, false), Field::new("c", DataType::UInt32, false), ]); - LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() + LogicalPlanBuilder::scan_empty(Some(name), &schema, None)?.build() +} + +/// some tests share a common table +pub fn test_table_scan() -> Result { + test_table_scan_with_name("test") } pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { From 702a5a0b70d009b3f41af6099478e2f94ccb5ab7 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 26 Jun 2021 10:48:12 -0700 Subject: [PATCH 2/5] handling join column expansion during USING JOIN planning get rid of shared field and move column expansion logic into plan builder and optimizer. --- ballista/rust/core/proto/ballista.proto | 1 - .../core/src/serde/logical_plan/from_proto.rs | 4 +- .../src/serde/physical_plan/from_proto.rs | 10 - .../rust/core/src/serde/physical_plan/mod.rs | 3 +- .../core/src/serde/physical_plan/to_proto.rs | 2 - datafusion/src/logical_plan/builder.rs | 91 ++---- datafusion/src/logical_plan/dfschema.rs | 55 +--- datafusion/src/logical_plan/expr.rs | 76 +++-- datafusion/src/logical_plan/mod.rs | 8 +- datafusion/src/logical_plan/plan.rs | 39 +++ datafusion/src/optimizer/filter_push_down.rs | 141 ++++++++- .../src/optimizer/projection_push_down.rs | 88 +++++- datafusion/src/optimizer/utils.rs | 9 +- datafusion/src/physical_plan/hash_join.rs | 269 ++++++------------ datafusion/src/physical_plan/hash_utils.rs | 33 +-- datafusion/src/physical_plan/planner.rs | 5 +- datafusion/src/sql/planner.rs | 68 ++--- 17 files changed, 481 insertions(+), 421 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index e42ddb364d70..fba760a0373e 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -576,7 +576,6 @@ message HashJoinExecNode { PhysicalPlanNode right = 2; repeated JoinOn on = 3; JoinType join_type = 4; - JoinConstraint join_constraint = 5; } message PhysicalColumn { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 51683f3d437d..cad054392308 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -267,9 +267,7 @@ impl TryInto for &protobuf::LogicalPlanNode { )) })?; - let builder = - LogicalPlanBuilder::from(&convert_box_required!(join.left)?); - + let builder = LogicalPlanBuilder::from(convert_box_required!(join.left)?); let builder = match join_constraint.into() { JoinConstraint::On => builder.join( &convert_box_required!(join.right)?, diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index e7ce75d7215c..493143ddff40 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -350,21 +350,11 @@ impl TryInto> for &protobuf::PhysicalPlanNode { )) })?; - let join_constraint = - protobuf::JoinConstraint::from_i32(hashjoin.join_constraint) - .ok_or_else(|| { - proto_error(format!( - "Received a HashJoinNode message with unknown JoinConstraint {}", - hashjoin.join_constraint, - )) - })?; - Ok(Arc::new(HashJoinExec::try_new( left, right, on, &join_type.into(), - join_constraint.into(), PartitionMode::CollectLeft, )?)) } diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index b2ddd48397c2..e2281d4e580c 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -27,7 +27,7 @@ mod roundtrip_tests { compute::kernels::sort::SortOptions, datatypes::{DataType, Field, Schema}, }, - logical_plan::{JoinConstraint, JoinType, Operator}, + logical_plan::{JoinType, Operator}, physical_plan::{ empty::EmptyExec, expressions::{binary, col, lit, InListExpr, NotExpr}, @@ -92,7 +92,6 @@ mod roundtrip_tests { Arc::new(EmptyExec::new(false, Arc::new(schema_right))), on, &JoinType::Inner, - JoinConstraint::On, PartitionMode::CollectLeft, )?)) } diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index a4b823bc0717..1bda2d6861d3 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -136,7 +136,6 @@ impl TryInto for Arc { }) .collect(); let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); - let join_constraint: protobuf::JoinConstraint = exec.join_constraint().into(); Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( @@ -145,7 +144,6 @@ impl TryInto for Arc { right: Some(Box::new(right)), on, join_type: join_type.into(), - join_constraint: join_constraint.into(), }, ))), }) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 0bab3f7fe790..64db75930fbc 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -40,7 +40,6 @@ use crate::logical_plan::{ columnize_expr, normalize_col, normalize_cols, Column, DFField, DFSchema, DFSchemaRef, Partitioning, }; -use std::collections::HashSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -213,7 +212,6 @@ impl LogicalPlanBuilder { /// * An invalid expression is used (e.g. a `sort` expression) pub fn project(&self, expr: impl IntoIterator) -> Result { let input_schema = self.plan.schema(); - let all_schemas = self.plan.all_schemas(); let mut projected_expr = vec![]; for e in expr { match e { @@ -223,10 +221,8 @@ impl LogicalPlanBuilder { .push(Expr::Column(input_schema.field(i).qualified_column())) }); } - _ => projected_expr.push(columnize_expr( - normalize_col(e, &all_schemas)?, - input_schema, - )), + _ => projected_expr + .push(columnize_expr(normalize_col(e, &self.plan)?, input_schema)), } } @@ -243,7 +239,7 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(&self, expr: Expr) -> Result { - let expr = normalize_col(expr, &self.plan.all_schemas())?; + let expr = normalize_col(expr, &self.plan)?; Ok(Self::from(LogicalPlan::Filter { predicate: expr, input: Arc::new(self.plan.clone()), @@ -260,9 +256,8 @@ impl LogicalPlanBuilder { /// Apply a sort pub fn sort(&self, exprs: impl IntoIterator) -> Result { - let schemas = self.plan.all_schemas(); Ok(Self::from(LogicalPlan::Sort { - expr: normalize_cols(exprs, &schemas)?, + expr: normalize_cols(exprs, &self.plan)?, input: Arc::new(self.plan.clone()), })) } @@ -288,20 +283,15 @@ impl LogicalPlanBuilder { let left_keys: Vec = left_keys .into_iter() - .map(|c| c.into().normalize(&self.plan.all_schemas())) + .map(|c| c.into().normalize(&self.plan)) .collect::>()?; let right_keys: Vec = right_keys .into_iter() - .map(|c| c.into().normalize(&right.all_schemas())) + .map(|c| c.into().normalize(right)) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); - let join_schema = build_join_schema( - self.plan.schema(), - right.schema(), - &on, - &join_type, - &JoinConstraint::On, - )?; + let join_schema = + build_join_schema(self.plan.schema(), right.schema(), &join_type)?; Ok(Self::from(LogicalPlan::Join { left: Arc::new(self.plan.clone()), @@ -323,21 +313,16 @@ impl LogicalPlanBuilder { let left_keys: Vec = using_keys .clone() .into_iter() - .map(|c| c.into().normalize(&self.plan.all_schemas())) + .map(|c| c.into().normalize(&self.plan)) .collect::>()?; let right_keys: Vec = using_keys .into_iter() - .map(|c| c.into().normalize(&right.all_schemas())) + .map(|c| c.into().normalize(right)) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); - let join_schema = build_join_schema( - self.plan.schema(), - right.schema(), - &on, - &join_type, - &JoinConstraint::Using, - )?; + let join_schema = + build_join_schema(self.plan.schema(), right.schema(), &join_type)?; Ok(Self::from(LogicalPlan::Join { left: Arc::new(self.plan.clone()), @@ -390,9 +375,8 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator, aggr_expr: impl IntoIterator, ) -> Result { - let schemas = self.plan.all_schemas(); - let group_expr = normalize_cols(group_expr, &schemas)?; - let aggr_expr = normalize_cols(aggr_expr, &schemas)?; + let group_expr = normalize_cols(group_expr, &self.plan)?; + let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; let all_expr = group_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?; @@ -436,53 +420,14 @@ impl LogicalPlanBuilder { pub fn build_join_schema( left: &DFSchema, right: &DFSchema, - on: &[(Column, Column)], join_type: &JoinType, - join_constraint: &JoinConstraint, ) -> Result { let fields: Vec = match join_type { JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - match join_constraint { - JoinConstraint::On => { - let right_fields = right.fields().iter(); - let left_fields = left.fields().iter(); - // left then right - left_fields.chain(right_fields).cloned().collect() - } - JoinConstraint::Using => { - // using join requires unique join column in the output schema, so we mark all - // right join keys as duplicate - let duplicate_join_names = - on.iter().map(|on| &on.1.name).collect::>(); - - let right_fields = right - .fields() - .iter() - .filter(|f| !duplicate_join_names.contains(f.name())) - .cloned(); - - let left_fields = left.fields().iter().map(|f| { - for key in on.iter() { - // update qualifiers for shared fields - if duplicate_join_names.contains(f.name()) { - let mut hs = HashSet::new(); - if let Some(q) = &key.0.relation { - hs.insert(q.to_string()); - } - if let Some(q) = &key.1.relation { - hs.insert(q.to_string()); - } - return f.clone().set_shared_qualifiers(hs); - } - } - - f.clone() - }); - - // left then right - left_fields.chain(right_fields).collect() - } - } + let right_fields = right.fields().iter(); + let left_fields = left.fields().iter(); + // left then right + left_fields.chain(right_fields).cloned().collect() } JoinType::Semi | JoinType::Anti => { // Only use the left side for the schema diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 75b9b5e308e0..b4bde87f3471 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -57,15 +57,6 @@ impl DFSchema { field.qualified_name() ))); } - } else if let Some(shared_qualifiers) = field.shared_qualifiers() { - for qualifier in shared_qualifiers { - if !qualified_names.insert((qualifier, field.name())) { - return Err(DataFusionError::Plan(format!( - "Schema contains duplicate qualified field name '{}'", - field.qualified_name() - ))); - } - } } else if !unqualified_names.insert(field.name()) { return Err(DataFusionError::Plan(format!( "Schema contains duplicate unqualified field name '{}'", @@ -170,18 +161,8 @@ impl DFSchema { // current field is qualified and not shared between relations, compare both // qualifer and name. (Some(q), Some(field_q)) => q == field_q && field.name() == name, - // field to lookup is qualified. - // current field is either unqualified or qualified and shared between relations. - (Some(q), None) => { - if let Some(shared_q) = field.shared_qualifiers() { - // current field is a shared qualified field, check for all shared - // relation names. - shared_q.contains(q) && field.name() == name - } else { - // current field is unqualifiied - false - } - } + // field to lookup is qualified but current field is unqualified. + (Some(_), None) => false, // field to lookup is unqualified, no need to compare qualifier _ => field.name() == name, }) @@ -222,13 +203,17 @@ impl DFSchema { } } - /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { - let matches: Vec<&DFField> = self - .fields + /// Find all fields match the given name + pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { + self.fields .iter() .filter(|field| field.name() == name) - .collect(); + .collect() + } + + /// Find the field with the given name + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { + let matches = self.fields_with_unqualified_name(name); match matches.len() { 0 => Err(DataFusionError::Plan(format!( "No field with unqualified name '{}'. Valid fields are {}.", @@ -417,9 +402,6 @@ impl Display for DFSchema { pub struct DFField { /// Optional qualifier (usually a table or relation name) qualifier: Option, - /// Optional set of qualifiers that all share this same field. This is used for `JOIN USING` - /// clause where the join keys are combined into a shared column. - shared_qualifiers: Option>, /// Arrow field definition field: Field, } @@ -434,7 +416,6 @@ impl DFField { ) -> Self { DFField { qualifier: qualifier.map(|s| s.to_owned()), - shared_qualifiers: None, field: Field::new(name, data_type, nullable), } } @@ -443,7 +424,6 @@ impl DFField { pub fn from(field: Field) -> Self { Self { qualifier: None, - shared_qualifiers: None, field, } } @@ -452,7 +432,6 @@ impl DFField { pub fn from_qualified(qualifier: &str, field: Field) -> Self { Self { qualifier: Some(qualifier.to_owned()), - shared_qualifiers: None, field, } } @@ -502,11 +481,6 @@ impl DFField { self.qualifier.as_ref() } - /// Get the optional qualifier - pub fn shared_qualifiers(&self) -> Option<&HashSet> { - self.shared_qualifiers.as_ref() - } - /// Get the arrow field pub fn field(&self) -> &Field { &self.field @@ -517,13 +491,6 @@ impl DFField { self.qualifier = None; self } - - /// Return field with shared qualifiers set and qualifier stripped - pub fn set_shared_qualifiers(mut self, shared_qualifiers: HashSet) -> Self { - self.qualifier = None; - self.shared_qualifiers = Some(shared_qualifiers); - self - } } #[cfg(test)] diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index a27140f65384..c63b97fb91cc 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,7 +20,7 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; -use crate::logical_plan::{window_frames, DFField, DFSchema, DFSchemaRef}; +use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan}; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, window_functions, @@ -29,7 +29,7 @@ use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::sync::Arc; @@ -84,14 +84,33 @@ impl Column { } /// Normalize Column with qualifier based on provided dataframe schemas. - pub fn normalize(self, schemas: &[&DFSchemaRef]) -> Result { + pub fn normalize(self, plan: &LogicalPlan) -> Result { if self.relation.is_some() { return Ok(self); } - for schema in schemas { - if let Ok(field) = schema.field_with_unqualified_name(&self.name) { - return Ok(field.qualified_column()); + let schemas = plan.all_schemas(); + let using_columns = plan.using_columns()?; + + for schema in &schemas { + let fields = schema.fields_with_unqualified_name(&self.name); + match fields.len() { + 0 => continue, + 1 => { + return Ok(fields[0].qualified_column()); + } + _ => { + for using_col in &using_columns { + let all_matched = fields + .iter() + .all(|f| using_col.contains(&f.qualified_column())); + // All matched fields belong to the same using column set, use the first + // qualifer + if all_matched { + return Ok(fields[0].qualified_column()); + } + } + } } } @@ -1109,35 +1128,56 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { } } -/// Recursively normalize all Column expressions in a given expression tree -pub fn normalize_col(e: Expr, schemas: &[&DFSchemaRef]) -> Result { - struct ColumnNormalizer<'a, 'b> { - schemas: &'a [&'b DFSchemaRef], +/// Recursively replace all Column expressions in a given expression tree with Column expressions +/// provided by the hash map argument. +pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { + struct ColumnReplacer<'a> { + replace_map: &'a HashMap<&'a Column, &'a Column>, + } + + impl<'a> ExprRewriter for ColumnReplacer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = &expr { + match self.replace_map.get(c) { + Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), + None => Ok(expr), + } + } else { + Ok(expr) + } + } + } + + e.rewrite(&mut ColumnReplacer { replace_map }) +} + +/// Recursively normalize all Column expressions in a given expression tree by adding qualifiers +/// wherever applicable +pub fn normalize_col(e: Expr, plan: &LogicalPlan) -> Result { + struct ColumnNormalizer<'a> { + plan: &'a LogicalPlan, } - impl<'a, 'b> ExprRewriter for ColumnNormalizer<'a, 'b> { + impl<'a> ExprRewriter for ColumnNormalizer<'a> { fn mutate(&mut self, expr: Expr) -> Result { if let Expr::Column(c) = expr { - Ok(Expr::Column(c.normalize(self.schemas)?)) + Ok(Expr::Column(c.normalize(self.plan)?)) } else { Ok(expr) } } } - e.rewrite(&mut ColumnNormalizer { schemas }) + e.rewrite(&mut ColumnNormalizer { plan }) } /// Recursively normalize all Column expressions in a list of expression trees #[inline] pub fn normalize_cols( exprs: impl IntoIterator, - schemas: &[&DFSchemaRef], + plan: &LogicalPlan, ) -> Result> { - exprs - .into_iter() - .map(|e| normalize_col(e, schemas)) - .collect() + exprs.into_iter().map(|e| normalize_col(e, plan)).collect() } /// Create an expression to represent the min() aggregate function diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 69d03d22bb21..86a2f567d7de 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -41,10 +41,10 @@ pub use expr::{ cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, - sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, - to_hex, translate, trim, trunc, upper, when, Column, Expr, ExprRewriter, - ExpressionVisitor, Literal, Recursion, + regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, rtrim, + sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, + substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Column, Expr, + ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index fa40904bd4e9..b954b6a97950 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -21,9 +21,11 @@ use super::display::{GraphvizVisitor, IndentVisitor}; use super::expr::{Column, Expr}; use super::extension::UserDefinedLogicalNode; use crate::datasource::TableProvider; +use crate::error::DataFusionError; use crate::logical_plan::dfschema::DFSchemaRef; use crate::sql::parser::FileType; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use std::collections::HashSet; use std::{ fmt::{self, Display}, sync::Arc, @@ -354,6 +356,43 @@ impl LogicalPlan { | LogicalPlan::CreateExternalTable { .. } => vec![], } } + + /// returns all `Using` join columns in a logical plan + pub fn using_columns(&self) -> Result>, DataFusionError> { + struct UsingJoinColumnVisitor { + using_columns: Vec>, + } + + impl PlanVisitor for UsingJoinColumnVisitor { + type Error = DataFusionError; + + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + if let LogicalPlan::Join { + join_constraint: JoinConstraint::Using, + on, + .. + } = plan + { + self.using_columns.push( + on.iter() + .map(|entry| { + std::iter::once(entry.0.clone()) + .chain(std::iter::once(entry.1.clone())) + }) + .flatten() + .collect::>(), + ); + } + Ok(true) + } + } + + let mut visitor = UsingJoinColumnVisitor { + using_columns: vec![], + }; + self.accept(&mut visitor)?; + Ok(visitor.using_columns) + } } /// Logical partitioning schemes supported by the repartition operator. diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index bdb7177cca66..2960cf3816cf 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -16,7 +16,7 @@ use crate::datasource::datasource::TableProviderFilterPushDown; use crate::execution::context::ExecutionProps; -use crate::logical_plan::{and, Column, LogicalPlan}; +use crate::logical_plan::{and, replace_col, Column, LogicalPlan}; use crate::logical_plan::{DFSchema, Expr}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -345,8 +345,91 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .collect::>(); issue_filters(state, used_columns, plan) } - LogicalPlan::Join { left, right, .. } - | LogicalPlan::CrossJoin { left, right, .. } => { + LogicalPlan::CrossJoin { left, right, .. } => { + let (pushable_to_left, pushable_to_right, keep) = + get_join_predicates(&state, left.schema(), right.schema()); + + let mut left_state = state.clone(); + left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); + let left = optimize(left, left_state)?; + + let mut right_state = state.clone(); + right_state.filters = keep_filters(&right_state.filters, &pushable_to_right); + let right = optimize(right, right_state)?; + + // create a new Join with the new `left` and `right` + let expr = plan.expressions(); + let plan = utils::from_plan(plan, &expr, &[left, right])?; + + if keep.0.is_empty() { + Ok(plan) + } else { + // wrap the join on the filter whose predicates must be kept + let plan = add_filter(plan, &keep.0); + state.filters = remove_filters(&state.filters, &keep.1); + + Ok(plan) + } + } + LogicalPlan::Join { + left, right, on, .. + } => { + // duplicate filters for joined columns so filters can be pushed down to both sides. + // Take the following query as an example: + // + // ```sql + // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 + // ``` + // + // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while + // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. + // + // Join clauses with `Using` constraints also take advantage of this logic to make sure + // predicates reference the shared join columns are pushed to both sides. + let join_side_filters = state + .filters + .iter() + .filter_map(|(predicate, columns)| { + let mut join_cols_to_replace = HashMap::new(); + for col in columns.iter() { + for (l, r) in on { + if col == l { + join_cols_to_replace.insert(col, r); + break; + } else if col == r { + join_cols_to_replace.insert(col, l); + break; + } + } + } + + if join_cols_to_replace.is_empty() { + return None; + } + + let join_side_predicate = + match replace_col(predicate.clone(), &join_cols_to_replace) { + Ok(p) => p, + Err(e) => { + return Some(Err(e)); + } + }; + + let join_side_columns = columns + .clone() + .into_iter() + // replace keys in join_cols_to_replace with values in resulting column + // set + .filter(|c| !join_cols_to_replace.contains_key(c)) + .chain(join_cols_to_replace.iter().map(|(_, v)| (*v).clone())) + .collect(); + + Some(Ok((join_side_predicate, join_side_columns))) + }) + .collect::>>()?; + + state.filters.extend(join_side_filters); + let (pushable_to_left, pushable_to_right, keep) = get_join_predicates(&state, left.schema(), right.schema()); @@ -887,9 +970,51 @@ mod tests { Ok(()) } - /// post-join predicates on a column common to both sides is pushed to both sides + /// post-on-join predicates on a column common to both sides is pushed to both sides + #[test] + fn filter_on_join_on_common_independent() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + vec![Column::from_name("a".to_string())], + )? + .filter(col("a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{:?}", plan), + "\ + Filter: #test.a LtEq Int64(1)\ + \n Join: #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" + ); + + // filter sent to side before the join + let expected = "\ + Join: #test.a = #test2.a\ + \n Filter: #test.a LtEq Int64(1)\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n Filter: #test2.a LtEq Int64(1)\ + \n TableScan: test2 projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + /// post-using-join predicates on a column common to both sides is pushed to both sides #[test] - fn filter_join_on_common_independent() -> Result<()> { + fn filter_using_join_on_common_independent() -> Result<()> { let table_scan = test_table_scan()?; let left = LogicalPlanBuilder::from(table_scan).build()?; let right_table_scan = test_table_scan_with_name("test2")?; @@ -909,7 +1034,7 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #a LtEq Int64(1)\ + Filter: #test.a LtEq Int64(1)\ \n Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ @@ -919,10 +1044,10 @@ mod tests { // filter sent to side before the join let expected = "\ Join: Using #test.a = #test2.a\ - \n Filter: #a LtEq Int64(1)\ + \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None\ \n Projection: #test2.a\ - \n Filter: #a LtEq Int64(1)\ + \n Filter: #test2.a LtEq Int64(1)\ \n TableScan: test2 projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 4bf2b6e797f8..3abcf326c2ca 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -216,9 +216,7 @@ fn optimize_plan( let schema = build_join_schema( optimized_left.schema(), optimized_right.schema(), - on, join_type, - join_constraint, )?; Ok(LogicalPlan::Join { @@ -499,7 +497,7 @@ mod tests { } #[test] - fn join_schema_trim() -> Result<()> { + fn join_schema_trim_full_join_column_projection() -> Result<()> { let table_scan = test_table_scan()?; let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); @@ -511,7 +509,7 @@ mod tests { .project(vec![col("a"), col("b"), col("c1")])? .build()?; - // make sure projections are pushed down to table scan + // make sure projections are pushed down to both table scans let expected = "Projection: #test.a, #test.b, #test2.c1\ \n Join: #test.a = #test2.c1\ \n TableScan: test projection=Some([0, 1])\ @@ -521,7 +519,48 @@ mod tests { let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); - // make sure schema for join node doesn't include c1 column + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new(vec![ + DFField::new(Some("test"), "a", DataType::UInt32, false), + DFField::new(Some("test"), "b", DataType::UInt32, false), + DFField::new(Some("test2"), "c1", DataType::UInt32, false), + ])?, + ); + + Ok(()) + } + + #[test] + fn join_schema_trim_partial_join_column_projection() -> Result<()> { + // test join column push down without explicit column projections + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); + let table2_scan = + LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])? + // projecting joined column `a` should push the right side column `c1` projection as + // well into test2 table even though `c1` is not referenced in projection. + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to both table scans + let expected = "Projection: #test.a, #test.b\ + \n Join: #test.a = #test2.c1\ + \n TableScan: test projection=Some([0, 1])\ + \n TableScan: test2 projection=Some([0])"; + + let optimized_plan = optimize(&plan)?; + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns let optimized_join = optimized_plan.inputs()[0]; assert_eq!( **optimized_join.schema(), @@ -535,6 +574,45 @@ mod tests { Ok(()) } + #[test] + fn join_schema_trim_using_join() -> Result<()> { + // shared join colums from using join should be pushed to both sides + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + let table2_scan = + LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join_using(&table2_scan, JoinType::Left, vec!["a"])? + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to table scan + let expected = "Projection: #test.a, #test.b\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=Some([0, 1])\ + \n TableScan: test2 projection=Some([0])"; + + let optimized_plan = optimize(&plan)?; + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new(vec![ + DFField::new(Some("test"), "a", DataType::UInt32, false), + DFField::new(Some("test"), "b", DataType::UInt32, false), + DFField::new(Some("test2"), "a", DataType::UInt32, false), + ])?, + ); + + Ok(()) + } + #[test] fn cast() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 394308f5af80..7c5d1b51d478 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -215,13 +215,8 @@ pub fn from_plan( on, .. } => { - let schema = build_join_schema( - inputs[0].schema(), - inputs[1].schema(), - on, - join_type, - join_constraint, - )?; + let schema = + build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?; Ok(LogicalPlan::Join { left: Arc::new(inputs[0].clone()), right: Arc::new(inputs[1].clone()), diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index ef330c579c11..a8072ca40515 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -58,7 +58,7 @@ use super::{ merge::MergeExec, }; use crate::error::{DataFusionError, Result}; -use crate::logical_plan::{JoinConstraint, JoinType}; +use crate::logical_plan::JoinType; use super::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -95,8 +95,6 @@ pub struct HashJoinExec { on: Vec<(Column, Column)>, /// How the join is performed join_type: JoinType, - /// Join constraint - join_constraint: JoinConstraint, /// The schema once the join is applied schema: SchemaRef, /// Build-side @@ -133,20 +131,13 @@ impl HashJoinExec { right: Arc, on: JoinOn, join_type: &JoinType, - join_constraint: JoinConstraint, partition_mode: PartitionMode, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); check_join_is_valid(&left_schema, &right_schema, &on)?; - let schema = Arc::new(build_join_schema( - &left_schema, - &right_schema, - &on, - join_type, - join_constraint, - )); + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, join_type)); let random_state = RandomState::with_seeds(0, 0, 0, 0); @@ -155,7 +146,6 @@ impl HashJoinExec { right, on, join_type: *join_type, - join_constraint, schema, build_side: Arc::new(Mutex::new(None)), random_state, @@ -183,11 +173,6 @@ impl HashJoinExec { &self.join_type } - /// Join constraint - pub fn join_constraint(&self) -> JoinConstraint { - self.join_constraint - } - /// Calculates column indices and left/right placement on input / output schemas and jointype fn column_indices_from_schema(&self) -> ArrowResult> { let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { @@ -245,7 +230,6 @@ impl ExecutionPlan for HashJoinExec { children[1].clone(), self.on.clone(), &self.join_type, - self.join_constraint, self.mode, )?)), _ => Err(DataFusionError::Internal( @@ -1325,16 +1309,8 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, - join_constraint: JoinConstraint, ) -> Result { - HashJoinExec::try_new( - left, - right, - on, - join_type, - join_constraint, - PartitionMode::CollectLeft, - ) + HashJoinExec::try_new(left, right, on, join_type, PartitionMode::CollectLeft) } async fn join_collect( @@ -1342,9 +1318,8 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, - join_constraint: JoinConstraint, ) -> Result<(Vec, Vec)> { - let join = join(left, right, on, join_type, join_constraint)?; + let join = join(left, right, on, join_type)?; let columns = columns(&join.schema()); let stream = join.execute(0).await?; @@ -1358,7 +1333,6 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, - join_constraint: JoinConstraint, ) -> Result<(Vec, Vec)> { let partition_count = 4; @@ -1383,7 +1357,6 @@ mod tests { )?), on, join_type, - join_constraint, PartitionMode::Partitioned, )?; @@ -1422,14 +1395,9 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let (columns, batches) = join_collect( - left.clone(), - right.clone(), - on.clone(), - &JoinType::Inner, - JoinConstraint::On, - ) - .await?; + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1447,49 +1415,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn join_inner_using() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 5]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - - let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, - )]; - - let (columns, batches) = join_collect( - left.clone(), - right.clone(), - on.clone(), - &JoinType::Inner, - JoinConstraint::Using, - ) - .await?; - - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); - - let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 5 | 9 | 20 | 80 |", - "+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - - Ok(()) - } - #[tokio::test] async fn partitioned_join_inner_one() -> Result<()> { let left = build_table( @@ -1512,20 +1437,19 @@ mod tests { right.clone(), on.clone(), &JoinType::Inner, - JoinConstraint::Using, ) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 5 | 9 | 20 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1549,8 +1473,7 @@ mod tests { Column::new_with_schema("b2", &right.schema())?, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, JoinConstraint::On).await?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1592,8 +1515,7 @@ mod tests { ), ]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, JoinConstraint::On).await?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -1616,7 +1538,7 @@ mod tests { /// Test where the left has 2 parts, the right with 1 part => 1 part #[tokio::test] - async fn join_inner_using_one_two_parts_left() -> Result<()> { + async fn join_inner_one_two_parts_left() -> Result<()> { let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -1645,22 +1567,20 @@ mod tests { ), ]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, JoinConstraint::Using) - .await?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+", - "| a1 | b2 | c1 | c2 |", - "+----+----+----+----+", - "| 1 | 1 | 7 | 70 |", - "| 2 | 2 | 8 | 80 |", - "| 2 | 2 | 9 | 80 |", - "+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1670,7 +1590,7 @@ mod tests { /// Test where the left has 1 part, the right has 2 parts => 2 parts #[tokio::test] - async fn join_inner_using_one_two_parts_right() -> Result<()> { + async fn join_inner_one_two_parts_right() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1694,10 +1614,10 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let join = join(left, right, on, &JoinType::Inner, JoinConstraint::Using)?; + let join = join(left, right, on, &JoinType::Inner)?; let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); // first part let stream = join.execute(0).await?; @@ -1705,11 +1625,11 @@ mod tests { assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1718,12 +1638,12 @@ mod tests { let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 2 | 5 | 8 | 30 | 90 |", - "| 3 | 5 | 9 | 30 | 90 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 2 | 5 | 8 | 30 | 5 | 90 |", + "| 3 | 5 | 9 | 30 | 5 | 90 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1760,24 +1680,24 @@ mod tests { Column::new_with_schema("b1", &right.schema()).unwrap(), )]; - let join = join(left, right, on, &JoinType::Left, JoinConstraint::Using).unwrap(); + let join = join(left, right, on, &JoinType::Left).unwrap(); let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let stream = join.execute(0).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1801,7 +1721,7 @@ mod tests { Column::new_with_schema("b2", &right.schema()).unwrap(), )]; - let join = join(left, right, on, &JoinType::Full, JoinConstraint::On).unwrap(); + let join = join(left, right, on, &JoinType::Full).unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1840,22 +1760,22 @@ mod tests { )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); - let join = join(left, right, on, &JoinType::Left, JoinConstraint::Using).unwrap(); + let join = join(left, right, on, &JoinType::Left).unwrap(); let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let stream = join.execute(0).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | | |", - "| 2 | 5 | 8 | | |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | | 4 | |", + "| 2 | 5 | 8 | | 5 | |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1875,7 +1795,7 @@ mod tests { )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); - let join = join(left, right, on, &JoinType::Full, JoinConstraint::On).unwrap(); + let join = join(left, right, on, &JoinType::Full).unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1913,14 +1833,9 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let (columns, batches) = join_collect( - left.clone(), - right.clone(), - on.clone(), - &JoinType::Left, - JoinConstraint::On, - ) - .await?; + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ @@ -1959,19 +1874,18 @@ mod tests { right.clone(), on.clone(), &JoinType::Left, - JoinConstraint::Using, ) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1995,7 +1909,7 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let join = join(left, right, on, &JoinType::Semi, JoinConstraint::On)?; + let join = join(left, right, on, &JoinType::Semi)?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -2034,7 +1948,7 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let join = join(left, right, on, &JoinType::Anti, JoinConstraint::On)?; + let join = join(left, right, on, &JoinType::Anti)?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -2071,8 +1985,7 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Right, JoinConstraint::On).await?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2108,25 +2021,19 @@ mod tests { Column::new_with_schema("b1", &right.schema())?, )]; - let (columns, batches) = partitioned_join_collect( - left, - right, - on, - &JoinType::Right, - JoinConstraint::Using, - ) - .await?; + let (columns, batches) = + partitioned_join_collect(left, right, on, &JoinType::Right).await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| | 6 | | 30 | 90 |", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2151,7 +2058,7 @@ mod tests { Column::new_with_schema("b2", &right.schema()).unwrap(), )]; - let join = join(left, right, on, &JoinType::Full, JoinConstraint::On)?; + let join = join(left, right, on, &JoinType::Full)?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index e4fd17be91a0..9243affe9cfc 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -21,7 +21,7 @@ use crate::error::{DataFusionError, Result}; use arrow::datatypes::{Field, Schema}; use std::collections::HashSet; -use crate::logical_plan::{JoinConstraint, JoinType}; +use crate::logical_plan::JoinType; use crate::physical_plan::expressions::Column; /// The on clause of the join, as vector of (left, right) columns. @@ -88,34 +88,13 @@ fn check_join_set_is_valid( /// Creates a schema for a join operation. /// The fields from the left side are first -pub fn build_join_schema( - left: &Schema, - right: &Schema, - on: JoinOnRef, - join_type: &JoinType, - join_constraint: JoinConstraint, -) -> Schema { +pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema { let fields: Vec = match join_type { JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - match join_constraint { - JoinConstraint::On => { - let left_fields = left.fields().iter(); - let right_fields = right.fields().iter(); - // left then right - left_fields.chain(right_fields).cloned().collect() - } - JoinConstraint::Using => { - // using join requires unique join columns in the output schema, so we mark all - // right join keys as duplicate - let duplicate_keys = - &on.iter().map(|on| on.1.name()).collect::>(); - let right_fields = right - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(f.name().as_str())); - left.fields().iter().chain(right_fields).cloned().collect() - } - } + let left_fields = left.fields().iter(); + let right_fields = right.fields().iter(); + // left then right + left_fields.chain(right_fields).cloned().collect() } JoinType::Semi | JoinType::Anti => left.fields().clone(), }; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 1153771f5f49..34b20b64b1b1 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -563,7 +563,6 @@ impl DefaultPhysicalPlanner { right, on: keys, join_type, - join_constraint, .. } => { let left_df_schema = left.schema(); @@ -604,7 +603,6 @@ impl DefaultPhysicalPlanner { )?), join_on, join_type, - *join_constraint, PartitionMode::Partitioned, )?)) } else { @@ -613,7 +611,6 @@ impl DefaultPhysicalPlanner { physical_right, join_on, join_type, - *join_constraint, PartitionMode::CollectLeft, )?)) } @@ -1389,7 +1386,7 @@ mod tests { let expected_error: &str = "Error during planning: \ Extension planner for NoOp created an ExecutionPlan with mismatched schema. \ LogicalPlan schema: DFSchema { fields: [\ - DFField { qualifier: None, shared_qualifiers: None, field: Field { \ + DFField { qualifier: None, field: Field { \ name: \"a\", \ data_type: Int32, \ nullable: false, \ diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index c855413dd13e..0174c0c25232 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -27,8 +27,8 @@ use crate::datasource::TableProvider; use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ - and, lit, union_with_alias, Column, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, PlanType, StringifiedPlan, ToDFSchema, + and, col, lit, normalize_col, union_with_alias, Column, DFSchema, Expr, LogicalPlan, + LogicalPlanBuilder, Operator, PlanType, StringifiedPlan, ToDFSchema, }; use crate::prelude::JoinType; use crate::scalar::ScalarValue; @@ -560,7 +560,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // SELECT c1 AS m FROM t HAVING c1 > 10; // SELECT c1, MAX(c2) AS m FROM t GROUP BY c1 HAVING MAX(c2) > 10; // - resolve_aliases_to_exprs(&having_expr, &alias_map) + let having_expr = resolve_aliases_to_exprs(&having_expr, &alias_map)?; + normalize_col(having_expr, &projected_plan) }) .transpose()?; @@ -584,6 +585,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let group_by_expr = resolve_positions_to_exprs(&group_by_expr, &select_exprs) .unwrap_or(group_by_expr); + let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; self.validate_schema_satisfies_exprs( plan.schema(), &[group_by_expr.clone()], @@ -662,13 +664,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result> { let input_schema = plan.schema(); - Ok(projection + projection .iter() .map(|expr| self.sql_select_to_rex(expr, input_schema)) .collect::>>()? .iter() .flat_map(|expr| expand_wildcard(expr, input_schema)) - .collect::>()) + .map(|expr| normalize_col(expr, plan)) + .collect::>>() } /// Wrap a plan in a projection @@ -816,24 +819,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { find_column_exprs(exprs) .iter() .try_for_each(|col| match col { - Expr::Column(col) => { - match &col.relation { - Some(r) => { - Ok(schema.field_with_qualified_name(r, &col.name)?.to_owned()) - } - None => { - Ok(schema.field_with_unqualified_name(&col.name)?.to_owned()) + Expr::Column(col) => match &col.relation { + Some(r) => { + schema.field_with_qualified_name(r, &col.name)?; + Ok(()) + } + None => { + if !schema.fields_with_unqualified_name(&col.name).is_empty() { + Ok(()) + } else { + Err(DataFusionError::Plan(format!( + "No field with unqualified name '{}'", + &col.name + ))) } } - .map_err(|_: DataFusionError| { - DataFusionError::Plan(format!( - "Invalid identifier '{}' for schema {}", - col, - schema.to_string() - )) - })?; - Ok(()) } + .map_err(|_: DataFusionError| { + DataFusionError::Plan(format!( + "Invalid identifier '{}' for schema {}", + col, + schema.to_string() + )) + }), _ => Err(DataFusionError::Internal("Not a column".to_string())), }) } @@ -911,11 +919,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let var_names = vec![id.value.clone()]; Ok(Expr::ScalarVariable(var_names)) } else { - Ok(Expr::Column( - schema - .field_with_unqualified_name(&id.value)? - .qualified_column(), - )) + Ok(col(&id.value)) } } @@ -1654,7 +1658,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -1712,7 +1716,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -1722,7 +1726,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'x'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#x' for schema "), )); } @@ -2193,7 +2197,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -2283,7 +2287,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Column #doesnotexist not found in provided schemas"), )); } @@ -2293,7 +2297,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -2724,7 +2728,7 @@ mod tests { FROM person \ JOIN person as person2 \ USING (id)"; - let expected = "Projection: #person.first_name, #id\ + let expected = "Projection: #person.first_name, #person.id\ \n Join: Using #person.id = #person2.id\ \n TableScan: person projection=None\ \n TableScan: person2 projection=None"; From ccbb844503dd5bd61a41428ba4502c374ea3db09 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 3 Jul 2021 22:03:46 -0700 Subject: [PATCH 3/5] add more comments & fix clippy --- datafusion/src/logical_plan/expr.rs | 17 +++++++++++++++-- datafusion/src/physical_plan/planner.rs | 4 +--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index d16f7346271e..9454d7593c3f 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -105,12 +105,25 @@ impl Column { return Ok(fields[0].qualified_column()); } _ => { + // More than 1 fields in this schema have their names set to self.name. + // + // This should only happen when a JOIN query with USING constraint references + // join columns using unqualified column name. For example: + // + // ```sql + // SELECT id FROM t1 JOIN t2 USING(id) + // ``` + // + // In this case, both `t1.id` and `t2.id` will match unqualified column `id`. + // We will use the relation from the first matched field to normalize self. + + // Compare matched fields with one USING JOIN clause at a time for using_col in &using_columns { let all_matched = fields .iter() .all(|f| using_col.contains(&f.qualified_column())); - // All matched fields belong to the same using column set, use the first - // qualifer + // All matched fields belong to the same using column set, in orther words + // the same join clause. We simply pick the qualifer from the first match. if all_matched { return Ok(fields[0].qualified_column()); } diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 5d56ba7a3f3d..12f563618d85 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -39,9 +39,7 @@ use crate::physical_plan::sort::SortExec; use crate::physical_plan::udf; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{hash_utils, Partitioning}; -use crate::physical_plan::{ - AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner, WindowExpr, -}; +use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; use crate::scalar::ScalarValue; use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys}; use crate::variable::VarType; From 091bb75fcc0c280923a7794bfe9dd55c4629eeea Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 3 Jul 2021 22:14:02 -0700 Subject: [PATCH 4/5] add more comment --- datafusion/src/sql/planner.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index c96eab596012..b633e6e8ca22 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -919,6 +919,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let var_names = vec![id.value.clone()]; Ok(Expr::ScalarVariable(var_names)) } else { + // create a column expression based on raw user input, this column will be + // normalized with qualifer later by the SQL planner. Ok(col(&id.value)) } } From e97b86a8bc410983a73b5802ae44eb7a55faecd3 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 3 Jul 2021 23:41:29 -0700 Subject: [PATCH 5/5] reduce duplicate code in join predicate pushdown --- datafusion/src/optimizer/filter_push_down.rs | 83 ++++++++------------ 1 file changed, 34 insertions(+), 49 deletions(-) diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index 9541057e6ef9..76d8c05bed4c 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -241,6 +241,38 @@ fn split_members<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>) { } } +fn optimize_join( + mut state: State, + plan: &LogicalPlan, + left: &LogicalPlan, + right: &LogicalPlan, +) -> Result { + let (pushable_to_left, pushable_to_right, keep) = + get_join_predicates(&state, left.schema(), right.schema()); + + let mut left_state = state.clone(); + left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); + let left = optimize(left, left_state)?; + + let mut right_state = state.clone(); + right_state.filters = keep_filters(&right_state.filters, &pushable_to_right); + let right = optimize(right, right_state)?; + + // create a new Join with the new `left` and `right` + let expr = plan.expressions(); + let plan = utils::from_plan(plan, &expr, &[left, right])?; + + if keep.0.is_empty() { + Ok(plan) + } else { + // wrap the join on the filter whose predicates must be kept + let plan = add_filter(plan, &keep.0); + state.filters = remove_filters(&state.filters, &keep.1); + + Ok(plan) + } +} + fn optimize(plan: &LogicalPlan, mut state: State) -> Result { match plan { LogicalPlan::Explain { .. } => { @@ -346,30 +378,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { issue_filters(state, used_columns, plan) } LogicalPlan::CrossJoin { left, right, .. } => { - let (pushable_to_left, pushable_to_right, keep) = - get_join_predicates(&state, left.schema(), right.schema()); - - let mut left_state = state.clone(); - left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); - let left = optimize(left, left_state)?; - - let mut right_state = state.clone(); - right_state.filters = keep_filters(&right_state.filters, &pushable_to_right); - let right = optimize(right, right_state)?; - - // create a new Join with the new `left` and `right` - let expr = plan.expressions(); - let plan = utils::from_plan(plan, &expr, &[left, right])?; - - if keep.0.is_empty() { - Ok(plan) - } else { - // wrap the join on the filter whose predicates must be kept - let plan = add_filter(plan, &keep.0); - state.filters = remove_filters(&state.filters, &keep.1); - - Ok(plan) - } + optimize_join(state, plan, left, right) } LogicalPlan::Join { left, right, on, .. @@ -427,33 +436,9 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { Some(Ok((join_side_predicate, join_side_columns))) }) .collect::>>()?; - state.filters.extend(join_side_filters); - let (pushable_to_left, pushable_to_right, keep) = - get_join_predicates(&state, left.schema(), right.schema()); - - let mut left_state = state.clone(); - left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); - let left = optimize(left, left_state)?; - - let mut right_state = state.clone(); - right_state.filters = keep_filters(&right_state.filters, &pushable_to_right); - let right = optimize(right, right_state)?; - - // create a new Join with the new `left` and `right` - let expr = plan.expressions(); - let plan = utils::from_plan(plan, &expr, &[left, right])?; - - if keep.0.is_empty() { - Ok(plan) - } else { - // wrap the join on the filter whose predicates must be kept - let plan = add_filter(plan, &keep.0); - state.filters = remove_filters(&state.filters, &keep.1); - - Ok(plan) - } + optimize_join(state, plan, left, right) } LogicalPlan::TableScan { source,