diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index e817e0bd927fe1..a6e0463e34b3d9 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -544,6 +544,10 @@ function abstract_call_method(interp::AbstractInterpreter, sigtuple = unwrap_unionall(sig) sigtuple isa DataType || return MethodCallResult(Any, false, false, nothing, Effects()) + if is_noinfer(method) + sig = get_nospecialize_sig(method, sig, sparams) + end + # Limit argument type tuple growth of functions: # look through the parents list to see if there's a call to the same method # and from the same method. @@ -1075,7 +1079,11 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, return nothing end force |= all_overridden - mi = specialize_method(match; preexisting=!force) + if is_noinfer(method) + mi = specialize_method_noinfer(match; preexisting=!force) + else + mi = specialize_method(match; preexisting=!force) + end if mi === nothing add_remark!(interp, sv, "[constprop] Failed to specialize") return nothing diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 3b1cb2c46ce6e1..eed7aa99265308 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -952,7 +952,11 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, end # See if there exists a specialization for this method signature - mi = specialize_method(match; preexisting=true) # Union{Nothing, MethodInstance} + if is_noinfer(method) + mi = specialize_method_noinfer(match; preexisting=true) + else + mi = specialize_method(match; preexisting=true) + end if mi === nothing et = InliningEdgeTracker(state.et, invokesig) effects = info_effects(nothing, match, state) diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index bae8ef5bae2421..0d81008125ddc3 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -107,6 +107,11 @@ function is_inlineable_constant(@nospecialize(x)) return count_const_size(x) <= MAX_INLINE_CONST_SIZE end +is_nospecialized(method::Method) = method.nospecialize ≠ 0 + +is_noinfer(method::Method) = method.noinfer && is_nospecialized(method) +# is_noinfer(method::Method) = is_nospecialized(method) && is_declared_noinline(method) + ########################### # MethodInstance/CodeInfo # ########################### @@ -158,6 +163,19 @@ function get_compileable_sig(method::Method, @nospecialize(atype), sparams::Simp mt, atype, sparams, method) end +function get_nospecialize_sig(method::Method, @nospecialize(atype), sparams::SimpleVector) + if isa(atype, UnionAll) + atype, sparams = normalize_typevars(method, atype, sparams) + end + isa(atype, DataType) || return method.sig + mt = ccall(:jl_method_table_for, Any, (Any,), atype) + mt === nothing && return method.sig + # TODO allow uncompileable signatures to be returned here + sig = ccall(:jl_normalize_to_compilable_sig, Any, (Any, Any, Any, Any), + mt, atype, sparams, method) + return sig === nothing ? method.sig : sig +end + isa_compileable_sig(@nospecialize(atype), sparams::SimpleVector, method::Method) = !iszero(ccall(:jl_isa_compileable_sig, Int32, (Any, Any, Any), atype, sparams, method)) @@ -199,7 +217,8 @@ function normalize_typevars(method::Method, @nospecialize(atype), sparams::Simpl end # get a handle to the unique specialization object representing a particular instantiation of a call -function specialize_method(method::Method, @nospecialize(atype), sparams::SimpleVector; preexisting::Bool=false, compilesig::Bool=false) +function specialize_method(method::Method, @nospecialize(atype), sparams::SimpleVector; + preexisting::Bool=false, compilesig::Bool=false) if isa(atype, UnionAll) atype, sparams = normalize_typevars(method, atype, sparams) end @@ -225,6 +244,11 @@ function specialize_method(match::MethodMatch; kwargs...) return specialize_method(match.method, match.spec_types, match.sparams; kwargs...) end +function specialize_method_noinfer((; method, spec_types, sparams)::MethodMatch; kwargs...) + atype = get_nospecialize_sig(method, spec_types, sparams) + return specialize_method(method, atype, sparams; kwargs...) +end + """ is_declared_inline(method::Method) -> Bool diff --git a/base/essentials.jl b/base/essentials.jl index 829341c4823833..c059bba21588bd 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -85,7 +85,8 @@ f(y) = [x for x in y] !!! note `@nospecialize` affects code generation but not inference: it limits the diversity of the resulting native code, but it does not impose any limitations (beyond the - standard ones) on type-inference. + standard ones) on type-inference. Use [`Base.@noinfer`](@ref) together with + `@nospecialize` to additionally suppress inference. # Example diff --git a/base/expr.jl b/base/expr.jl index 5649303b41ef4a..c61a6605455ab9 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -339,7 +339,6 @@ macro noinline(x) return annotate_meta_def_or_block(x, :noinline) end - """ @constprop setting [ex] @@ -753,6 +752,39 @@ function compute_assumed_setting(@nospecialize(setting), val::Bool=true) end end +""" + Base.@noinfer function f(args...) + @nospecialize ... + ... + end + Base.@noinfer f(@nospecialize args...) = ... + +Tells the compiler to infer `f` using the declared types of `@nospecialize`d arguments. +This can be used to limit the number of compiler-generated specializations during inference. + +# Example + +```julia +julia> f(A::AbstractArray) = g(A) +f (generic function with 1 method) + +julia> @noinline Base.@noinfer g(@nospecialize(A::AbstractArray)) = A[1] +g (generic function with 1 method) + +julia> @code_typed f([1.0]) +CodeInfo( +1 ─ %1 = invoke Main.g(_2::AbstractArray)::Any +└── return %1 +) => Any +``` + +In this example, `f` will be inferred for each specific type of `A`, +but `g` will only be inferred once. +""" +macro noinfer(ex) + esc(isa(ex, Expr) ? pushmeta!(ex, :noinfer) : ex) +end + """ @propagate_inbounds diff --git a/doc/src/base/base.md b/doc/src/base/base.md index 7e45e2176478d3..72f0a635e6b389 100644 --- a/doc/src/base/base.md +++ b/doc/src/base/base.md @@ -285,6 +285,8 @@ Base.@inline Base.@noinline Base.@nospecialize Base.@specialize +Base.@noinfer +Base.@constprop Base.gensym Base.@gensym var"name" diff --git a/src/ast.c b/src/ast.c index 3f3d6176d342e5..91fbae5f80d5df 100644 --- a/src/ast.c +++ b/src/ast.c @@ -83,6 +83,7 @@ JL_DLLEXPORT jl_sym_t *jl_aggressive_constprop_sym; JL_DLLEXPORT jl_sym_t *jl_no_constprop_sym; JL_DLLEXPORT jl_sym_t *jl_purity_sym; JL_DLLEXPORT jl_sym_t *jl_nospecialize_sym; +JL_DLLEXPORT jl_sym_t *jl_noinfer_sym; JL_DLLEXPORT jl_sym_t *jl_macrocall_sym; JL_DLLEXPORT jl_sym_t *jl_colon_sym; JL_DLLEXPORT jl_sym_t *jl_hygienicscope_sym; @@ -342,6 +343,7 @@ void jl_init_common_symbols(void) jl_isdefined_sym = jl_symbol("isdefined"); jl_nospecialize_sym = jl_symbol("nospecialize"); jl_specialize_sym = jl_symbol("specialize"); + jl_noinfer_sym = jl_symbol("noinfer"); jl_optlevel_sym = jl_symbol("optlevel"); jl_compile_sym = jl_symbol("compile"); jl_force_compile_sym = jl_symbol("force_compile"); diff --git a/src/ircode.c b/src/ircode.c index f967dd1a29f510..6aa5199faf46a5 100644 --- a/src/ircode.c +++ b/src/ircode.c @@ -434,13 +434,14 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal) } } -static jl_code_info_flags_t code_info_flags(uint8_t inferred, uint8_t propagate_inbounds, - uint8_t has_fcall, uint8_t inlining, uint8_t constprop) +static jl_code_info_flags_t code_info_flags(uint8_t inferred, uint8_t propagate_inbounds, uint8_t has_fcall, + uint8_t noinfer, uint8_t inlining, uint8_t constprop) { jl_code_info_flags_t flags; flags.bits.inferred = inferred; flags.bits.propagate_inbounds = propagate_inbounds; flags.bits.has_fcall = has_fcall; + flags.bits.noinfer = noinfer; flags.bits.inlining = inlining; flags.bits.constprop = constprop; return flags; @@ -780,8 +781,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) 1 }; - jl_code_info_flags_t flags = code_info_flags(code->inferred, code->propagate_inbounds, - code->has_fcall, code->inlining, code->constprop); + jl_code_info_flags_t flags = code_info_flags(code->inferred, code->propagate_inbounds, code->has_fcall, + code->noinfer, code->inlining, code->constprop); write_uint8(s.s, flags.packed); write_uint8(s.s, code->purity.bits); write_uint16(s.s, code->inlining_cost); @@ -880,6 +881,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t code->inferred = flags.bits.inferred; code->propagate_inbounds = flags.bits.propagate_inbounds; code->has_fcall = flags.bits.has_fcall; + code->noinfer = flags.bits.noinfer; code->purity.bits = read_uint8(s.s); code->inlining_cost = read_uint16(s.s); diff --git a/src/jltypes.c b/src/jltypes.c index 2aa8385e744a34..4dbf13137fc5a5 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2681,7 +2681,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_code_info_type = jl_new_datatype(jl_symbol("CodeInfo"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(21, + jl_perm_symsvec(22, "code", "codelocs", "ssavaluetypes", @@ -2699,11 +2699,12 @@ void jl_init_types(void) JL_GC_DISABLED "inferred", "propagate_inbounds", "has_fcall", + "noinfer", "inlining", "constprop", "purity", "inlining_cost"), - jl_svec(21, + jl_svec(22, jl_array_any_type, jl_array_int32_type, jl_any_type, @@ -2721,17 +2722,18 @@ void jl_init_types(void) JL_GC_DISABLED jl_bool_type, jl_bool_type, jl_bool_type, + jl_bool_type, jl_uint8_type, jl_uint8_type, jl_uint8_type, jl_uint16_type), jl_emptysvec, - 0, 1, 20); + 0, 1, 22); jl_method_type = jl_new_datatype(jl_symbol("Method"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(28, + jl_perm_symsvec(29, "name", "module", "file", @@ -2758,9 +2760,10 @@ void jl_init_types(void) JL_GC_DISABLED "nkw", "isva", "is_for_opaque_closure", + "noinfer", "constprop", "purity"), - jl_svec(28, + jl_svec(29, jl_symbol_type, jl_module_type, jl_symbol_type, @@ -2787,6 +2790,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_int32_type, jl_bool_type, jl_bool_type, + jl_bool_type, jl_uint8_type, jl_uint8_type), jl_emptysvec, diff --git a/src/julia.h b/src/julia.h index 45363c8092bb13..ab0fdd85026123 100644 --- a/src/julia.h +++ b/src/julia.h @@ -286,6 +286,7 @@ typedef struct _jl_code_info_t { uint8_t inferred; uint8_t propagate_inbounds; uint8_t has_fcall; + uint8_t noinfer; // uint8 settings uint8_t inlining; // 0 = default; 1 = @inline; 2 = @noinline uint8_t constprop; // 0 = use heuristic; 1 = aggressive; 2 = none @@ -343,6 +344,7 @@ typedef struct _jl_method_t { // various boolean properties uint8_t isva; uint8_t is_for_opaque_closure; + uint8_t noinfer; // uint8 settings uint8_t constprop; // 0x00 = use heuristic; 0x01 = aggressive; 0x02 = none diff --git a/src/julia_internal.h b/src/julia_internal.h index 4f1a0b4513d8d7..6c0cdb3248777b 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -599,6 +599,7 @@ typedef struct { uint8_t inferred:1; uint8_t propagate_inbounds:1; uint8_t has_fcall:1; + uint8_t noinfer:1; uint8_t inlining:2; // 0 = use heuristic; 1 = aggressive; 2 = none uint8_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none } jl_code_info_flags_bitfield_t; @@ -1566,6 +1567,7 @@ extern JL_DLLEXPORT jl_sym_t *jl_aggressive_constprop_sym; extern JL_DLLEXPORT jl_sym_t *jl_no_constprop_sym; extern JL_DLLEXPORT jl_sym_t *jl_purity_sym; extern JL_DLLEXPORT jl_sym_t *jl_nospecialize_sym; +extern JL_DLLEXPORT jl_sym_t *jl_noinfer_sym; extern JL_DLLEXPORT jl_sym_t *jl_macrocall_sym; extern JL_DLLEXPORT jl_sym_t *jl_colon_sym; extern JL_DLLEXPORT jl_sym_t *jl_hygienicscope_sym; diff --git a/src/method.c b/src/method.c index 8b4c87a46ecd97..7cc50d29cfb248 100644 --- a/src/method.c +++ b/src/method.c @@ -320,6 +320,8 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir) li->inlining = 2; else if (ma == (jl_value_t*)jl_propagate_inbounds_sym) li->propagate_inbounds = 1; + else if (ma == (jl_value_t*)jl_noinfer_sym) + li->noinfer = 1; else if (ma == (jl_value_t*)jl_aggressive_constprop_sym) li->constprop = 1; else if (ma == (jl_value_t*)jl_no_constprop_sym) @@ -476,6 +478,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void) src->inferred = 0; src->propagate_inbounds = 0; src->has_fcall = 0; + src->noinfer = 0; src->edges = jl_nothing; src->constprop = 0; src->inlining = 0; @@ -679,6 +682,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) } } m->called = called; + m->noinfer = src->noinfer; m->constprop = src->constprop; m->purity.bits = src->purity.bits; jl_add_function_name_to_lineinfo(src, (jl_value_t*)m->name); @@ -808,6 +812,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module) m->primary_world = 1; m->deleted_world = ~(size_t)0; m->is_for_opaque_closure = 0; + m->noinfer = 0; m->constprop = 0; JL_MUTEX_INIT(&m->writelock); return m; diff --git a/stdlib/Serialization/src/Serialization.jl b/stdlib/Serialization/src/Serialization.jl index 630185ebd575a0..411f6c0ce5f05e 100644 --- a/stdlib/Serialization/src/Serialization.jl +++ b/stdlib/Serialization/src/Serialization.jl @@ -80,7 +80,7 @@ const TAGS = Any[ const NTAGS = length(TAGS) @assert NTAGS == 255 -const ser_version = 23 # do not make changes without bumping the version #! +const ser_version = 24 # do not make changes without bumping the version #! format_version(::AbstractSerializer) = ser_version format_version(s::Serializer) = s.version @@ -418,6 +418,7 @@ function serialize(s::AbstractSerializer, meth::Method) serialize(s, meth.nargs) serialize(s, meth.isva) serialize(s, meth.is_for_opaque_closure) + serialize(s, meth.noinfer) serialize(s, meth.constprop) serialize(s, meth.purity) if isdefined(meth, :source) @@ -1026,10 +1027,14 @@ function deserialize(s::AbstractSerializer, ::Type{Method}) nargs = deserialize(s)::Int32 isva = deserialize(s)::Bool is_for_opaque_closure = false + noinfer = false constprop = purity = 0x00 template_or_is_opaque = deserialize(s) if isa(template_or_is_opaque, Bool) is_for_opaque_closure = template_or_is_opaque + if format_version(s) >= 24 + noinfer = deserialize(s)::Bool + end if format_version(s) >= 14 constprop = deserialize(s)::UInt8 end @@ -1054,6 +1059,7 @@ function deserialize(s::AbstractSerializer, ::Type{Method}) meth.nargs = nargs meth.isva = isva meth.is_for_opaque_closure = is_for_opaque_closure + meth.noinfer = noinfer meth.constprop = constprop meth.purity = purity if template !== nothing @@ -1195,6 +1201,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo}) if format_version(s) >= 20 ci.has_fcall = deserialize(s) end + if format_version(s) >= 24 + ci.noinfer = deserialize(s)::Bool + end if format_version(s) >= 21 ci.inlining = deserialize(s)::UInt8 end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index d97f5d3a8a0951..bbc60db8a4b7b9 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1164,25 +1164,18 @@ let typeargs = Tuple{Type{Int},Type{Int},Type{Int},Type{Int},Type{Int},Type{Int} @test only(Base.return_types(promote_type, typeargs)) === Type{Int} end -function count_specializations(method::Method) - specs = method.specializations - specs isa Core.MethodInstance && return 1 - n = count(!isnothing, specs::Core.SimpleVector) - return n -end - # demonstrate that inference can complete without waiting for MAX_TYPE_DEPTH copy_dims_out(out) = () copy_dims_out(out, dim::Int, tail...) = copy_dims_out((out..., dim), tail...) copy_dims_out(out, dim::Colon, tail...) = copy_dims_out((out..., dim), tail...) @test Base.return_types(copy_dims_out, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}] -@test all(m -> 4 < count_specializations(m) < 15, methods(copy_dims_out)) # currently about 5 +@test all(m -> 4 < length(Base.specializations(m)) < 15, methods(copy_dims_out)) # currently about 5 copy_dims_pair(out) = () copy_dims_pair(out, dim::Int, tail...) = copy_dims_pair(out => dim, tail...) copy_dims_pair(out, dim::Colon, tail...) = copy_dims_pair(out => dim, tail...) @test Base.return_types(copy_dims_pair, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}] -@test all(m -> 3 < count_specializations(m) < 15, methods(copy_dims_pair)) # currently about 5 +@test all(m -> 3 < length(Base.specializations(m)) < 15, methods(copy_dims_pair)) # currently about 5 # splatting an ::Any should still allow inference to use types of parameters preceding it f22364(::Int, ::Any...) = 0 @@ -4157,6 +4150,83 @@ Base.getproperty(x::Interface41024Extended, sym::Symbol) = x.x end |> only === Int +count_specialization(@nospecialize f) = length(Base.specializations(only(methods(f)))) + +function invokef(f, itr) + local r = 0 + r += f(itr[1]) + r += f(itr[2]) + r += f(itr[3]) + r += f(itr[4]) + r += f(itr[5]) + r +end +global inline_checker = c -> c # untyped global prevents inlining +# if `f` is inlined, `GlobalRef(m, :inline_checker)` should appear within the body of `invokef` +function is_inline_checker(@nospecialize stmt) + isa(stmt, GlobalRef) && stmt.name === :inline_checker +end + +function nospecialize(@nospecialize a) + c = isa(a, Function) + inline_checker(c) +end + +@inline function inline_nospecialize(@nospecialize a) + c = isa(a, Function) + inline_checker(c) +end + +Base.@noinfer function noinfer(@nospecialize a) + c = isa(a, Function) + inline_checker(c) +end + +# `@noinfer` should still allow inlinining +Base.@noinfer @inline function inline_noinfer(@nospecialize a) + c = isa(a, Function) + inline_checker(c) +end + +@testset "compilation annotations" begin + dispatchonly = Any[sin, muladd, "foo", nothing, Dict] # untyped container can cause excessive runtime dispatch + withinfernce = tuple(sin, muladd, "foo", nothing, Dict) # typed container can cause excessive inference + + @testset "@nospecialize" begin + # `@nospecialize` should suppress runtime dispatches of `nospecialize` + invokef(nospecialize, dispatchonly) + @test count_specialization(nospecialize) == 1 + # `@nospecialize` should allow inference to happen + invokef(nospecialize, withinfernce) + @test count_specialization(nospecialize) == 6 + @test !any(is_inline_checker, @get_code invokef(nospecialize, dispatchonly)) + + # `@nospecialize` should allow inlinining + invokef(inline_nospecialize, dispatchonly) + @test count_specialization(inline_nospecialize) == 1 + invokef(inline_nospecialize, withinfernce) + @test count_specialization(inline_nospecialize) == 6 + @test any(is_inline_checker, @get_code invokef(inline_nospecialize, dispatchonly)) + end + + @testset "@noinfer" begin + # `@nospecialize` should suppress runtime dispatches of `nospecialize` + invokef(noinfer, dispatchonly) + @test count_specialization(noinfer) == 1 + # `@noinfer` suppresses inference also + invokef(noinfer, withinfernce) + @test count_specialization(noinfer) == 1 + @test !any(is_inline_checker, @get_code invokef(noinfer, dispatchonly)) + + # `@noinfer` should still allow inlinining + invokef(inline_noinfer, dispatchonly) + @test count_specialization(inline_noinfer) == 1 + invokef(inline_noinfer, withinfernce) + @test count_specialization(inline_noinfer) == 1 + @test any(is_inline_checker, @get_code invokef(inline_noinfer, dispatchonly)) + end +end + @testset "fieldtype for unions" begin # e.g. issue #40177 f40177(::Type{T}) where {T} = fieldtype(T, 1) for T in [ diff --git a/test/compiler/irutils.jl b/test/compiler/irutils.jl index 95ac0d555ef889..00de9b2472de4b 100644 --- a/test/compiler/irutils.jl +++ b/test/compiler/irutils.jl @@ -1,10 +1,17 @@ -import Core: CodeInfo, ReturnNode, MethodInstance -import Core.Compiler: IRCode, IncrementalCompact, VarState, argextype, singleton_type -import Base.Meta: isexpr +using Core: CodeInfo, ReturnNode, MethodInstance +using Core.Compiler: IRCode, IncrementalCompact, singleton_type, VarState +using Base.Meta: isexpr +using InteractiveUtils: gen_call_with_extracted_types_and_kwargs -argextype(@nospecialize args...) = argextype(args..., VarState[]) +argextype(@nospecialize args...) = Core.Compiler.argextype(args..., VarState[]) code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::CodeInfo +macro code_typed1(ex0...) + return gen_call_with_extracted_types_and_kwargs(__module__, :code_typed1, ex0) +end get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code +macro get_code(ex0...) + return gen_call_with_extracted_types_and_kwargs(__module__, :get_code, ex0) +end # check if `x` is a statement with a given `head` isnew(@nospecialize x) = isexpr(x, :new) @@ -45,3 +52,6 @@ function fully_eliminated(@nospecialize args...; retval=(@__FILE__), kwargs...) return length(code) == 1 && isreturn(code[1]) end end +macro fully_eliminated(ex0...) + return gen_call_with_extracted_types_and_kwargs(__module__, :fully_eliminated, ex0) +end