diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index b70af8884625a..a9cc839dcc4e2 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -2534,17 +2534,8 @@ fn extract_join_keys( extract_join_keys(*right, accum, accum_filter); } } - _other - if matches!(**left, Expr::Column(_)) - || matches!(**right, Expr::Column(_)) => - { - accum_filter.push(expr); - } _other => { - if let Expr::BinaryExpr { left, op: _, right } = expr { - extract_join_keys(*left, accum, accum_filter); - extract_join_keys(*right, accum, accum_filter); - } + accum_filter.push(expr); } }, _other => { @@ -4775,6 +4766,32 @@ mod tests { quick_test(sql, expected); } + #[test] + fn join_on_disjunction_condition() { + let sql = "SELECT id, order_id \ + FROM person \ + JOIN orders ON id = customer_id OR person.age > 30"; + let expected = "Projection: #person.id, #orders.order_id\ + \n Filter: #person.id = #orders.customer_id OR #person.age > Int64(30)\ + \n CrossJoin:\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn join_on_complex_condition() { + let sql = "SELECT id, order_id \ + FROM person \ + JOIN orders ON id = customer_id AND (person.age > 30 OR person.last_name = 'X')"; + let expected = "Projection: #person.id, #orders.order_id\ + \n Filter: #person.age > Int64(30) OR #person.last_name = Utf8(\"X\")\ + \n Inner Join: #person.id = #orders.customer_id\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => {