Skip to content

Commit

Permalink
ensure no alias collision
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamedAbdeen21 committed Jun 11, 2024
1 parent eef86f9 commit 72e16a4
Showing 1 changed file with 44 additions and 37 deletions.
81 changes: 44 additions & 37 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use datafusion_common::tree_node::{
TreeNodeVisitor,
};
use datafusion_common::{
internal_err, qualified_name, Column, DFSchema, DFSchemaRef, DataFusionError, Result,
qualified_name, Column, DFSchema, DFSchemaRef, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window};
Expand Down Expand Up @@ -166,6 +166,15 @@ impl CommonSubexprEliminate {
) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
let mut common_exprs = IndexMap::new();

input.schema().iter().for_each(|(qualifier, field)| {
let name = field.name();
if name.starts_with('#') {
common_exprs.insert(name.clone(), Expr::from((qualifier, field)));
}
});

let input_cse_len = common_exprs.len();

let rewrite_exprs = self.rewrite_exprs_list(
exprs_list,
arrays_list,
Expand All @@ -176,9 +185,9 @@ impl CommonSubexprEliminate {
let mut new_input = self
.try_optimize(input, config)?
.unwrap_or_else(|| input.clone());
if !common_exprs.is_empty() {
new_input =
build_common_expr_project_plan(new_input, common_exprs, expr_stats)?;

if common_exprs.len() > input_cse_len {
new_input = build_common_expr_project_plan(new_input, common_exprs)?;
}

Ok((rewrite_exprs, new_input))
Expand Down Expand Up @@ -517,18 +526,15 @@ fn to_arrays(
fn build_common_expr_project_plan(
input: LogicalPlan,
common_exprs: CommonExprs,
expr_stats: &ExprStats,
) -> Result<LogicalPlan> {
let mut fields_set = BTreeSet::new();
let mut project_exprs = common_exprs
.into_iter()
.enumerate()
.map(|(index, (expr_id, expr))| {
let Some((_, data_type)) = expr_stats.get(&expr_id) else {
return internal_err!("expr_stats invalid state");
};
.map(|(index, (_, expr))| {
let alias = format!("#{}", index + 1);
let field = Field::new(&alias, data_type.clone(), true);
let (dt, nullable) = expr.data_type_and_nullable(input.schema())?;
let field = Field::new(&alias, dt, nullable);
fields_set.insert(field.name().to_owned());
Ok(expr.alias(alias))
})
Expand Down Expand Up @@ -1225,28 +1231,16 @@ mod test {
#[test]
fn redundant_project_fields() {
let table_scan = test_table_scan().unwrap();
let expr_stats_1 = ExprStats::from([
("c+a".to_string(), (1, DataType::UInt32)),
("b+a".to_string(), (1, DataType::UInt32)),
]);
let common_exprs_1 = CommonExprs::from([
("c+a".to_string(), col("c") + col("a")),
("b+a".to_string(), col("b") + col("a")),
]);
let exprs_stats_2 = ExprStats::from([
("c+a".to_string(), (1, DataType::UInt32)),
("b+a".to_string(), (1, DataType::UInt32)),
]);
let common_exprs_2 = CommonExprs::from([
("c+a".to_string(), col("#1")),
("b+a".to_string(), col("#2")),
]);
let project =
build_common_expr_project_plan(table_scan, common_exprs_1, &expr_stats_1)
.unwrap();
let project_2 =
build_common_expr_project_plan(project, common_exprs_2, &exprs_stats_2)
.unwrap();
let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();

let mut field_set = BTreeSet::new();
for name in project_2.schema().field_names() {
Expand All @@ -1263,10 +1257,6 @@ mod test {
.unwrap()
.build()
.unwrap();
let expr_stats_1 = ExprStats::from([
("test1.c+test1.a".to_string(), (1, DataType::UInt32)),
("test1.b+test1.a".to_string(), (1, DataType::UInt32)),
]);
let common_exprs_1 = CommonExprs::from([
(
"test1.c+test1.a".to_string(),
Expand All @@ -1277,19 +1267,12 @@ mod test {
col("test1.b") + col("test1.a"),
),
]);
let expr_stats_2 = ExprStats::from([
("test1.c+test1.a".to_string(), (1, DataType::UInt32)),
("test1.b+test1.a".to_string(), (1, DataType::UInt32)),
]);
let common_exprs_2 = CommonExprs::from([
("test1.c+test1.a".to_string(), col("#1")),
("test1.b+test1.a".to_string(), col("#2")),
]);
let project =
build_common_expr_project_plan(join, common_exprs_1, &expr_stats_1).unwrap();
let project_2 =
build_common_expr_project_plan(project, common_exprs_2, &expr_stats_2)
.unwrap();
let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();

let mut field_set = BTreeSet::new();
for name in project_2.schema().field_names() {
Expand Down Expand Up @@ -1402,6 +1385,30 @@ mod test {
Ok(())
}

#[test]
fn test_alias_collision() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![(col("a") + col("b")).alias("#1"), col("c")])?
.project(vec![
col("#1").alias("c1"),
col("#1").alias("c2"),
(col("c") + lit(2)).alias("c3"),
(col("c") + lit(2)).alias("c4"),
])?
.build()?;

let expected = "Projection: #1 AS c1, #1 AS c2, #2 AS c3, #2 AS c4\
\n Projection: #1 AS #1, test.c + Int32(2) AS #2, test.c\
\n Projection: test.a + test.b AS #1, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);

Ok(())
}

#[test]
fn test_extract_expressions_from_col() -> Result<()> {
let mut result = Vec::with_capacity(1);
Expand Down

0 comments on commit 72e16a4

Please sign in to comment.