Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-deterministic inference results? #50735

Open
ChrisRackauckas opened this issue Jul 31, 2023 · 4 comments
Open

Non-deterministic inference results? #50735

ChrisRackauckas opened this issue Jul 31, 2023 · 4 comments
Labels
compiler:inference Type inference

Comments

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Jul 31, 2023

MWE, here's the setup code isolated to just ForwardDiff:

using ForwardDiff
const DUALCHECK_RECURSION_MAX = 10

"""
  reduce_tup(f::F, inds::Tuple{Vararg{Any,N}}) where {F,N}

An optimized `reduce` for tuples. `Base.reduce`'s `afoldl` will often not inline.
Additionally, `reduce_tup` attempts to order the reduction in an optimal manner.

More importantly, `reduce_tup(_pick_range, inds)` often performs better than `reduce(_pick_range, inds)`.
"""
@generated function reduce_tup(f::F, inds::Tuple{Vararg{Any, N}}) where {F, N}
    q = Expr(:block, Expr(:meta, :inline, :propagate_inbounds))
    if N == 1
        push!(q.args, :(inds[1]))
        return q
    end
    syms = Vector{Symbol}(undef, N)
    i = 0
    for n in 1:N
        syms[n] = iₙ = Symbol(:i_, (i += 1))
        push!(q.args, Expr(:(=), iₙ, Expr(:ref, :inds, n)))
    end
    W = 1 << (8sizeof(N) - 2 - leading_zeros(N))
    while W > 0
        _N = length(syms)
        for _ in (2W):W:_N
            for w in 1:W
                new_sym = Symbol(:i_, (i += 1))
                push!(q.args, Expr(:(=), new_sym, Expr(:call, :f, syms[w], syms[w + W])))
                syms[w] = new_sym
            end
            deleteat!(syms, (1 + W):(2W))
        end
        W >>>= 1
    end
    q
end

"""
    promote_dual(::Type{T},::Type{T2})


Is like the number promotion system, but always prefers a dual number type above
anything else. For higher order differentiation, it returns the most dualiest of
them all. This is then used to promote `u0` into the suspected highest differentiation
space for solving the equation.
"""
promote_dual(::Type{T}, ::Type{T2}) where {T, T2} = T
promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T
function promote_dual(::Type{T},
    ::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual}
    T
end
promote_dual(::Type{T}, ::Type{T2}) where {T, T2 <: ForwardDiff.Dual} = T2

function promote_dual(::Type{T},
    ::Type{T2}) where {T3, T4, V, V2 <: ForwardDiff.Dual, N, N2,
    T <: ForwardDiff.Dual{T3, V, N},
    T2 <: ForwardDiff.Dual{T4, V2, N2}}
    T2
end
function promote_dual(::Type{T},
    ::Type{T2}) where {T3, T4, V <: ForwardDiff.Dual, V2, N, N2,
    T <: ForwardDiff.Dual{T3, V, N},
    T2 <: ForwardDiff.Dual{T4, V2, N2}}
    T
end
function promote_dual(::Type{T},
    ::Type{T2}) where {
    T3, V <: ForwardDiff.Dual, V2 <: ForwardDiff.Dual,
    N,
    T <: ForwardDiff.Dual{T3, V, N},
    T2 <: ForwardDiff.Dual{T3, V2, N}}
    ForwardDiff.Dual{T3, promote_dual(V, V2), N}
end

# `reduce` and `map` are specialized on tuples to be unrolled (via recursion)
# Therefore, they can be type stable even with heterogeneous input types.
# We also don't care about allocating any temporaries with them, as it should
# all be unrolled and optimized away.
# Being unrolled also means const prop can work for things like
# `mapreduce(f, op, propertynames(x))`
# where `f` may call `getproperty` and thus have return type dependent
# on the particular symbol.
# `mapreduce` hasn't received any such specialization.
@inline diffeqmapreduce(f::F, op::OP, x::Tuple) where {F, OP} = reduce_tup(op, map(f, x))
@inline function diffeqmapreduce(f::F, op::OP, x::NamedTuple) where {F, OP}
    reduce_tup(op, map(f, x))
end
# For other container types, we probably just want to call `mapreduce`
@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x)

