From b672e6c79e73a6519a0f4fee34bf29a955283981 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Wed, 1 Sep 2021 13:09:18 -0400 Subject: [PATCH 1/4] inference: propagate variable changes to all exception frames Fix #42022 --- base/compiler/abstractinterpretation.jl | 42 +++++----- base/compiler/inferencestate.jl | 103 +++++++++++++++++++++--- test/compiler/inference.jl | 45 +++++++++++ 3 files changed, 155 insertions(+), 35 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index d37ad96adfefb..c2a8b19b6af3a 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1764,18 +1764,16 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) ssavaluetypes = frame.src.ssavaluetypes::Vector{Any} while frame.pc´´ <= n # make progress on the active ip set - local pc::Int = frame.pc´´ # current program-counter + local pc::Int = frame.pc´´ while true # inner loop optimizes the common case where it can run straight from pc to pc + 1 #print(pc,": ",s[pc],"\n") local pc´::Int = pc + 1 # next program-counter (after executing instruction) if pc == frame.pc´´ - # need to update pc´´ to point at the new lowest instruction in W - min_pc = _bits_findnext(W.bits, pc + 1) - frame.pc´´ = min_pc == -1 ? n + 1 : min_pc + # want to update pc´´ to point at the new lowest instruction in W + frame.pc´´ = pc´ end delete!(W, pc) frame.currpc = pc - frame.cur_hand = frame.handler_at[pc] edges = frame.stmt_edges[pc] edges === nothing || empty!(edges) frame.stmt_info[pc] = nothing @@ -1817,7 +1815,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) pc´ = l else # general case - frame.handler_at[l] = frame.cur_hand changes_else = changes if isa(condt, Conditional) changes_else = conditional_changes(changes_else, condt.elsetype, condt.var) @@ -1877,7 +1874,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) elseif hd === :enter stmt = stmt::Expr l = stmt.args[1]::Int - frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand) # propagate type info to exception handler old = states[l] newstate_catch = stupdate!(old, changes) @@ -1889,12 +1885,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) states[l] = newstate_catch end typeassert(states[l], VarTable) - frame.handler_at[l] = frame.cur_hand elseif hd === :leave - stmt = stmt::Expr - for i = 1:((stmt.args[1])::Int) - frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second - end else if hd === :(=) stmt = stmt::Expr @@ -1928,16 +1919,21 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) ssavaluetypes[pc] = t end end - if frame.cur_hand !== nothing && isa(changes, StateUpdate) - # propagate new type info to exception handler - # the handling for Expr(:enter) propagates all changes from before the try/catch - # so this only needs to propagate any changes - l = frame.cur_hand.first::Int - if stupdate1!(states[l]::VarTable, changes::StateUpdate) !== false - if l < frame.pc´´ - frame.pc´´ = l + cur_hand = frame.handler_at[pc] + if isa(changes, StateUpdate) + while cur_hand != 0 + let l = frame.handler_at[cur_hand + 1] + # propagate new type info to exception handler + # the handling for Expr(:enter) propagates all changes from before the try/catch + # so this only needs to propagate any changes + if stupdate1!(states[l]::VarTable, changes::StateUpdate) !== false + if l < frame.pc´´ + frame.pc´´ = l + end + push!(W, l) + end end - push!(W, l) + cur_hand = frame.handler_at[cur_hand] end end end @@ -1950,7 +1946,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end pc´ > n && break # can't proceed with the fast-path fall-through - frame.handler_at[pc´] = frame.cur_hand newstate = stupdate!(states[pc´], changes) if isa(stmt, GotoNode) && frame.pc´´ < pc´ # if we are processing a goto node anyways, @@ -1961,7 +1956,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) states[pc´] = newstate end push!(W, pc´) - pc = frame.pc´´ + break elseif newstate !== nothing states[pc´] = newstate pc = pc´ @@ -1971,6 +1966,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) break end end + frame.pc´´ = _bits_findnext(W.bits, frame.pc´´)::Int # next program-counter end frame.dont_work_on_me = false nothing diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 216c397af31e4..c3b78d984f2eb 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -28,9 +28,7 @@ mutable struct InferenceState pc´´::LineNum nstmts::Int # current exception handler info - cur_hand #::Union{Nothing, Pair{LineNum, prev_handler}} - handler_at::Vector{Any} - n_handlers::Int + handler_at::Vector{LineNum} # ssavalue sparsity and restart info ssavalue_uses::Vector{BitSet} throw_blocks::BitSet @@ -86,25 +84,21 @@ mutable struct InferenceState throw_blocks = find_throw_blocks(code) # exception handlers - cur_hand = nothing - handler_at = Any[ nothing for i=1:n ] - n_handlers = 0 - - W = BitSet() - push!(W, 1) #initial pc to visit + ip = BitSet() + handler_at = compute_trycatch(src.code, ip) + push!(ip, 1) mod = isa(def, Method) ? def.module : def - valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world) + frame = new( InferenceParams(interp), result, linfo, sp, slottypes, mod, 0, IdSet{InferenceState}(), IdSet{InferenceState}(), src, get_world_counter(interp), valid_worlds, nargs, s_types, s_edges, stmt_info, - Union{}, W, 1, n, - cur_hand, handler_at, n_handlers, + Union{}, ip, 1, n, handler_at, ssavalue_uses, throw_blocks, Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges Vector{InferenceState}(), # callers_in_cycle @@ -118,6 +112,91 @@ mutable struct InferenceState end end +function compute_trycatch(code::Vector{Any}, ip::BitSet) + # The goal initially is to record the frame like this for the state at exit: + # 1: (enter 3) # == 0 + # 3: (expr) # == 1 + # 3: (leave 1) # == 1 + # 4: (expr) # == 0 + # then we can find all trys by walking backwards from :enter statements, + # and all catches by looking at the statement after the :enter + n = length(code) + empty!(ip) + ip.offset = 0 # for _bits_findnext + push!(ip, n + 1) + handler_at = fill(0, n) + + # start from all :enter statements and record the location of the try + for pc = 1:n + stmt = code[pc] + if isexpr(stmt, :enter) + l = stmt.args[1]::Int + handler_at[pc + 1] = pc + push!(ip, pc + 1) + handler_at[l] = pc + push!(ip, l) + end + end + + # now forward those marks to all :leave statements + pc´´ = 0 + while true + # make progress on the active ip set + pc = _bits_findnext(ip.bits, pc´´)::Int + pc > n && break + while true # inner loop optimizes the common case where it can run straight from pc to pc + 1 + pc´ = pc + 1 # next program-counter (after executing instruction) + if pc == pc´´ + pc´´ = pc´ + end + delete!(ip, pc) + cur_hand = handler_at[pc] + @assert cur_hand != 0 "unbalanced try/catch" + stmt = code[pc] + if isa(stmt, GotoNode) + pc´ = stmt.label + elseif isa(stmt, GotoIfNot) + l = stmt.dest::Int + if handler_at[l] != cur_hand + @assert handler_at[l] == 0 "unbalanced try/catch" + handler_at[l] = cur_hand + if l < pc´´ + pc´´ = l + end + push!(ip, l) + end + elseif isa(stmt, ReturnNode) + @assert !isdefined(stmt, :val) "unbalanced try/catch" + break + elseif isa(stmt, Expr) + head = stmt.head + if head === :enter + cur_hand = pc + elseif head === :leave + l = stmt.args[1]::Int + for i = 1:l + cur_hand = handler_at[cur_hand] + end + cur_hand == 0 && break + end + end + + pc´ > n && break # can't proceed with the fast-path fall-through + if handler_at[pc´] != cur_hand + @assert handler_at[pc´] == 0 "unbalanced try/catch" + handler_at[pc´] = cur_hand + elseif !in(pc´, ip) + break # already visited + end + pc = pc´ + end + end + + @assert first(ip) == n + 1 + return handler_at +end + + """ Iterate through all callers of the given InferenceState in the abstract interpretation stack (including the given InferenceState itself), vising diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index c78cd52297581..7270bfdd8beb8 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3428,3 +3428,48 @@ end f41908(x::Complex{T}) where {String<:T<:String} = 1 g41908() = f41908(Any[1][1]) @test only(Base.return_types(g41908, ())) <: Int + +# issue #42022 +let x = Any[ + Expr(:(=), Core.SlotNumber(3), 1) + Expr(:enter, 18) + Expr(:(=), Core.SlotNumber(3), 2.0) + Expr(:enter, 12) + Expr(:(=), Core.SlotNumber(3), '3') + Core.GotoIfNot(Core.SlotNumber(2), 9) + Expr(:leave, 2) + Core.ReturnNode(1) + Expr(:call, GlobalRef(Main, :throw)) + Expr(:leave, 1) + Core.GotoNode(16) + Expr(:leave, 1) + Expr(:(=), Core.SlotNumber(4), Expr(:the_exception)) + Expr(:call, GlobalRef(Main, :rethrow)) + Expr(:pop_exception, Core.SSAValue(4)) + Expr(:leave, 1) + Core.GotoNode(22) + Expr(:leave, 1) + Expr(:(=), Core.SlotNumber(5), Expr(:the_exception)) + nothing + Expr(:pop_exception, Core.SSAValue(2)) + Core.ReturnNode(Core.SlotNumber(3)) + ] + handler_at = Core.Compiler.compute_trycatch(x, Core.Compiler.BitSet()) + @test handler_at == [0, 0, 2, 2, 4, 4, 4, 0, 4, 4, 2, 4, 2, 2, 2, 2, 0, 2, 0, 0, 0, 0] +end + +@test only(Base.return_types((Bool,)) do y + x = 1 + try + x = 2.0 + try + x = '3' + y ? (return 1) : throw() + catch ex1 + rethrow() + end + catch ex2 + nothing + end + return x + end) === Union{Int64, Float64, Char} From ec9f9a5b2c9e7496e88c748fad3374cbe7d43259 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Wed, 1 Sep 2021 16:20:06 -0400 Subject: [PATCH 2/4] Update test/compiler/inference.jl --- test/compiler/inference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 7270bfdd8beb8..d83d0dbd8a0b9 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3472,4 +3472,4 @@ end nothing end return x - end) === Union{Int64, Float64, Char} + end) === Union{Int, Float64, Char} From c759eeebf6cc2fc1ba802c0ca807d50d5d7ad6ba Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Thu, 2 Sep 2021 11:34:42 -0400 Subject: [PATCH 3/4] Update test/compiler/inference.jl Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> --- test/compiler/inference.jl | 50 +++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index d83d0dbd8a0b9..acdfa4d4632f8 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3430,32 +3430,32 @@ g41908() = f41908(Any[1][1]) @test only(Base.return_types(g41908, ())) <: Int # issue #42022 -let x = Any[ - Expr(:(=), Core.SlotNumber(3), 1) - Expr(:enter, 18) - Expr(:(=), Core.SlotNumber(3), 2.0) - Expr(:enter, 12) - Expr(:(=), Core.SlotNumber(3), '3') - Core.GotoIfNot(Core.SlotNumber(2), 9) - Expr(:leave, 2) - Core.ReturnNode(1) - Expr(:call, GlobalRef(Main, :throw)) - Expr(:leave, 1) - Core.GotoNode(16) - Expr(:leave, 1) - Expr(:(=), Core.SlotNumber(4), Expr(:the_exception)) - Expr(:call, GlobalRef(Main, :rethrow)) - Expr(:pop_exception, Core.SSAValue(4)) - Expr(:leave, 1) - Core.GotoNode(22) - Expr(:leave, 1) - Expr(:(=), Core.SlotNumber(5), Expr(:the_exception)) - nothing - Expr(:pop_exception, Core.SSAValue(2)) - Core.ReturnNode(Core.SlotNumber(3)) +let x = Tuple{Int,Any}[ + #= 1=# (0, Expr(:(=), Core.SlotNumber(3), 1)) + #= 2=# (0, Expr(:enter, 18)) + #= 3=# (2, Expr(:(=), Core.SlotNumber(3), 2.0)) + #= 4=# (2, Expr(:enter, 12)) + #= 5=# (4, Expr(:(=), Core.SlotNumber(3), '3')) + #= 6=# (4, Core.GotoIfNot(Core.SlotNumber(2), 9)) + #= 7=# (4, Expr(:leave, 2)) + #= 8=# (0, Core.ReturnNode(1)) + #= 9=# (4, Expr(:call, GlobalRef(Main, :throw))) + #=10=# (4, Expr(:leave, 1)) + #=11=# (2, Core.GotoNode(16)) + #=12=# (4, Expr(:leave, 1)) + #=13=# (2, Expr(:(=), Core.SlotNumber(4), Expr(:the_exception))) + #=14=# (2, Expr(:call, GlobalRef(Main, :rethrow))) + #=15=# (2, Expr(:pop_exception, Core.SSAValue(4))) + #=16=# (2, Expr(:leave, 1)) + #=17=# (0, Core.GotoNode(22)) + #=18=# (2, Expr(:leave, 1)) + #=19=# (0, Expr(:(=), Core.SlotNumber(5), Expr(:the_exception))) + #=20=# (0, nothing) + #=21=# (0, Expr(:pop_exception, Core.SSAValue(2))) + #=22=# (0, Core.ReturnNode(Core.SlotNumber(3))) ] - handler_at = Core.Compiler.compute_trycatch(x, Core.Compiler.BitSet()) - @test handler_at == [0, 0, 2, 2, 4, 4, 4, 0, 4, 4, 2, 4, 2, 2, 2, 2, 0, 2, 0, 0, 0, 0] + handler_at = Core.Compiler.compute_trycatch(last.(x), Core.Compiler.BitSet()) + @test handler_at == first.(x) end @test only(Base.return_types((Bool,)) do y From 695f9f8ad9e3eeb1e606fc274614281a0b01318f Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Thu, 2 Sep 2021 15:10:13 -0400 Subject: [PATCH 4/4] fixup! inference: propagate variable changes to all exception frames --- base/compiler/abstractinterpretation.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index c2a8b19b6af3a..b3acefb895762 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1919,10 +1919,11 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) ssavaluetypes[pc] = t end end - cur_hand = frame.handler_at[pc] if isa(changes, StateUpdate) - while cur_hand != 0 - let l = frame.handler_at[cur_hand + 1] + let cur_hand = frame.handler_at[pc], l, enter + while cur_hand != 0 + enter = frame.src.code[cur_hand] + l = (enter::Expr).args[1]::Int # propagate new type info to exception handler # the handling for Expr(:enter) propagates all changes from before the try/catch # so this only needs to propagate any changes @@ -1932,8 +1933,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end push!(W, l) end + cur_hand = frame.handler_at[cur_hand] end - cur_hand = frame.handler_at[cur_hand] end end end