From 101c9d37ce0162cdb98640141b0c971e108831c9 Mon Sep 17 00:00:00 2001 From: jaystarshot Date: Tue, 7 Nov 2023 13:33:02 +0530 Subject: [PATCH] Initial commit --- .../presto/sql/planner/PlanOptimizers.java | 2 +- .../optimizations/LogicalCteOptimizer.java | 179 ++++++++++++++++-- 2 files changed, 168 insertions(+), 13 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index ac61a9c46824..d80a002ad136 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -277,7 +277,7 @@ public PlanOptimizers( new PruneLimitColumns(), new PruneTableScanColumns()); - builder.add(new LogicalCteOptimizer(metadata)); + builder.add(new LogicalCteOptimizer(metadata, costComparator, costCalculator, statsCalculator)); IterativeOptimizer inlineProjections = new IterativeOptimizer( metadata, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LogicalCteOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LogicalCteOptimizer.java index f10cfe14a4f2..1b5cba1e8ad5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LogicalCteOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LogicalCteOptimizer.java @@ -14,6 +14,15 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.cost.CachingCostProvider; +import com.facebook.presto.cost.CachingStatsProvider; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.CostProvider; +import com.facebook.presto.cost.PlanCostEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; @@ -25,11 +34,15 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.SequenceNode; import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Ordering; import com.google.common.graph.GraphBuilder; import com.google.common.graph.MutableGraph; import com.google.common.graph.Traverser; +import io.airlift.units.Duration; import java.util.ArrayList; import java.util.Arrays; @@ -37,11 +50,14 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.Stack; import static com.facebook.presto.SystemSessionProperties.isMaterializeAllCtes; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; /* * Transformation of CTE Reference Nodes: @@ -71,9 +87,21 @@ public class LogicalCteOptimizer { private final Metadata metadata; - public LogicalCteOptimizer(Metadata metadata) + private final CostComparator costComparator; + + private final StatsCalculator statsCalculator; + + private final CostCalculator costCalculator; + + public LogicalCteOptimizer(Metadata metadata, + CostComparator costComparator, + CostCalculator costCalculator, + StatsCalculator statsCalculator) { this.metadata = metadata; + this.costComparator = costComparator; + this.costCalculator = costCalculator; + this.statsCalculator = statsCalculator; } @Override @@ -83,26 +111,49 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider || session.getCteInformationCollector().getCTEInformationList().stream().noneMatch(CTEInformation::isMaterialized)) { return PlanOptimizerResult.optimizerResult(plan, false); } - PlanNode rewrittenPlan = new CteEnumerator(idAllocator, variableAllocator).transformPersistentCtes(plan); + PlanNode rewrittenPlan = plan; + Duration timeout = SystemSessionProperties.getOptimizerTimeout(session); + StatsProvider statsProvider = new CachingStatsProvider( + statsCalculator, + Optional.empty(), + Lookup.noLookup(), + session, + TypeProvider.viewOf(variableAllocator.getVariables())); + CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.empty(), session); + if (isMaterializeAllCtes(session)) { + rewrittenPlan = new CteEnumerator(session, idAllocator, variableAllocator, costProvider).persistAllCtes(plan); + } + else { + rewrittenPlan = new CteEnumerator(session, idAllocator, variableAllocator, costProvider).choosePersistentCtes(plan); + } + return PlanOptimizerResult.optimizerResult(rewrittenPlan, !rewrittenPlan.equals(plan)); } public class CteEnumerator { - PlanNodeIdAllocator planNodeIdAllocator; - VariableAllocator variableAllocator; + private final Session session; + private final PlanNodeIdAllocator planNodeIdAllocator; + private final VariableAllocator variableAllocator; + + private final CostProvider costProvider; - public CteEnumerator(PlanNodeIdAllocator planNodeIdAllocator, VariableAllocator variableAllocator) + public CteEnumerator(Session session, + PlanNodeIdAllocator planNodeIdAllocator, + VariableAllocator variableAllocator, + CostProvider costProvider) { + this.session = session; this.planNodeIdAllocator = planNodeIdAllocator; this.variableAllocator = variableAllocator; + this.costProvider = costProvider; } - public PlanNode transformPersistentCtes(PlanNode root) + public PlanNode persistAllCtes(PlanNode root) { checkArgument(root.getSources().size() == 1, "expected newChildren to contain 1 node"); - CteTransformerContext context = new CteTransformerContext(); - PlanNode transformedCte = SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(planNodeIdAllocator, variableAllocator), + CteTransformerContext context = new CteTransformerContext(Optional.empty()); + PlanNode transformedCte = SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(planNodeIdAllocator, variableAllocator, CteConsumerTransformer.Operation.REWRITE), root, context); List topologicalOrderedList = context.getTopologicalOrdering(); if (topologicalOrderedList.isEmpty()) { @@ -112,19 +163,72 @@ public PlanNode transformPersistentCtes(PlanNode root) transformedCte.getSources().get(0)); return root.replaceChildren(Arrays.asList(sequenceNode)); } + + public PlanNode choosePersistentCtes(PlanNode root) + { + // cost based + // ToDo cleanup + CteTransformerContext context = new CteTransformerContext(Optional.empty()); + SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(planNodeIdAllocator, variableAllocator, CteConsumerTransformer.Operation.EXPLORE), + root, context); + List cteProducerList = context.getTopologicalOrdering(); + if (session.getCteInformationCollector().getCTEInformationList().size() > 20) { + // 2^n combinations which will be processed + return root; + } + int numberOfCtes = cteProducerList.size(); + int combinations = 1 << numberOfCtes; // 2^n combinations + + List candidates = new ArrayList<>(); + for (int i = 0; i < combinations; i++) { + // For each combination, decide which CTEs to materialize + List materializedCtes = new ArrayList<>(); + for (int j = 0; j < numberOfCtes; j++) { + if ((i & (1 << j)) != 0) { + materializedCtes.add(cteProducerList.get(j)); + } + } + // Generate the plan for this combination + candidates.add(generatePlanForCombination(root, materializedCtes)); + } + Ordering resultComparator = costComparator.forSession(session).onResultOf(result -> result.cost); + return resultComparator.min(candidates).getPlanNode().orElse(root); + } + + public CteEnumerationResult generatePlanForCombination(PlanNode root, List cteProducers) + { + CteTransformerContext context = new CteTransformerContext(Optional.of(cteProducers.stream() + .map(node -> ((CteProducerNode) node).getCteName()).collect(toImmutableSet()))); + PlanNode transformedCte = + SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(planNodeIdAllocator, variableAllocator, CteConsumerTransformer.Operation.REWRITE), + root, context); + PlanNode sequencePlan = new SequenceNode(root.getSourceLocation(), planNodeIdAllocator.getNextId(), cteProducers, + transformedCte.getSources().get(0)); + return CteEnumerationResult.createCteEnumeration(Optional.of(sequencePlan) + , costProvider.getCost(sequencePlan)); + } } - public class CteConsumerTransformer + public static class CteConsumerTransformer extends SimplePlanRewriter { private final PlanNodeIdAllocator idAllocator; private final VariableAllocator variableAllocator; - public CteConsumerTransformer(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) + enum Operation + { + EXPLORE, + REWRITE, + } + + private final Operation operation; + + public CteConsumerTransformer(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Operation operation) { this.idAllocator = idAllocator; this.variableAllocator = variableAllocator; + this.operation = operation; } @Override @@ -141,6 +245,9 @@ public PlanNode visitCteReference(CteReferenceNode node, RewriteContext node.getMayParticipateInAntiJoin()); }} - public class CteTransformerContext + public static class CteTransformerContext { public Map cteProducerMap; @@ -167,8 +274,11 @@ public class CteTransformerContext MutableGraph graph; public Stack activeCteStack; - public CteTransformerContext() + public final Optional> ctesToMaterialize; + + public CteTransformerContext(Optional> ctesToMaterialize) { + this.ctesToMaterialize = ctesToMaterialize; cteProducerMap = new HashMap<>(); // The cte graph will never have cycles because sql won't allow it graph = GraphBuilder.directed().build(); @@ -180,6 +290,11 @@ public Map getCteProducerMap() return cteProducerMap; } + public boolean shouldMaterialize(String cteName) + { + return ctesToMaterialize.map(strings -> strings.contains(cteName)).orElse(true); + } + public void addProducer(String cteName, CteProducerNode cteProducer) { cteProducerMap.putIfAbsent(cteName, cteProducer); @@ -215,4 +330,44 @@ public List getTopologicalOrdering() return topSortedCteProducerList; } } + + @VisibleForTesting + static class CteEnumerationResult + { + public static final CteEnumerationResult UNKNOWN_COST_RESULT = new CteEnumerationResult(Optional.empty(), PlanCostEstimate.unknown()); + public static final CteEnumerationResult INFINITE_COST_RESULT = new CteEnumerationResult(Optional.empty(), PlanCostEstimate.infinite()); + + private final Optional planNode; + private final PlanCostEstimate cost; + + private CteEnumerationResult(Optional planNode, PlanCostEstimate cost) + { + this.planNode = requireNonNull(planNode, "planNode is null"); + this.cost = requireNonNull(cost, "cost is null"); + checkArgument((cost.hasUnknownComponents() || cost.equals(PlanCostEstimate.infinite())) && !planNode.isPresent() + || (!cost.hasUnknownComponents() || !cost.equals(PlanCostEstimate.infinite())) && planNode.isPresent(), + "planNode should be present if and only if cost is known"); + } + + public Optional getPlanNode() + { + return planNode; + } + + public PlanCostEstimate getCost() + { + return cost; + } + + static CteEnumerationResult createCteEnumeration(Optional planNode, PlanCostEstimate cost) + { + if (cost.hasUnknownComponents()) { + return UNKNOWN_COST_RESULT; + } + if (cost.equals(PlanCostEstimate.infinite())) { + return INFINITE_COST_RESULT; + } + return new CteEnumerationResult(planNode, cost); + } + } }