Skip to content

Commit

Permalink
use only current expr symbol + remove visit stack
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamedAbdeen21 committed May 4, 2024
1 parent 5d3bbaa commit eca91a5
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 151 deletions.
98 changes: 26 additions & 72 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ impl ExprSet {
expr.visit(&mut ExprIdentifierVisitor {
expr_set: self,
input_schema,
visit_stack: vec![],
node_count: 0,
expr_mask,
})?;

Expand Down Expand Up @@ -590,6 +588,8 @@ impl ExprMask {
}
}

// TODO: Docs

/// Go through an expression tree and generate identifiers for each subexpression.
///
/// An identifier contains information of the expression itself and its sub-expression.
Expand All @@ -615,86 +615,40 @@ struct ExprIdentifierVisitor<'a> {
/// input schema for the node that we're optimizing, so we can determine the correct datatype
/// for each subexpression
input_schema: DFSchemaRef,
// inner states
visit_stack: Vec<VisitRecord>,
/// increased in fn_down, start from 0.
node_count: usize,
/// which expression should be skipped?
expr_mask: ExprMask,
}

/// Record item that used when traversing a expression tree.
enum VisitRecord {
/// `usize` is the monotone increasing series number assigned in pre_visit().
/// Starts from 0. Is used to index the identifier array `id_array` in post_visit().
EnterMark(usize),
/// the node's children were skipped => jump to f_up on same node
JumpMark(usize),
/// Accumulated identifier of sub expression.
ExprItem(Identifier),
}

impl ExprIdentifierVisitor<'_> {
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
/// before it.
fn pop_enter_mark(&mut self) -> (usize, Identifier) {
let mut desc = String::new();

while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => {
return (idx, desc);
}
VisitRecord::ExprItem(id) => {
desc.push_str(&id);
}
}
}
unreachable!("Enter mark should paired with node number");
}
}

impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
type Node = Expr;

fn f_down(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
// related to https://github.com/apache/datafusion/issues/8814
// If the expr contain volatile expression or is a short-circuit expression, skip it.
if expr.short_circuits() || expr.is_volatile()? {
self.visit_stack
.push(VisitRecord::JumpMark(self.node_count));
return Ok(TreeNodeRecursion::Jump); // go to f_up
}

self.visit_stack
.push(VisitRecord::EnterMark(self.node_count));
self.node_count += 1;

Ok(TreeNodeRecursion::Continue)
}

fn f_up(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
let (_idx, sub_expr_identifier) = self.pop_enter_mark();

// skip exprs should not be recognize.
if self.expr_mask.ignores(expr) {
let curr_expr_identifier = ExprSet::expr_identifier(expr);
self.visit_stack
.push(VisitRecord::ExprItem(curr_expr_identifier));
return Ok(TreeNodeRecursion::Continue);
}
let curr_expr_identifier = ExprSet::expr_identifier(expr);
let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}");

self.visit_stack
.push(VisitRecord::ExprItem(alias_symbol.clone()));
let curr_expr_identifier = ExprSet::expr_identifier(expr);

let data_type = expr.get_type(&self.input_schema)?;

let alias_symbol = format!("#{{{curr_expr_identifier}}}");

self.expr_set
.entry(curr_expr_identifier)
.or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol))
.1 += 1;

Ok(TreeNodeRecursion::Continue)
}
}
Expand Down Expand Up @@ -811,8 +765,8 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\
\n Projection: test.a * (Int32(1) - test.b) AS test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\
let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(#{test.a * (Int32(1) - test.b)} AS test.a * Int32(1) - test.b), SUM(#{test.a * (Int32(1) - test.b)} AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\
\n Projection: test.a * (Int32(1) - test.b) AS #{test.a * (Int32(1) - test.b)}, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand Down Expand Up @@ -864,8 +818,8 @@ mod test {
)?
.build()?;

