Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
duanmeng committed Jan 11, 2024
1 parent 8f2351f commit f581f34
Showing 1 changed file with 63 additions and 48 deletions.
111 changes: 63 additions & 48 deletions velox/exec/fuzzer/AggregationFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@

#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/exec/tests/utils/TempDirectoryPath.h"

#include "velox/exec/PartitionFunction.h"
#include "velox/exec/fuzzer/AggregationFuzzerBase.h"
#include "velox/expression/tests/utils/FuzzerToolkit.h"

#include "velox/functions/prestosql/types/HyperLogLogType.h"

#include "velox/vector/VectorSaver.h"
#include "velox/vector/fuzzer/VectorFuzzer.h"

Expand Down Expand Up @@ -111,13 +113,14 @@ class AggregationFuzzer : public AggregationFuzzerBase {
void updateReferenceQueryStats(
AggregationFuzzerBase::ReferenceQueryErrorCode errorCode);

bool checkResult(
bool compareEquivalentPlanResults(
const core::PlanNodePtr& plan,
bool customVerification,
const std::vector<RowVectorPtr>& input,
const std::vector<PlanWithSplits>& plans,
const std::vector<std::shared_ptr<ResultVerifier>>& customVerifiers,
int32_t maxDrivers = 2);
int32_t maxDrivers = 2,
bool distinct = false);

// Return 'true' if query plans failed.
bool verifyWindow(
Expand Down Expand Up @@ -324,15 +327,14 @@ bool canSortInputs(const CallableSignature& signature) {

// Returns true if specified aggregate function can be applied to distinct
// inputs.
bool canDistinctInputs(const CallableSignature& signature) {
if (signature.args.empty()) {
bool supportsDistinctInputs(const CallableSignature& signature) {
if (signature.args.size() != 1) {
return false;
}

for (const auto& arg : signature.args) {
if (!arg->isComparable()) {
return false;
}
const auto& arg = signature.args.at(0);
if (!arg->isComparable() || arg == HYPERLOGLOG()) {
return false;
}

return true;
Expand All @@ -349,6 +351,18 @@ void AggregationFuzzer::go() {
while (!isDone(iteration, startTime)) {
LOG(INFO) << "==============================> Started iteration "
<< iteration << " (seed: " << currentSeed_ << ")";
// 10% of times test distinct aggregation.
if (vectorFuzzer_.coinToss(0.1)) {
++stats_.numDistinct;

std::vector<TypePtr> types;
std::vector<std::string> names;

auto groupingKeys = generateKeys("g", names, types);
auto input = generateInputData(names, types, std::nullopt);

verifyAggregation(groupingKeys, {}, {}, input, false, {});
} else {
// Pick a random signature.
auto signatureWithStats = pickSignature();
signatureWithStats.second.numRuns++;
Expand Down Expand Up @@ -391,14 +405,8 @@ void AggregationFuzzer::go() {
(signature.name.find("approx_") == std::string::npos) &&
vectorFuzzer_.coinToss(0.2);

// Exclude approx_xxx, merge aggregations since it conflicts with
// distinct semantics.
// TODO: Add a exclude list for aggregations that cannot be supported by
// DistinctAggregations.
const bool distinct = !sortedInputs && canDistinctInputs(signature) &&
(signature.name.find("approx_") == std::string::npos) &&
(signature.name.find("merge") == std::string::npos) &&
vectorFuzzer_.coinToss(0.2);
const bool distinct = !sortedInputs &&
supportsDistinctInputs(signature) && vectorFuzzer_.coinToss(0.2);

auto call =
makeFunctionCall(signature.name, argNames, sortedInputs, distinct);
Expand Down Expand Up @@ -460,7 +468,7 @@ void AggregationFuzzer::go() {
}
}
}

}
LOG(INFO) << "==============================> Done with iteration "
<< iteration;

Expand Down Expand Up @@ -720,7 +728,6 @@ bool AggregationFuzzer::verifyWindow(
}
}

namespace {
void resetCustomVerifiers(
const std::vector<std::shared_ptr<ResultVerifier>>& customVerifiers) {
for (auto& verifier : customVerifiers) {
Expand All @@ -730,33 +737,14 @@ void resetCustomVerifiers(
}
}

void initVerifier(
const std::vector<std::shared_ptr<ResultVerifier>>& customVerifiers,
const std::vector<RowVectorPtr>& input,
const std::vector<std::string>& groupingKeys,
const std::shared_ptr<const core::AggregationNode>& aggregationNode) {
for (auto i = 0; i < customVerifiers.size(); ++i) {
auto& verifier = customVerifiers[i];
if (verifier == nullptr) {
continue;
}

verifier->initialize(
input,
groupingKeys,
aggregationNode->aggregates()[i],
aggregationNode->aggregateNames()[i]);
}
}
} // namespace

bool AggregationFuzzer::checkResult(
bool AggregationFuzzer::compareEquivalentPlanResults(
const core::PlanNodePtr& plan,
bool customVerification,
const std::vector<RowVectorPtr>& input,
const std::vector<PlanWithSplits>& plans,
const std::vector<std::shared_ptr<ResultVerifier>>& customVerifiers,
int32_t maxDrivers) {
int32_t maxDrivers,
bool distinct) {
try {
auto resultOrError = execute(plan);
if (resultOrError.exceptionPtr) {
Expand Down Expand Up @@ -797,7 +785,12 @@ bool AggregationFuzzer::checkResult(
}

testPlans(
plans, customVerification, customVerifiers, resultOrError, maxDrivers);
plans,
customVerification,
customVerifiers,
resultOrError,
maxDrivers,
distinct);

return resultOrError.exceptionPtr != nullptr;
} catch (...) {
Expand All @@ -823,7 +816,18 @@ bool AggregationFuzzer::verifyAggregation(
if (customVerification) {
const auto& aggregationNode =
std::dynamic_pointer_cast<const core::AggregationNode>(firstPlan);
initVerifier(customVerifiers, input, groupingKeys, aggregationNode);
for (auto i = 0; i < customVerifiers.size(); ++i) {
auto& verifier = customVerifiers[i];
if (verifier == nullptr) {
continue;
}

verifier->initialize(
input,
groupingKeys,
aggregationNode->aggregates()[i],
aggregationNode->aggregateNames()[i]);
}
}

SCOPE_EXIT {
Expand Down Expand Up @@ -896,7 +900,7 @@ bool AggregationFuzzer::verifyAggregation(
persistReproInfo(plans, reproPersistPath_);
}

return checkResult(
return compareEquivalentPlanResults(
firstPlan, customVerification, input, plans, customVerifiers);
}

Expand Down Expand Up @@ -976,7 +980,7 @@ bool AggregationFuzzer::verifySortedAggregation(

// Set customVerification to false to trigger direct result comparison.
// TODO Figure out how to enable custom verify(), but not compare().
testPlans(plans, false, {}, resultOrError, 1, true);
testPlans(plans, false, {}, resultOrError, 1);

return resultOrError.exceptionPtr != nullptr;
}
Expand All @@ -996,7 +1000,18 @@ bool AggregationFuzzer::verifyDistinctAggregation(
if (customVerification) {
const auto& aggregationNode =
std::dynamic_pointer_cast<const core::AggregationNode>(firstPlan);
initVerifier(customVerifiers, input, groupingKeys, aggregationNode);
for (auto i = 0; i < customVerifiers.size(); ++i) {
auto& verifier = customVerifiers[i];
if (verifier == nullptr) {
continue;
}

verifier->initialize(
input,
groupingKeys,
aggregationNode->aggregates()[i],
aggregationNode->aggregateNames()[i]);
}
}

SCOPE_EXIT {
Expand Down Expand Up @@ -1033,8 +1048,8 @@ bool AggregationFuzzer::verifyDistinctAggregation(

// Distinct aggregations cannot be split into partial and final, hence we
// could ony use single-thread to test the plan.
return checkResult(
firstPlan, customVerification, input, plans, customVerifiers, 1);
return compareEquivalentPlanResults(
firstPlan, customVerification, input, plans, customVerifiers, 1, true);
}

// verifyAggregation(std::vector<core::PlanNodePtr> plans) is tied to plan
Expand Down

0 comments on commit f581f34

Please sign in to comment.