diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 93f197885c0a..1bd8896331e6 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -29,9 +29,9 @@ use url::Url; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, - EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF, - Values, + aggregate_function, builder::subquery_alias, expr::find_df_window_func, Aggregate, + BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, + Projection, ScalarUDF, Values, }; use datafusion::logical_expr::{ @@ -233,10 +233,10 @@ pub async fn from_substrait_plan( match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(ctx, rel, &function_extension).await?) + Ok(from_substrait_rel(ctx, rel, &function_extension, &mut Vec::new()).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?; + let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension, &mut Vec::new()).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -413,12 +413,13 @@ pub async fn from_substrait_rel( ctx: &SessionContext, rel: &Rel, extensions: &HashMap, + seen_table_names: &mut Vec, ) -> Result { match &rel.rel_type { Some(RelType::Project(p)) => { if let Some(input) = p.input.as_ref() { let mut input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(ctx, input, extensions, seen_table_names).await?, ); let mut exprs: Vec = vec![]; for e in &p.expressions { @@ -446,7 +447,7 @@ pub async fn from_substrait_rel( Some(RelType::Filter(filter)) => { if let Some(input) = filter.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(ctx, input, extensions, seen_table_names).await?, ); if let Some(condition) = filter.condition.as_ref() { let expr = @@ -463,7 +464,7 @@ pub async fn from_substrait_rel( Some(RelType::Fetch(fetch)) => { if let Some(input) = fetch.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(ctx, input, extensions, seen_table_names).await?, ); let offset = fetch.offset as usize; // Since protobuf can't directly distinguish `None` vs `0` `None` is encoded as `MAX` @@ -480,7 +481,7 @@ pub async fn from_substrait_rel( Some(RelType::Sort(sort)) => { if let Some(input) = sort.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(ctx, input, extensions, seen_table_names).await?, ); let sorts = from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) @@ -493,7 +494,7 @@ pub async fn from_substrait_rel( Some(RelType::Aggregate(agg)) => { if let Some(input) = agg.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(ctx, input, extensions, seen_table_names).await?, ); let mut group_expr = vec![]; let mut aggr_expr = vec![]; @@ -589,10 +590,22 @@ pub async fn from_substrait_rel( } let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?, + from_substrait_rel( + ctx, + join.left.as_ref().unwrap(), + extensions, + seen_table_names, + ) + .await?, ); let right = LogicalPlanBuilder::from( - from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?, + from_substrait_rel( + ctx, + join.right.as_ref().unwrap(), + extensions, + seen_table_names, + ) + .await?, ); let join_type = from_substrait_jointype(join.r#type)?; // The join condition expression needs full input schema and not the output schema from join since we lose columns from @@ -628,11 +641,21 @@ pub async fn from_substrait_rel( } Some(RelType::Cross(cross)) => { let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, + from_substrait_rel( + ctx, + cross.left.as_ref().unwrap(), + extensions, + seen_table_names, + ) + .await?, ); - let right = - from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) - .await?; + let right = from_substrait_rel( + ctx, + cross.right.as_ref().unwrap(), + extensions, + seen_table_names, + ) + .await?; left.cross_join(right)?.build() } Some(RelType::Read(read)) => match &read.as_ref().read_type { @@ -656,7 +679,8 @@ pub async fn from_substrait_rel( }; let t = ctx.table(table_reference).await?; let t = t.into_optimized_plan()?; - extract_projection(t, &read.projection) + let t = extract_projection(t, &read.projection)?; + qualify_table_uniquely(t, seen_table_names) } Some(ReadType::VirtualTable(vt)) => { let base_schema = read.base_schema.as_ref().ok_or_else(|| { @@ -666,10 +690,13 @@ pub async fn from_substrait_rel( let schema = from_substrait_named_struct(base_schema)?; if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema, - })); + return qualify_table_uniquely( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema, + }), + seen_table_names, + ); } let values = vt @@ -699,8 +726,8 @@ pub async fn from_substrait_rel( Ok(lits) }) .collect::>()?; - - Ok(LogicalPlan::Values(Values { schema, values })) + let t = LogicalPlan::Values(Values { schema, values }); + qualify_table_uniquely(t, seen_table_names) } Some(ReadType::LocalFiles(lf)) => { fn extract_filename(name: &str) -> Option { @@ -735,7 +762,8 @@ pub async fn from_substrait_rel( let table_reference = TableReference::Bare { table: name.into() }; let t = ctx.table(table_reference).await?; let t = t.into_optimized_plan()?; - extract_projection(t, &read.projection) + let t = extract_projection(t, &read.projection)?; + qualify_table_uniquely(t, seen_table_names) } _ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type), }, @@ -744,11 +772,24 @@ pub async fn from_substrait_rel( set_rel::SetOp::UnionAll => { if !set.inputs.is_empty() { let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(ctx, &set.inputs[0], extensions).await?, + from_substrait_rel( + ctx, + &set.inputs[0], + extensions, + seen_table_names, + ) + .await?, )); for input in &set.inputs[1..] { - union_builder = union_builder? - .union(from_substrait_rel(ctx, input, extensions).await?); + union_builder = union_builder?.union( + from_substrait_rel( + ctx, + input, + extensions, + seen_table_names, + ) + .await?, + ); } union_builder?.build() } else { @@ -782,7 +823,8 @@ pub async fn from_substrait_rel( "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" ); }; - let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; + let input_plan = + from_substrait_rel(ctx, input_rel, extensions, seen_table_names).await?; let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; Ok(LogicalPlan::Extension(Extension { node: plan })) @@ -797,7 +839,8 @@ pub async fn from_substrait_rel( .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let mut inputs = Vec::with_capacity(extension.inputs.len()); for input in &extension.inputs { - let input_plan = from_substrait_rel(ctx, input, extensions).await?; + let input_plan = + from_substrait_rel(ctx, input, extensions, seen_table_names).await?; inputs.push(input_plan); } let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; @@ -807,7 +850,9 @@ pub async fn from_substrait_rel( let Some(input) = exchange.input.as_ref() else { return substrait_err!("Unexpected empty input in ExchangeRel"); }; - let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + let input = Arc::new( + from_substrait_rel(ctx, input, extensions, seen_table_names).await?, + ); let Some(exchange_kind) = &exchange.exchange_kind else { return substrait_err!("Unexpected empty input in ExchangeRel"); @@ -846,6 +891,61 @@ pub async fn from_substrait_rel( } } +// Substrait doesn't keep table aliases, nor column aliases within the plan - columns are just +// referred to by their index in the schema. This means that column names within the plan may not +// be unique, unless we ensure each input table is qualified with a unique table reference. +fn qualify_table_uniquely( + table: LogicalPlan, + seen_table_names: &mut Vec, +) -> Result { + let original_table_ref = match table.schema().columns().first() { + Some(col) => col.relation.to_owned(), + None => return Ok(table), + }; + + let original_table_name = original_table_ref + .to_owned() + .map(|r| r.table().to_owned()) + .unwrap_or("table".to_string()); + + let mut new_ref = original_table_ref + .to_owned() + .unwrap_or(TableReference::bare("table")); + let mut i = 1; + while seen_table_names.contains(&new_ref) { + match new_ref { + TableReference::Bare { table: _ } => { + new_ref = TableReference::bare(format!("{}_{}", original_table_name, i)); + } + TableReference::Partial { schema, table: _ } => { + new_ref = TableReference::partial( + schema.to_owned(), + format!("{}_{}", original_table_name, i), + ); + } + TableReference::Full { + catalog, + schema, + table: _, + } => { + new_ref = TableReference::full( + catalog.to_owned(), + schema.to_owned(), + format!("{}_{}", original_table_name, i), + ); + } + } + i += 1; + } + + seen_table_names.push(new_ref.to_owned()); + + if original_table_ref.as_ref() == Some(&new_ref) { + return Ok(table); + } + subquery_alias(table, new_ref) +} + fn from_substrait_jointype(join_type: i32) -> Result { if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { match substrait_join_type { @@ -1267,9 +1367,13 @@ pub async fn from_substrait_rex( let needle_expr = &in_predicate.needles[0]; let haystack_expr = &in_predicate.haystack; if let Some(haystack_expr) = haystack_expr { - let haystack_expr = - from_substrait_rel(ctx, haystack_expr, extensions) - .await?; + let haystack_expr = from_substrait_rel( + ctx, + haystack_expr, + extensions, + &mut Vec::new(), + ) + .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); Ok(Arc::new(Expr::InSubquery(InSubquery { expr: Box::new( diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 4e4fa45a15a6..13a0ef005738 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -491,6 +491,24 @@ async fn roundtrip_outer_join() -> Result<()> { roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a = data2.a").await } +#[tokio::test] +async fn roundtrip_self_join() -> Result<()> { + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we generate unique aliases for each table in order to + // provide unique qualified names for each column. + // This test works because we set aliases to what the Substrait consumer will generate + roundtrip("SELECT data.a as data_a, data_1.a as data_1_a FROM data JOIN data AS data_1 ON data.a = data_1.a").await?; + roundtrip("SELECT data.a as data_a, data_1.a as data_1_a FROM data JOIN data AS data_1 ON data.b = data_1.b").await +} + +#[tokio::test] +async fn roundtrip_self_implicit_cross_join() -> Result<()> { + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we generate unique aliases for each table. + // This test works because we set aliases to what the Substrait consumer will generate + roundtrip("SELECT data.a p1_a, data_1.a p2_a FROM data, data data_1").await +} + #[tokio::test] async fn roundtrip_arithmetic_ops() -> Result<()> { roundtrip("SELECT a - a FROM data").await?; @@ -576,20 +594,23 @@ async fn roundtrip_is_not_unknown() -> Result<()> { #[tokio::test] async fn roundtrip_union() -> Result<()> { - roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data").await + // Set aliases to what the Substrait consumer will generate + roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data data_1").await } #[tokio::test] async fn roundtrip_union2() -> Result<()> { + // Set aliases to what the Substrait consumer will generate roundtrip( - "SELECT a, b FROM data UNION SELECT a, b FROM data UNION SELECT a, b FROM data", + "SELECT a, b FROM data UNION SELECT a, b FROM data data_1 UNION SELECT a, b FROM data data_2", ) .await } #[tokio::test] async fn roundtrip_union_all() -> Result<()> { - roundtrip("SELECT a, e FROM data UNION ALL SELECT a, e FROM data").await + // Set aliases to what the Substrait consumer will generate + roundtrip("SELECT a, e FROM data UNION ALL SELECT a, e FROM data data_1").await } #[tokio::test] @@ -610,7 +631,8 @@ async fn simple_intersect() -> Result<()> { #[tokio::test] async fn simple_intersect_table_reuse() -> Result<()> { - roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await + // Set aliases to what the Substrait consumer will generate + roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data_1.a FROM data data_1);").await } #[tokio::test] @@ -628,32 +650,6 @@ async fn qualified_catalog_schema_table_reference() -> Result<()> { roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await } -#[tokio::test] -async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> { - assert_expected_plan( - "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a", - "Projection: data.b, data.c\ - \n Inner Join: data.a = data.a\ - \n TableScan: data projection=[a, b]\ - \n TableScan: data projection=[a, c]", - false, // "d1" vs "data" field qualifier - ) - .await -} - -#[tokio::test] -async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> { - assert_expected_plan( - "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", - "Projection: data.b, data.c\ - \n Inner Join: data.b = data.b\ - \n TableScan: data projection=[b]\ - \n TableScan: data projection=[b, c]", - false, // "d1" vs "data" field qualifier - ) - .await -} - /// Construct a plan that contains several literals of types that are currently supported. /// This case ignores: /// - Date64, for this literal is not supported @@ -707,21 +703,21 @@ async fn roundtrip_literal_struct() -> Result<()> { #[tokio::test] async fn roundtrip_values() -> Result<()> { // TODO: would be nice to have a struct inside the LargeList, but arrow_cast doesn't support that currently - let values = "(\ + // Using assert instead of roundtrip as the original plan's VALUES gets some extra aliases that get dropped + assert_expected_plan( + "SELECT * FROM (VALUES \ + (\ 1, \ 'a', \ [[-213.1, NULL, 5.5, 2.0, 1.0], []], \ arrow_cast([1,2,3], 'LargeList(Int64)'), \ STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \ [STRUCT(STRUCT('a' AS string_field) AS struct_field)]\ - )"; - - // Test LogicalPlan::Values - assert_expected_plan( - format!("VALUES \ - {values}, \ - (NULL, NULL, NULL, NULL, NULL, NULL)").as_str(), - "Values: \ + ), \ + (NULL, NULL, NULL, NULL, NULL, NULL)\ + ) AS table", // Set aliases to what the Substrait consumer will generate + "SubqueryAlias: table\ + \n Values: \ (\ Int64(1), \ Utf8(\"a\"), \ @@ -731,11 +727,29 @@ async fn roundtrip_values() -> Result<()> { List([{struct_field: {string_field: a}}])\ ), \ (Int64(NULL), Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())", - true) - .await?; + true + ) + .await +} + +#[tokio::test] +async fn roundtrip_values_empty_relation() -> Result<()> { + // Set aliases to what the Substrait consumer will generate + roundtrip("SELECT * FROM (VALUES (1)) AS table LIMIT 0").await +} - // Test LogicalPlan::EmptyRelation - roundtrip(format!("SELECT * FROM (VALUES {values}) LIMIT 0").as_str()).await +#[tokio::test] +async fn roundtrip_values_duplicate_column_join() -> Result<()> { + // Set aliases to what the Substrait consumer will generate + roundtrip( + "SELECT table.column1 as c1, table_1.column1 as c2 \ + FROM \ + (VALUES (1)) AS table \ + JOIN \ + (VALUES (2)) AS table_1 \ + ON table.column1 == table_1.column1", + ) + .await } /// Construct a plan that cast columns. Only those SQL types are supported for now.