"""
    anyeltypedual(x)


Searches through a type to see if any of its values are parameters. This is used to
then promote other values to match the dual type. For example, if a user passes a parameter

which is a `Dual` and a `u0` which is a `Float64`, after the first time step, `f(u,p,t) = p*u`
will change `u0` from `Float64` to `Dual`. Thus the state variable always needs to be converted
to a dual number before the solve. Worse still, this needs to be done in the case of
`f(du,u,p,t) = du[1] = p*u[1]`, and thus running `f` and taking the return value is not a valid
way to calculate the required state type.

But given the properties of automatic differentiation requiring that differentiation of parameters
implies differentiation of state, we assume any dual parameters implies differentiation of state
and then attempt to upconvert `u0` to match that dual-ness. Because this changes types, this needs
to be specified at compiled time and thus cannot have a Bool-based opt out, so in the future this
may be extended to use a preference system to opt-out with a `UPCONVERT_DUALS`. In the case where
upconversion is not done automatically, the user is required to upconvert all initial conditions
themselves, for an example of how this can be confusing to a user see
https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937
"""
function anyeltypedual(x, counter = 0)
    if propertynames(x) === ()
        Any
    elseif counter < DUALCHECK_RECURSION_MAX
        diffeqmapreduce(DualEltypeChecker(x, counter), promote_dual,
            map(Val, propertynames(x)))
    else
        Any
    end
end

# Opt out since these are using for preallocation, not differentiation
anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, counter = 0) = Any
anyeltypedual(x::Type{T}, counter = 0) where {T <: ForwardDiff.AbstractConfig} = Any

Base.@pure function __anyeltypedual(::Type{T}) where {T}
    hasproperty(T, :parameters) ?
    mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) : T
end
anyeltypedual(::Type{T}, counter = 0) where {T} = __anyeltypedual(T)
anyeltypedual(::Type{T}, counter = 0) where {T <: ForwardDiff.Dual} = T
function anyeltypedual(::Type{T}, counter = 0) where {T <: Union{AbstractArray, Set}}
    anyeltypedual(eltype(T))
end
Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple}
    if isconcretetype(eltype(T))
        return eltype(T)
    end
    if isempty(T.parameters)
        Any
    else
        mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any)
    end
end
anyeltypedual(::Type{T}, counter = 0) where {T <: NTuple} = __anyeltypedual_ntuple(T)

# Any in this context just means not Dual
anyeltypedual(x::Number, counter = 0) = anyeltypedual(typeof(x))
anyeltypedual(x::Union{String, Symbol}, counter = 0) = typeof(x)
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
    counter = 0) where {
    T <:
    Union{Number,
        Symbol,
        String}}
    anyeltypedual(T)
end
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
    counter = 0) where {
    T <: Union{
        AbstractArray{
            <:Number,
        },
        Set{
            <:Number,
        }}}
    anyeltypedual(eltype(x))
end
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
    counter = 0) where {N, T <: NTuple{N, <:Number}}
    anyeltypedual(eltype(x))
end

# Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x))
function anyeltypedual(x::AbstractArray, counter = 0)
    if isconcretetype(eltype(x))
        anyeltypedual(eltype(x))
    elseif !isempty(x) && all(i -> isassigned(x, i), 1:length(x)) &&
           counter < DUALCHECK_RECURSION_MAX
        counter += 1
        mapreduce(y -> anyeltypedual(y, counter), promote_dual, x)
    else
        # This fallback to Any is required since otherwise we cannot handle `undef` in all cases
        #  misses cases of
        Any
    end
end

function anyeltypedual(x::Set, counter = 0)
    if isconcretetype(eltype(x))
        anyeltypedual(eltype(x))
    else
        # This fallback to Any is required since otherwise we cannot handle `undef` in all cases
        Any
    end
end

function anyeltypedual(x::Tuple, counter = 0)
    # Handle the empty tuple case separately for inference and to avoid mapreduce error
    if x === ()
        Any
    else
        diffeqmapreduce(anyeltypedual, promote_dual, x)
    end
end
function anyeltypedual(x::Dict, counter = 0)
    isempty(x) ? eltype(values(x)) : mapreduce(anyeltypedual, promote_dual, values(x))
