Skip to content

Commit

Permalink
Fix double memory accounting for spilling aggregations
Browse files Browse the repository at this point in the history
InMemoryHashAggregationBuilder was doing it's own memory
accounting for the purpose of GroupByHash yielding.
When SpillableHashAggregationBuilder is used it also
accounts InMemoryHashAggregationBuilder as a revocable memory.
Therefore memory was accounted twice.

This commit removes memory accoutning from
InMemoryHashAggregationBuilder.
  • Loading branch information
sopel39 committed Mar 12, 2019
1 parent 5ceaa08 commit 627d631
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;
import io.prestosql.memory.context.LocalMemoryContext;
import io.prestosql.operator.aggregation.Accumulator;
import io.prestosql.operator.aggregation.AccumulatorFactory;
import io.prestosql.operator.aggregation.builder.HashAggregationBuilder;
Expand Down Expand Up @@ -272,6 +273,7 @@ public OperatorFactory duplicate()
private final HashCollisionsCounter hashCollisionsCounter;

private HashAggregationBuilder aggregationBuilder;
private LocalMemoryContext memoryContext;
private WorkProcessor<Page> outputPages;
private boolean inputProcessed;
private boolean finishing;
Expand Down Expand Up @@ -323,6 +325,11 @@ public HashAggregationOperator(
this.hashCollisionsCounter = new HashCollisionsCounter(operatorContext);
operatorContext.setInfoSupplier(hashCollisionsCounter);
this.useSystemMemory = useSystemMemory;

this.memoryContext = operatorContext.localUserMemoryContext();
if (useSystemMemory) {
this.memoryContext = operatorContext.localSystemMemoryContext();
}
}

@Override
Expand Down Expand Up @@ -378,8 +385,14 @@ public void addInput(Page page)
operatorContext,
maxPartialMemory,
joinCompiler,
true,
useSystemMemory);
() -> {
memoryContext.setBytes(((InMemoryHashAggregationBuilder) aggregationBuilder).getSizeInMemory());
if (step.isOutputPartial() && maxPartialMemory.isPresent()) {
// do not yield on memory for partial aggregations
return true;
}
return operatorContext.isWaitingForMemory().isDone();
});
}
else {
verify(!useSystemMemory, "using system memory in spillable aggregations is not supported");
Expand Down Expand Up @@ -511,8 +524,7 @@ private void closeAggregationBuilder()
// The reference must be set to null afterwards to avoid unaccounted memory.
aggregationBuilder = null;
}
operatorContext.localUserMemoryContext().setBytes(0);
operatorContext.localRevocableMemoryContext().setBytes(0);
memoryContext.setBytes(0);
}

private Page getGlobalAggregationOutput()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;
import io.prestosql.array.IntBigArray;
import io.prestosql.memory.context.LocalMemoryContext;
import io.prestosql.operator.GroupByHash;
import io.prestosql.operator.GroupByIdBlock;
import io.prestosql.operator.HashCollisionsCounter;
Expand Down Expand Up @@ -58,12 +57,9 @@ public class InMemoryHashAggregationBuilder
{
private final GroupByHash groupByHash;
private final List<Aggregator> aggregators;
private final OperatorContext operatorContext;
private final boolean partial;
private final OptionalLong maxPartialMemory;
private final LocalMemoryContext systemMemoryContext;
private final LocalMemoryContext localUserMemoryContext;
private final boolean useSystemMemory;
private final UpdateMemory updateMemory;

private boolean full;

Expand All @@ -77,8 +73,7 @@ public InMemoryHashAggregationBuilder(
OperatorContext operatorContext,
Optional<DataSize> maxPartialMemory,
JoinCompiler joinCompiler,
boolean yieldForMemoryReservation,
boolean useSystemMemory)
UpdateMemory updateMemory)
{
this(accumulatorFactories,
step,
Expand All @@ -90,8 +85,7 @@ public InMemoryHashAggregationBuilder(
maxPartialMemory,
Optional.empty(),
joinCompiler,
yieldForMemoryReservation,
useSystemMemory);
updateMemory);
}

