From 7ff56c656b1b9551c2baf89801d0facd61bfa9d9 Mon Sep 17 00:00:00 2001 From: Bikramjeet Vig Date: Thu, 25 Apr 2024 15:04:56 -0700 Subject: [PATCH] Back out "Back out "[velox][PR] Refactor greatest and least Presto functions using simple function API"" Summary: Re-introducing this as the issue that initiated the backout is resolved. Original PR: https://github.com/facebookincubator/velox/pull/9308 Reviewed By: mbasmanova, s4ayub Differential Revision: D56548695 fbshipit-source-id: d0a9032f5cc958c8f4a3124c1ad81f290e31800b --- velox/functions/prestosql/CMakeLists.txt | 1 - velox/functions/prestosql/GreatestLeast.cpp | 207 ------------------ velox/functions/prestosql/GreatestLeast.h | 101 +++++++++ .../GeneralFunctionsRegistration.cpp | 32 ++- .../prestosql/tests/GreatestLeastTest.cpp | 69 ++++-- 5 files changed, 175 insertions(+), 235 deletions(-) delete mode 100644 velox/functions/prestosql/GreatestLeast.cpp create mode 100644 velox/functions/prestosql/GreatestLeast.h diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index 3a8008be601b..54d959cbbb37 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -34,7 +34,6 @@ add_library( FindFirst.cpp FromUnixTime.cpp FromUtf8.cpp - GreatestLeast.cpp InPredicate.cpp JsonFunctions.cpp Map.cpp diff --git a/velox/functions/prestosql/GreatestLeast.cpp b/velox/functions/prestosql/GreatestLeast.cpp deleted file mode 100644 index afc085d4a7ea..000000000000 --- a/velox/functions/prestosql/GreatestLeast.cpp +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "velox/common/base/Exceptions.h" -#include "velox/expression/Expr.h" -#include "velox/expression/VectorFunction.h" -#include "velox/type/Type.h" - -namespace facebook::velox::functions { - -namespace { - -template -class ExtremeValueFunction; - -using LeastFunction = ExtremeValueFunction; -using GreatestFunction = ExtremeValueFunction; - -/** - * This class implements two functions: - * - * greatest(value1, value2, ..., valueN) → [same as input] - * Returns the largest of the provided values. - * - * least(value1, value2, ..., valueN) → [same as input] - * Returns the smallest of the provided values. - **/ -template -class ExtremeValueFunction : public exec::VectorFunction { - private: - template - bool shouldOverride(const T& currentValue, const T& candidateValue) const { - return isLeast ? candidateValue < currentValue - : candidateValue > currentValue; - } - - // For double, presto should throw error if input is Nan - template - void checkNan(const T& value) const { - if constexpr (std::is_same_v::NativeType>) { - if (std::isnan(value)) { - VELOX_USER_FAIL( - "Invalid argument to {}: NaN", isLeast ? "least()" : "greatest()"); - } - } - } - - template - void applyTyped( - const SelectivityVector& rows, - const std::vector& args, - const TypePtr& outputType, - exec::EvalCtx& context, - VectorPtr& result) const { - context.ensureWritable(rows, outputType, result); - result->clearNulls(rows); - - auto* flatResult = result->as>(); - BufferPtr resultValues = flatResult->mutableValues(rows.end()); - T* __restrict rawResult = resultValues->asMutable(); - - exec::DecodedArgs decodedArgs(rows, args, context); - - std::set usedInputs; - context.applyToSelectedNoThrow(rows, [&](int row) { - size_t valueIndex = 0; - - T currentValue = decodedArgs.at(0)->valueAt(row); - checkNan(currentValue); - - for (auto i = 1; i < args.size(); ++i) { - auto candidateValue = decodedArgs.at(i)->template valueAt(row); - checkNan(candidateValue); - - if constexpr (isLeast) { - if (candidateValue < currentValue) { - currentValue = candidateValue; - valueIndex = i; - } - } else { - if (candidateValue > currentValue) { - currentValue = candidateValue; - valueIndex = i; - } - } - } - usedInputs.insert(valueIndex); - - if constexpr (std::is_same_v) { - flatResult->set(row, currentValue); - } else { - rawResult[row] = currentValue; - } - }); - - if constexpr (std::is_same_v) { - for (auto index : usedInputs) { - flatResult->acquireSharedStringBuffers(args[index].get()); - } - } - } - - public: - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& outputType, - exec::EvalCtx& context, - VectorPtr& result) const override { - switch (outputType.get()->kind()) { - case TypeKind::BOOLEAN: - applyTyped(rows, args, outputType, context, result); - return; - case TypeKind::TINYINT: - applyTyped(rows, args, outputType, context, result); - return; - case TypeKind::SMALLINT: - applyTyped(rows, args, outputType, context, result); - return; - case TypeKind::INTEGER: - applyTyped(rows, args, outputType, context, result); - return; - case TypeKind::BIGINT: - applyTyped(rows, args, outputType, context, result); - return; - case TypeKind::HUGEINT: - applyTyped(rows, args, outputType, context, result); - return; - case TypeKind::REAL: - applyTyped(rows, args, outputType, context, result); - return; - case TypeKind::DOUBLE: - applyTyped(rows, args, outputType, context, result); - return; - case TypeKind::VARCHAR: - applyTyped(rows, args, outputType, context, result); - return; - case TypeKind::TIMESTAMP: - applyTyped(rows, args, outputType, context, result); - return; - default: - VELOX_FAIL( - "Unsupported input type for {}: {}", - isLeast ? "least()" : "greatest()", - outputType->toString()); - } - } - - static std::vector> signatures() { - const std::vector types = { - "boolean", - "tinyint", - "smallint", - "integer", - "bigint", - "double", - "real", - "varchar", - "timestamp", - "date", - }; - std::vector> signatures; - for (const auto& type : types) { - signatures.emplace_back(exec::FunctionSignatureBuilder() - .returnType(type) - .argumentType(type) - .variableArity() - .build()); - } - signatures.emplace_back(exec::FunctionSignatureBuilder() - .integerVariable("precision") - .integerVariable("scale") - .returnType("DECIMAL(precision, scale)") - .argumentType("DECIMAL(precision, scale)") - .variableArity() - .build()); - return signatures; - } -}; -} // namespace - -VELOX_DECLARE_VECTOR_FUNCTION( - udf_least, - LeastFunction::signatures(), - std::make_unique()); - -VELOX_DECLARE_VECTOR_FUNCTION( - udf_greatest, - GreatestFunction::signatures(), - std::make_unique()); - -} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/GreatestLeast.h b/velox/functions/prestosql/GreatestLeast.h new file mode 100644 index 000000000000..a648aa5611d7 --- /dev/null +++ b/velox/functions/prestosql/GreatestLeast.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/functions/Macros.h" + +namespace facebook::velox::functions { +namespace details { +/** + * This class implements two functions: + * + * greatest(value1, value2, ..., valueN) → [same as input] + * Returns the largest of the provided values. + * + * least(value1, value2, ..., valueN) → [same as input] + * Returns the smallest of the provided values. + * + * For DOUBLE and REAL type, NaN is considered as the biggest according to + * https://github.com/prestodb/presto/issues/22391 + **/ +template +struct ExtremeValueFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& firstElement, + const arg_type>& remainingElement) { + auto currentValue = firstElement; + + for (auto element : remainingElement) { + auto candidateValue = element.value(); + + if constexpr (isLeast) { + if (smallerThan(candidateValue, currentValue)) { + currentValue = candidateValue; + } + } else { + if (greaterThan(candidateValue, currentValue)) { + currentValue = candidateValue; + } + } + } + + result = currentValue; + } + + private: + template + bool greaterThan(const K& lhs, const K& rhs) const { + if constexpr (std::is_same_v || std::is_same_v) { + if (std::isnan(lhs)) { + return true; + } + + if (std::isnan(rhs)) { + return false; + } + } + + return lhs > rhs; + } + + template + bool smallerThan(const K& lhs, const K& rhs) const { + if constexpr (std::is_same_v || std::is_same_v) { + if (std::isnan(lhs)) { + return false; + } + + if (std::isnan(rhs)) { + return true; + } + } + + return lhs < rhs; + } +}; +} // namespace details + +template +using LeastFunction = details::ExtremeValueFunction; + +template +using GreatestFunction = details::ExtremeValueFunction; + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp index c5a10411eb86..37acd5d68130 100644 --- a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp @@ -17,9 +17,35 @@ #include "velox/functions/Registerer.h" #include "velox/functions/lib/IsNull.h" #include "velox/functions/prestosql/Cardinality.h" +#include "velox/functions/prestosql/GreatestLeast.h" #include "velox/functions/prestosql/InPredicate.h" namespace facebook::velox::functions { + +template +inline void registerGreatestLeastFunction(const std::string& prefix) { + registerFunction, T, T, Variadic>( + {prefix + "greatest"}); + + registerFunction, T, T, Variadic>( + {prefix + "least"}); +} + +inline void registerAllGreatestLeastFunctions(const std::string& prefix) { + registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction>(prefix); + registerGreatestLeastFunction>(prefix); + registerGreatestLeastFunction(prefix); + registerGreatestLeastFunction(prefix); +} + extern void registerSubscriptFunction( const std::string& name, bool enableCaching); @@ -47,12 +73,10 @@ void registerGeneralFunctions(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, prefix + "transform"); VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, prefix + "reduce"); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_filter, prefix + "filter"); - - VELOX_REGISTER_VECTOR_FUNCTION(udf_least, prefix + "least"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_greatest, prefix + "greatest"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_typeof, prefix + "typeof"); + registerAllGreatestLeastFunctions(prefix); + registerFunction>( {prefix + "cardinality"}); registerFunction>( diff --git a/velox/functions/prestosql/tests/GreatestLeastTest.cpp b/velox/functions/prestosql/tests/GreatestLeastTest.cpp index e19e13d61410..9bfddc552144 100644 --- a/velox/functions/prestosql/tests/GreatestLeastTest.cpp +++ b/velox/functions/prestosql/tests/GreatestLeastTest.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include +#include #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" @@ -83,29 +85,50 @@ TEST_F(GreatestLeastTest, leastReal) { {0, -100, -1.1}); } -TEST_F(GreatestLeastTest, nanInput) { - // Presto rejects NaN inputs of type DOUBLE, but allows NaN inputs of type - // REAL. - std::vector input{0, 1.1, std::nan("1")}; - VELOX_ASSERT_THROW( - runTest("least(c0)", {{0.0 / 0.0}}, {0}), - "Invalid argument to least(): NaN"); - runTest("try(least(c0, 1.0))", {input}, {0, 1.0, std::nullopt}); - - VELOX_ASSERT_THROW( - runTest("greatest(c0)", {1, {0.0 / 0.0}}, {1, 0}), - "Invalid argument to greatest(): NaN"); - runTest("try(greatest(c0, 1.0))", {input}, {1.0, 1.1, std::nullopt}); - - auto result = evaluateOnce( - "is_nan(least(c0))", std::nanf("1"), 1.2); - ASSERT_TRUE(result.has_value()); - ASSERT_TRUE(result.value()); - - result = evaluateOnce( - "is_nan(greatest(c0))", std::nanf("1"), 1.2); - ASSERT_TRUE(result.has_value()); - ASSERT_TRUE(result.value()); +TEST_F(GreatestLeastTest, greatestNanInput) { + auto constexpr kInf32 = std::numeric_limits::infinity(); + auto constexpr kInf64 = std::numeric_limits::infinity(); + + auto greatestFloat = [&](float a, float b, float c) { + return evaluateOnce( + "greatest(c0, c1, c2)", {a}, {b}, {c}) + .value(); + }; + + auto greatestDouble = [&](double a, double b, double c) { + return evaluateOnce( + "greatest(c0, c1, c2)", {a}, {b}, {c}) + .value(); + }; + + EXPECT_TRUE(std::isnan(greatestFloat(1.0, std::nanf("1"), 2.0))); + EXPECT_TRUE(std::isnan(greatestFloat(std::nanf("1"), 1.0, kInf32))); + + EXPECT_TRUE(std::isnan(greatestDouble(1.0, std::nan("1"), 2.0))); + EXPECT_TRUE(std::isnan(greatestDouble(std::nan("1"), 1.0, kInf64))); +} + +TEST_F(GreatestLeastTest, leastNanInput) { + auto constexpr kInf32 = std::numeric_limits::infinity(); + auto constexpr kInf64 = std::numeric_limits::infinity(); + + auto leastFloat = [&](float a, float b, float c) { + return evaluateOnce( + "least(c0, c1, c2)", {a}, {b}, {c}) + .value(); + }; + + auto leastDouble = [&](double a, double b, double c) { + return evaluateOnce( + "least(c0, c1, c2)", {a}, {b}, {c}) + .value(); + }; + + EXPECT_EQ(leastFloat(1.0, std::nanf("1"), 0.5), 0.5); + EXPECT_EQ(leastFloat(std::nanf("1"), 1.0, -kInf32), -kInf32); + + EXPECT_EQ(leastDouble(1.0, std::nan("1"), 0.5), 0.5); + EXPECT_EQ(leastDouble(std::nan("1"), 1.0, -kInf64), -kInf64); } TEST_F(GreatestLeastTest, greatestDouble) {