end
function anyeltypedual(x::NamedTuple, counter = 0)
    isempty(x) ? Any : diffeqmapreduce(anyeltypedual, promote_dual, values(x))
end
@inline function promote_u0(u0, p, t0)
    if !(eltype(u0) <: ForwardDiff.Dual)
        T = anyeltypedual(p)
        T === Any && return u0
        if T <: ForwardDiff.Dual
            return T.(u0)
        end
    end
    u0
end

@inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0)
    if !(real(eltype(u0)) <: ForwardDiff.Dual)
        T = anyeltypedual(p)
        T === Any && return u0
        if T <: ForwardDiff.Dual
            Ts = promote_type(T, eltype(u0))
            return Ts.(u0)
        end
    end
    u0
end

struct DualEltypeChecker{T}
    x::T
    counter::Int
    DualEltypeChecker(x::T, counter::Int) where {T} = new{T}(x, counter + 1)
end

function (dec::DualEltypeChecker)(::Val{Y}) where {Y}
    isdefined(dec.x, Y) || return Any
    dec.counter >= DUALCHECK_RECURSION_MAX && return Any
    anyeltypedual(getproperty(dec.x, Y), dec.counter)
end

# use `getfield` on `Pairs`, see https://github.com/JuliaLang/julia/pull/39448
function (dec::DualEltypeChecker{<:Base.Pairs})(::Val{Y}) where {Y}
    isdefined(dec.x, Y) || return Any
    dec.counter >= DUALCHECK_RECURSION_MAX && return Any
    anyeltypedual(getfield(dec.x, Y), dec.counter)
end

struct Thing
    a::Float64
end
struct Wrapper1{T}
    thing::T
end
struct Wrapper2{T}
    thing::T
end

thing = Thing(1.0)
x = 1.0

Now the checks:

