Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure that the data types are supported in hashjoin before genera… #2702

Merged
merged 5 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,3 +1204,141 @@ async fn join_partitioned() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn join_with_hash_unsupported_data_type() -> Result<()> {
let ctx = SessionContext::new();

let schema = Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
Field::new("c3", DataType::Int64, true),
Field::new("c4", DataType::Date32, true),
]);
let data = RecordBatch::try_new(
Arc::new(schema),
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(StringArray::from_slice(&["aaa", "bbb", "ccc"])),
Arc::new(Int64Array::from_slice(&[100, 200, 300])),
Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])),
],
)?;
let table = MemTable::try_new(data.schema(), vec![vec![data]])?;
ctx.register_table("foo", Arc::new(table))?;

// join on hash unsupported data type (Date32), use cross join instead hash join
let sql = "select * from foo t1 join foo t2 on t1.c4 = t2.c4";
let msg = format!("Creating logical plan for '{}'", sql);
Comment on lines +1230 to +1232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think CrossJoin is almost never what the user would want: as once the tables get beyond any trivial size the query will effectively never finish or will run out of memory. An error is clearer.

From the issue description #2145 (comment) I think @pjmore's idea to cast unsupported types to a supported type is a good one -- the arrow cast kernels are quite efficient for things like Date32 -> Int32 (no copies) as the representations are the same

@pjmore what do you think?

Copy link
Contributor Author

@HuSen8891 HuSen8891 Jun 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think CrossJoin is almost never what the user would want: as once the tables get beyond any trivial size the query will effectively never finish or will run out of memory. An error is clearer.

From the issue description #2145 (comment) I think @pjmore's idea to cast unsupported types to a supported type is a good one -- the arrow cast kernels are quite efficient for things like Date32 -> Int32 (no copies) as the representations are the same

@pjmore what do you think?

I agree. supporting more data types in hash join is the better way to solve this issue, and i'm already working on it.
And this pr only wants to make hash unsupported join running in cross join instead of error/panic, we can support more data types in hash join continuously.

let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" CrossJoin: [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+----+-----+-----+------------+----+-----+-----+------------+",
"| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
"+----+-----+-----+------------+----+-----+-----+------------+",
"| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |",
"| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |",
"| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |",
"+----+-----+-----+------------+----+-----+-----+------------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

// join on hash supported data type (Int32), use hash join
let sql = "select * from foo t1 join foo t2 on t1.c1 = t2.c1";
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" Inner Join: #t1.c1 = #t2.c1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+----+-----+-----+------------+----+-----+-----+------------+",
"| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
"+----+-----+-----+------------+----+-----+-----+------------+",
"| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |",
"| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |",
"| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |",
"+----+-----+-----+------------+----+-----+-----+------------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

