Skip to content

Commit

Permalink
[multistage] fix usage of metadata overrides (apache#11587)
Browse files Browse the repository at this point in the history
  • Loading branch information
walterddr authored Sep 13, 2023
1 parent ae77667 commit 3caf9ab
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -152,14 +153,14 @@ public void shutDown() {
public void processQuery(DistributedStagePlan distributedStagePlan, Map<String, String> requestMetadata) {
long requestId = Long.parseLong(requestMetadata.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID));
long timeoutMs = Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS));
Map<String, String> opChainMetadata = consolidateMetadata(
distributedStagePlan.getStageMetadata().getCustomProperties(), requestMetadata);
long deadlineMs = System.currentTimeMillis() + timeoutMs;

setStageCustomProperties(distributedStagePlan.getStageMetadata().getCustomProperties(), requestMetadata);

// run pre-stage execution for all pipeline breakers
PipelineBreakerResult pipelineBreakerResult =
PipelineBreakerExecutor.executePipelineBreakers(_opChainScheduler, _mailboxService, distributedStagePlan,
requestMetadata, requestId, deadlineMs);
opChainMetadata, requestId, deadlineMs);

// Send error block to all the receivers if pipeline breaker fails
if (pipelineBreakerResult != null && pipelineBreakerResult.getErrorBlock() != null) {
Expand Down Expand Up @@ -188,7 +189,7 @@ public void processQuery(DistributedStagePlan distributedStagePlan, Map<String,
// run OpChain
OpChainExecutionContext executionContext =
new OpChainExecutionContext(_mailboxService, requestId, distributedStagePlan.getStageId(),
distributedStagePlan.getServer(), deadlineMs, requestMetadata, distributedStagePlan.getStageMetadata(),
distributedStagePlan.getServer(), deadlineMs, opChainMetadata, distributedStagePlan.getStageMetadata(),
pipelineBreakerResult);
OpChain opChain;
if (DistributedStagePlan.isLeafStage(distributedStagePlan)) {
Expand All @@ -199,39 +200,47 @@ public void processQuery(DistributedStagePlan distributedStagePlan, Map<String,
_opChainScheduler.register(opChain);
}

private void setStageCustomProperties(Map<String, String> customProperties, Map<String, String> requestMetadata) {
Integer numGroupsLimit = QueryOptionsUtils.getNumGroupsLimit(requestMetadata);
private Map<String, String> consolidateMetadata(Map<String, String> customProperties,
Map<String, String> requestMetadata) {
Map<String, String> opChainMetadata = new HashMap<>();
// 1. put all request level metadata
opChainMetadata.putAll(requestMetadata);
// 2. put all stageMetadata.customProperties.
opChainMetadata.putAll(customProperties);
// 3. add all overrides from config if anything is still empty.
Integer numGroupsLimit = QueryOptionsUtils.getNumGroupsLimit(opChainMetadata);
if (numGroupsLimit == null) {
numGroupsLimit = _numGroupsLimit;
}
if (numGroupsLimit != null) {
customProperties.put(QueryOptionKey.NUM_GROUPS_LIMIT, Integer.toString(numGroupsLimit));
opChainMetadata.put(QueryOptionKey.NUM_GROUPS_LIMIT, Integer.toString(numGroupsLimit));
}

Integer maxInitialResultHolderCapacity = QueryOptionsUtils.getMaxInitialResultHolderCapacity(requestMetadata);
Integer maxInitialResultHolderCapacity = QueryOptionsUtils.getMaxInitialResultHolderCapacity(opChainMetadata);
if (maxInitialResultHolderCapacity == null) {
maxInitialResultHolderCapacity = _maxInitialResultHolderCapacity;
}
if (maxInitialResultHolderCapacity != null) {
customProperties.put(QueryOptionKey.MAX_INITIAL_RESULT_HOLDER_CAPACITY,
opChainMetadata.put(QueryOptionKey.MAX_INITIAL_RESULT_HOLDER_CAPACITY,
Integer.toString(maxInitialResultHolderCapacity));
}

Integer maxRowsInJoin = QueryOptionsUtils.getMaxRowsInJoin(requestMetadata);
Integer maxRowsInJoin = QueryOptionsUtils.getMaxRowsInJoin(opChainMetadata);
if (maxRowsInJoin == null) {
maxRowsInJoin = _maxRowsInJoin;
}
if (maxRowsInJoin != null) {
customProperties.put(QueryOptionKey.MAX_ROWS_IN_JOIN, Integer.toString(maxRowsInJoin));
opChainMetadata.put(QueryOptionKey.MAX_ROWS_IN_JOIN, Integer.toString(maxRowsInJoin));
}

JoinOverFlowMode joinOverflowMode = QueryOptionsUtils.getJoinOverflowMode(requestMetadata);
JoinOverFlowMode joinOverflowMode = QueryOptionsUtils.getJoinOverflowMode(opChainMetadata);
if (joinOverflowMode == null) {
joinOverflowMode = _joinOverflowMode;
}
if (joinOverflowMode != null) {
customProperties.put(QueryOptionKey.JOIN_OVERFLOW_MODE, joinOverflowMode.name());
opChainMetadata.put(QueryOptionKey.JOIN_OVERFLOW_MODE, joinOverflowMode.name());
}
return opChainMetadata;
}

public void cancel(long requestId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.apache.pinot.query.runtime.operator.block.DataBlockValSet;
import org.apache.pinot.query.runtime.operator.block.FilteredDataBlockValSet;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.query.runtime.plan.StageMetadata;
import org.apache.pinot.spi.data.FieldSpec;


Expand Down Expand Up @@ -134,12 +133,10 @@ public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inp
// Initialize the appropriate executor.
if (!groupSet.isEmpty()) {
_isGroupByAggregation = true;
StageMetadata stageMetadata = context.getStageMetadata();
Map<String, String> customProperties =
stageMetadata != null ? stageMetadata.getCustomProperties() : Collections.emptyMap();
Map<String, String> opChainMetadata = context.getOpChainMetadata();
_groupByExecutor =
new MultistageGroupByExecutor(groupByExpr, aggFunctions, filterArgIndexArray, aggType, _colNameToIndexMap,
_resultSchema, customProperties, nodeHint);
_resultSchema, opChainMetadata, nodeHint);
} else {
_isGroupByAggregation = false;
_aggregationExecutor =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -46,7 +45,6 @@
import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
import org.apache.pinot.query.runtime.operator.operands.TransformOperandFactory;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.query.runtime.plan.StageMetadata;
import org.apache.pinot.spi.utils.BooleanUtils;
import org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner.JoinOverFlowMode;

Expand Down Expand Up @@ -145,14 +143,12 @@ public HashJoinOperator(OpChainExecutionContext context, MultiStageOperator left
} else {
_matchedRightRows = null;
}
StageMetadata stageMetadata = context.getStageMetadata();
Map<String, String> customProperties =
stageMetadata != null ? stageMetadata.getCustomProperties() : Collections.emptyMap();
_maxRowsInHashTable = getMaxRowInJoin(customProperties, node.getJoinHints());
_joinOverflowMode = getJoinOverflowMode(customProperties, node.getJoinHints());
Map<String, String> metadata = context.getOpChainMetadata();
_maxRowsInHashTable = getMaxRowInJoin(metadata, node.getJoinHints());
_joinOverflowMode = getJoinOverflowMode(metadata, node.getJoinHints());
}

private int getMaxRowInJoin(Map<String, String> customProperties, @Nullable AbstractPlanNode.NodeHint nodeHint) {
private int getMaxRowInJoin(Map<String, String> opChainMetadata, @Nullable AbstractPlanNode.NodeHint nodeHint) {
if (nodeHint != null) {
Map<String, String> joinOptions = nodeHint._hintOptions.get(PinotHintOptions.JOIN_HINT_OPTIONS);
if (joinOptions != null) {
Expand All @@ -162,11 +158,11 @@ private int getMaxRowInJoin(Map<String, String> customProperties, @Nullable Abst
}
}
}
Integer maxRowsInJoin = QueryOptionsUtils.getMaxRowsInJoin(customProperties);
Integer maxRowsInJoin = QueryOptionsUtils.getMaxRowsInJoin(opChainMetadata);
return maxRowsInJoin != null ? maxRowsInJoin : DEFAULT_MAX_ROWS_IN_JOIN;
}

private JoinOverFlowMode getJoinOverflowMode(Map<String, String> customProperties,
private JoinOverFlowMode getJoinOverflowMode(Map<String, String> contextMetadata,
@Nullable AbstractPlanNode.NodeHint nodeHint) {
if (nodeHint != null) {
Map<String, String> joinOptions = nodeHint._hintOptions.get(PinotHintOptions.JOIN_HINT_OPTIONS);
Expand All @@ -177,7 +173,7 @@ private JoinOverFlowMode getJoinOverflowMode(Map<String, String> customPropertie
}
}
}
JoinOverFlowMode joinOverflowMode = QueryOptionsUtils.getJoinOverflowMode(customProperties);
JoinOverFlowMode joinOverflowMode = QueryOptionsUtils.getJoinOverflowMode(contextMetadata);
return joinOverflowMode != null ? joinOverflowMode : DEFAULT_JOIN_OVERFLOW_MODE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public LeafStageTransferableBlockOperator(OpChainExecutionContext context, List<
_dataSchema = dataSchema;
_queryExecutor = queryExecutor;
_executorService = executorService;
Integer maxStreamingPendingBlocks = QueryOptionsUtils.getMaxStreamingPendingBlocks(context.getRequestMetadata());
Integer maxStreamingPendingBlocks = QueryOptionsUtils.getMaxStreamingPendingBlocks(context.getOpChainMetadata());
_blockingQueue = new ArrayBlockingQueue<>(maxStreamingPendingBlocks != null ? maxStreamingPendingBlocks
: QueryOptionValue.DEFAULT_MAX_STREAMING_PENDING_BLOCKS);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public class MultistageGroupByExecutor {

public MultistageGroupByExecutor(List<ExpressionContext> groupByExpr, AggregationFunction[] aggFunctions,
@Nullable int[] filterArgIndices, AggType aggType, Map<String, Integer> colNameToIndexMap,
DataSchema resultSchema, Map<String, String> customProperties, @Nullable AbstractPlanNode.NodeHint nodeHint) {
DataSchema resultSchema, Map<String, String> opChainMetadata, @Nullable AbstractPlanNode.NodeHint nodeHint) {
_aggType = aggType;
_colNameToIndexMap = colNameToIndexMap;
_groupSet = groupByExpr;
Expand All @@ -85,8 +85,8 @@ public MultistageGroupByExecutor(List<ExpressionContext> groupByExpr, Aggregatio

_groupKeyToIdMap = new HashMap<>();

_numGroupsLimit = getNumGroupsLimit(customProperties, nodeHint);
_maxInitialResultHolderCapacity = getMaxInitialResultHolderCapacity(customProperties, nodeHint);
_numGroupsLimit = getNumGroupsLimit(opChainMetadata, nodeHint);
_maxInitialResultHolderCapacity = getMaxInitialResultHolderCapacity(opChainMetadata, nodeHint);

for (int i = 0; i < _aggFunctions.length; i++) {
_aggregateResultHolders[i] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pinot.query.runtime.plan;

import java.util.Collections;
import java.util.Map;
import org.apache.pinot.query.mailbox.MailboxService;
import org.apache.pinot.query.routing.VirtualServerAddress;
Expand All @@ -38,30 +39,30 @@ public class OpChainExecutionContext {
private final int _stageId;
private final VirtualServerAddress _server;
private final long _deadlineMs;
private final Map<String, String> _requestMetadata;
private final Map<String, String> _opChainMetadata;
private final StageMetadata _stageMetadata;
private final OpChainId _id;
private final OpChainStats _stats;
private final PipelineBreakerResult _pipelineBreakerResult;
private final boolean _traceEnabled;

public OpChainExecutionContext(MailboxService mailboxService, long requestId, int stageId,
VirtualServerAddress server, long deadlineMs, Map<String, String> requestMetadata, StageMetadata stageMetadata,
VirtualServerAddress server, long deadlineMs, Map<String, String> opChainMetadata, StageMetadata stageMetadata,
PipelineBreakerResult pipelineBreakerResult) {
_mailboxService = mailboxService;
_requestId = requestId;
_stageId = stageId;
_server = server;
_deadlineMs = deadlineMs;
_requestMetadata = requestMetadata;
_opChainMetadata = Collections.unmodifiableMap(opChainMetadata);
_stageMetadata = stageMetadata;
_id = new OpChainId(requestId, server.workerId(), stageId);
_stats = new OpChainStats(_id.toString());
_pipelineBreakerResult = pipelineBreakerResult;
if (pipelineBreakerResult != null && pipelineBreakerResult.getOpChainStats() != null) {
_stats.getOperatorStatsMap().putAll(pipelineBreakerResult.getOpChainStats().getOperatorStatsMap());
}
_traceEnabled = Boolean.parseBoolean(requestMetadata.get(CommonConstants.Broker.Request.TRACE));
_traceEnabled = Boolean.parseBoolean(opChainMetadata.get(CommonConstants.Broker.Request.TRACE));
}

public MailboxService getMailboxService() {
Expand All @@ -84,8 +85,8 @@ public long getDeadlineMs() {
return _deadlineMs;
}

public Map<String, String> getRequestMetadata() {
return _requestMetadata;
public Map<String, String> getOpChainMetadata() {
return _opChainMetadata;
}

public StageMetadata getStageMetadata() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pinot.query.runtime.plan;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -34,7 +35,7 @@ public class StageMetadata {

StageMetadata(List<WorkerMetadata> workerMetadataList, Map<String, String> customProperties) {
_workerMetadataList = workerMetadataList;
_customProperties = customProperties;
_customProperties = Collections.unmodifiableMap(customProperties);
}

public List<WorkerMetadata> getWorkerMetadataList() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private PipelineBreakerExecutor() {
* @param scheduler scheduler service to run the pipeline breaker main thread.
* @param mailboxService mailbox service to attach the {@link MailboxReceiveNode} against.
* @param distributedStagePlan the distributed stage plan to run pipeline breaker on.
* @param requestMetadata request metadata, including query options
* @param opChainMetadata request metadata, including query options
* @param requestId request ID
* @param deadlineMs execution deadline
* @return pipeline breaker result;
Expand All @@ -65,7 +65,7 @@ private PipelineBreakerExecutor() {
*/
@Nullable
public static PipelineBreakerResult executePipelineBreakers(OpChainSchedulerService scheduler,
MailboxService mailboxService, DistributedStagePlan distributedStagePlan, Map<String, String> requestMetadata,
MailboxService mailboxService, DistributedStagePlan distributedStagePlan, Map<String, String> opChainMetadata,
long requestId, long deadlineMs) {
PipelineBreakerContext pipelineBreakerContext = new PipelineBreakerContext();
PipelineBreakerVisitor.visitPlanRoot(distributedStagePlan.getStageRoot(), pipelineBreakerContext);
Expand All @@ -76,7 +76,7 @@ public static PipelineBreakerResult executePipelineBreakers(OpChainSchedulerServ
// see also: MailboxIdUtils TODOs, de-couple mailbox id from query information
OpChainExecutionContext opChainExecutionContext =
new OpChainExecutionContext(mailboxService, requestId, distributedStagePlan.getStageId(),
distributedStagePlan.getServer(), deadlineMs, requestMetadata, distributedStagePlan.getStageMetadata(),
distributedStagePlan.getServer(), deadlineMs, opChainMetadata, distributedStagePlan.getStageMetadata(),
null);
return execute(scheduler, pipelineBreakerContext, opChainExecutionContext);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ private static ServerPlanRequestContext build(OpChainExecutionContext executionC
long requestId = (executionContext.getRequestId() << 16) + ((long) stagePlan.getStageId() << 8) + (
tableType == TableType.REALTIME ? 1 : 0);
PinotQuery pinotQuery = new PinotQuery();
Integer leafNodeLimit = QueryOptionsUtils.getMultiStageLeafLimit(executionContext.getRequestMetadata());
Integer leafNodeLimit = QueryOptionsUtils.getMultiStageLeafLimit(executionContext.getOpChainMetadata());
if (leafNodeLimit != null) {
pinotQuery.setLimit(leafNodeLimit);
} else {
Expand Down Expand Up @@ -174,7 +174,7 @@ private static ServerPlanRequestContext build(OpChainExecutionContext executionC
* Helper method to update query options.
*/
private static void updateQueryOptions(PinotQuery pinotQuery, OpChainExecutionContext executionContext) {
Map<String, String> queryOptions = new HashMap<>(executionContext.getRequestMetadata());
Map<String, String> queryOptions = new HashMap<>(executionContext.getOpChainMetadata());
queryOptions.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS,
Long.toString(executionContext.getDeadlineMs() - System.currentTimeMillis()));
pinotQuery.setQueryOptions(queryOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -99,7 +100,7 @@ public void submit(Worker.QueryRequest request, StreamObserver<Worker.QueryRespo
// Deserialize the request
List<DistributedStagePlan> distributedStagePlans;
Map<String, String> requestMetadata;
requestMetadata = request.getMetadataMap();
requestMetadata = Collections.unmodifiableMap(request.getMetadataMap());
long requestId = Long.parseLong(requestMetadata.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID));
long timeoutMs = Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS));
long deadlineMs = System.currentTimeMillis() + timeoutMs;
Expand Down

0 comments on commit 3caf9ab

Please sign in to comment.