let expected = "Projection: AVG(test.a)test.a AS AVG(test.a) AS col1, AVG(test.a)test.a AS AVG(test.a) AS col2, col3, AVG(test.c) AS AVG(test.c), my_agg(test.a)test.a AS my_agg(test.a) AS col4, my_agg(test.a)test.a AS my_agg(test.a) AS col5, col6, my_agg(test.c) AS my_agg(test.c)\
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\
let expected = "Projection: #{AVG(test.a)} AS AVG(test.a) AS col1, #{AVG(test.a)} AS AVG(test.a) AS col2, col3, AVG(test.c) AS AVG(test.c), #{my_agg(test.a)} AS my_agg(test.a) AS col4, #{my_agg(test.a)} AS my_agg(test.a) AS col5, col6, my_agg(test.c) AS my_agg(test.c)\
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS #{AVG(test.a)}, my_agg(test.a) AS #{my_agg(test.a)}, AVG(test.b) AS col3, AVG(test.c) AS AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand All @@ -883,8 +837,8 @@ mod test {
)?
.build()?;

let expected = "Projection: Int32(1) + AVG(test.a)test.a AS AVG(test.a), Int32(1) - AVG(test.a)test.a AS AVG(test.a), Int32(1) + my_agg(test.a)test.a AS my_agg(test.a), Int32(1) - my_agg(test.a)test.a AS my_agg(test.a)\
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a]]\
let expected = "Projection: Int32(1) + #{AVG(test.a)} AS AVG(test.a), Int32(1) - #{AVG(test.a)} AS AVG(test.a), Int32(1) + #{my_agg(test.a)} AS my_agg(test.a), Int32(1) - #{my_agg(test.a)} AS my_agg(test.a)\
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS #{AVG(test.a)}, my_agg(test.a) AS #{my_agg(test.a)}]]\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand All @@ -900,8 +854,8 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col1, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col2]]\
\n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
let expected = "Aggregate: groupBy=[[]], aggr=[[AVG(#{UInt32(1) + test.a} AS UInt32(1) + test.a) AS col1, my_agg(#{UInt32(1) + test.a} AS UInt32(1) + test.a) AS col2]]\
\n Projection: UInt32(1) + test.a AS #{UInt32(1) + test.a}, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand All @@ -917,8 +871,8 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a]], aggr=[[AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col1, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col2]]\
\n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
let expected = "Aggregate: groupBy=[[#{UInt32(1) + test.a} AS UInt32(1) + test.a]], aggr=[[AVG(#{UInt32(1) + test.a} AS UInt32(1) + test.a) AS col1, my_agg(#{UInt32(1) + test.a} AS UInt32(1) + test.a) AS col2]]\
\n Projection: UInt32(1) + test.a AS #{UInt32(1) + test.a}, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand All @@ -938,9 +892,9 @@ mod test {
)?
.build()?;

let expected = "Projection: UInt32(1) + test.a, UInt32(1) + AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS AVG(UInt32(1) + test.a) AS col1, UInt32(1) - AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS AVG(UInt32(1) + test.a) AS col2, AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS AVG(UInt32(1) + test.a), UInt32(1) + my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS my_agg(UInt32(1) + test.a) AS col3, UInt32(1) - my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS my_agg(UInt32(1) + test.a) AS col4, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS my_agg(UInt32(1) + test.a)\
\n Aggregate: groupBy=[[UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a]], aggr=[[AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a]]\
\n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
let expected = "Projection: UInt32(1) + test.a, UInt32(1) + #{AVG(#{UInt32(1) + test.a} AS UInt32(1) + test.a)} AS AVG(UInt32(1) + test.a) AS col1, UInt32(1) - #{AVG(#{UInt32(1) + test.a} AS UInt32(1) + test.a)} AS AVG(UInt32(1) + test.a) AS col2, #{AVG(#{UInt32(1) + test.a} AS UInt32(1) + test.a)} AS AVG(UInt32(1) + test.a), UInt32(1) + #{my_agg(#{UInt32(1) + test.a} AS UInt32(1) + test.a)} AS my_agg(UInt32(1) + test.a) AS col3, UInt32(1) - #{my_agg(#{UInt32(1) + test.a} AS UInt32(1) + test.a)} AS my_agg(UInt32(1) + test.a) AS col4, #{my_agg(#{UInt32(1) + test.a} AS UInt32(1) + test.a)} AS my_agg(UInt32(1) + test.a)\
\n Aggregate: groupBy=[[#{UInt32(1) + test.a} AS UInt32(1) + test.a]], aggr=[[AVG(#{UInt32(1) + test.a} AS UInt32(1) + test.a) AS #{AVG(#{UInt32(1) + test.a} AS UInt32(1) + test.a)}, my_agg(#{UInt32(1) + test.a} AS UInt32(1) + test.a) AS #{my_agg(#{UInt32(1) + test.a} AS UInt32(1) + test.a)}]]\
\n Projection: UInt32(1) + test.a AS #{UInt32(1) + test.a}, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand All @@ -965,9 +919,9 @@ mod test {
)?
.build()?;

