Skip to content

Commit

Permalink
inference: propagate variable changes to all exception frames #42081 (#…
Browse files Browse the repository at this point in the history
…42110)

cherry-picked from #42081

Co-Authored-By: Jameson Nash <vtjnash+github@gmail.com>
  • Loading branch information
2 people authored and KristofferC committed Sep 6, 2021
1 parent 9d13e16 commit 19e66b3
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 43 deletions.
45 changes: 22 additions & 23 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1346,19 +1346,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
n = frame.nstmts
while frame.pc´´ <= n
# make progress on the active ip set
local pc::Int = frame.pc´´ # current program-counter
local pc::Int = frame.pc´´
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
#print(pc,": ",s[pc],"\n")
local pc´::Int = pc + 1 # next program-counter (after executing instruction)
if pc == frame.pc´´
# need to update pc´´ to point at the new lowest instruction in W
min_pc = _bits_findnext(W.bits, pc + 1)
frame.pc´´ = min_pc == -1 ? n + 1 : min_pc
# want to update pc´´ to point at the new lowest instruction in W
frame.pc´´ = pc´
end
delete!(W, pc)
frame.currpc = pc
frame.cur_hand = frame.handler_at[pc]
frame.stmt_edges[pc] === nothing || empty!(frame.stmt_edges[pc])
edges = frame.stmt_edges[pc]
edges === nothing || empty!(edges)
frame.stmt_info[pc] = nothing
stmt = frame.src.code[pc]
changes = s[pc]::VarTable
Expand Down Expand Up @@ -1392,7 +1391,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
pc´ = l
else
# general case
frame.handler_at[l] = frame.cur_hand
changes_else = changes
if isa(condt, Conditional)
if condt.elsetype !== Any && condt.elsetype !== changes[slot_id(condt.var)]
Expand Down Expand Up @@ -1440,7 +1438,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
elseif hd === :enter
l = stmt.args[1]::Int
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
# propagate type info to exception handler
old = s[l]
newstate_catch = stupdate!(old, changes)
Expand All @@ -1452,11 +1449,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
s[l] = newstate_catch
end
typeassert(s[l], VarTable)
frame.handler_at[l] = frame.cur_hand
elseif hd === :leave
for i = 1:((stmt.args[1])::Int)
frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second
end
else
if hd === :(=)
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
Expand All @@ -1482,16 +1475,22 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
frame.src.ssavaluetypes[pc] = t
end
end
if frame.cur_hand !== nothing && isa(changes, StateUpdate)
# propagate new type info to exception handler
# the handling for Expr(:enter) propagates all changes from before the try/catch
# so this only needs to propagate any changes
l = frame.cur_hand.first::Int
if stupdate1!(s[l]::VarTable, changes::StateUpdate) !== false
if l < frame.pc´´
frame.pc´´ = l
if isa(changes, StateUpdate)
let cur_hand = frame.handler_at[pc], l, enter
while cur_hand != 0
enter = frame.src.code[cur_hand]
l = (enter::Expr).args[1]::Int
# propagate new type info to exception handler
# the handling for Expr(:enter) propagates all changes from before the try/catch
# so this only needs to propagate any changes
if stupdate1!(s[l]::VarTable, changes::StateUpdate) !== false
if l < frame.pc´´
frame.pc´´ = l
end
push!(W, l)
end
cur_hand = frame.handler_at[cur_hand]
end
push!(W, l)
end
end
end
Expand All @@ -1504,7 +1503,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end

pc´ > n && break # can't proceed with the fast-path fall-through
frame.handler_at[pc´] = frame.cur_hand
newstate = stupdate!(s[pc´], changes)
if isa(stmt, GotoNode) && frame.pc´´ < pc´
# if we are processing a goto node anyways,
Expand All @@ -1515,7 +1513,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
s[pc´] = newstate
end
push!(W, pc´)
pc = frame.pc´´
break
elseif newstate !== nothing
s[pc´] = newstate
pc = pc´
Expand All @@ -1525,6 +1523,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
break
end
end
frame.pc´´ = _bits_findnext(W.bits, frame.pc´´)::Int # next program-counter
end
frame.dont_work_on_me = false
nothing
Expand Down
114 changes: 94 additions & 20 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ mutable struct InferenceState
pc´´::LineNum
nstmts::Int
# current exception handler info
cur_hand #::Union{Nothing, Pair{LineNum, prev_handler}}
handler_at::Vector{Any}
n_handlers::Int
handler_at::Vector{LineNum}
# ssavalue sparsity and restart info
ssavalue_uses::Vector{BitSet}
throw_blocks::BitSet
Expand All @@ -57,8 +55,9 @@ mutable struct InferenceState
function InferenceState(result::InferenceResult, src::CodeInfo,
cached::Bool, interp::AbstractInterpreter)
linfo = result.linfo
def = linfo.def
code = src.code::Array{Any,1}
toplevel = !isa(linfo.def, Method)
toplevel = !isa(def, Method)

sp = sptypes_from_meth_instance(linfo::MethodInstance)

Expand Down Expand Up @@ -87,30 +86,21 @@ mutable struct InferenceState
throw_blocks = find_throw_blocks(code)

# exception handlers
cur_hand = nothing
handler_at = Any[ nothing for i=1:n ]
n_handlers = 0

