Skip to content

Commit

Permalink
use CustomRewriter reimplement, and rewrite upper exprid to prevent e…
Browse files Browse the repository at this point in the history
…xprid becoming duplicated
  • Loading branch information
feiniaofeiafei committed Nov 14, 2024
1 parent 6f13f10 commit 6b4a551
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new EliminateJoinByUnique())
),
topic("eliminate Aggregate according to fd items",
topDown(new EliminateGroupByKeyByUniform()),
custom(RuleType.ELIMINATE_GROUP_BY_KEY_BY_UNIFORM, EliminateGroupByKeyByUniform::new),
topDown(new EliminateGroupByKey()),
topDown(new PushDownAggThroughJoinOnPkFk()),
topDown(new PullUpJoinFromUnionAll())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ public enum RuleType {
REWRITE_HAVING_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_REPEAT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_OLAP_TABLE_SINK_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_SINK_EXPRESSION(RuleTypeClass.REWRITE),
EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),
MERGE_PERCENTILE_TO_ARRAY(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
* expression of plan rewrite rule.
*/
public class ExpressionRewrite implements RewriteRuleFactory {
private final ExpressionRuleExecutor rewriter;
protected final ExpressionRuleExecutor rewriter;

public ExpressionRewrite(ExpressionRewriteRule... rules) {
this.rewriter = new ExpressionRuleExecutor(ImmutableList.copyOf(rules));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,25 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
Expand All @@ -40,22 +44,37 @@
* ->
* +--aggregate(group by b output b,any_value(a) as a,max(c))
* */
public class EliminateGroupByKeyByUniform extends OneRewriteRuleFactory {
public class EliminateGroupByKeyByUniform extends DefaultPlanRewriter<Map<ExprId, ExprId>> implements CustomRewriter {
private ExprIdRewriter exprIdReplacer;

@Override
public Rule build() {
return logicalAggregate().whenNot(agg -> agg.getSourceRepeat().isPresent())
.whenNot(agg -> agg.getGroupByExpressions().isEmpty())
.then(EliminateGroupByKeyByUniform::eliminate)
.toRule(RuleType.ELIMINATE_GROUP_BY_KEY_BY_UNIFORM);
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
Map<ExprId, ExprId> replaceMap = new HashMap<>();
ExprIdRewriter.ReplaceRule replaceRule = new ExprIdRewriter.ReplaceRule(replaceMap);
exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext);
return plan.accept(this, replaceMap);
}

@Override
public Plan visit(Plan plan, Map<ExprId, ExprId> replaceMap) {
plan = visitChildren(this, plan, replaceMap);
plan = exprIdReplacer.rewriteExpr(plan);
return plan;
}

private static Plan eliminate(LogicalAggregate<Plan> agg) {
DataTrait aggChildTrait = agg.child().getLogicalProperties().getTrait();
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Map<ExprId, ExprId> replaceMap) {
aggregate = visitChildren(this, aggregate, replaceMap);
aggregate = (LogicalAggregate<? extends Plan>) exprIdReplacer.rewriteExpr(aggregate);

if (aggregate.getGroupByExpressions().isEmpty() || aggregate.getSourceRepeat().isPresent()) {
return aggregate;
}
DataTrait aggChildTrait = aggregate.child().getLogicalProperties().getTrait();
// Get the Group by column of agg. If there is a uniform one, delete the group by key.
Set<Expression> removedExpression = new LinkedHashSet<>();
List<Expression> newGroupBy = new ArrayList<>();
for (Expression groupBy : agg.getGroupByExpressions()) {
for (Expression groupBy : aggregate.getGroupByExpressions()) {
if (!(groupBy instanceof Slot)) {
newGroupBy.add(groupBy);
continue;
Expand All @@ -67,7 +86,7 @@ private static Plan eliminate(LogicalAggregate<Plan> agg) {
}
}
if (removedExpression.isEmpty()) {
return null;
return aggregate;
}
// when newGroupBy is empty, need retain one expr in group by, otherwise the result may be wrong in empty table
if (newGroupBy.isEmpty()) {
Expand All @@ -76,20 +95,22 @@ private static Plan eliminate(LogicalAggregate<Plan> agg) {
removedExpression.remove(expr);
}
if (removedExpression.isEmpty()) {
return null;
return aggregate;
}
List<NamedExpression> newOutputs = new ArrayList<>();
// If this output appears in the removedExpression column, replace it with any_value
for (NamedExpression output : agg.getOutputExpressions()) {
for (NamedExpression output : aggregate.getOutputExpressions()) {
if (output instanceof Slot) {
if (removedExpression.contains(output)) {
newOutputs.add(new Alias(output.getExprId(), new AnyValue(false, output), output.getName()));
Alias alias = new Alias(new AnyValue(false, output), output.getName());
newOutputs.add(alias);
replaceMap.put(output.getExprId(), alias.getExprId());
} else {
newOutputs.add(output);
}
} else if (output instanceof Alias) {
if (removedExpression.contains(output.child(0))) {
newOutputs.add(new Alias(output.getExprId(),
newOutputs.add(new Alias(
new AnyValue(false, output.child(0)), output.getName()));
} else {
newOutputs.add(output);
Expand All @@ -111,6 +132,6 @@ private static Plan eliminate(LogicalAggregate<Plan> agg) {
}
}
orderOutput.addAll(aggFuncs);
return agg.withGroupByAndOutput(newGroupBy, orderOutput);
return aggregate.withGroupByAndOutput(newGroupBy, orderOutput);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Map;

/**ExprIdReplacer*/
public class ExprIdRewriter extends ExpressionRewrite {
private final List<Rule> rules;
private final JobContext jobContext;

public ExprIdRewriter(ReplaceRule replaceRule, JobContext jobContext) {
super(new ExpressionRuleExecutor(ImmutableList.of(bottomUp(replaceRule))));
rules = buildRules();
this.jobContext = jobContext;
}

@Override
public List<Rule> buildRules() {
ImmutableList.Builder<Rule> builder = ImmutableList.builder();
builder.addAll(super.buildRules());
builder.addAll(ImmutableList.of(
new LogicalResultSinkRewrite().build(),
new LogicalFileSinkRewrite().build(),
new LogicalHiveTableSinkRewrite().build(),
new LogicalIcebergTableSinkRewrite().build(),
new LogicalJdbcTableSinkRewrite().build(),
new LogicalOlapTableSinkRewrite().build(),
new LogicalDeferMaterializeResultSinkRewrite().build()
));
return builder.build();
}

/**rewriteExpr*/
public Plan rewriteExpr(Plan plan) {
for (Rule rule : rules) {
Pattern<Plan> pattern = (Pattern<Plan>) rule.getPattern();
if (pattern.matchPlanTree(plan)) {
List<Plan> newPlans = rule.transform(plan, jobContext.getCascadesContext());
Plan newPlan = newPlans.get(0);
if (!newPlan.deepEquals(plan)) {
return newPlan;
}
}
}
return plan;
}

/**ReplaceRule*/
public static class ReplaceRule implements ExpressionPatternRuleFactory {
private final Map<ExprId, ExprId> replaceMap;

public ReplaceRule(Map<ExprId, ExprId> replaceMap) {
this.replaceMap = replaceMap;
}

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(SlotReference.class).thenApply(ctx -> {
Slot slot = ctx.expr;
if (replaceMap.containsKey(slot.getExprId())) {
ExprId newId = replaceMap.get(slot.getExprId());
while (replaceMap.containsKey(newId)) {
newId = replaceMap.get(slot.getExprId());
}
return slot.withExprId(newId);
}
return slot;
})
);
}
}

private class LogicalResultSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalResultSink().thenApply(ExprIdRewriter.this::applyRewrite)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}

private class LogicalFileSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFileSink().thenApply(ExprIdRewriter.this::applyRewrite)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}

private class LogicalHiveTableSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalHiveTableSink().thenApply(ExprIdRewriter.this::applyRewrite)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}

private class LogicalIcebergTableSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalIcebergTableSink().thenApply(ExprIdRewriter.this::applyRewrite)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}

private class LogicalJdbcTableSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJdbcTableSink().thenApply(ExprIdRewriter.this::applyRewrite)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}

private class LogicalOlapTableSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalOlapTableSink().thenApply(ExprIdRewriter.this::applyRewrite)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}

private class LogicalDeferMaterializeResultSinkRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalDeferMaterializeResultSink().thenApply(ExprIdRewriter.this::applyRewrite)
.toRule(RuleType.REWRITE_SINK_EXPRESSION);
}
}


private LogicalSink<Plan> applyRewrite(MatchingContext<? extends LogicalSink<Plan>> ctx) {
LogicalSink<Plan> sink = ctx.root;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<NamedExpression> outputExprs = sink.getOutputExprs();
List<NamedExpression> newOutputExprs = rewriteAll(outputExprs, rewriter, context);
if (outputExprs.equals(newOutputExprs)) {
return sink;
}
return sink.withOutputExprs(newOutputExprs);
}
}
Loading

0 comments on commit 6b4a551

Please sign in to comment.