Skip to content

Commit

Permalink
fix join column handling logic for On and Using constraints (#605)
Browse files Browse the repository at this point in the history
* fix join column handling logic for `On` and `Using` constraints

* handling join column expansion during USING JOIN planning

get rid of shared field and move column expansion logic into plan
builder and optimizer.

* add more comments & fix clippy

* add more comment

* reduce duplicate code in join predicate pushdown
  • Loading branch information
QP Hou authored Jul 7, 2021
1 parent 3664766 commit 18c581c
Show file tree
Hide file tree
Showing 23 changed files with 836 additions and 500 deletions.
10 changes: 8 additions & 2 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,18 @@ enum JoinType {
ANTI = 5;
}

enum JoinConstraint {
ON = 0;
USING = 1;
}

message JoinNode {
LogicalPlanNode left = 1;
LogicalPlanNode right = 2;
JoinType join_type = 3;
repeated Column left_join_column = 4;
repeated Column right_join_column = 5;
JoinConstraint join_constraint = 4;
repeated Column left_join_column = 5;
repeated Column right_join_column = 6;
}

message LimitNode {
Expand Down
41 changes: 25 additions & 16 deletions ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use datafusion::logical_plan::window_frames::{
};
use datafusion::logical_plan::{
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan,
LogicalPlanBuilder, Operator,
sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinConstraint, JoinType,
LogicalPlan, LogicalPlanBuilder, Operator,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::csv::CsvReadOptions;
Expand Down Expand Up @@ -257,23 +257,32 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
join.join_type
))
})?;
let join_type = match join_type {
protobuf::JoinType::Inner => JoinType::Inner,
protobuf::JoinType::Left => JoinType::Left,
protobuf::JoinType::Right => JoinType::Right,
protobuf::JoinType::Full => JoinType::Full,
protobuf::JoinType::Semi => JoinType::Semi,
protobuf::JoinType::Anti => JoinType::Anti,
};
LogicalPlanBuilder::from(convert_box_required!(join.left)?)
.join(
let join_constraint = protobuf::JoinConstraint::from_i32(
join.join_constraint,
)
.ok_or_else(|| {
proto_error(format!(
"Received a JoinNode message with unknown JoinConstraint {}",
join.join_constraint
))
})?;

let builder = LogicalPlanBuilder::from(convert_box_required!(join.left)?);
let builder = match join_constraint.into() {
JoinConstraint::On => builder.join(
&convert_box_required!(join.right)?,
join_type,
join_type.into(),
left_keys,
right_keys,
)?
.build()
.map_err(|e| e.into())
)?,
JoinConstraint::Using => builder.join_using(
&convert_box_required!(join.right)?,
join_type.into(),
left_keys,
)?,
};

builder.build().map_err(|e| e.into())
}
}
}
Expand Down
15 changes: 6 additions & 9 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUn
use datafusion::datasource::CsvFile;
use datafusion::logical_plan::{
window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits},
Column, Expr, JoinType, LogicalPlan,
Column, Expr, JoinConstraint, JoinType, LogicalPlan,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::functions::BuiltinScalarFunction;
Expand Down Expand Up @@ -804,26 +804,23 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
right,
on,
join_type,
join_constraint,
..
} => {
let left: protobuf::LogicalPlanNode = left.as_ref().try_into()?;
let right: protobuf::LogicalPlanNode = right.as_ref().try_into()?;
let join_type = match join_type {
JoinType::Inner => protobuf::JoinType::Inner,
JoinType::Left => protobuf::JoinType::Left,
JoinType::Right => protobuf::JoinType::Right,
JoinType::Full => protobuf::JoinType::Full,
JoinType::Semi => protobuf::JoinType::Semi,
JoinType::Anti => protobuf::JoinType::Anti,
};
let (left_join_column, right_join_column) =
on.iter().map(|(l, r)| (l.into(), r.into())).unzip();
let join_type: protobuf::JoinType = join_type.to_owned().into();
let join_constraint: protobuf::JoinConstraint =
join_constraint.to_owned().into();
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Join(Box::new(
protobuf::JoinNode {
left: Some(Box::new(left)),
right: Some(Box::new(right)),
join_type: join_type.into(),
join_constraint: join_constraint.into(),
left_join_column,
right_join_column,
},
Expand Down
46 changes: 45 additions & 1 deletion ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

use std::{convert::TryInto, io::Cursor};

use datafusion::logical_plan::Operator;
use datafusion::logical_plan::{JoinConstraint, JoinType, Operator};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;

Expand Down Expand Up @@ -291,3 +291,47 @@ impl Into<datafusion::arrow::datatypes::DataType> for protobuf::PrimitiveScalarT
}
}
}

