Skip to content

Commit

Permalink
Modify FunctionRegistry APIs to retrieve function metadata for covera…
Browse files Browse the repository at this point in the history
…ge map
  • Loading branch information
pramodsatya committed Aug 27, 2024
1 parent d33cdb2 commit 2b18863
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 74 deletions.
19 changes: 19 additions & 0 deletions velox/exec/Aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "velox/exec/AggregateCompanionSignatures.h"
#include "velox/exec/AggregateWindow.h"
#include "velox/expression/SignatureBinder.h"
#include "velox/functions/FunctionRegistry.h"

namespace facebook::velox::exec {

Expand All @@ -39,6 +40,24 @@ AggregateFunctionMap& aggregateFunctions() {
return functions;
}

const AggregateFunctionMapBase aggregateFunctions(
bool includeCompanionFunctions) {
AggregateFunctionMapBase aggregateFunctions;
exec::aggregateFunctions().withRLock([&](const auto& functions) {
for (const auto& entry : functions) {
auto isCompanionFunction =
isCompanionFunctionName(entry.first, functions);
if (includeCompanionFunctions && isCompanionFunction) {
aggregateFunctions.insert(entry);
} else if (!isCompanionFunction) {
aggregateFunctions.insert(entry);
}
}
});

return aggregateFunctions;
}

const AggregateFunctionEntry* FOLLY_NULLABLE
getAggregateFunctionEntry(const std::string& name) {
auto sanitizedName = sanitizeName(name);
Expand Down
8 changes: 6 additions & 2 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,15 @@ struct AggregateFunctionEntry {
AggregateFunctionMetadata metadata;
};

using AggregateFunctionMap = folly::Synchronized<
std::unordered_map<std::string, AggregateFunctionEntry>>;
using AggregateFunctionMapBase =
std::unordered_map<std::string, AggregateFunctionEntry>;
using AggregateFunctionMap = folly::Synchronized<AggregateFunctionMapBase>;

AggregateFunctionMap& aggregateFunctions();

const AggregateFunctionMapBase aggregateFunctions(
bool includeCompanionFunctions);

const AggregateFunctionEntry* FOLLY_NULLABLE
getAggregateFunctionEntry(const std::string& name);

Expand Down
21 changes: 20 additions & 1 deletion velox/exec/WindowFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
*/

#include "velox/exec/WindowFunction.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/exec/Aggregate.h"
#include "velox/expression/SignatureBinder.h"
#include "velox/functions/FunctionRegistry.h"

namespace facebook::velox::exec {

Expand All @@ -25,6 +26,24 @@ WindowFunctionMap& windowFunctions() {
return functions;
}

const WindowFunctionMap windowFunctions(bool includeAggregates) {
const auto functions = windowFunctions();
if (includeAggregates) {
return functions;
}

WindowFunctionMap windowMap;
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, aggregateFunctions) &&
aggregateFunctions.count(entry.first) == 0) {
windowMap.insert(entry);
}
}
});
return windowMap;
}

namespace {
std::optional<const WindowFunctionEntry*> getWindowFunctionEntry(
const std::string& name) {
Expand Down
2 changes: 2 additions & 0 deletions velox/exec/WindowFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,6 @@ using WindowFunctionMap = std::unordered_map<std::string, WindowFunctionEntry>;

/// Returns a map of all window function names to their registrations.
WindowFunctionMap& windowFunctions();

const WindowFunctionMap windowFunctions(bool includeAggregates);
} // namespace facebook::velox::exec
25 changes: 25 additions & 0 deletions velox/exec/tests/WindowFunctionRegistryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "velox/exec/WindowFunction.h"
#include "velox/expression/SignatureBinder.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"

namespace facebook::velox::exec::test {
Expand Down Expand Up @@ -134,4 +135,28 @@ TEST_F(WindowFunctionRegistryTest, prefix) {
}
}

TEST_F(WindowFunctionRegistryTest, includeAggregates) {
aggregate::prestosql::registerAllAggregateFunctions();
window::prestosql::registerAllWindowFunctions();
const auto expectedWindows = {"lead", "ntile", "nth_value", "first_value"};
const auto aggregates = {
"approx_distinct", "covar_pop", "count", "map_union_sum"};

const auto& windowMapNoAggs = exec::windowFunctions(false);
for (const auto& function : expectedWindows) {
ASSERT_TRUE(windowMapNoAggs.find(function) != windowMapNoAggs.end());
}
for (const auto& function : aggregates) {
ASSERT_FALSE(windowMapNoAggs.find(function) != windowMapNoAggs.end());
}

const auto& windowMapWithAggs = exec::windowFunctions(true);
for (const auto& function : expectedWindows) {
ASSERT_TRUE(windowMapWithAggs.find(function) != windowMapWithAggs.end());
}
for (const auto& function : aggregates) {
ASSERT_TRUE(windowMapWithAggs.find(function) != windowMapWithAggs.end());
}
}

} // namespace facebook::velox::exec::test
1 change: 1 addition & 0 deletions velox/functions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ velox_add_library(velox_coverage_util CoverageUtil.cpp)

