Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: propagate variable changes to all exception frames #42081

Merged
merged 4 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1764,18 +1764,16 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
while frame.pc´´ <= n
# make progress on the active ip set
local pc::Int = frame.pc´´ # current program-counter
local pc::Int = frame.pc´´
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
#print(pc,": ",s[pc],"\n")
local pc´::Int = pc + 1 # next program-counter (after executing instruction)
if pc == frame.pc´´
# need to update pc´´ to point at the new lowest instruction in W
min_pc = _bits_findnext(W.bits, pc + 1)
frame.pc´´ = min_pc == -1 ? n + 1 : min_pc
# want to update pc´´ to point at the new lowest instruction in W
frame.pc´´ = pc´
end
delete!(W, pc)
frame.currpc = pc
frame.cur_hand = frame.handler_at[pc]
edges = frame.stmt_edges[pc]
edges === nothing || empty!(edges)
frame.stmt_info[pc] = nothing
Expand Down Expand Up @@ -1817,7 +1815,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
pc´ = l
else
# general case
frame.handler_at[l] = frame.cur_hand
changes_else = changes
if isa(condt, Conditional)
changes_else = conditional_changes(changes_else, condt.elsetype, condt.var)
Expand Down Expand Up @@ -1877,7 +1874,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
elseif hd === :enter
stmt = stmt::Expr
l = stmt.args[1]::Int
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
# propagate type info to exception handler
old = states[l]
newstate_catch = stupdate!(old, changes)
Expand All @@ -1889,12 +1885,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
states[l] = newstate_catch
end
typeassert(states[l], VarTable)
frame.handler_at[l] = frame.cur_hand
elseif hd === :leave
stmt = stmt::Expr
for i = 1:((stmt.args[1])::Int)
frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second
end
else
if hd === :(=)
stmt = stmt::Expr
Expand Down Expand Up @@ -1928,16 +1919,22 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
ssavaluetypes[pc] = t
end
end
if frame.cur_hand !== nothing && isa(changes, StateUpdate)
# propagate new type info to exception handler
# the handling for Expr(:enter) propagates all changes from before the try/catch
# so this only needs to propagate any changes
l = frame.cur_hand.first::Int
if stupdate1!(states[l]::VarTable, changes::StateUpdate) !== false
if l < frame.pc´´
frame.pc´´ = l
if isa(changes, StateUpdate)
let cur_hand = frame.handler_at[pc], l, enter
while cur_hand != 0
enter = frame.src.code[cur_hand]
l = (enter::Expr).args[1]::Int
# propagate new type info to exception handler
# the handling for Expr(:enter) propagates all changes from before the try/catch
# so this only needs to propagate any changes
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
if stupdate1!(states[l]::VarTable, changes::StateUpdate) !== false
if l < frame.pc´´
frame.pc´´ = l
end
push!(W, l)
end
Comment on lines +1930 to +1935
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, why do we want to propagate changes via :enter ? Can't we propagate changes directly to :leave statements here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The l value here is short for :leave (strictly speaking, it doesn't have to be a :leave, but it normally will be)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

frame.src.code[l] is supposed to be :enter expression always, no ?
My understanding is that the algorithm accumulates all changes within try/catch clauses to :enter expression's states, and then :enter will propagate the changes to :leave.

So I wonder why we don't want to do:

if isa(changes, StateUpdate)
    while cur_hand != 0
        let l = frame.handler_at[cur_hand + 1]
            # propagate new type info to exception handler
            enter = frame.src.code[l]
            @assert isexpr(enter, :enter)
            leavestate = states[enter.args[1]::Int]::VarTable
            stupdate1!(leavestate, changes::StateUpdate) !== false
        end
        cur_hand = frame.handler_at[cur_hand]
    end
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because l is almost always a :leave expr there

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, or was supposed to be, until I rearranged the code, and then it wasn't anymore

cur_hand = frame.handler_at[cur_hand]
end
push!(W, l)
end
end
end
Expand All @@ -1950,7 +1947,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end

pc´ > n && break # can't proceed with the fast-path fall-through
frame.handler_at[pc´] = frame.cur_hand
newstate = stupdate!(states[pc´], changes)
if isa(stmt, GotoNode) && frame.pc´´ < pc´
# if we are processing a goto node anyways,
Expand All @@ -1961,7 +1957,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
states[pc´] = newstate
end
push!(W, pc´)
pc = frame.pc´´
break
elseif newstate !== nothing
states[pc´] = newstate
pc = pc´
Expand All @@ -1971,6 +1967,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
break
end
end
frame.pc´´ = _bits_findnext(W.bits, frame.pc´´)::Int # next program-counter
end
frame.dont_work_on_me = false
nothing
Expand Down
103 changes: 91 additions & 12 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ mutable struct InferenceState
pc´´::LineNum
nstmts::Int
# current exception handler info
cur_hand #::Union{Nothing, Pair{LineNum, prev_handler}}
handler_at::Vector{Any}
n_handlers::Int
handler_at::Vector{LineNum}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "handler_at" naming is actually a bit confusing ? To me "entered_at" sounds more reasonable ("handler" usually mean catch clause ... ?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is the same to say it is handled by the try statement. The catch / leave is where we stop handling errors (which is why it is simpler to point to the try statement also, since we have more information there)

# ssavalue sparsity and restart info
ssavalue_uses::Vector{BitSet}
throw_blocks::BitSet
Expand Down Expand Up @@ -86,25 +84,21 @@ mutable struct InferenceState
throw_blocks = find_throw_blocks(code)

# exception handlers
cur_hand = nothing
handler_at = Any[ nothing for i=1:n ]
n_handlers = 0

W = BitSet()
push!(W, 1) #initial pc to visit
ip = BitSet()
handler_at = compute_trycatch(src.code, ip)
push!(ip, 1)

mod = isa(def, Method) ? def.module : def

valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)

