Skip to content

Commit

Permalink
Let HashProbe keep track of memory consumption when listing join resu…
Browse files Browse the repository at this point in the history
…lts (facebookincubator#10652) (#495)

Summary:
Hash probe currently has limited memory control when extracting results from the hash table. When a small number of large sized rows from the build side is frequently joined with the left side, the total extracted size will explode, making HashProbe using a large amount of memory. And the process of filling output is not in spillable state, and will often cause OOM.
This PR computes the total size when listing join results in hash probe if there are any variable size columns from the build side that is going to be extracted. It stops listing further when it reaches the maximum size. This can help to control hash probe side memory usage to a confined limit.

Pull Request resolved: facebookincubator#10652

Reviewed By: xiaoxmeng

Differential Revision: D60771773

Pulled By: tanjialiang

fbshipit-source-id: 2cb8c58ba795a0aa1df0485b58e4f6d0100be8f8
(cherry picked from commit 82e5492)

Co-authored-by: Jialiang Tan <jacob.jialiang.tan@gmail.com>
  • Loading branch information
zsmj2017 and tanjialiang authored Sep 18, 2024
1 parent 88856e6 commit 9e22c2e
Show file tree
Hide file tree
Showing 11 changed files with 492 additions and 142 deletions.
52 changes: 40 additions & 12 deletions velox/exec/HashProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,30 @@ void HashProbe::maybeSetupSpillInputReader(
spillPartitionSet_.erase(iter);
}

void HashProbe::initializeResultIter() {
VELOX_CHECK_NOT_NULL(table_);
if (resultIter_ != nullptr) {
return;
}
std::vector<vector_size_t> listColumns;
listColumns.reserve(tableOutputProjections_.size());
for (const auto& projection : tableOutputProjections_) {
listColumns.push_back(projection.inputChannel);
}
std::vector<vector_size_t> varSizeListColumns;
uint64_t fixedSizeListColumnsSizeSum{0};
varSizeListColumns.reserve(tableOutputProjections_.size());
for (const auto column : listColumns) {
if (table_->rows()->columnTypes()[column]->isFixedWidth()) {
fixedSizeListColumnsSizeSum += table_->rows()->fixedSizeAt(column);
} else {
varSizeListColumns.push_back(column);
}
}
resultIter_ = std::make_unique<BaseHashTable::JoinResultIterator>(
std::move(varSizeListColumns), fixedSizeListColumnsSizeSum);
}

void HashProbe::asyncWaitForHashTable() {
checkRunning();
VELOX_CHECK_NULL(table_);
Expand All @@ -309,6 +333,8 @@ void HashProbe::asyncWaitForHashTable() {
}

table_ = std::move(hashBuildResult->table);
initializeResultIter();

VELOX_CHECK_NOT_NULL(table_);

maybeSetupSpillInputReader(hashBuildResult->restoredPartitionId);
Expand Down Expand Up @@ -660,7 +686,8 @@ void HashProbe::addInput(RowVectorPtr input) {
lookup_->hits.resize(lookup_->rows.back() + 1);
table_->joinProbe(*lookup_);
}
results_.reset(*lookup_);

resultIter_->reset(*lookup_);
}

void HashProbe::prepareOutput(vector_size_t size) {
Expand Down Expand Up @@ -995,10 +1022,11 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
}
} else {
numOut = table_->listJoinResults(
results_,
*resultIter_,
joinIncludesMissesFromLeft(joinType_),
mapping,
folly::Range(outputTableRows_.data(), outputTableRows_.size()));
folly::Range(outputTableRows_.data(), outputTableRows_.size()),
operatorCtx_->driverCtx()->queryConfig().preferredOutputBatchBytes());
}