velox_link_libraries(
velox_function_registry
velox_exec
velox_expression
velox_type
velox_core
Expand Down
68 changes: 15 additions & 53 deletions velox/functions/CoverageUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,82 +270,44 @@ void printCoverageMap(
std::cout << out.str() << std::endl;
}

// A function name is a companion function's if the name is an existing
// aggregation functio name followed by a specific suffixes.
bool isCompanionFunctionName(
const std::string& name,
const std::unordered_map<std::string, exec::AggregateFunctionEntry>&
aggregateFunctions) {
auto suffixOffset = name.rfind("_partial");
if (suffixOffset == std::string::npos) {
suffixOffset = name.rfind("_merge_extract");
}
if (suffixOffset == std::string::npos) {
suffixOffset = name.rfind("_merge");
}
if (suffixOffset == std::string::npos) {
suffixOffset = name.rfind("_extract");
}
if (suffixOffset == std::string::npos) {
return false;
}
return aggregateFunctions.count(name.substr(0, suffixOffset)) > 0;
}

/// Returns alphabetically sorted list of scalar functions available in Velox,
/// excluding companion functions.
std::vector<std::string> getSortedScalarNames() {
// Do not print "internal" functions.
static const std::unordered_set<std::string> kBlockList = {"row_constructor"};

auto functions = getFunctionSignatures();

const auto& functions = getFunctionSignatures(false);
std::vector<std::string> names;
names.reserve(functions.size());
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& func : functions) {
const auto& name = func.first;
if (!isCompanionFunctionName(name, aggregateFunctions) &&
kBlockList.count(name) == 0) {
names.emplace_back(name);
}
}
});
for (const auto& entry : functions) {
names.push_back(entry.first);
}

std::sort(names.begin(), names.end());
return names;
}

/// Returns alphabetically sorted list of aggregate functions available in
/// Velox, excluding compaion functions.
std::vector<std::string> getSortedAggregateNames() {
const auto& functions = exec::aggregateFunctions(false);
std::vector<std::string> names;
exec::aggregateFunctions().withRLock([&](const auto& functions) {
names.reserve(functions.size());
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, functions)) {
names.push_back(entry.first);
}
}
});
for (const auto& entry : functions) {
names.push_back(entry.first);
}

std::sort(names.begin(), names.end());
return names;
}

/// Returns alphabetically sorted list of window functions available in Velox,
/// excluding companion functions.
std::vector<std::string> getSortedWindowNames() {
const auto& functions = exec::windowFunctions();

const auto& functions = exec::windowFunctions(false);
std::vector<std::string> names;
names.reserve(functions.size());
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& entry : functions) {
if (!isCompanionFunctionName(entry.first, aggregateFunctions) &&
aggregateFunctions.count(entry.first) == 0) {
names.emplace_back(entry.first);
}
}
});
for (const auto& entry : functions) {
names.push_back(entry.first);
}