impl From<protobuf::JoinType> for JoinType {
fn from(t: protobuf::JoinType) -> Self {
match t {
protobuf::JoinType::Inner => JoinType::Inner,
protobuf::JoinType::Left => JoinType::Left,
protobuf::JoinType::Right => JoinType::Right,
protobuf::JoinType::Full => JoinType::Full,
protobuf::JoinType::Semi => JoinType::Semi,
protobuf::JoinType::Anti => JoinType::Anti,
}
}
}

impl From<JoinType> for protobuf::JoinType {
fn from(t: JoinType) -> Self {
match t {
JoinType::Inner => protobuf::JoinType::Inner,
JoinType::Left => protobuf::JoinType::Left,
JoinType::Right => protobuf::JoinType::Right,
JoinType::Full => protobuf::JoinType::Full,
JoinType::Semi => protobuf::JoinType::Semi,
JoinType::Anti => protobuf::JoinType::Anti,
}
}
}

impl From<protobuf::JoinConstraint> for JoinConstraint {
fn from(t: protobuf::JoinConstraint) -> Self {
match t {
protobuf::JoinConstraint::On => JoinConstraint::On,
protobuf::JoinConstraint::Using => JoinConstraint::Using,
}
}
}

impl From<JoinConstraint> for protobuf::JoinConstraint {
fn from(t: JoinConstraint) -> Self {
match t {
JoinConstraint::On => protobuf::JoinConstraint::On,
JoinConstraint::Using => protobuf::JoinConstraint::Using,
}
}
}
16 changes: 5 additions & 11 deletions ballista/rust/core/src/serde/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ use datafusion::catalog::catalog::{
use datafusion::execution::context::{
ExecutionConfig, ExecutionContextState, ExecutionProps,
};
use datafusion::logical_plan::{window_frames::WindowFrame, DFSchema, Expr};
use datafusion::logical_plan::{
window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType,
};
use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction};
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
Expand All @@ -57,7 +59,6 @@ use datafusion::physical_plan::{
filter::FilterExec,
functions::{self, BuiltinScalarFunction, ScalarFunctionExpr},
hash_join::HashJoinExec,
hash_utils::JoinType,
limit::{GlobalLimitExec, LocalLimitExec},
parquet::ParquetExec,
projection::ProjectionExec,
Expand Down Expand Up @@ -348,14 +349,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
hashjoin.join_type
))
})?;
let join_type = match join_type {
protobuf::JoinType::Inner => JoinType::Inner,
protobuf::JoinType::Left => JoinType::Left,
protobuf::JoinType::Right => JoinType::Right,
protobuf::JoinType::Full => JoinType::Full,
protobuf::JoinType::Semi => JoinType::Semi,
protobuf::JoinType::Anti => JoinType::Anti,
};

