diff --git a/velox/core/SimpleFunctionMetadata.h b/velox/core/SimpleFunctionMetadata.h index a2ad56dd45f0..7da117b21905 100644 --- a/velox/core/SimpleFunctionMetadata.h +++ b/velox/core/SimpleFunctionMetadata.h @@ -242,6 +242,36 @@ struct TypeAnalysis> { } }; +template +struct TypeAnalysis> { + void run(TypeAnalysisResults& results) { + results.stats.concreteCount++; + + const auto p = P::name(); + const auto s = S::name(); + results.out << fmt::format("decimal({},{})", p, s); + results.addVariable(exec::SignatureVariable( + p, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable(exec::SignatureVariable( + s, std::nullopt, exec::ParameterType::kIntegerParameter)); + } +}; + +template +struct TypeAnalysis> { + void run(TypeAnalysisResults& results) { + results.stats.concreteCount++; + + const auto p = P::name(); + const auto s = S::name(); + results.out << fmt::format("decimal({},{})", p, s); + results.addVariable(exec::SignatureVariable( + p, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable(exec::SignatureVariable( + s, std::nullopt, exec::ParameterType::kIntegerParameter)); + } +}; + template struct TypeAnalysis> { void run(TypeAnalysisResults& results) { @@ -329,6 +359,8 @@ class ISimpleFunctionMetadata { virtual bool isDeterministic() const = 0; virtual uint32_t priority() const = 0; virtual const std::shared_ptr signature() const = 0; + virtual TypeKind resultTypeKind() const = 0; + virtual const std::vector& argTypeKinds() const = 0; virtual std::string helpMessage(const std::string& name) const = 0; virtual ~ISimpleFunctionMetadata() = default; }; @@ -407,10 +439,14 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { } } - explicit SimpleFunctionMetadata() { - auto analysis = analyzeSignatureTypes(); + explicit SimpleFunctionMetadata( + const std::vector& constraints) { + auto analysis = analyzeSignatureTypes(constraints); + buildSignature(analysis); priority_ = analysis.stats.computePriority(); + resultTypeKind_ = analysis.resultTypeKind; + argTypeKinds_ = analysis.argTypeKinds; } ~SimpleFunctionMetadata() override = default; @@ -419,6 +455,14 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { return signature_; } + TypeKind resultTypeKind() const override { + return resultTypeKind_; + } + + const std::vector& argTypeKinds() const override { + return argTypeKinds_; + } + std::string helpMessage(const std::string& name) const final { // return fmt::format("{}({})", name, signature_->toString()); std::string s{name}; @@ -446,15 +490,21 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { std::string outputType; std::map variables; TypeAnalysisResults::Stats stats; + TypeKind resultTypeKind; + std::vector argTypeKinds; }; - SignatureTypesAnalysisResults analyzeSignatureTypes() { + SignatureTypesAnalysisResults analyzeSignatureTypes( + const std::vector& constraints) { std::vector argsTypes; TypeAnalysisResults results; TypeAnalysis().run(results); std::string outputType = results.typeAsString(); + auto resultTypeKind = SimpleTypeTrait::typeKind; + std::vector argTypeKinds; + ( [&]() { // Clear string representation but keep other collected information @@ -462,14 +512,30 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { results.resetTypeString(); TypeAnalysis().run(results); argsTypes.push_back(results.typeAsString()); + + if constexpr (!isVariadicType::value) { + argTypeKinds.push_back(SimpleTypeTrait::typeKind); + } }(), ...); + for (const auto& constraint : constraints) { + VELOX_CHECK( + !constraint.constraint().empty(), + "Constraint must be set for variable {}", + constraint.name()); + + results.variablesInformation.erase(constraint.name()); + results.variablesInformation.emplace(constraint.name(), constraint); + } + return SignatureTypesAnalysisResults{ std::move(argsTypes), std::move(outputType), std::move(results.variablesInformation), - std::move(results.stats)}; + std::move(results.stats), + resultTypeKind, + argTypeKinds}; } void buildSignature(const SignatureTypesAnalysisResults& analysis) { @@ -497,6 +563,8 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { exec::FunctionSignaturePtr signature_; uint32_t priority_; + TypeKind resultTypeKind_; + std::vector argTypeKinds_; }; // wraps a UDF object to provide the inheritance @@ -544,6 +612,7 @@ class UDFHolder final DECLARE_METHOD_RESOLVER(callNullFree_method_resolver, callNullFree); DECLARE_METHOD_RESOLVER(callAscii_method_resolver, callAscii); DECLARE_METHOD_RESOLVER(initialize_method_resolver, initialize); + DECLARE_METHOD_RESOLVER(initializeTypes_method_resolver, initializeTypes); // Check which flavor of the call() method is provided by the UDF object. UDFs // are required to provide at least one of the following methods: @@ -650,6 +719,13 @@ class UDFHolder final const core::QueryConfig&, const exec_arg_type*...>::value; + // initializeTypes(): + static constexpr bool udf_has_initializeTypes = util::has_method< + Fun, + initializeTypes_method_resolver, + void, + const std::vector&>::value; + static_assert( udf_has_call || udf_has_callNullable || udf_has_callNullFree, "UDF must implement at least one of `call`, `callNullable`, or `callNullFree` functions.\n" @@ -703,7 +779,8 @@ class UDFHolder final template using exec_type_at = typename std::tuple_element::type; - explicit UDFHolder() : Metadata(), instance_{} {} + explicit UDFHolder(const std::vector& constraints) + : Metadata(constraints), instance_{} {} FOLLY_ALWAYS_INLINE void initialize( const core::QueryConfig& config, @@ -713,6 +790,13 @@ class UDFHolder final } } + FOLLY_ALWAYS_INLINE void initializeTypes( + const std::vector& argTypes) { + if constexpr (udf_has_initializeTypes) { + return instance_.initializeTypes(argTypes); + } + } + FOLLY_ALWAYS_INLINE bool call( exec_return_type& out, const typename exec_resolver::in_type&... args) { diff --git a/velox/expression/ExprCompiler.cpp b/velox/expression/ExprCompiler.cpp index 93f6080ee027..df05e7fd5cd3 100644 --- a/velox/expression/ExprCompiler.cpp +++ b/velox/expression/ExprCompiler.cpp @@ -435,7 +435,7 @@ ExprPtr compileRewrittenExpression( resultType, folly::join(", ", inputTypes)); auto func_2 = simpleFunctionEntry->createFunction()->createVectorFunction( - getConstantInputs(compiledInputs), config); + inputTypes, getConstantInputs(compiledInputs), config); result = std::make_shared( resultType, std::move(compiledInputs), diff --git a/velox/expression/SimpleFunctionAdapter.h b/velox/expression/SimpleFunctionAdapter.h index 784576c28d5c..af06c780c7fb 100644 --- a/velox/expression/SimpleFunctionAdapter.h +++ b/velox/expression/SimpleFunctionAdapter.h @@ -228,9 +228,10 @@ class SimpleFunctionAdapter : public VectorFunction { public: explicit SimpleFunctionAdapter( + const std::vector& inputTypes, const core::QueryConfig& config, const std::vector& constantInputs) - : fn_{std::make_unique()} { + : fn_{std::make_unique(std::vector{})} { if constexpr (FUNC::udf_has_initialize) { try { unpackInitialize<0>(config, constantInputs); @@ -240,6 +241,10 @@ class SimpleFunctionAdapter : public VectorFunction { initializeException_ = std::current_exception(); } } + + if constexpr (FUNC::udf_has_initializeTypes) { + (*fn_).initializeTypes(inputTypes); + } } explicit SimpleFunctionAdapter() {} @@ -901,10 +906,11 @@ class SimpleFunctionAdapterFactoryImpl : public SimpleFunctionAdapterFactory { explicit SimpleFunctionAdapterFactoryImpl() {} std::unique_ptr createVectorFunction( + const std::vector& inputTypes, const std::vector& constantInputs, const core::QueryConfig& config) const override { return std::make_unique>( - config, constantInputs); + inputTypes, config, constantInputs); } }; diff --git a/velox/expression/SimpleFunctionRegistry.cpp b/velox/expression/SimpleFunctionRegistry.cpp index 36c6849f48b4..f5d9cab8cd73 100644 --- a/velox/expression/SimpleFunctionRegistry.cpp +++ b/velox/expression/SimpleFunctionRegistry.cpp @@ -40,8 +40,8 @@ void SimpleFunctionRegistry::registerFunctionInternal( const auto sanitizedName = sanitizeName(name); registeredFunctions_.withWLock([&](auto& map) { SignatureMap& signatureMap = map[sanitizedName]; - signatureMap[*metadata->signature()] = - std::make_unique(metadata, factory); + signatureMap[*metadata->signature()].emplace_back( + std::make_unique(metadata, factory)); }); } @@ -83,12 +83,34 @@ SimpleFunctionRegistry::resolveFunction( for (const auto& [candidateSignature, functionEntry] : *signatureMap) { SignatureBinder binder(candidateSignature, argTypes); if (binder.tryBind()) { - auto* currentCandidate = functionEntry.get(); - if (!selectedCandidate || - currentCandidate->getMetadata().priority() < - selectedCandidate->getMetadata().priority()) { - selectedCandidate = currentCandidate; - selectedCandidateType = binder.tryResolveReturnType(); + for (const auto& currentCandidate : functionEntry) { + const auto& m = currentCandidate->getMetadata(); + + // Check that TypeKinds of arguments match. + bool match = true; + for (auto i = 0; i < m.argTypeKinds().size(); ++i) { + if (argTypes[i]->kind() != m.argTypeKinds()[i]) { + match = false; + } + } + + if (!match) { + continue; + } + + if (!selectedCandidate || + currentCandidate->getMetadata().priority() < + selectedCandidate->getMetadata().priority()) { + auto resultType = binder.tryResolveReturnType(); + VELOX_CHECK_NOT_NULL(resultType); + + if (resultType->kind() != m.resultTypeKind()) { + continue; + } + + selectedCandidate = currentCandidate.get(); + selectedCandidateType = resultType; + } } } } diff --git a/velox/expression/SimpleFunctionRegistry.h b/velox/expression/SimpleFunctionRegistry.h index 2d1e1a431c83..ad456aa1a667 100644 --- a/velox/expression/SimpleFunctionRegistry.h +++ b/velox/expression/SimpleFunctionRegistry.h @@ -24,8 +24,9 @@ namespace facebook::velox::exec { template -const std::shared_ptr& singletonUdfMetadata() { - static auto instance = std::make_shared(); +const std::shared_ptr& singletonUdfMetadata( + const std::vector& constraints) { + static auto instance = std::make_shared(constraints); return instance; } @@ -52,15 +53,19 @@ struct FunctionEntry { const FunctionFactory factory_; }; -using SignatureMap = - std::unordered_map>; +using SignatureMap = std::unordered_map< + FunctionSignature, + std::vector>>; using FunctionMap = std::unordered_map; class SimpleFunctionRegistry { public: template - void registerFunction(const std::vector& aliases = {}) { - const auto& metadata = singletonUdfMetadata(); + void registerFunction( + const std::vector& aliases, + const std::vector& constraints) { + const auto& metadata = + singletonUdfMetadata(constraints); const auto factory = [metadata]() { return CreateUdf(); }; if (aliases.empty()) { @@ -139,9 +144,12 @@ SimpleFunctionRegistry& mutableSimpleFunctions(); // This function should be called once and alone. template -void registerSimpleFunction(const std::vector& names) { +void registerSimpleFunction( + const std::vector& names, + const std::vector& constraints) { mutableSimpleFunctions() - .registerFunction>(names); + .registerFunction>( + names, constraints); } } // namespace facebook::velox::exec diff --git a/velox/expression/UdfTypeResolver.h b/velox/expression/UdfTypeResolver.h index 8d188ef97c64..827be7ab4a37 100644 --- a/velox/expression/UdfTypeResolver.h +++ b/velox/expression/UdfTypeResolver.h @@ -101,6 +101,20 @@ struct resolver> { using out_type = ArrayWriter; }; +template +struct resolver> { + using in_type = int64_t; + using null_free_in_type = in_type; + using out_type = int64_t; +}; + +template +struct resolver> { + using in_type = int128_t; + using null_free_in_type = in_type; + using out_type = int128_t; +}; + template <> struct resolver { using in_type = StringView; diff --git a/velox/expression/VectorFunction.h b/velox/expression/VectorFunction.h index 6442cfd7b502..323606acf44f 100644 --- a/velox/expression/VectorFunction.h +++ b/velox/expression/VectorFunction.h @@ -182,6 +182,7 @@ class ApplyNeverCalled final : public VectorFunction { class SimpleFunctionAdapterFactory { public: virtual std::unique_ptr createVectorFunction( + const std::vector& inputTypes, const std::vector& constantInputs, const core::QueryConfig& config) const = 0; virtual ~SimpleFunctionAdapterFactory() = default; diff --git a/velox/expression/tests/SimpleFunctionTest.cpp b/velox/expression/tests/SimpleFunctionTest.cpp index a77c3de60eca..bfa34590b723 100644 --- a/velox/expression/tests/SimpleFunctionTest.cpp +++ b/velox/expression/tests/SimpleFunctionTest.cpp @@ -992,7 +992,7 @@ VectorPtr testVariadicArgReuse( exec::simpleFunctions() .resolveFunction(functionName, {}) ->createFunction() - ->createVectorFunction({}, execCtx->queryCtx()->queryConfig()); + ->createVectorFunction({}, {}, execCtx->queryCtx()->queryConfig()); // Create a dummy EvalCtx. SelectivityVector rows(inputs[0]->size()); diff --git a/velox/functions/Registerer.h b/velox/functions/Registerer.h index 7b893c6c3928..52e4fbf61b29 100644 --- a/velox/functions/Registerer.h +++ b/velox/functions/Registerer.h @@ -43,7 +43,7 @@ void registerFunction(const std::vector& aliases = {}) { TReturn, ConstantChecker, typename UnwrapConstantType::type...>; - exec::registerSimpleFunction(aliases); + exec::registerSimpleFunction(aliases, {}); } // New registration function; mostly a copy from the function above, but taking @@ -51,7 +51,9 @@ void registerFunction(const std::vector& aliases = {}) { // a while to maintain backwards compatibility, but the idea is to remove the // one above eventually. template