From 9fb7f9852a7c75ae9d244d2a9e8e54290e3274b2 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 11 Jun 2024 22:28:20 +0300 Subject: [PATCH] ensure no alias collision --- .../optimizer/src/common_subexpr_eliminate.rs | 81 ++++++++++--------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ae9ec2c6c8d6e..42d419e452d09 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -28,7 +28,7 @@ use datafusion_common::tree_node::{ TreeNodeVisitor, }; use datafusion_common::{ - internal_err, qualified_name, Column, DFSchema, DFSchemaRef, DataFusionError, Result, + qualified_name, Column, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; @@ -166,6 +166,15 @@ impl CommonSubexprEliminate { ) -> Result<(Vec>, LogicalPlan)> { let mut common_exprs = IndexMap::new(); + input.schema().iter().for_each(|(qualifier, field)| { + let name = field.name(); + if name.starts_with('#') { + common_exprs.insert(name.clone(), Expr::from((qualifier, field))); + } + }); + + let input_cse_len = common_exprs.len(); + let rewrite_exprs = self.rewrite_exprs_list( exprs_list, arrays_list, @@ -176,9 +185,9 @@ impl CommonSubexprEliminate { let mut new_input = self .try_optimize(input, config)? .unwrap_or_else(|| input.clone()); - if !common_exprs.is_empty() { - new_input = - build_common_expr_project_plan(new_input, common_exprs, expr_stats)?; + + if common_exprs.len() > input_cse_len { + new_input = build_common_expr_project_plan(new_input, common_exprs)?; } Ok((rewrite_exprs, new_input)) @@ -517,18 +526,15 @@ fn to_arrays( fn build_common_expr_project_plan( input: LogicalPlan, common_exprs: CommonExprs, - expr_stats: &ExprStats, ) -> Result { let mut fields_set = BTreeSet::new(); let mut project_exprs = common_exprs .into_iter() .enumerate() - .map(|(index, (expr_id, expr))| { - let Some((_, data_type)) = expr_stats.get(&expr_id) else { - return internal_err!("expr_stats invalid state"); - }; + .map(|(index, (_, expr))| { let alias = format!("#{}", index + 1); - let field = Field::new(&alias, data_type.clone(), true); + let (dt, nullable) = expr.data_type_and_nullable(input.schema())?; + let field = Field::new(&alias, dt, nullable); fields_set.insert(field.name().to_owned()); Ok(expr.alias(alias)) }) @@ -1225,28 +1231,16 @@ mod test { #[test] fn redundant_project_fields() { let table_scan = test_table_scan().unwrap(); - let expr_stats_1 = ExprStats::from([ - ("c+a".to_string(), (1, DataType::UInt32)), - ("b+a".to_string(), (1, DataType::UInt32)), - ]); let common_exprs_1 = CommonExprs::from([ ("c+a".to_string(), col("c") + col("a")), ("b+a".to_string(), col("b") + col("a")), ]); - let exprs_stats_2 = ExprStats::from([ - ("c+a".to_string(), (1, DataType::UInt32)), - ("b+a".to_string(), (1, DataType::UInt32)), - ]); let common_exprs_2 = CommonExprs::from([ ("c+a".to_string(), col("#1")), ("b+a".to_string(), col("#2")), ]); - let project = - build_common_expr_project_plan(table_scan, common_exprs_1, &expr_stats_1) - .unwrap(); - let project_2 = - build_common_expr_project_plan(project, common_exprs_2, &exprs_stats_2) - .unwrap(); + let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap(); + let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap(); let mut field_set = BTreeSet::new(); for name in project_2.schema().field_names() { @@ -1263,10 +1257,6 @@ mod test { .unwrap() .build() .unwrap(); - let expr_stats_1 = ExprStats::from([ - ("test1.c+test1.a".to_string(), (1, DataType::UInt32)), - ("test1.b+test1.a".to_string(), (1, DataType::UInt32)), - ]); let common_exprs_1 = CommonExprs::from([ ( "test1.c+test1.a".to_string(), @@ -1277,19 +1267,12 @@ mod test { col("test1.b") + col("test1.a"), ), ]); - let expr_stats_2 = ExprStats::from([ - ("test1.c+test1.a".to_string(), (1, DataType::UInt32)), - ("test1.b+test1.a".to_string(), (1, DataType::UInt32)), - ]); let common_exprs_2 = CommonExprs::from([ ("test1.c+test1.a".to_string(), col("#1")), ("test1.b+test1.a".to_string(), col("#2")), ]); - let project = - build_common_expr_project_plan(join, common_exprs_1, &expr_stats_1).unwrap(); - let project_2 = - build_common_expr_project_plan(project, common_exprs_2, &expr_stats_2) - .unwrap(); + let project = build_common_expr_project_plan(join, common_exprs_1).unwrap(); + let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap(); let mut field_set = BTreeSet::new(); for name in project_2.schema().field_names() { @@ -1402,6 +1385,30 @@ mod test { Ok(()) } + #[test] + fn test_alias_collision() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![(col("a") + col("b")).alias("#1"), col("c")])? + .project(vec![ + col("#1").alias("c1"), + col("#1").alias("c2"), + (col("c") + lit(2)).alias("c3"), + (col("c") + lit(2)).alias("c4"), + ])? + .build()?; + + let expected = "Projection: #1 AS c1, #1 AS c2, #2 AS c3, #2 AS c4\ + \n Projection: #1 AS #1, test.c + Int32(2) AS #2, test.c\ + \n Projection: test.a + test.b AS #1, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, &plan); + + Ok(()) + } + #[test] fn test_extract_expressions_from_col() -> Result<()> { let mut result = Vec::with_capacity(1);