Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weird inference failure in promote_u0 #918

Closed
fjebaker opened this issue Jul 28, 2023 · 4 comments
Closed

Weird inference failure in promote_u0 #918

fjebaker opened this issue Jul 28, 2023 · 4 comments

Comments

@fjebaker
Copy link

fjebaker commented Jul 28, 2023

Originally opened in SciML/OrdinaryDiffEq.jl#2001

MWE:

using DiffEqBase

struct Thing
    a::Float64
end
struct Wrapper1{T}
    thing::T
end
struct Wrapper2{T}
    thing::T
end

thing = Thing(1.0)
x = 1.0

DiffEqBase.promote_u0(x, Wrapper1(thing), (0.0, 1.0))
@code_warntype DiffEqBase.promote_u0(x, Wrapper1(thing), (0.0, 1.0))

DiffEqBase.promote_u0(x, Wrapper2(thing), (0.0, 1.0))
@code_warntype DiffEqBase.promote_u0(x, Wrapper2(thing), (0.0, 1.0))

Output:

MethodInstance for DiffEqBase.promote_u0(::Float64, ::Wrapper1{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ DiffEqBase ~/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:208
Arguments
  #self#::Core.Const(DiffEqBase.promote_u0)
  u0::Float64
  p::Wrapper1{Thing}
  t0::Tuple{Float64, Float64}
Locals
  T::Any
Body::Any
1nothing
│         Core.NewvarNode(:(T))
│   %3  = DiffEqBase.eltype(u0)::Core.Const(Float64)
│   %4  = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %5  = (%3 <: %4)::Core.Const(false)
│   %6  = !%5::Core.Const(true)
└──       goto #6 if not %6
2 ─       (T = DiffEqBase.anyeltypedual(p))
│   %9  = (T === DiffEqBase.Any)::Bool
└──       goto #4 if not %9
3return u0
4%12 = T::Any%13 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %14 = (%12 <: %13)::Bool
└──       goto #6 if not %14
5%16 = Base.broadcasted(T, u0)::Any%17 = Base.materialize(%16)::Any
└──       return %17
6return u0

MethodInstance for DiffEqBase.promote_u0(::Float64, ::Wrapper2{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ DiffEqBase ~/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:208
Arguments
  #self#::Core.Const(DiffEqBase.promote_u0)
  u0::Float64
  p::Wrapper2{Thing}
  t0::Tuple{Float64, Float64}
Locals
  T::Type{Any}
Body::Float64
1nothing
│        Core.NewvarNode(:(T))
│   %3 = DiffEqBase.eltype(u0)::Core.Const(Float64)
│   %4 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %5 = (%3 <: %4)::Core.Const(false)
│   %6 = !%5::Core.Const(true)
└──      goto #5 if not %6
2 ─      (T = DiffEqBase.anyeltypedual(p))
│   %9 = (T::Core.Const(Any) === DiffEqBase.Any)::Core.Const(true)
└──      goto #4 if not %9
3return u0
4 ─      Core.Const(:(T))
│        Core.Const(:(ForwardDiff.Dual))
│        Core.Const(:(%12 <: %13))
│        Core.Const(:(goto %19 if not %14))
│        Core.Const(:(Base.broadcasted(T, u0)))
│        Core.Const(:(Base.materialize(%16)))
└──      Core.Const(:(return %17))
5 ┄      Core.Const(:(return u0))

The first one always fails to infer (i.e. doing Wrapper2 first makes it fail to infer Wrapper2).

I have no idea why, since anyeltypedual seems to be const folded:

@code_warntype DiffEqBase.anyeltypedual(Wrapper1(thing))
@code_warntype DiffEqBase.anyeltypedual(Wrapper2(thing))
MethodInstance for DiffEqBase.anyeltypedual(::Wrapper1{Thing})
  from anyeltypedual(x) @ DiffEqBase ~/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:101
Arguments
  #self#::Core.Const(DiffEqBase.anyeltypedual)
  x::Wrapper1{Thing}
Body::Type{Any}
1%1 = (#self#)(x, 0)::Core.Const(Any)
└──      return %1

MethodInstance for DiffEqBase.anyeltypedual(::Wrapper2{Thing})
  from anyeltypedual(x) @ DiffEqBase ~/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:101
Arguments
  #self#::Core.Const(DiffEqBase.anyeltypedual)
  x::Wrapper2{Thing}
Body::Type{Any}
1%1 = (#self#)(x, 0)::Core.Const(Any)
└──      return %1
@fjebaker
Copy link
Author

fjebaker commented Jul 28, 2023

I'm not really sure what I'm doing with these tools, but I'm having a look anyway. SnoopCompile finds the same:

julia> inv1 = @snoopi_deep DiffEqBase.promote_u0(x, Wrapper1(thing), (0.0, 1.0)) 
InferenceTimingNode: 0.005841/0.056050 on Core.Compiler.Timings.ROOT() with 3 direct children

julia> inv2 = @snoopi_deep DiffEqBase.promote_u0(x, Wrapper2(thing), (0.0, 1.0)) 
InferenceTimingNode: 0.004530/0.005901 on Core.Compiler.Timings.ROOT() with 2 direct children

julia> inference_triggers(inv1)
1-element Vector{InferenceTrigger}:
 Inference triggered to call DiffEqBase.diffeqmapreduce(::DiffEqBase.DualEltypeChecker{Wrapper1{Thing}}, ::typeof(DiffEqBase.promote_dual), ::Tuple{Val{:thing}}) from anyeltypedual (/Users/lx21966/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:105) inlined into DiffEqBase.promote_u0(::Float64, ::Wrapper1{Thing}, ::Tuple{Float64, Float64}) (/Users/lx21966/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:210)

julia> inference_triggers(inv2)
InferenceTrigger[]

julia> summary(accumulate_by_source(Method, inference_triggers(inv1)))
promote_u0(u0, p, t0) @ DiffEqBase ~/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:208 had 1 specializations
Triggering calls:
Inlined anyeltypedual at /Users/lx21966/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:105: calling diffeqmapreduce (1 instances)

It actually even narrows it down to anyeltypedual. Looking deeper

Screenshot 2023-07-28 at 16 00 56

Update: pulling the code apart a bit boils the problem down to

1-element Vector{InferenceTrigger}:
 Inference triggered to call map(::DiffEqBase.DualEltypeChecker{Wrapper1{Thing}}, ::Tuple{Val{:thing}}) from anyeltypedual (/Users/lx21966/Developer/jl/test-inf3.jl:38) with specialization DiffEqBase.anyeltypedual(::Wrapper1{Thing}, ::Int64)

@ChrisRackauckas
Copy link
Member

Moved to Julia's Base: JuliaLang/julia#50735. Indeed, that's weird.

@thomvet
Copy link
Contributor

thomvet commented Feb 28, 2024

Not sure why, but when I am running the above code on my computer, the return type now seems correctly inferred. Code warntype output (after a direct copy paste of the above code sample in a fresh Julia session):

@code_warntype DiffEqBase.promote_u0(x, Wrapper1(thing), (0.0, 1.0))
MethodInstance for DiffEqBase.promote_u0(::Float64, ::Wrapper1{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ DiffEqBase C:\Users\thvt\.julia\packages\DiffEqBase\4084R\src\forwarddiff.jl:250
Arguments
  #self#::Core.Const(DiffEqBase.promote_u0)
  u0::Float64
  p::Wrapper1{Thing}
  t0::Tuple{Float64, Float64}
Locals
  T::Type{Any}
Body::Float64
1 ─      nothing
│        Core.NewvarNode(:(T))
│   %3 = DiffEqBase.eltype(u0)::Core.Const(Float64)
│   %4 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %5 = (%3 <: %4)::Core.Const(false)
│   %6 = !%5::Core.Const(true)
└──      goto #6 if not %6
2 ─      (T = DiffEqBase.anyeltypedual(p))
│   %9 = (T::Core.Const(Any) === DiffEqBase.Any)::Core.Const(true)
└──      goto #5 if not %9
3 ─      return u0
4 ─      Core.Const(:(goto %13))
5 ┄      Core.Const(:(T))
│        Core.Const(:(ForwardDiff.Dual))
│        Core.Const(:(%13 <: %14))
│        Core.Const(:(goto %20 if not %15))
│        Core.Const(:(Base.broadcasted(T, u0)))
│        Core.Const(:(Base.materialize(%17)))
└──      Core.Const(:(return %18))
6 ┄      Core.Const(:(return u0))

And the second example provides the same output as given above, i.e., inference still works.

Pkg status:
[2b5f629d] DiffEqBase v6.147.1

Manifest:

  [47edcb42] ADTypes v0.2.6
  [79e6a3ab] Adapt v4.0.1
  [4fba245c] ArrayInterface v7.7.1
  [62783981] BitTwiddlingConvenienceFunctions v0.1.5
  [2a0fbf3d] CPUSummary v0.2.4
  [fb6a15b2] CloseOpenIntervals v0.1.12
  [38540f10] CommonSolve v0.2.4
  [bbf7d656] CommonSubexpressions v0.3.0
  [34da2185] Compat v4.14.0
  [187b0558] ConstructionBase v1.5.4
  [adafc99b] CpuId v0.3.1
  [9a962f9c] DataAPI v1.16.0
  [864edb3b] DataStructures v0.18.17
  [e2d170a0] DataValueInterfaces v1.0.0
  [2b5f629d] DiffEqBase v6.147.1
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
  [ffbed154] DocStringExtensions v0.9.3
  [4e289a0a] EnumX v1.0.4
  [f151be2c] EnzymeCore v0.6.5
  [e2ba6199] ExprTools v0.1.10
  [7034ab61] FastBroadcast v0.2.8
  [f6369f11] ForwardDiff v0.10.36
  [069b7b12] FunctionWrappers v1.1.3
  [77dc65aa] FunctionWrappersWrappers v0.1.3
  [46192b85] GPUArraysCore v0.1.6
  [615f187c] IfElse v0.1.1
  [92d709cd] IrrationalConstants v0.2.2
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.5.0
  [10f19ff3] LayoutPointers v0.1.15
  [2ab3a3ac] LogExpFunctions v0.3.27
  [1914dd2f] MacroTools v0.5.13
  [d125e4d3] ManualMemory v0.1.8
  [46d2c3a1] MuladdMacro v0.2.4
  [77ba4419] NaNMath v1.0.2
  [bac558e1] OrderedCollections v1.6.3
  [d96e819e] Parameters v0.12.3
  [f517fe37] Polyester v0.7.9
  [1d0040c9] PolyesterWeave v0.2.1
  [d236fae5] PreallocationTools v0.4.20
  [aea7be01] PrecompileTools v1.2.0
  [21216c6a] Preferences v1.4.1
  [3cdcf5f2] RecipesBase v1.3.4
  [731186ca] RecursiveArrayTools v3.10.1
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.3.0
  [7e49a35a] RuntimeGeneratedFunctions v0.5.12
  [94e857df] SIMDTypes v0.1.0
  [0bca4576] SciMLBase v2.29.0
  [c0aeaf25] SciMLOperators v0.3.8
  [efcf1570] Setfield v1.1.1
  [276daf66] SpecialFunctions v2.3.1
  [aedffcd0] Static v0.8.10
  [0d7ed370] StaticArrayInterface v1.5.0
  [1e83bf80] StaticArraysCore v1.4.2
  [7792a7ef] StrideArraysCore v0.5.2
  [2efcf032] SymbolicIndexingInterface v0.3.8
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.11.1
  [8290d209] ThreadingUtilities v0.5.2
  [410a4b4d] Tricks v0.1.8
  [781d530d] TruncatedStacktraces v1.4.0
  [3a884ed6] UnPack v1.0.2
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [b27032c2] LibCURL v0.6.4
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.10.0
  [de0858da] Printf
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [6462fe0b] Sockets
  [2f01184e] SparseArrays v1.10.0
  [10745b16] Statistics v1.10.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v1.0.5+1
  [deac9b47] LibCURL_jll v8.4.0+0
  [e37daf67] LibGit2_jll v1.6.4+0
  [29816b5a] LibSSH2_jll v1.11.0+1
  [c8ffd9c3] MbedTLS_jll v2.28.2+1
  [14a3606d] MozillaCACerts_jll v2023.1.10
  [4536629a] OpenBLAS_jll v0.3.23+2
  [05823500] OpenLibm_jll v0.8.1+2
  [bea87d4a] SuiteSparse_jll v7.2.1+1
  [83775a58] Zlib_jll v1.2.13+1
  [8e850b90] libblastrampoline_jll v5.8.0+1
  [8e850ede] nghttp2_jll v1.52.0+1
  [3f19e933] p7zip_jll v17.4.0+2

Versioninfo:

Julia Version 1.10.0
Commit 3120989f39 (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 24 × 12th Gen Intel(R) Core(TM) i7-12850HX
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
  Threads: 35 on 24 virtual cores

@ChrisRackauckas
Copy link
Member

It recently got some tweaks.

#1017

That would have fixed these cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants