Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for DECIMAL types to Simple Function API (facebookincubat…
…or#9096) Summary: Use the new functionality to re-write decimal plus, minus, multiple, divide, between, negate, floor and round. **type/Type.h** Add P1, P2, P3, P4, S1, S2, S3, S4 types to specify precision and scale parameters for decimal types during function registration. Add LongDecimal<P, S> and ShortDecimal<P, S> templates to specify decimal argument and return types during function registration. ``` registerFunction< DecimalAddFunction, LongDecimal<P3, S3>, LongDecimal<P1, S1>, LongDecimal<P2, S2>>({"plus"}, constraints); ``` **expression/UdfTypeResolver.h** Define arg_type and out_type for LongDecimal and ShortDecimal. ``` arg_type<LongDecimal> = int128_t out_type<LongDecimal> = int128_t arg_type<ShortDecimal> = int64_t out_type<ShortDecimal> = int64_t ``` **functions/Registerer.h** Add optional ‘constraints’ parameter to registerFunction template. This allows to specify rules for calculating precision and scale for decimal return types. ``` template <template <class> typename Func, typename TReturn, typename... TArgs> void registerFunction( const std::vector<std::string>& aliases = {}, const std::vector<exec::SignatureVariable>& constraints = {}) ``` Here is how we can specify calculation of precision and scale for the return type of plus(decimal, decimal). ``` std::vector<exec::SignatureVariable> constraints = { exec::SignatureVariable( P3::name(), fmt::format( "min(38, max({a_precision} - {a_scale}, {b_precision} - {b_scale}) + max({a_scale}, {b_scale}) + 1)", fmt::arg("a_precision", P1::name()), fmt::arg("b_precision", P2::name()), fmt::arg("a_scale", S1::name()), fmt::arg("b_scale", S2::name())), exec::ParameterType::kIntegerParameter), exec::SignatureVariable( S3::name(), fmt::format( "max({a_scale}, {b_scale})", fmt::arg("a_scale", S1::name()), fmt::arg("b_scale", S2::name())), exec::ParameterType::kIntegerParameter), }; ``` **core/SimpleFunctionMetadata.h** Extend SimpleFunctionMetadata to store physical types (TypeKind) of input arguments and return type in addition to signature. Decimal “plus” function has a single signature: ``` (decimal(p1, s1), decimal(p2, s2)) -> decimal(p3, s3) ``` But 5 implementations: ``` (int64_t, int64_t) -> int64_t (int64_t, int64_t) -> int128_t (int64_t, int128_t) -> int128_t (int128_t, int64_t) -> int128_t (int128_t, int128_t) -> int128_t ``` We need a way to distinguish between these. **expression/SimpleFunctionRegistry.h/cpp** Allow for storing multiple implementations for a single signature. ``` using SignatureMap = std::unordered_map< FunctionSignature, std::vector<std::unique_ptr<const FunctionEntry>>>; using FunctionMap = std::unordered_map<std::string, SignatureMap>; ``` Modify SimpleFunctionRegistry::resolveFunction method to find an implementation with matching signature and matching TypeKinds for arguments and return type. **core/SimpleFunctionMetadata.h** Add 'inputTypes' parameter to 'initialize' method. Functions that operate on decimal types use this parameter to get access to precision and scale of the arguments. Landed separately. ``` void initialize( const std::vector<TypePtr>& inputTypes, const core::QueryConfig& config, ...) ``` **Example: Decimal Plus** Here is how a function that adds 2 decimal numbers can be defined. This function supports adding decimal numbers with possibly different precision and scale. ``` template <typename TExec> struct DecimalPlusFunction { VELOX_DEFINE_FUNCTION_TYPES(TExec); template <typename A, typename B> void initialize( const std::vector<TypePtr>& inputTypes, const core::QueryConfig& /*config*/, A* /*a*/, B* /*b*/) { auto aType = inputTypes[0]; auto bType = inputTypes[1]; auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); auto [rPrecision, rScale] = Addition::computeResultPrecisionScale( aPrecision, aScale, bPrecision, bScale); aRescale_ = Addition::computeRescaleFactor(aScale, bScale, rScale); bRescale_ = Addition::computeRescaleFactor(bScale, aScale, rScale); } template <typename R, typename A, typename B> void call(R& out, const A& a, const B& b) { Addition::template apply<R, A, B>(out, a, b, aRescale_, bRescale_); } private: uint8_t aRescale_; uint8_t bRescale_; }; ``` The registration involves specifying a rule for calculating precision and scale for the result based on precision and scale of the inputs and provides 5 implementations with all possible permutations of short and long decimals in the input and result. ``` std::vector<exec::SignatureVariable> constraints = { exec::SignatureVariable( P3::name(), fmt::format( "min(38, max({a_precision} - {a_scale}, {b_precision} - {b_scale}) + max({a_scale}, {b_scale}) + 1)", fmt::arg("a_precision", P1::name()), fmt::arg("b_precision", P2::name()), fmt::arg("a_scale", S1::name()), fmt::arg("b_scale", S2::name())), exec::ParameterType::kIntegerParameter), exec::SignatureVariable( S3::name(), fmt::format( "max({a_scale}, {b_scale})", fmt::arg("a_scale", S1::name()), fmt::arg("b_scale", S2::name())), exec::ParameterType::kIntegerParameter), }; // (long, long) -> long registerFunction< DecimalAddFunction, LongDecimal<P3, S3>, LongDecimal<P1, S1>, LongDecimal<P2, S2>>({"plus"}, constraints); // (short, short) -> short registerFunction< DecimalAddFunction, ShortDecimal<P3, S3>, ShortDecimal<P1, S1>, ShortDecimal<P2, S2>>({"plus"}, constraints); // (short, short) -> long registerFunction< DecimalAddFunction, LongDecimal<P3, S3>, ShortDecimal<P1, S1>, ShortDecimal<P2, S2>>({"plus"}, constraints); // (short, long) -> long registerFunction< DecimalAddFunction, LongDecimal<P3, S3>, ShortDecimal<P1, S1>, LongDecimal<P2, S2>>({"plus"}, constraints); // (long, short) -> long registerFunction< DecimalAddFunction, LongDecimal<P3, S3>, LongDecimal<P1, S1>, ShortDecimal<P2, S2>>({"plus"}, constraints); ``` Pull Request resolved: facebookincubator#9096 Reviewed By: xiaoxmeng Differential Revision: D54953663
- Loading branch information