Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jaystarshot committed Nov 16, 2023
1 parent 01a7b4c commit 101c9d3
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ public PlanOptimizers(
new PruneLimitColumns(),
new PruneTableScanColumns());

builder.add(new LogicalCteOptimizer(metadata));
builder.add(new LogicalCteOptimizer(metadata, costComparator, costCalculator, statsCalculator));

IterativeOptimizer inlineProjections = new IterativeOptimizer(
metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.CachingCostProvider;
import com.facebook.presto.cost.CachingStatsProvider;
import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
Expand All @@ -25,23 +34,30 @@
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.SequenceNode;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Ordering;
import com.google.common.graph.GraphBuilder;
import com.google.common.graph.MutableGraph;
import com.google.common.graph.Traverser;
import io.airlift.units.Duration;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.Stack;

import static com.facebook.presto.SystemSessionProperties.isMaterializeAllCtes;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;

/*
* Transformation of CTE Reference Nodes:
Expand Down Expand Up @@ -71,9 +87,21 @@ public class LogicalCteOptimizer
{
private final Metadata metadata;

public LogicalCteOptimizer(Metadata metadata)
private final CostComparator costComparator;

private final StatsCalculator statsCalculator;

private final CostCalculator costCalculator;

public LogicalCteOptimizer(Metadata metadata,
CostComparator costComparator,
CostCalculator costCalculator,
StatsCalculator statsCalculator)
{
this.metadata = metadata;
this.costComparator = costComparator;
this.costCalculator = costCalculator;
this.statsCalculator = statsCalculator;
}

@Override
Expand All @@ -83,26 +111,49 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider
|| session.getCteInformationCollector().getCTEInformationList().stream().noneMatch(CTEInformation::isMaterialized)) {
return PlanOptimizerResult.optimizerResult(plan, false);
}
PlanNode rewrittenPlan = new CteEnumerator(idAllocator, variableAllocator).transformPersistentCtes(plan);
PlanNode rewrittenPlan = plan;
Duration timeout = SystemSessionProperties.getOptimizerTimeout(session);
StatsProvider statsProvider = new CachingStatsProvider(
statsCalculator,
Optional.empty(),
Lookup.noLookup(),
session,
TypeProvider.viewOf(variableAllocator.getVariables()));
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.empty(), session);
if (isMaterializeAllCtes(session)) {
rewrittenPlan = new CteEnumerator(session, idAllocator, variableAllocator, costProvider).persistAllCtes(plan);
}
else {
rewrittenPlan = new CteEnumerator(session, idAllocator, variableAllocator, costProvider).choosePersistentCtes(plan);
}

return PlanOptimizerResult.optimizerResult(rewrittenPlan, !rewrittenPlan.equals(plan));
}

public class CteEnumerator
{
PlanNodeIdAllocator planNodeIdAllocator;
VariableAllocator variableAllocator;
private final Session session;
private final PlanNodeIdAllocator planNodeIdAllocator;
private final VariableAllocator variableAllocator;

private final CostProvider costProvider;

public CteEnumerator(PlanNodeIdAllocator planNodeIdAllocator, VariableAllocator variableAllocator)
public CteEnumerator(Session session,
PlanNodeIdAllocator planNodeIdAllocator,
VariableAllocator variableAllocator,
CostProvider costProvider)
{
this.session = session;
this.planNodeIdAllocator = planNodeIdAllocator;
this.variableAllocator = variableAllocator;
this.costProvider = costProvider;
}

public PlanNode transformPersistentCtes(PlanNode root)
public PlanNode persistAllCtes(PlanNode root)
{
checkArgument(root.getSources().size() == 1, "expected newChildren to contain 1 node");
CteTransformerContext context = new CteTransformerContext();
PlanNode transformedCte = SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(planNodeIdAllocator, variableAllocator),
CteTransformerContext context = new CteTransformerContext(Optional.empty());
PlanNode transformedCte = SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(planNodeIdAllocator, variableAllocator, CteConsumerTransformer.Operation.REWRITE),
root, context);
List<PlanNode> topologicalOrderedList = context.getTopologicalOrdering();
if (topologicalOrderedList.isEmpty()) {
Expand All @@ -112,19 +163,72 @@ public PlanNode transformPersistentCtes(PlanNode root)
transformedCte.getSources().get(0));
return root.replaceChildren(Arrays.asList(sequenceNode));
}

public PlanNode choosePersistentCtes(PlanNode root)
{
// cost based
// ToDo cleanup
CteTransformerContext context = new CteTransformerContext(Optional.empty());
SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(planNodeIdAllocator, variableAllocator, CteConsumerTransformer.Operation.EXPLORE),
root, context);
List<PlanNode> cteProducerList = context.getTopologicalOrdering();
if (session.getCteInformationCollector().getCTEInformationList().size() > 20) {
// 2^n combinations which will be processed
return root;
}
int numberOfCtes = cteProducerList.size();
int combinations = 1 << numberOfCtes; // 2^n combinations

List<CteEnumerationResult> candidates = new ArrayList<>();
for (int i = 0; i < combinations; i++) {
// For each combination, decide which CTEs to materialize
List<PlanNode> materializedCtes = new ArrayList<>();
for (int j = 0; j < numberOfCtes; j++) {
if ((i & (1 << j)) != 0) {
materializedCtes.add(cteProducerList.get(j));
}
}
// Generate the plan for this combination
candidates.add(generatePlanForCombination(root, materializedCtes));
}
Ordering<CteEnumerationResult> resultComparator = costComparator.forSession(session).onResultOf(result -> result.cost);
return resultComparator.min(candidates).getPlanNode().orElse(root);
}

public CteEnumerationResult generatePlanForCombination(PlanNode root, List<PlanNode> cteProducers)
{
CteTransformerContext context = new CteTransformerContext(Optional.of(cteProducers.stream()
.map(node -> ((CteProducerNode) node).getCteName()).collect(toImmutableSet())));
PlanNode transformedCte =
SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(planNodeIdAllocator, variableAllocator, CteConsumerTransformer.Operation.REWRITE),
root, context);
PlanNode sequencePlan = new SequenceNode(root.getSourceLocation(), planNodeIdAllocator.getNextId(), cteProducers,
transformedCte.getSources().get(0));
return CteEnumerationResult.createCteEnumeration(Optional.of(sequencePlan)
, costProvider.getCost(sequencePlan));
}
}

public class CteConsumerTransformer
public static class CteConsumerTransformer
extends SimplePlanRewriter<CteTransformerContext>
{
private final PlanNodeIdAllocator idAllocator;

private final VariableAllocator variableAllocator;

public CteConsumerTransformer(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator)
enum Operation
{
EXPLORE,
REWRITE,
}

private final Operation operation;

public CteConsumerTransformer(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Operation operation)
{
this.idAllocator = idAllocator;
this.variableAllocator = variableAllocator;
this.operation = operation;
}

@Override
Expand All @@ -141,6 +245,9 @@ public PlanNode visitCteReference(CteReferenceNode node, RewriteContext<CteTrans
node.getCteName(),
variableAllocator.newVariable("rows", BIGINT), node.getOutputVariables());
context.get().addProducer(node.getCteName(), cteProducerSource);
if (operation.equals(Operation.EXPLORE) || !context.get().shouldMaterialize(node.getCteName())) {
return node;
}
return new CteConsumerNode(node.getSourceLocation(), idAllocator.getNextId(), actualSource.getOutputVariables(), node.getCteName());
}

Expand All @@ -159,16 +266,19 @@ public PlanNode visitApply(ApplyNode node, RewriteContext<CteTransformerContext>
node.getMayParticipateInAntiJoin());
}}

public class CteTransformerContext
public static class CteTransformerContext
{
public Map<String, CteProducerNode> cteProducerMap;

// a -> b indicates that b needs to be processed before a
MutableGraph<String> graph;
public Stack<String> activeCteStack;

public CteTransformerContext()
public final Optional<Set<String>> ctesToMaterialize;

public CteTransformerContext(Optional<Set<String>> ctesToMaterialize)
{
this.ctesToMaterialize = ctesToMaterialize;
cteProducerMap = new HashMap<>();
// The cte graph will never have cycles because sql won't allow it
graph = GraphBuilder.directed().build();
Expand All @@ -180,6 +290,11 @@ public Map<String, CteProducerNode> getCteProducerMap()
return cteProducerMap;
}

public boolean shouldMaterialize(String cteName)
{
return ctesToMaterialize.map(strings -> strings.contains(cteName)).orElse(true);
}

public void addProducer(String cteName, CteProducerNode cteProducer)
{
cteProducerMap.putIfAbsent(cteName, cteProducer);
Expand Down Expand Up @@ -215,4 +330,44 @@ public List<PlanNode> getTopologicalOrdering()
return topSortedCteProducerList;
}
}

@VisibleForTesting
static class CteEnumerationResult
{
public static final CteEnumerationResult UNKNOWN_COST_RESULT = new CteEnumerationResult(Optional.empty(), PlanCostEstimate.unknown());
public static final CteEnumerationResult INFINITE_COST_RESULT = new CteEnumerationResult(Optional.empty(), PlanCostEstimate.infinite());

private final Optional<PlanNode> planNode;
private final PlanCostEstimate cost;

private CteEnumerationResult(Optional<PlanNode> planNode, PlanCostEstimate cost)
{
this.planNode = requireNonNull(planNode, "planNode is null");
this.cost = requireNonNull(cost, "cost is null");
checkArgument((cost.hasUnknownComponents() || cost.equals(PlanCostEstimate.infinite())) && !planNode.isPresent()
|| (!cost.hasUnknownComponents() || !cost.equals(PlanCostEstimate.infinite())) && planNode.isPresent(),
"planNode should be present if and only if cost is known");
}

public Optional<PlanNode> getPlanNode()
{
return planNode;
}

public PlanCostEstimate getCost()
{
return cost;
}

static CteEnumerationResult createCteEnumeration(Optional<PlanNode> planNode, PlanCostEstimate cost)
{
if (cost.hasUnknownComponents()) {
return UNKNOWN_COST_RESULT;
}
if (cost.equals(PlanCostEstimate.infinite())) {
return INFINITE_COST_RESULT;
}
return new CteEnumerationResult(planNode, cost);
}
}
}

0 comments on commit 101c9d3

Please sign in to comment.