std::sort(names.begin(), names.end());
return names;
}
Expand Down
76 changes: 59 additions & 17 deletions velox/functions/FunctionRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,62 @@
namespace facebook::velox {
namespace {

void populateSimpleFunctionSignatures(FunctionSignatureMap& map) {
void populateSimpleFunctionSignatures(
FunctionSignatureMap& map,
bool includeCompanionFunctions) {
const auto& simpleFunctions = exec::simpleFunctions();
for (const auto& functionName : simpleFunctions.getFunctionNames()) {
map[functionName] = simpleFunctions.getFunctionSignatures(functionName);
}
exec::aggregateFunctions().withRLock([&](const auto& aggregateFunctions) {
for (const auto& functionName : simpleFunctions.getFunctionNames()) {
auto isCompanionFunction =
isCompanionFunctionName(functionName, aggregateFunctions);
if (includeCompanionFunctions && isCompanionFunction) {
map[functionName] = simpleFunctions.getFunctionSignatures(functionName);
} else if (!isCompanionFunction) {
map[functionName] = simpleFunctions.getFunctionSignatures(functionName);
}
}
});
}

inline void insertVectorFunctionToSignatureMap(
FunctionSignatureMap& map,
const std::pair<std::string, exec::VectorFunctionEntry>& vectorFunction) {
const auto& allSignatures = vectorFunction.second.signatures;
auto& curSignatures = map[vectorFunction.first];
std::transform(
allSignatures.begin(),
allSignatures.end(),
std::back_inserter(curSignatures),
[](std::shared_ptr<exec::FunctionSignature> signature)
-> exec::FunctionSignature* { return signature.get(); });
}

void populateVectorFunctionSignatures(FunctionSignatureMap& map) {
void populateVectorFunctionSignatures(
FunctionSignatureMap& map,
bool includeCompanionFunctions = true) {
auto vectorFunctions = exec::vectorFunctionFactories();
vectorFunctions.withRLock([&map](const auto& locked) {
vectorFunctions.withRLock([&](const auto& locked) {
for (const auto& it : locked) {
const auto& allSignatures = it.second.signatures;
auto& curSignatures = map[it.first];
std::transform(
allSignatures.begin(),
allSignatures.end(),
std::back_inserter(curSignatures),
[](std::shared_ptr<exec::FunctionSignature> signature)
-> exec::FunctionSignature* { return signature.get(); });
if (includeCompanionFunctions) {
insertVectorFunctionToSignatureMap(map, it);
} else {
exec::aggregateFunctions().withRLock(
[&](const auto& aggregateFunctions) {
if (!isCompanionFunctionName(it.first, aggregateFunctions)) {
insertVectorFunctionToSignatureMap(map, it);
}
});
}
}
});
}

} // namespace

FunctionSignatureMap getFunctionSignatures() {
FunctionSignatureMap getFunctionSignatures(bool includeCompanionFunctions) {
FunctionSignatureMap result;
populateSimpleFunctionSignatures(result);
populateVectorFunctionSignatures(result);
populateSimpleFunctionSignatures(result, includeCompanionFunctions);
populateVectorFunctionSignatures(result, includeCompanionFunctions);
return result;
}

Expand Down Expand Up @@ -160,4 +187,19 @@ resolveVectorFunctionWithMetadata(
return exec::resolveVectorFunctionWithMetadata(functionName, argTypes);
}

bool isCompanionFunctionName(
const std::string& name,
const std::unordered_map<std::string, exec::AggregateFunctionEntry>&
aggregateFunctions) {
static const std::vector<std::string> kCompanionFunctionSuffixList = {
"_partial", "_merge_extract", "_merge", "_extract"};
for (const auto& companionFunctionSuffix : kCompanionFunctionSuffixList) {
auto suffixOffset = name.rfind(companionFunctionSuffix);
if (suffixOffset != std::string::npos) {
return aggregateFunctions.count(name.substr(0, suffixOffset)) > 0;
}
}
return false;
}

} // namespace facebook::velox
11 changes: 10 additions & 1 deletion velox/functions/FunctionRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <string>
#include <vector>

#include "velox/exec/Aggregate.h"
#include "velox/expression/FunctionMetadata.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/type/Type.h"
Expand All @@ -29,7 +30,8 @@ using FunctionSignatureMap = std::

/// Returns a mapping of all Simple and Vector functions registered in Velox
/// The mapping is function name -> list of function signatures
FunctionSignatureMap getFunctionSignatures();
FunctionSignatureMap getFunctionSignatures(
bool includeCompanionFunctions = true);

/// Returns a mapping of all Vector functions registered in Velox
/// The mapping is function name -> list of function signatures
Expand Down Expand Up @@ -95,4 +97,11 @@ resolveVectorFunctionWithMetadata(
/// Clears the function registry.
void clearFunctionRegistry();

/// A function name is a companion function's if the name is an existing
/// aggregation function name followed by one of specific suffixes.
bool isCompanionFunctionName(
const std::string& name,
const std::unordered_map<std::string, exec::AggregateFunctionEntry>&
aggregateFunctions);

} // namespace facebook::velox
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,32 @@ TEST_F(AggregationFunctionRegTest, prestoSupportedSignatures) {
clearAndCheckRegistry();
}

TEST_F(AggregationFunctionRegTest, includeCompanionFunctions) {
aggregate::prestosql::registerAllAggregateFunctions();
const auto expectedAggregates = {
"approx_distinct", "covar_pop", "count", "map_union_sum"};
const auto aggregateCompanions = {
"approx_distinct_partial", "approx_percentile_merge_extract_integer"};

const auto& aggrFuncsWoCompanions = exec::aggregateFunctions(false);
for (const auto& aggregate : expectedAggregates) {
ASSERT_TRUE(
aggrFuncsWoCompanions.find(aggregate) != aggrFuncsWoCompanions.end());
}
for (const auto& companion : aggregateCompanions) {
ASSERT_FALSE(
aggrFuncsWoCompanions.find(companion) != aggrFuncsWoCompanions.end());
}

const auto& aggrFuncWithCompanions = exec::aggregateFunctions(true);
for (const auto& aggregate : expectedAggregates) {
ASSERT_TRUE(
aggrFuncWithCompanions.find(aggregate) != aggrFuncWithCompanions.end());
}
for (const auto& companion : aggregateCompanions) {
ASSERT_TRUE(
aggrFuncWithCompanions.find(companion) != aggrFuncWithCompanions.end());
}
}

} // namespace facebook::velox::aggregate::test

0 comments on commit 2b18863

Please sign in to comment.