Skip to content

Commit

Permalink
Add isCompanionFunction to function metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
pramodsatya committed Aug 29, 2024
1 parent 2ce8f71 commit 0dccd1c
Show file tree
Hide file tree
Showing 24 changed files with 145 additions and 45 deletions.
21 changes: 17 additions & 4 deletions velox/exec/Aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,25 @@ AggregateRegistrationResult registerAggregateFunction(
}

// Register the aggregate as a window function also.
registerAggregateWindowFunction(sanitizedName);
registerAggregateWindowFunction(sanitizedName, metadata);

// Register companion function if needed.
if (registerCompanionFunctions) {
auto companionMetadata = metadata;
companionMetadata.isCompanionFunction = true;

registered.partialFunction =
CompanionFunctionsRegistrar::registerPartialFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
registered.mergeFunction =
CompanionFunctionsRegistrar::registerMergeFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
registered.extractFunction =
CompanionFunctionsRegistrar::registerExtractFunction(
name, signatures, overwrite);
registered.mergeExtractFunction =
CompanionFunctionsRegistrar::registerMergeExtractFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
}
return registered;
}
Expand Down Expand Up @@ -141,6 +144,16 @@ std::vector<AggregateRegistrationResult> registerAggregateFunction(
return registrationResults;
}

const AggregateFunctionMetadata& getAggregateFunctionMetadata(
const std::string& name) {
const auto sanitizedName = sanitizeName(name);
if (auto func = getAggregateFunctionEntry(sanitizedName)) {
return func->metadata;
} else {
VELOX_USER_FAIL("Metadata not found for aggregate function: {}", name);
}
}

std::unordered_map<
std::string,
std::vector<std::shared_ptr<AggregateFunctionSignature>>>
Expand Down
6 changes: 6 additions & 0 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ struct AggregateFunctionMetadata {
/// True if results of the aggregation depend on the order of inputs. For
/// example, array_agg is order sensitive while count is not.
bool orderSensitive{true};

/// Indicates if this is a companion function.
bool isCompanionFunction{false};
};
/// Register an aggregate function with the specified name and signatures. If
/// registerCompanionFunctions is true, also register companion aggregate and
Expand Down Expand Up @@ -514,6 +517,9 @@ std::vector<AggregateRegistrationResult> registerAggregateFunction(
bool registerCompanionFunctions,
bool overwrite);

const AggregateFunctionMetadata& getAggregateFunctionMetadata(
const std::string& name);

/// Returns signatures of the aggregate function with the specified name.
/// Returns empty std::optional if function with that name is not found.
std::optional<std::vector<std::shared_ptr<AggregateFunctionSignature>>>
Expand Down
25 changes: 20 additions & 5 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void AggregateCompanionAdapter::ExtractFunction::apply(
bool CompanionFunctionsRegistrar::registerPartialFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto partialSignatures =
CompanionSignatures::partialFunctionSignatures(signatures);
Expand Down Expand Up @@ -280,6 +281,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
name,
CompanionSignatures::partialFunctionName(name));
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
Expand All @@ -288,6 +290,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
bool CompanionFunctionsRegistrar::registerMergeFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto mergeSignatures =
CompanionSignatures::mergeFunctionSignatures(signatures);
Expand Down Expand Up @@ -320,16 +323,18 @@ bool CompanionFunctionsRegistrar::registerMergeFunction(
name,
CompanionSignatures::mergeFunctionName(name));
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
}

bool registerAggregateFunction(
bool registerMergeExtractFunctionImpl(
const std::string& name,
const std::string& mergeExtractFunctionName,
const std::vector<std::shared_ptr<AggregateFunctionSignature>>&
mergeExtractSignatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
return exec::registerAggregateFunction(
mergeExtractFunctionName,
Expand Down Expand Up @@ -365,6 +370,7 @@ bool registerAggregateFunction(
name,
mergeExtractFunctionName);
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
Expand All @@ -373,6 +379,7 @@ bool registerAggregateFunction(
bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
Expand All @@ -387,10 +394,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type);

registered |= registerAggregateFunction(
registered |= registerMergeExtractFunctionImpl(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
metadata,
overwrite);
}
return registered;
Expand All @@ -399,10 +407,12 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
return registerMergeExtractFunctionWithSuffix(
name, signatures, metadata, overwrite);
}

auto mergeExtractSignatures =
Expand All @@ -413,10 +423,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return registerAggregateFunction(
return registerMergeExtractFunctionImpl(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
metadata,
overwrite);
}

Expand Down Expand Up @@ -475,6 +486,7 @@ bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
std::move(factory),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.isCompanionFunction(true)
.build(),
overwrite);
}
Expand Down Expand Up @@ -502,7 +514,10 @@ bool CompanionFunctionsRegistrar::registerExtractFunction(
CompanionSignatures::extractFunctionName(originalName),
std::move(extractSignatures),
std::move(factory),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.isCompanionFunction(true)
.build(),
overwrite);
}

Expand Down
4 changes: 4 additions & 0 deletions velox/exec/AggregateCompanionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class CompanionFunctionsRegistrar {
static bool registerPartialFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

// When there is already a function of the same name as the merge companion
Expand All @@ -186,6 +187,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

// If there are multiple signatures of the original aggregation function
Expand Down Expand Up @@ -213,6 +215,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

private:
Expand All @@ -227,6 +230,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite);
};