let expected = "Projection: table.test.col.a, UInt32(1) + AVG(UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a)UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a AS AVG(UInt32(1) + table.test.col.a), AVG(UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a)UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a AS AVG(UInt32(1) + table.test.col.a)\
\n Aggregate: groupBy=[[table.test.col.a]], aggr=[[AVG(UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a) AS AVG(UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a)UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a]]\
\n Projection: UInt32(1) + table.test.col.a AS UInt32(1) + table.test.col.atable.test.col.aUInt32(1), table.test.col.a\
let expected = "Projection: table.test.col.a, UInt32(1) + #{AVG(#{UInt32(1) + table.test.col.a} AS UInt32(1) + table.test.col.a)} AS AVG(UInt32(1) + table.test.col.a), #{AVG(#{UInt32(1) + table.test.col.a} AS UInt32(1) + table.test.col.a)} AS AVG(UInt32(1) + table.test.col.a)\
\n Aggregate: groupBy=[[table.test.col.a]], aggr=[[AVG(#{UInt32(1) + table.test.col.a} AS UInt32(1) + table.test.col.a) AS #{AVG(#{UInt32(1) + table.test.col.a} AS UInt32(1) + table.test.col.a)}]]\
\n Projection: UInt32(1) + table.test.col.a AS #{UInt32(1) + table.test.col.a}, table.test.col.a\
\n TableScan: table.test";

assert_optimized_plan_eq(expected, &plan);
Expand All @@ -986,8 +940,8 @@ mod test {
])?
.build()?;

let expected = "Projection: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a AS first, Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a AS second\
\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\
let expected = "Projection: #{Int32(1) + test.a} AS Int32(1) + test.a AS first, #{Int32(1) + test.a} AS Int32(1) + test.a AS second\
\n Projection: Int32(1) + test.a AS #{Int32(1) + test.a}, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand Down Expand Up @@ -1193,8 +1147,8 @@ mod test {
.build()?;

let expected = "Projection: test.a, test.b, test.c\
\n Filter: Int32(1) + test.atest.aInt32(1) - Int32(10) > Int32(1) + test.atest.aInt32(1)\
\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\
\n Filter: #{Int32(1) + test.a} - Int32(10) > #{Int32(1) + test.a}\
\n Projection: Int32(1) + test.a AS #{Int32(1) + test.a}, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/group_by.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4187,8 +4187,8 @@ EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT CAST(x AS DOUBLE))
logical_plan
01)Projection: SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x)
02)--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1), MAX(alias1)]]
03)----Aggregate: groupBy=[[t1.y, CAST(t1.x AS Float64)t1.x AS t1.x AS alias1]], aggr=[[]]
04)------Projection: CAST(t1.x AS Float64) AS CAST(t1.x AS Float64)t1.x, t1.y
03)----Aggregate: groupBy=[[t1.y, #{CAST(t1.x AS Float64)} AS t1.x AS alias1]], aggr=[[]]
04)------Projection: CAST(t1.x AS Float64) AS #{CAST(t1.x AS Float64)}, t1.y
05)--------TableScan: t1 projection=[x, y]
physical_plan
01)ProjectionExec: expr=[SUM(alias1)@1 as SUM(DISTINCT t1.x), MAX(alias1)@2 as MAX(DISTINCT t1.x)]
Expand All @@ -4200,8 +4200,8 @@ physical_plan
07)------------CoalesceBatchesExec: target_batch_size=2
08)--------------RepartitionExec: partitioning=Hash([y@0, alias1@1], 8), input_partitions=8
09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1
10)------------------AggregateExec: mode=Partial, gby=[y@1 as y, CAST(t1.x AS Float64)t1.x@0 as alias1], aggr=[]
11)--------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y]
10)------------------AggregateExec: mode=Partial, gby=[y@1 as y, #{CAST(t1.x AS Float64)}@0 as alias1], aggr=[]
11)--------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as #{CAST(t1.x AS Float64)}, y@1 as y]
12)----------------------MemoryExec: partitions=1, partition_sizes=[1]

