From 7ef53af3830d78568dd018aec90a5b0ab6566368 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Wed, 26 Jul 2023 23:31:16 -0400 Subject: [PATCH] inference: permit non-direct recursion reducers Fix #45759 Fix #46557 Fix #31485 --- base/compiler/abstractinterpretation.jl | 20 ++++++++++++------- test/compiler/inference.jl | 26 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 6e552423cc25e..e8161a75e71e0 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -543,25 +543,30 @@ function abstract_call_method(interp::AbstractInterpreter, if topmost !== nothing msig = unwrap_unionall(method.sig)::DataType spec_len = length(msig.parameters) + 1 - ls = length(sigtuple.parameters) mi = frame_instance(sv) + if isdefined(method, :recursion_relation) + # We don't require the recursion_relation to be transitive, so + # apply a hard limit + hardlimit = true + end + if method === mi.def # Under direct self-recursion, permit much greater use of reducers. # here we assume that complexity(specTypes) :>= complexity(sig) comparison = mi.specTypes l_comparison = length((unwrap_unionall(comparison)::DataType).parameters) spec_len = max(spec_len, l_comparison) + elseif !hardlimit && isa(topmost, InferenceState) + # Without a hardlimit, permit use of reducers too. + comparison = frame_instance(topmost).specTypes + # n.b. currently don't allow vararg reducers + #l_comparison = length((unwrap_unionall(comparison)::DataType).parameters) + #spec_len = max(spec_len, l_comparison) else comparison = method.sig end - if isdefined(method, :recursion_relation) - # We don't require the recursion_relation to be transitive, so - # apply a hard limit - hardlimit = true - end - # see if the type is actually too big (relative to the caller), and limit it if required newsig = limit_type_size(sig, comparison, hardlimit ? comparison : mi.specTypes, InferenceParams(interp).tuple_complexity_limit_depth, spec_len) @@ -588,6 +593,7 @@ function abstract_call_method(interp::AbstractInterpreter, poison_callstack!(sv, parentframe === nothing ? topmost : parentframe) end end + # n.b. this heuristic depends on the non-local state, so we must record the limit later sig = newsig sparams = svec() edgelimited = true diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 4f5cdea59da44..46875e8d52f3e 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -5099,3 +5099,29 @@ end refine_partial_struct2(42, s) end |> only === String # JET.test_call(s::AbstractString->Base._string(s, 'c')) + +# issue #45759 #46557 +g45759(x::Tuple{Any,Vararg}) = x[1] + _g45759(x[2:end]) +g45759(x::Tuple{}) = 0 +_g45759(x) = g45759(x) +@test only(Base.return_types(g45759, Tuple{Tuple{Int,Int,Int,Int,Int,Int,Int}})) == Int + +h45759(x::Tuple{Any,Vararg}; kwargs...) = x[1] + h45759(x[2:end]; kwargs...) +h45759(x::Tuple{}; kwargs...) = 0 +@test only(Base.return_types(h45759, Tuple{Tuple{Int,Int,Int,Int,Int,Int,Int}})) == Int + +@test only(Base.return_types((typeof([[[1]]]))) do x + sum(x) do v; sum(length, v); end +end) == Int + +struct FunctionSum{Tf} + functions::Tf +end +(F::FunctionSum)(x) = sum(f -> f(x), F.functions) +F = FunctionSum((x -> sqrt(x), FunctionSum((x -> x^2, x -> x^3)))) +@test @inferred(F(1.)) === 3.0 + +f31485(arr::AbstractArray{T, 0}) where {T} = arr +indirect31485(arr) = f31485(arr) +f31485(arr::AbstractArray{T, N}) where {T, N} = indirect31485(view(arr, 1, ntuple(i -> :, Val(N-1))...)) +@test @inferred(f31485(zeros(3,3,3,3,3),)) == fill(0.0)