frame = new(
InferenceParams(interp), result, linfo,
sp, slottypes, mod, 0,
IdSet{InferenceState}(), IdSet{InferenceState}(),
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
Union{}, ip, 1, n, handler_at,
ssavalue_uses, throw_blocks,
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
Expand All @@ -118,6 +112,91 @@ mutable struct InferenceState
end
end

function compute_trycatch(code::Vector{Any}, ip::BitSet)
# The goal initially is to record the frame like this for the state at exit:
# 1: (enter 3) # == 0
# 3: (expr) # == 1
# 3: (leave 1) # == 1
# 4: (expr) # == 0
# then we can find all trys by walking backwards from :enter statements,
# and all catches by looking at the statement after the :enter
n = length(code)
empty!(ip)
ip.offset = 0 # for _bits_findnext
push!(ip, n + 1)
handler_at = fill(0, n)

# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isexpr(stmt, :enter)
l = stmt.args[1]::Int
handler_at[pc + 1] = pc
push!(ip, pc + 1)
handler_at[l] = pc
push!(ip, l)
end
end

# now forward those marks to all :leave statements
pc´´ = 0
while true
# make progress on the active ip set
pc = _bits_findnext(ip.bits, pc´´)::Int
pc > n && break
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
pc´ = pc + 1 # next program-counter (after executing instruction)
if pc == pc´´
pc´´ = pc´
end
delete!(ip, pc)
cur_hand = handler_at[pc]
@assert cur_hand != 0 "unbalanced try/catch"
stmt = code[pc]
if isa(stmt, GotoNode)
pc´ = stmt.label
elseif isa(stmt, GotoIfNot)
l = stmt.dest::Int
if handler_at[l] != cur_hand
@assert handler_at[l] == 0 "unbalanced try/catch"
handler_at[l] = cur_hand
if l < pc´´
pc´´ = l
end
push!(ip, l)
end
elseif isa(stmt, ReturnNode)
@assert !isdefined(stmt, :val) "unbalanced try/catch"
break
elseif isa(stmt, Expr)
head = stmt.head
if head === :enter
cur_hand = pc
elseif head === :leave
l = stmt.args[1]::Int
for i = 1:l
cur_hand = handler_at[cur_hand]
end
cur_hand == 0 && break
end
end

pc´ > n && break # can't proceed with the fast-path fall-through
if handler_at[pc´] != cur_hand
@assert handler_at[pc´] == 0 "unbalanced try/catch"
handler_at[pc´] = cur_hand
elseif !in(pc´, ip)
break # already visited
end
pc = pc´
end
end

@assert first(ip) == n + 1
return handler_at
end


"""
Iterate through all callers of the given InferenceState in the abstract
interpretation stack (including the given InferenceState itself), vising
Expand Down
45 changes: 45 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3428,3 +3428,48 @@ end
f41908(x::Complex{T}) where {String<:T<:String} = 1
g41908() = f41908(Any[1][1])
@test only(Base.return_types(g41908, ())) <: Int

# issue #42022
let x = Tuple{Int,Any}[
#= 1=# (0, Expr(:(=), Core.SlotNumber(3), 1))
#= 2=# (0, Expr(:enter, 18))
#= 3=# (2, Expr(:(=), Core.SlotNumber(3), 2.0))
#= 4=# (2, Expr(:enter, 12))
#= 5=# (4, Expr(:(=), Core.SlotNumber(3), '3'))
#= 6=# (4, Core.GotoIfNot(Core.SlotNumber(2), 9))
#= 7=# (4, Expr(:leave, 2))
#= 8=# (0, Core.ReturnNode(1))
#= 9=# (4, Expr(:call, GlobalRef(Main, :throw)))
#=10=# (4, Expr(:leave, 1))
#=11=# (2, Core.GotoNode(16))
#=12=# (4, Expr(:leave, 1))
#=13=# (2, Expr(:(=), Core.SlotNumber(4), Expr(:the_exception)))
#=14=# (2, Expr(:call, GlobalRef(Main, :rethrow)))
#=15=# (2, Expr(:pop_exception, Core.SSAValue(4)))
#=16=# (2, Expr(:leave, 1))
#=17=# (0, Core.GotoNode(22))
#=18=# (2, Expr(:leave, 1))
#=19=# (0, Expr(:(=), Core.SlotNumber(5), Expr(:the_exception)))
#=20=# (0, nothing)
#=21=# (0, Expr(:pop_exception, Core.SSAValue(2)))
#=22=# (0, Core.ReturnNode(Core.SlotNumber(3)))
]
handler_at = Core.Compiler.compute_trycatch(last.(x), Core.Compiler.BitSet())
@test handler_at == first.(x)
end

@test only(Base.return_types((Bool,)) do y
x = 1
try
x = 2.0
try
x = '3'
y ? (return 1) : throw()
catch ex1
rethrow()
end
catch ex2
nothing
end
return x
end) === Union{Int, Float64, Char}