W = BitSet()
push!(W, 1) #initial pc to visit

if !toplevel
meth = linfo.def
inmodule = meth.module
else
inmodule = linfo.def::Module
end
ip = BitSet()
handler_at = compute_trycatch(src.code, ip)
push!(ip, 1)

mod = isa(def, Method) ? def.module : def
valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)

frame = new(
InferenceParams(interp), result, linfo,
sp, slottypes, inmodule, 0,
sp, slottypes, mod, 0,
IdSet{InferenceState}(), IdSet{InferenceState}(),
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
Union{}, ip, 1, n, handler_at,
ssavalue_uses, throw_blocks,
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
Expand All @@ -124,6 +114,90 @@ mutable struct InferenceState
end
end

function compute_trycatch(code::Vector{Any}, ip::BitSet)
# The goal initially is to record the frame like this for the state at exit:
# 1: (enter 3) # == 0
# 3: (expr) # == 1
# 3: (leave 1) # == 1
# 4: (expr) # == 0
# then we can find all trys by walking backwards from :enter statements,
# and all catches by looking at the statement after the :enter
n = length(code)
empty!(ip)
ip.offset = 0 # for _bits_findnext
push!(ip, n + 1)
handler_at = fill(0, n)

# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isexpr(stmt, :enter)
l = stmt.args[1]::Int
handler_at[pc + 1] = pc
push!(ip, pc + 1)
handler_at[l] = pc
push!(ip, l)
end
end

# now forward those marks to all :leave statements
pc´´ = 0
while true
# make progress on the active ip set
pc = _bits_findnext(ip.bits, pc´´)::Int
pc > n && break
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
pc´ = pc + 1 # next program-counter (after executing instruction)
if pc == pc´´
pc´´ = pc´
end
delete!(ip, pc)
cur_hand = handler_at[pc]
@assert cur_hand != 0 "unbalanced try/catch"
stmt = code[pc]
if isa(stmt, GotoNode)
pc´ = stmt.label
elseif isa(stmt, GotoIfNot)
l = stmt.dest::Int
if handler_at[l] != cur_hand
@assert handler_at[l] == 0 "unbalanced try/catch"
handler_at[l] = cur_hand
if l < pc´´
pc´´ = l
end
push!(ip, l)
end
elseif isa(stmt, ReturnNode)
@assert !isdefined(stmt, :val) "unbalanced try/catch"
break
elseif isa(stmt, Expr)
head = stmt.head
if head === :enter
cur_hand = pc
elseif head === :leave
l = stmt.args[1]::Int
for i = 1:l
cur_hand = handler_at[cur_hand]
end
cur_hand == 0 && break
end
end

pc´ > n && break # can't proceed with the fast-path fall-through
if handler_at[pc´] != cur_hand
@assert handler_at[pc´] == 0 "unbalanced try/catch"
handler_at[pc´] = cur_hand
elseif !in(pc´, ip)
break # already visited
end
pc = pc´
end
end

@assert first(ip) == n + 1
return handler_at
end

method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table

function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)
Expand Down
45 changes: 45 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3040,3 +3040,48 @@ Base.return_types((Union{Int,Nothing},)) do x
end
x
end == [Int]

# issue #42022
let x = Tuple{Int,Any}[
#= 1=# (0, Expr(:(=), Core.SlotNumber(3), 1))
#= 2=# (0, Expr(:enter, 18))
#= 3=# (2, Expr(:(=), Core.SlotNumber(3), 2.0))
#= 4=# (2, Expr(:enter, 12))
#= 5=# (4, Expr(:(=), Core.SlotNumber(3), '3'))
#= 6=# (4, Core.GotoIfNot(Core.SlotNumber(2), 9))
#= 7=# (4, Expr(:leave, 2))
#= 8=# (0, Core.ReturnNode(1))
#= 9=# (4, Expr(:call, GlobalRef(Main, :throw)))
#=10=# (4, Expr(:leave, 1))
#=11=# (2, Core.GotoNode(16))
#=12=# (4, Expr(:leave, 1))
#=13=# (2, Expr(:(=), Core.SlotNumber(4), Expr(:the_exception)))
#=14=# (2, Expr(:call, GlobalRef(Main, :rethrow)))
#=15=# (2, Expr(:pop_exception, Core.SSAValue(4)))
#=16=# (2, Expr(:leave, 1))
#=17=# (0, Core.GotoNode(22))
#=18=# (2, Expr(:leave, 1))
#=19=# (0, Expr(:(=), Core.SlotNumber(5), Expr(:the_exception)))
#=20=# (0, nothing)
#=21=# (0, Expr(:pop_exception, Core.SSAValue(2)))
#=22=# (0, Core.ReturnNode(Core.SlotNumber(3)))
]
handler_at = Core.Compiler.compute_trycatch(last.(x), Core.Compiler.BitSet())
@test handler_at == first.(x)
end

@test only(Base.return_types((Bool,)) do y
x = 1
try
x = 2.0
try
x = '3'
y ? (return 1) : throw()
catch ex1
rethrow()
end
catch ex2
nothing
end
return x
end) === Union{Int, Float64, Char}

0 comments on commit 19e66b3

Please sign in to comment.