let partition_mode =
protobuf::PartitionMode::from_i32(hashjoin.partition_mode)
.ok_or_else(|| {
Expand All @@ -372,7 +366,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
left,
right,
on,
&join_type,
&join_type.into(),
partition_mode,
)?))
}
Expand Down
3 changes: 1 addition & 2 deletions ballista/rust/core/src/serde/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ mod roundtrip_tests {
compute::kernels::sort::SortOptions,
datatypes::{DataType, Field, Schema},
},
logical_plan::Operator,
logical_plan::{JoinType, Operator},
physical_plan::{
empty::EmptyExec,
expressions::{binary, col, lit, InListExpr, NotExpr},
expressions::{Avg, Column, PhysicalSortExpr},
filter::FilterExec,
hash_aggregate::{AggregateMode, HashAggregateExec},
hash_join::{HashJoinExec, PartitionMode},
hash_utils::JoinType,
limit::{GlobalLimitExec, LocalLimitExec},
sort::SortExec,
AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning,
Expand Down
13 changes: 4 additions & 9 deletions ballista/rust/core/src/serde/physical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::{
sync::Arc,
};

use datafusion::logical_plan::JoinType;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::csv::CsvExec;
use datafusion::physical_plan::expressions::{
Expand All @@ -35,7 +36,6 @@ use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::hash_aggregate::AggregateMode;
use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::hash_utils::JoinType;
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion::physical_plan::parquet::ParquetExec;
use datafusion::physical_plan::projection::ProjectionExec;
Expand Down Expand Up @@ -135,18 +135,13 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
}),
})
.collect();
let join_type = match exec.join_type() {
JoinType::Inner => protobuf::JoinType::Inner,
JoinType::Left => protobuf::JoinType::Left,
JoinType::Right => protobuf::JoinType::Right,
JoinType::Full => protobuf::JoinType::Full,
JoinType::Semi => protobuf::JoinType::Semi,
JoinType::Anti => protobuf::JoinType::Anti,
};
let join_type: protobuf::JoinType = exec.join_type().to_owned().into();

let partition_mode = match exec.partition_mode() {
PartitionMode::CollectLeft => protobuf::PartitionMode::CollectLeft,
PartitionMode::Partitioned => protobuf::PartitionMode::Partitioned,
};

Ok(protobuf::PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new(
protobuf::HashJoinExecNode {
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/queries/q7.sql
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ group by
order by
supp_nation,
cust_nation,
l_year;
l_year;
90 changes: 90 additions & 0 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,96 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn left_join_using() -> Result<()> {
let results = execute(
"SELECT t1.c1, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2",
1,
)
.await?;
assert_eq!(results.len(), 1);

let expected = vec![
"+----+----+",
"| c1 | c2 |",
"+----+----+",
"| 0 | 1 |",
"| 0 | 2 |",
"| 0 | 3 |",
"| 0 | 4 |",
"| 0 | 5 |",
"| 0 | 6 |",
"| 0 | 7 |",
"| 0 | 8 |",
"| 0 | 9 |",
"| 0 | 10 |",
"+----+----+",
];

assert_batches_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn left_join_using_join_key_projection() -> Result<()> {
let results = execute(
"SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2",
1,
)
.await?;
assert_eq!(results.len(), 1);

let expected = vec![
"+----+----+----+",
"| c1 | c2 | c2 |",
"+----+----+----+",
"| 0 | 1 | 1 |",
"| 0 | 2 | 2 |",
"| 0 | 3 | 3 |",
"| 0 | 4 | 4 |",
"| 0 | 5 | 5 |",
"| 0 | 6 | 6 |",
"| 0 | 7 | 7 |",
"| 0 | 8 | 8 |",
"| 0 | 9 | 9 |",
"| 0 | 10 | 10 |",
"+----+----+----+",
];

assert_batches_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn left_join() -> Result<()> {
let results = execute(
"SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 ON t1.c2 = t2.c2 ORDER BY t1.c2",
1,
)
.await?;
assert_eq!(results.len(), 1);

let expected = vec![
"+----+----+----+",
"| c1 | c2 | c2 |",
"+----+----+----+",
"| 0 | 1 | 1 |",
"| 0 | 2 | 2 |",
"| 0 | 3 | 3 |",
"| 0 | 4 | 4 |",
"| 0 | 5 | 5 |",
"| 0 | 6 | 6 |",
"| 0 | 7 | 7 |",
"| 0 | 8 | 8 |",
"| 0 | 9 | 9 |",
"| 0 | 10 | 10 |",
"+----+----+----+",
];

assert_batches_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn window() -> Result<()> {
let results = execute(
Expand Down
Loading

0 comments on commit 18c581c

Please sign in to comment.