# create an unbounded table that contains ordered timestamp.
Expand Down
16 changes: 8 additions & 8 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1426,12 +1426,12 @@ query TT
EXPLAIN SELECT x/2, x/2+1 FROM t;
----
logical_plan
01)Projection: t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2), t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2) + Int64(1)
02)--Projection: t.x / Int64(2) AS t.x / Int64(2)Int64(2)t.x
01)Projection: #{t.x / Int64(2)} AS t.x / Int64(2), #{t.x / Int64(2)} AS t.x / Int64(2) + Int64(1)
02)--Projection: t.x / Int64(2) AS #{t.x / Int64(2)}
03)----TableScan: t projection=[x]
physical_plan
01)ProjectionExec: expr=[t.x / Int64(2)Int64(2)t.x@0 as t.x / Int64(2), t.x / Int64(2)Int64(2)t.x@0 + 1 as t.x / Int64(2) + Int64(1)]
02)--ProjectionExec: expr=[x@0 / 2 as t.x / Int64(2)Int64(2)t.x]
01)ProjectionExec: expr=[#{t.x / Int64(2)}@0 as t.x / Int64(2), #{t.x / Int64(2)}@0 + 1 as t.x / Int64(2) + Int64(1)]
02)--ProjectionExec: expr=[x@0 / 2 as #{t.x / Int64(2)}]
03)----MemoryExec: partitions=1, partition_sizes=[1]

query II
Expand All @@ -1444,12 +1444,12 @@ query TT
EXPLAIN SELECT abs(x), abs(x) + abs(y) FROM t;
----
logical_plan
01)Projection: abs(t.x)t.x AS abs(t.x), abs(t.x)t.x AS abs(t.x) + abs(t.y)
02)--Projection: abs(t.x) AS abs(t.x)t.x, t.y
01)Projection: #{abs(t.x)} AS abs(t.x), #{abs(t.x)} AS abs(t.x) + abs(t.y)
02)--Projection: abs(t.x) AS #{abs(t.x)}, t.y
03)----TableScan: t projection=[x, y]
physical_plan
01)ProjectionExec: expr=[abs(t.x)t.x@0 as abs(t.x), abs(t.x)t.x@0 + abs(y@1) as abs(t.x) + abs(t.y)]
02)--ProjectionExec: expr=[abs(x@0) as abs(t.x)t.x, y@1 as y]
01)ProjectionExec: expr=[#{abs(t.x)}@0 as abs(t.x), #{abs(t.x)}@0 + abs(y@1) as abs(t.x) + abs(t.y)]
02)--ProjectionExec: expr=[abs(x@0) as #{abs(t.x)}, y@1 as y]
03)----MemoryExec: partitions=1, partition_sizes=[1]

query II
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1072,8 +1072,8 @@ query TT
explain select a/2, a/2 + 1 from t
----
logical_plan
01)Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1)
02)--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a
01)Projection: #{t.a / Int64(2)} AS t.a / Int64(2), #{t.a / Int64(2)} AS t.a / Int64(2) + Int64(1)
02)--Projection: t.a / Int64(2) AS #{t.a / Int64(2)}
03)----TableScan: t projection=[a]

statement ok
Expand All @@ -1083,8 +1083,8 @@ query TT
explain select a/2, a/2 + 1 from t
----
logical_plan
01)Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1)
02)--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a
01)Projection: #{t.a / Int64(2)} AS t.a / Int64(2), #{t.a / Int64(2)} AS t.a / Int64(2) + Int64(1)
02)--Projection: t.a / Int64(2) AS #{t.a / Int64(2)}
03)----TableScan: t projection=[a]

###
Expand Down
Loading

0 comments on commit eca91a5

Please sign in to comment.