diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 41d448aa5b6333..5200d3dbf6d94c 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1832,11 +1832,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) empty!(frame.pclimitations) break end - if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condx, SlotNumber) - # if this non-`Conditional` object is a slot, we form and propagate - # the conditional constraint on it - condt = Conditional(condx, Const(true), Const(false)) - end condval = maybe_extract_const_bool(condt) l = stmt.dest::Int if !isempty(frame.pclimitations) @@ -1857,6 +1852,17 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) changes_else = conditional_changes(changes_else, condt.elsetype, condt.var) changes = conditional_changes(changes, condt.vtype, condt.var) end + if isa(condx, SlotNumber) + if isa(condt, Conditional) + tfalse = Conditional(condt.var, Bottom, condt.elsetype) + ttrue = Conditional(condt.var, condt.vtype, Bottom) + else + tfalse = Const(false) + ttrue = Const(true) + end + changes_else = add_change!(changes_else, condx, tfalse, true) + changes = add_change!(changes, condx, ttrue, true) + end newstate_else = stupdate!(states[l], changes_else) if newstate_else !== nothing # add else branch to active IP list diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 0ec151b1cb4a7e..ec4a00e64637b0 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1869,7 +1869,8 @@ let M = Module() @test rt == Tuple{Union{Nothing,Int},Any} end -@testset "conditional constraint propagation from non-`Conditional` object" begin +@testset "state update on branching" begin + # refine condition type into constant boolean value on branching @test Base.return_types((Bool,)) do b if b return !b ? nothing : 1 # ::Int @@ -1878,6 +1879,7 @@ end end end == Any[Int] + # even when the original type isn't boolean type @test Base.return_types((Any,)) do b if b return b # ::Bool @@ -1885,6 +1887,16 @@ end return nothing end end == Any[Union{Bool,Nothing}] + + # even when the original type is `Conditional` + @test Base.return_types((Any,)) do a + b = isa(a, Int) + if b + return !b ? nothing : a # ::Int + else + return 0 + end + end == Any[Int] end function f25579(g)