diff --git a/velox/functions/sparksql/Arithmetic.h b/velox/functions/sparksql/Arithmetic.h index 62fdae2551c2..f2eb9aa990bd 100644 --- a/velox/functions/sparksql/Arithmetic.h +++ b/velox/functions/sparksql/Arithmetic.h @@ -25,7 +25,7 @@ namespace facebook::velox::functions::sparksql { template -struct PModFunction { +struct PModIntFunction { template FOLLY_ALWAYS_INLINE bool call(TInput& result, const TInput a, const TInput n) #if defined(__has_feature) @@ -43,6 +43,20 @@ struct PModFunction { } }; +template +struct PModFloatFunction { + template + FOLLY_ALWAYS_INLINE bool call(TInput& result, const TInput a, const TInput n) + { + if (UNLIKELY(n == (TInput)0)) { + return false; + } + TInput r = fmod(a, n); + result = (r > 0) ? r : fmod(r + n, n); + return true; + } +}; + template struct RemainderFunction { template diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/RegisterArithmetic.cpp index 4f72fdf59c30..a17c66cc79ec 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/RegisterArithmetic.cpp @@ -32,7 +32,8 @@ void registerArithmeticFunctions(const std::string& prefix) { // Math functions. registerUnaryNumeric({prefix + "abs"}); registerFunction({prefix + "exp"}); - registerBinaryIntegral({prefix + "pmod"}); + registerBinaryIntegral({prefix + "pmod"}); + registerBinaryFloatingPoint({prefix + "pmod"}); registerFunction({prefix + "power"}); registerUnaryNumeric({prefix + "round"}); registerFunction({prefix + "round"}); diff --git a/velox/functions/sparksql/tests/ArithmeticTest.cpp b/velox/functions/sparksql/tests/ArithmeticTest.cpp index 886123b2d05e..f3fb53f08df4 100644 --- a/velox/functions/sparksql/tests/ArithmeticTest.cpp +++ b/velox/functions/sparksql/tests/ArithmeticTest.cpp @@ -67,6 +67,18 @@ TEST_F(PmodTest, int64) { EXPECT_EQ(INT64_MAX - 1, pmod(INT64_MIN, INT64_MAX)); } +TEST_F(PmodTest, float) { + EXPECT_FLOAT_EQ(0.2, pmod(0.5, 0.3).value()); + EXPECT_FLOAT_EQ(0.9, pmod(-1.1, 2).value()); + EXPECT_EQ(std::nullopt, pmod(2.14159, 0.0)); +} + +TEST_F(PmodTest, double) { + EXPECT_DOUBLE_EQ(0.2, pmod(0.5, 0.3).value()); + EXPECT_DOUBLE_EQ(0.9, pmod(-1.1, 2).value()); + EXPECT_EQ(std::nullopt, pmod(2.14159, 0.0)); +} + class RemainderTest : public SparkFunctionBaseTest { protected: template