public InMemoryHashAggregationBuilder(
Expand All @@ -105,22 +99,8 @@ public InMemoryHashAggregationBuilder(
Optional<DataSize> maxPartialMemory,
Optional<Integer> overwriteIntermediateChannelOffset,
JoinCompiler joinCompiler,
boolean yieldForMemoryReservation,
boolean useSystemMemory)
UpdateMemory updateMemory)
{
UpdateMemory updateMemory;
if (yieldForMemoryReservation) {
updateMemory = this::updateMemoryWithYieldInfo;
}
else {
// Report memory usage but do not yield for memory.
// This is specially used for spillable hash aggregation operator.
// TODO: revisit this when spillable hash aggregation operator is turned on
updateMemory = () -> {
updateMemoryWithYieldInfo();
return true;
};
}
this.groupByHash = createGroupByHash(
groupByTypes,
Ints.toArray(groupByChannels),
Expand All @@ -129,12 +109,9 @@ public InMemoryHashAggregationBuilder(
isDictionaryAggregationEnabled(operatorContext.getSession()),
joinCompiler,
updateMemory);
this.operatorContext = operatorContext;
this.partial = step.isOutputPartial();
this.maxPartialMemory = maxPartialMemory.map(dataSize -> OptionalLong.of(dataSize.toBytes())).orElseGet(OptionalLong::empty);
this.systemMemoryContext = operatorContext.newLocalSystemMemoryContext(InMemoryHashAggregationBuilder.class.getSimpleName());
this.localUserMemoryContext = operatorContext.localUserMemoryContext();
this.useSystemMemory = useSystemMemory;
this.updateMemory = requireNonNull(updateMemory, "updateMemory is null");

// wrapper each function with an aggregator
ImmutableList.Builder<Aggregator> builder = ImmutableList.builder();
Expand All @@ -151,10 +128,7 @@ public InMemoryHashAggregationBuilder(
}

@Override
public void close()
{
updateMemory(0);
}
public void close() {}

@Override
public Work<?> processPage(Page page)
Expand All @@ -178,7 +152,7 @@ public Work<?> processPage(Page page)
@Override
public void updateMemory()
{
updateMemoryWithYieldInfo();
updateMemory.update();
}

@Override
Expand Down Expand Up @@ -221,9 +195,20 @@ public long getSizeInMemory()
for (Aggregator aggregator : aggregators) {
sizeInMemory += aggregator.getEstimatedSize();
}

updateIsFull(sizeInMemory);
return sizeInMemory;
}

private void updateIsFull(long sizeInMemory)
{
if (!partial || !maxPartialMemory.isPresent()) {
return;
}

full = sizeInMemory > maxPartialMemory.getAsLong();
}

/**
* building hash sorted results requires memory for sorting group IDs.
* This method returns size of that memory requirement.
Expand Down Expand Up @@ -316,39 +301,6 @@ public List<Type> buildTypes()
return types;
}

/**
* Update memory usage with extra memory needed.
*
* @return true to if the reservation is within the limit
*/
// TODO: update in the interface after the new memory tracking framework is landed
// Essentially we would love to have clean interfaces to support both pushing and pulling memory usage
// The following implementation is a hybrid model, where the push model is going to call the pull model causing reentrancy
private boolean updateMemoryWithYieldInfo()
{
long memorySize = getSizeInMemory();
if (partial && maxPartialMemory.isPresent()) {
updateMemory(memorySize);
full = (memorySize > maxPartialMemory.getAsLong());
return true;
}
// Operator/driver will be blocked on memory after we call setBytes.
// If memory is not available, once we return, this operator will be blocked until memory is available.
updateMemory(memorySize);
// If memory is not available, inform the caller that we cannot proceed for allocation.
return operatorContext.isWaitingForMemory().isDone();
}

private void updateMemory(long memorySize)
{
if (useSystemMemory) {
systemMemoryContext.setBytes(memorySize);
}
else {
localUserMemoryContext.setBytes(memorySize);
}
}

private IntIterator consecutiveGroupIds()
{
return IntIterators.fromTo(0, groupByHash.getGroupCount());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import io.prestosql.memory.context.AggregatedMemoryContext;
import io.prestosql.memory.context.LocalMemoryContext;
import io.prestosql.operator.OperatorContext;
import io.prestosql.operator.WorkProcessor;
Expand Down Expand Up @@ -44,7 +45,7 @@ public class MergingHashAggregationBuilder
private final WorkProcessor<Page> sortedPages;
private InMemoryHashAggregationBuilder hashAggregationBuilder;
private final List<Type> groupByTypes;
private final LocalMemoryContext systemMemoryContext;
private final LocalMemoryContext memoryContext;
private final long memoryLimitForMerge;
private final int overwriteIntermediateChannelOffset;
private final JoinCompiler joinCompiler;
Expand All @@ -57,7 +58,7 @@ public MergingHashAggregationBuilder(
Optional<Integer> hashChannel,
OperatorContext operatorContext,
WorkProcessor<Page> sortedPages,
LocalMemoryContext systemMemoryContext,
AggregatedMemoryContext aggregatedMemoryContext,
long memoryLimitForMerge,
int overwriteIntermediateChannelOffset,
JoinCompiler joinCompiler)
Expand All @@ -75,7 +76,7 @@ public MergingHashAggregationBuilder(
this.operatorContext = operatorContext;
this.sortedPages = sortedPages;
this.groupByTypes = groupByTypes;
this.systemMemoryContext = systemMemoryContext;
this.memoryContext = aggregatedMemoryContext.newLocalMemoryContext(MergingHashAggregationBuilder.class.getSimpleName());
this.memoryLimitForMerge = memoryLimitForMerge;
this.overwriteIntermediateChannelOffset = overwriteIntermediateChannelOffset;
this.joinCompiler = joinCompiler;
Expand Down Expand Up @@ -110,7 +111,7 @@ public TransformationState<WorkProcessor<Page>> process(Optional<Page> inputPage
// TODO: this class does not yield wrt memory limit; enable it
verify(done);
memorySize = hashAggregationBuilder.getSizeInMemory();
systemMemoryContext.setBytes(memorySize);
memoryContext.setBytes(memorySize);

if (!shouldProduceOutput(memorySize)) {
return TransformationState.needsMoreData();
Expand Down Expand Up @@ -149,7 +150,7 @@ private void rebuildHashAggregationBuilder()
Optional.of(DataSize.succinctBytes(0)),
Optional.of(overwriteIntermediateChannelOffset),
joinCompiler,
false,
false);
// TODO: merging should also yield on memory reservations
() -> true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ public SpillableHashAggregationBuilder(
public Work<?> processPage(Page page)
{
checkState(hasPreviousSpillCompletedSuccessfully(), "Previous spill hasn't yet finished");
// hashAggregationBuilder is constructed with yieldForMemoryReservation = false
// Therefore the processing of the returned Work should always be true
// hashAggregationBuilder is constructed with non yielding UpdateMemory instance.
// Therefore the processing of the returned Work should always be true.
// It is not possible to spill during processing of a page.
return hashAggregationBuilder.processPage(page);
}

Expand Down Expand Up @@ -305,7 +306,7 @@ private WorkProcessor<Page> mergeSortedPages(WorkProcessor<Page> sortedPages, lo
hashChannel,
operatorContext,
sortedPages,
operatorContext.newLocalSystemMemoryContext(SpillableHashAggregationBuilder.class.getSimpleName()),
operatorContext.aggregateSystemMemoryContext(),
memoryLimitForMerge,
hashAggregationBuilder.getKeyChannels(),
joinCompiler));
Expand All @@ -331,8 +332,11 @@ private void rebuildHashAggregationBuilder()
operatorContext,
Optional.of(DataSize.succinctBytes(0)),
joinCompiler,
false,
false);
() -> {
updateMemory();
// TODO: Support GroupByHash yielding in spillable hash aggregation (https://github.com/prestosql/presto/issues/460)
return true;
});
emptyHashAggregationBuilderSize = hashAggregationBuilder.getSizeInMemory();
}
}

0 comments on commit 627d631

Please sign in to comment.