// We are done processing the input batch if there are no more joined rows
Expand All @@ -1024,7 +1052,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
// Right semi join only returns the build side output when the probe side
// is fully complete. Do not return anything here.
if (isRightSemiFilterJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
if (results_.atEnd()) {
if (resultIter_->atEnd()) {
input_ = nullptr;
}
return nullptr;
Expand Down Expand Up @@ -1329,7 +1357,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
}

noMatchDetector_.finishIteration(
addMiss, results_.atEnd(), outputTableRows_.size() - numPassed);
addMiss, resultIter_->atEnd(), outputTableRows_.size() - numPassed);
} else if (isLeftSemiFilterJoin(joinType_)) {
auto addLastMatch = [&](auto row) {
outputTableRows_[numPassed] = nullptr;
Expand All @@ -1341,7 +1369,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
rawOutputProbeRowMapping[i], addLastMatch);
}
}
if (results_.atEnd()) {
if (resultIter_->atEnd()) {
leftSemiFilterJoinTracker_.finish(addLastMatch);
}
} else if (isLeftSemiProjectJoin(joinType_)) {
Expand Down Expand Up @@ -1378,7 +1406,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
leftSemiProjectJoinTracker_.advance(probeRow, passed, addLast);
}
leftSemiProjectIsNull_.updateBounds();
if (results_.atEnd()) {
if (resultIter_->atEnd()) {
leftSemiProjectJoinTracker_.finish(addLast);
}
} else {
Expand All @@ -1391,7 +1419,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
leftSemiProjectJoinTracker_.advance(
rawOutputProbeRowMapping[i], filterPassed(i), addLast);
}
if (results_.atEnd()) {
if (resultIter_->atEnd()) {
leftSemiProjectJoinTracker_.finish(addLast);
}
}
Expand All @@ -1416,7 +1444,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
}