// join on two columns, hash supported data type(Int64) and hash unsupported data type (Date32),
// use hash join on Int64 column, and filter on Date32 column.
let sql = "select * from foo t1, foo t2 where t1.c3 = t2.c3 and t1.c4 = t2.c4";
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" Inner Join: #t1.c3 = #t2.c3 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
" TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+----+-----+-----+------------+----+-----+-----+------------+",
"| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
"+----+-----+-----+------------+----+-----+-----+------------+",
"| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |",
"| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |",
"| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |",
"+----+-----+-----+------------+----+-----+-----+------------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}
54 changes: 42 additions & 12 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs};
use crate::utils::{columnize_expr, exprlist_to_fields, from_plan};
use crate::{and, binary_expr, Operator};
use crate::{
logical_plan::{
Aggregate, Analyze, CrossJoin, EmptyRelation, Explain, Filter, Join,
Expand All @@ -27,7 +28,7 @@ use crate::{
Union, Values, Window,
},
utils::{
expand_qualified_wildcard, expand_wildcard, expr_to_columns,
can_hash, expand_qualified_wildcard, expand_wildcard, expr_to_columns,
group_window_expr_by_sort_keys,
},
Expr, ExprSchemable, TableSource,
Expand Down Expand Up @@ -605,17 +606,46 @@ impl LogicalPlanBuilder {
let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
let join_schema =
build_join_schema(self.plan.schema(), right.schema(), &join_type)?;

Ok(Self::from(LogicalPlan::Join(Join {
left: Arc::new(self.plan.clone()),
right: Arc::new(right.clone()),
on,
filter: None,
join_type,
join_constraint: JoinConstraint::Using,
schema: DFSchemaRef::new(join_schema),
null_equals_null: false,
})))
let mut join_on: Vec<(Column, Column)> = vec![];
let mut filters: Option<Expr> = None;
for (l, r) in &on {
if self.plan.schema().field_from_column(l).is_ok()
&& right.schema().field_from_column(r).is_ok()
&& can_hash(self.plan.schema().field_from_column(l).unwrap().data_type())
{
join_on.push((l.clone(), r.clone()));
} else if self.plan.schema().field_from_column(r).is_ok()
&& right.schema().field_from_column(l).is_ok()
&& can_hash(self.plan.schema().field_from_column(r).unwrap().data_type())
{
join_on.push((r.clone(), l.clone()));
} else {
let expr = binary_expr(
Expr::Column(l.clone()),
Operator::Eq,
Expr::Column(r.clone()),
);
match filters {
None => filters = Some(expr),
Some(filter_expr) => filters = Some(and(expr, filter_expr)),
}
}
}
if join_on.is_empty() {
let join = Self::from(self.plan.clone()).cross_join(&right.clone())?;
join.filter(filters.unwrap())
} else {
Ok(Self::from(LogicalPlan::Join(Join {
left: Arc::new(self.plan.clone()),
right: Arc::new(right.clone()),
on: join_on,
filter: filters,
join_type,
join_constraint: JoinConstraint::Using,
schema: DFSchemaRef::new(join_schema),
null_equals_null: false,
})))
}
}

/// Apply a cross join
Expand Down
30 changes: 30 additions & 0 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::logical_plan::{
Values, Window,
};
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::{
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
Expand Down Expand Up @@ -643,6 +644,35 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
}
}

/// can this data type be used in hash join equal conditions??
/// If more data types are supported in hash join, add those data types here
/// to generate join logical plan.
pub fn can_hash(data_type: &DataType) -> bool {
match data_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decimal should be supported here

Copy link
Contributor Author

@HuSen8891 HuSen8891 Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data types here come from function equal_rows, decimal is not supported in equal_rows so that hash join currently does not support joining on columns of decimal data type. That's why decimal is not here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may help to add a comment here (or in equal_rows) mentioning they need to remain in sync

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may help to add a comment here (or in equal_rows) mentioning they need to remain in sync

I'll add this comment.

DataType::Null => true,
DataType::Boolean => true,
DataType::Int8 => true,
DataType::Int16 => true,
DataType::Int32 => true,
DataType::Int64 => true,
DataType::UInt8 => true,
DataType::UInt16 => true,
DataType::UInt32 => true,
DataType::UInt64 => true,
DataType::Float32 => true,
DataType::Float64 => true,
DataType::Timestamp(time_unit, None) => match time_unit {
TimeUnit::Second => true,
TimeUnit::Millisecond => true,
TimeUnit::Microsecond => true,
TimeUnit::Nanosecond => true,
},
DataType::Utf8 => true,
DataType::LargeUtf8 => true,
_ => false,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
57 changes: 51 additions & 6 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ use datafusion_expr::logical_plan::{
ToStringifiedPlan,
};
use datafusion_expr::utils::{
expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns,
find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION,
can_hash, expand_qualified_wildcard, expand_wildcard, expr_as_column_expr,
expr_to_columns, find_aggregate_exprs, find_column_exprs, find_window_exprs,
COUNT_STAR_EXPANSION,
};
use datafusion_expr::{
and, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF,
Expand Down Expand Up @@ -600,7 +601,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut filter = vec![];

// extract join keys
extract_join_keys(expr, &mut keys, &mut filter);
extract_join_keys(
expr,
&mut keys,
&mut filter,
left.schema(),
right.schema(),
);

let (left_keys, right_keys): (Vec<Column>, Vec<Column>) =
keys.into_iter().unzip();
Expand Down Expand Up @@ -819,10 +826,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
for (l, r) in &possible_join_keys {
if left_schema.field_from_column(l).is_ok()
&& right_schema.field_from_column(r).is_ok()
&& can_hash(
left_schema
.field_from_column(l)
.unwrap()
.data_type(),
)
{
join_keys.push((l.clone(), r.clone()));
} else if left_schema.field_from_column(r).is_ok()
&& right_schema.field_from_column(l).is_ok()
&& can_hash(
left_schema
.field_from_column(r)
.unwrap()
.data_type(),
)
{
join_keys.push((r.clone(), l.clone()));
}
Expand Down Expand Up @@ -2516,21 +2535,47 @@ fn extract_join_keys(
expr: Expr,
accum: &mut Vec<(Column, Column)>,
accum_filter: &mut Vec<Expr>,
left_schema: &Arc<DFSchema>,
right_schema: &Arc<DFSchema>,
) {
match &expr {
Expr::BinaryExpr { left, op, right } => match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
accum.push((l.clone(), r.clone()));
if left_schema.field_from_column(l).is_ok()
&& right_schema.field_from_column(r).is_ok()
&& can_hash(left_schema.field_from_column(l).unwrap().data_type())
{
accum.push((l.clone(), r.clone()));
} else if left_schema.field_from_column(r).is_ok()
&& right_schema.field_from_column(l).is_ok()
&& can_hash(left_schema.field_from_column(r).unwrap().data_type())
{
accum.push((r.clone(), l.clone()));
} else {
accum_filter.push(expr);
}
}
_other => {
accum_filter.push(expr);
}
},
Operator::And => {
if let Expr::BinaryExpr { left, op: _, right } = expr {
extract_join_keys(*left, accum, accum_filter);
extract_join_keys(*right, accum, accum_filter);
extract_join_keys(
*left,
accum,
accum_filter,
left_schema,
right_schema,
);
extract_join_keys(
*right,
accum,
accum_filter,
left_schema,
right_schema,
);
}
}
_other => {
Expand Down