Expand Down
8 changes: 6 additions & 2 deletions velox/exec/AggregateWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ class AggregateWindowFunction : public exec::WindowFunction {

} // namespace

void registerAggregateWindowFunction(const std::string& name) {
void registerAggregateWindowFunction(
const std::string& name,
const AggregateFunctionMetadata& metadata) {
auto aggregateFunctionSignatures = exec::getAggregateFunctionSignatures(name);
if (aggregateFunctionSignatures.has_value()) {
// This copy is needed to obtain a vector of the base FunctionSignaturePtr
Expand All @@ -410,7 +412,9 @@ void registerAggregateWindowFunction(const std::string& name) {
exec::registerWindowFunction(
name,
std::move(signatures),
{exec::WindowFunction::ProcessMode::kRows, true},
{exec::WindowFunction::ProcessMode::kRows,
true,
metadata.isCompanionFunction},
[name](
const std::vector<exec::WindowFunctionArg>& args,
const TypePtr& resultType,
Expand Down
5 changes: 4 additions & 1 deletion velox/exec/AggregateWindow.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
*/
#pragma once
#include <string>
#include "velox/exec/Aggregate.h"

namespace facebook::velox::exec {

void registerAggregateWindowFunction(const std::string& name);
void registerAggregateWindowFunction(
const std::string& name,
const AggregateFunctionMetadata& metadata);

} // namespace facebook::velox::exec
7 changes: 4 additions & 3 deletions velox/exec/WindowFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ class WindowFunction {
kRows,
};

/// Indicates whether this is an aggregate window function and its process
/// unit.
/// Indicates whether this is an aggregate window function, whether it is a
/// companion function, and its process unit.
struct Metadata {
ProcessMode processMode;
bool isAggregate;
bool isCompanionFunction;

static Metadata defaultMetadata() {
static Metadata defaultValue{ProcessMode::kPartition, false};
static Metadata defaultValue{ProcessMode::kPartition, false, false};
return defaultValue;
}
};
Expand Down
19 changes: 19 additions & 0 deletions velox/exec/tests/WindowFunctionRegistryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "velox/exec/WindowFunction.h"
#include "velox/expression/SignatureBinder.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"

namespace facebook::velox::exec::test {
Expand Down Expand Up @@ -134,4 +135,22 @@ TEST_F(WindowFunctionRegistryTest, prefix) {
}
}

TEST_F(WindowFunctionRegistryTest, isCompanionFunction) {
aggregate::prestosql::registerAllAggregateFunctions();
window::prestosql::registerAllWindowFunctions();
const auto windowFunctions = {
"count", "lead", "ntile", "nth_value", "first_value", "map_union_sum"};
const auto companionFunctions = {
"approx_most_frequent_partial",
"approx_percentile_merge",
"arbitrary_merge_extract"};

for (const auto& function : windowFunctions) {
ASSERT_FALSE(getWindowFunctionMetadata(function).isCompanionFunction);
}
for (const auto& function : companionFunctions) {
ASSERT_TRUE(getWindowFunctionMetadata(function).isCompanionFunction);
}
}

} // namespace facebook::velox::exec::test
8 changes: 8 additions & 0 deletions velox/expression/FunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ struct VectorFunctionMetadata {
/// In this case, 'rows' in VectorFunction::apply will point only to positions
/// for which all arguments are not null.
bool defaultNullBehavior{true};

/// Indicates if this is a companion function.
bool isCompanionFunction{false};
};

class VectorFunctionMetadataBuilder {
Expand All @@ -59,6 +62,11 @@ class VectorFunctionMetadataBuilder {
return *this;
}

VectorFunctionMetadataBuilder& isCompanionFunction(bool isCompanionFunction) {
metadata_.isCompanionFunction = isCompanionFunction;
return *this;
}

const VectorFunctionMetadata& build() const {
return metadata_;
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/lib/aggregates/BitwiseAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ exec::AggregateRegistrationResult registerBitwise(
inputType->kindName());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/lib/window/Rank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ void registerRankInternal(
exec::registerWindowFunction(
name,
std::move(signatures),
{exec::WindowFunction::ProcessMode::kRows, false},
{exec::WindowFunction::ProcessMode::kRows, false, false},
std::move(windowFunctionFactory));
} else {
exec::registerWindowFunction(
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/lib/window/RowNumber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void registerRowNumber(const std::string& name, TypeKind resultTypeKind) {
exec::registerWindowFunction(
name,
std::move(signatures),
{exec::WindowFunction::ProcessMode::kRows, false},
{exec::WindowFunction::ProcessMode::kRows, false, false},
[name](
const std::vector<exec::WindowFunctionArg>& /*args*/,
const TypePtr& resultType,
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void registerAverageAggregate(
}
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/BoolAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ exec::AggregateRegistrationResult registerBool(
inputType->kindName());
return std::make_unique<T>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/ChecksumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ void registerChecksumAggregate(

return std::make_unique<ChecksumAggregate>(VARBINARY());
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/CountAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ void registerCountAggregate(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<CountAggregate>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/CountIfAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void registerCountIfAggregate(

return std::make_unique<CountIfAggregate>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*isCompanionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
Loading

0 comments on commit 0dccd1c

Please sign in to comment.