noMatchDetector_.finishIteration(
addMiss, results_.atEnd(), outputTableRows_.size() - numPassed);
addMiss, resultIter_->atEnd(), outputTableRows_.size() - numPassed);
} else {
for (auto i = 0; i < numRows; ++i) {
if (filterPassed(i)) {
Expand All @@ -1429,7 +1457,7 @@ int32_t HashProbe::evalFilter(int32_t numRows) {
}

void HashProbe::ensureLoadedIfNotAtEnd(column_index_t channel) {
if (results_.atEnd()) {
if (resultIter_->atEnd()) {
return;
}

Expand Down Expand Up @@ -1683,7 +1711,7 @@ void HashProbe::spillOutput(const std::vector<HashProbe*>& operators) {
}
}

auto syncGuard = folly::makeGuard([&]() {
SCOPE_EXIT {
for (auto& spillTask : spillTasks) {
// We consume the result for the pending tasks. This is a cleanup in the
// guard and must not throw. The first error is already captured before
Expand All @@ -1693,7 +1721,7 @@ void HashProbe::spillOutput(const std::vector<HashProbe*>& operators) {
} catch (const std::exception&) {
}
}
});
};

for (auto& spillTask : spillTasks) {
const auto result = spillTask->move();
Expand Down
19 changes: 11 additions & 8 deletions velox/exec/HashProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,15 @@ class HashProbe : public Operator {
// the hash table.
void asyncWaitForHashTable();

// Sets up 'filter_' and related members.p
// Sets up 'filter_' and related members.
void initializeFilter(
const core::TypedExprPtr& filter,
const RowTypePtr& probeType,
const RowTypePtr& tableType);

// Setup 'resultIter_'.
void initializeResultIter();

// If 'toSpillOutput', the produced output is spilled to disk for memory
// arbitration.
RowVectorPtr getOutputInternal(bool toSpillOutput);
Expand Down Expand Up @@ -611,21 +614,21 @@ class HashProbe : public Operator {

BaseHashTable::RowsIterator lastProbeIterator_;

/// For left and anti join with filter, tracks the probe side rows which had
/// matches on the build side but didn't pass the filter.
// For left and anti join with filter, tracks the probe side rows which had
// matches on the build side but didn't pass the filter.
NoMatchDetector noMatchDetector_;

/// For left semi join filter with extra filter, de-duplicates probe side rows
/// with multiple matches.
// For left semi join filter with extra filter, de-duplicates probe side rows
// with multiple matches.
LeftSemiFilterJoinTracker leftSemiFilterJoinTracker_;

/// For left semi join project with filter, de-duplicates probe side rows with
/// multiple matches.
// For left semi join project with filter, de-duplicates probe side rows with
// multiple matches.
LeftSemiProjectJoinTracker leftSemiProjectJoinTracker_;

// Keeps track of returned results between successive batches of
// output for a batch of input.
BaseHashTable::JoinResultIterator results_;
std::unique_ptr<BaseHashTable::JoinResultIterator> resultIter_;

RowVectorPtr output_;

Expand Down
52 changes: 43 additions & 9 deletions velox/exec/HashTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1728,18 +1728,39 @@ void HashTable<ignoreNullKeys>::prepareJoinTable(
checkHashBitsOverlap(spillInputStartPartitionBit);
}

template <bool ignoreNullKeys>
inline uint64_t HashTable<ignoreNullKeys>::joinProjectedVarColumnsSize(
const std::vector<vector_size_t>& columns,
const char* row) const {
uint64_t totalBytes{0};
for (const auto& column : columns) {
if (!rows_->columnTypes()[column]->isFixedWidth()) {
totalBytes += rows_->variableSizeAt(row, column);
}
}
return totalBytes;
}

template <bool ignoreNullKeys>
int32_t HashTable<ignoreNullKeys>::listJoinResults(
JoinResultIterator& iter,
bool includeMisses,
folly::Range<vector_size_t*> inputRows,
folly::Range<char**> hits) {
folly::Range<char**> hits,
uint64_t maxBytes) {
VELOX_CHECK_LE(inputRows.size(), hits.size());
if (!hasDuplicates_) {
return listJoinResultsNoDuplicates(iter, includeMisses, inputRows, hits);

if (iter.varSizeListColumns.empty() && !hasDuplicates_) {
// When there is no duplicates, and no variable length columns are selected
// to be projected, we are able to calculate fixed length columns total size
// directly and go through fast path.
return listJoinResultsFastPath(
iter, includeMisses, inputRows, hits, maxBytes);
}

size_t numOut = 0;
auto maxOut = inputRows.size();
uint64_t totalBytes{0};
while (iter.lastRowIndex < iter.rows->size()) {
auto row = (*iter.rows)[iter.lastRowIndex];
auto hit = (*iter.hits)[row]; // NOLINT
Expand All @@ -1762,6 +1783,9 @@ int32_t HashTable<ignoreNullKeys>::listJoinResults(
hits[numOut] = hit;
numOut++;
iter.lastRowIndex++;
totalBytes +=
(joinProjectedVarColumnsSize(iter.varSizeListColumns, hit) +
iter.fixedSizeListColumnsSizeSum);
} else {
auto numRows = rows->size();
auto num =
Expand All @@ -1773,36 +1797,46 @@ int32_t HashTable<ignoreNullKeys>::listJoinResults(
num * sizeof(char*));
iter.lastDuplicateRowIndex += num;
numOut += num;
for (const auto* dupRow : *rows) {
totalBytes +=
joinProjectedVarColumnsSize(iter.varSizeListColumns, dupRow);
}
totalBytes += (iter.fixedSizeListColumnsSizeSum * numRows);
if (iter.lastDuplicateRowIndex >= numRows) {
iter.lastDuplicateRowIndex = 0;
iter.lastRowIndex++;
}
}
if (numOut >= maxOut) {
if (numOut >= maxOut || totalBytes >= maxBytes) {
return numOut;
}
}
return numOut;
}

template <bool ignoreNullKeys>
int32_t HashTable<ignoreNullKeys>::listJoinResultsNoDuplicates(
int32_t HashTable<ignoreNullKeys>::listJoinResultsFastPath(
JoinResultIterator& iter,
bool includeMisses,
folly::Range<vector_size_t*> inputRows,
folly::Range<char**> hits) {
folly::Range<char**> hits,
uint64_t maxBytes) {
int32_t numOut = 0;
auto maxOut = inputRows.size();
const auto maxOut = std::min(
static_cast<uint64_t>(inputRows.size()),
(iter.fixedSizeListColumnsSizeSum != 0
? maxBytes / iter.fixedSizeListColumnsSizeSum
: std::numeric_limits<uint64_t>::max()));
int32_t i = iter.lastRowIndex;
auto numRows = iter.rows->size();
const auto numRows = iter.rows->size();

constexpr int32_t kWidth = xsimd::batch<int64_t>::size;
auto sourceHits = reinterpret_cast<int64_t*>(iter.hits->data());
auto sourceRows = iter.rows->data();
// We pass the pointers as int64_t's in 'hitWords'.
auto resultHits = reinterpret_cast<int64_t*>(hits.data());
auto resultRows = inputRows.data();
int32_t outLimit = maxOut - kWidth;
const auto outLimit = maxOut - kWidth;
for (; i + kWidth <= numRows && numOut < outLimit; i += kWidth) {
auto indices = simd::loadGatherIndices<int64_t, int32_t>(sourceRows + i);
auto hitWords = simd::gather(sourceHits, indices);
Expand Down
Loading

0 comments on commit 9e22c2e

Please sign in to comment.