promote_u0(x, Wrapper1(thing), (0.0, 1.0))
@code_warntype promote_u0(x, Wrapper1(thing), (0.0, 1.0))
MethodInstance for promote_u0(::Float64, ::Wrapper1{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ Main c:\Users\accou\OneDrive\Computer\Desktop\test.jl:346
Arguments
  #self#::Core.Const(promote_u0)
  u0::Float64
  p::Wrapper1{Thing}
  t0::Tuple{Float64, Float64}
Locals
  T::Any
Body::Any
1 ─       nothing
│         Core.NewvarNode(:(T))
│   %3  = Main.eltype(u0)::Core.Const(Float64)
│   %4  = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %5  = (%3 <: %4)::Core.Const(false)
│   %6  = !%5::Core.Const(true)
└──       goto #6 if not %6
2 ─       (T = Main.anyeltypedual(p))
│   %9  = (T === Main.Any)::Bool
└──       goto #4 if not %9
3 ─       return u0
4 ─ %12 = T::Any
│   %13 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %14 = (%12 <: %13)::Bool
└──       goto #6 if not %14
5 ─ %16 = Base.broadcasted(T, u0)::Base.Broadcast.Broadcasted{Style, Nothing} where Style<:Union{Nothing, Base.Broadcast.BroadcastStyle}
│   %17 = Base.materialize(%16)::Any
└──       return %17
6 ┄       return u0
promote_u0(x, Wrapper2(thing), (0.0, 1.0))
@code_warntype promote_u0(x, Wrapper2(thing), (0.0, 1.0))
MethodInstance for promote_u0(::Float64, ::Wrapper2{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ Main c:\Users\accou\OneDrive\Computer\Desktop\test.jl:346
Arguments
  #self#::Core.Const(promote_u0)
  u0::Float64
  p::Wrapper2{Thing}
  t0::Tuple{Float64, Float64}
Locals
  T::Type{Any}
Body::Float64
1 ─      nothing
│        Core.NewvarNode(:(T))
│   %3 = Main.eltype(u0)::Core.Const(Float64)
│   %4 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %5 = (%3 <: %4)::Core.Const(false)
│   %6 = !%5::Core.Const(true)
└──      goto #5 if not %6
2 ─      (T = Main.anyeltypedual(p))
│   %9 = (T::Core.Const(Any) === Main.Any)::Core.Const(true)
└──      goto #4 if not %9
3 ─      return u0
4 ─      Core.Const(:(T))
│        Core.Const(:(ForwardDiff.Dual))
│        Core.Const(:(%12 <: %13))
│        Core.Const(:(goto %19 if not %14))
│        Core.Const(:(Base.broadcasted(T, u0)))
│        Core.Const(:(Base.materialize(%16)))
└──      Core.Const(:(return %17))
5 ┄      Core.Const(:(return u0))

First reported as SciML/DiffEqBase.jl#918

@quinnj
Copy link
Member

quinnj commented Jul 31, 2023

Perhaps related to the issue I ran into w/ quinnj/JSONBase.jl#2, which @aviatesk helped diagnose as an issue in inference w/ @generated functions called recursively.

@vtjnash
Copy link
Member

vtjnash commented Jul 31, 2023

Can you try with #50696? There appears to be several things in here that are expected to trigger undefined behavior (such as __anyeltypedual), but it looks like it might just be hitting unnecessarily restrictions on recursive THTTs.

@ChrisRackauckas
Copy link
Member Author

Same behavior.

@fjebaker
Copy link

fjebaker commented Aug 2, 2023

Slightly boiled down MWE:

const DUALCHECK_RECURSION_MAX = 10

promote_dual(A, B) = A

diffeqmapreduce(f::F, op::OP, x::Tuple) where {F, OP} = reduce(op, map(f, x))

function anyeltypedual(x, counter = 0)
    diffeqmapreduce(DualEltypeChecker(x, counter), promote_dual,
        map(Val, propertynames(x)))
end
function __anyeltypedual(::Type{T}) where {T}
    hasproperty(T, :parameters) ?
    mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) : T
end
anyeltypedual(::Type{T}, counter = 0) where {T} = __anyeltypedual(T)
anyeltypedual(x::Number, counter = 0) = anyeltypedual(typeof(x))

promote_u0(u0, p, t0) = anyeltypedual(p)

struct DualEltypeChecker{T}
    x::T
    counter::Int
    DualEltypeChecker(x::T, counter::Int) where {T} = new{T}(x, counter + 1)
end

function (dec::DualEltypeChecker)(::Val{Y}) where {Y}
    isdefined(dec.x, Y) || return Any
    dec.counter >= DUALCHECK_RECURSION_MAX && return Any
    anyeltypedual(getproperty(dec.x, Y), dec.counter)
end

struct Thing
    a::Float64
end
struct Wrapper1{T}
    thing::T
end
struct Wrapper2{T}
    thing::T
end

thing = Thing(1.0)
x = 1.0
promote_u0(x, Wrapper1(thing), (0.0, 1.0))
@code_warntype promote_u0(x, Wrapper1(thing), (0.0, 1.0))

MethodInstance for promote_u0(::Float64, ::Wrapper1{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ Main ~/Developer/jl-forward-diff/mwe-chris.jl:19
Arguments
  #self#::Core.Const(promote_u0)
  u0::Float64
  p::Wrapper1{Thing}
  t0::Tuple{Float64, Float64}
Body::Any
1%1 = Main.anyeltypedual(p)::Any
└──      return %1
promote_u0(x, Wrapper2(thing), (0.0, 1.0))
@code_warntype promote_u0(x, Wrapper2(thing), (0.0, 1.0))

MethodInstance for promote_u0(::Float64, ::Wrapper2{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ Main ~/Developer/jl-forward-diff/mwe-chris.jl:19
Arguments
  #self#::Core.Const(promote_u0)
  u0::Float64
  p::Wrapper2{Thing}
  t0::Tuple{Float64, Float64}
Body::Type{Any}
1%1 = Main.anyeltypedual(p)::Core.Const(Any)
└──      return %1

Weirdly however, calling directly

@code_warntype anyeltypedual(Wrapper1(thing))

MethodInstance for anyeltypedual(::Wrapper1{Thing})
  from anyeltypedual(x) @ Main ~/Developer/jl-forward-diff/mwe-chris.jl:9
Arguments
  #self#::Core.Const(anyeltypedual)
  x::Wrapper1{Thing}
Body::Type{Any}
1%1 = (#self#)(x, 0)::Core.Const(Any)
└──      return %1

and similar for Wrapper2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:inference Type inference
Projects
None yet
Development

No branches or pull requests

5 participants