Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check that the correct AD backend is being used #2291

Merged
merged 5 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module HMCTests

using ..Models: gdemo_default
using ..ADUtils: ADTypeCheckContext
#using ..Models: gdemo
using ..NumericalTests: check_gdemo, check_numerical
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
Expand Down Expand Up @@ -321,6 +322,15 @@ using Turing
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
end

@testset "Check ADType" begin
alg = HMC(0.1, 10; adtype=adbackend)
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
sample(rng, m, alg, 10)
end
end

end
14 changes: 13 additions & 1 deletion test/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module OptimisationTests

using ..Models: gdemo, gdemo_default
using ..ADUtils: ADTypeCheckContext
using Distributions
using Distributions.FillArrays: Zeros
using DynamicPPL: DynamicPPL
Expand Down Expand Up @@ -140,7 +141,6 @@ using Turing
gdemo_default, OptimizationOptimJL.LBFGS(); initial_params=true_value
)
m3 = maximum_likelihood(gdemo_default, OptimizationOptimJL.Newton())
# TODO(mhauru) How can we check that the adtype is actually AutoReverseDiff?
m4 = maximum_likelihood(
gdemo_default, OptimizationOptimJL.BFGS(); adtype=AutoReverseDiff()
)
Expand Down Expand Up @@ -616,6 +616,18 @@ using Turing
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
@assert get(result, :c) == (; :c => Array{Float64}[])
end

@testset "ADType" begin
Random.seed!(222)
for adbackend in (AutoReverseDiff(), AutoForwardDiff(), AutoTracker())
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
maximum_likelihood(m; adtype=adbackend)
maximum_a_posteriori(m; adtype=adbackend)
end
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import Turing

include(pkgdir(Turing) * "/test/test_utils/models.jl")
include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl")
include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl")

Turing.setprogress!(false)

Expand Down
213 changes: 213 additions & 0 deletions test/test_utils/ad_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
module ADUtils

import ForwardDiff
import ReverseDiff
import Test
import Tracker
import Turing
import Turing: DynamicPPL
import Zygote
yebai marked this conversation as resolved.
Show resolved Hide resolved

export ADTypeCheckContext

"""Element types that are always valid for a VarInfo regardless of ADType."""
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)

"""A dictionary mapping ADTypes to the element types they use."""
const eltypes_by_adtype = Dict(
Turing.AutoForwardDiff => (ForwardDiff.Dual,),
Turing.AutoReverseDiff => (
ReverseDiff.TrackedArray,
ReverseDiff.TrackedMatrix,
ReverseDiff.TrackedReal,
ReverseDiff.TrackedStyle,
ReverseDiff.TrackedType,
ReverseDiff.TrackedVecOrMat,
ReverseDiff.TrackedVector,
),
# TODO(mhauru) Zygote.Dual is actually the same as ForwardDiff.Dual, so can't
# distinguish between the two.
Turing.AutoZygote => (Zygote.Dual,),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah there's no way you can do this here unfortunately 😕

Possibly crazy idea, but it might make sense to specifically overload the adjoint computation for a given backend which doesn't use types, e.g.

Zygote.@adjoint function check_adtype(...)
    # Should only be hit if we're using Zygote.jl.
    ...
end

? Could do the same with Enzyme.jl and Tapir.jl, both of which would suffer from the same issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this, and implemented it. Turns out there's Zygote.hook for exactly these sorts of things.

I thought a bit about switching to doing all checks like this, and giving up on the element type approach, but that seemed like it would get more complicated than is worth at this point. We now catch all cases except if one uses ForwardDiff when Zygote is expected, which seems sufficient to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very, very nice @mhauru :)

Turing.AutoTracker => (
Tracker.Tracked,
Tracker.TrackedArray,
Tracker.TrackedMatrix,
Tracker.TrackedReal,
Tracker.TrackedStyle,
Tracker.TrackedVecOrMat,
Tracker.TrackedVector,
),
)

"""
IncompatibleADTypeError

An error thrown when an element type is encountered that is unexpected for the given ADType.
"""
struct IncompatibleADTypeError <: Exception
valtype::Type
adtype::Type
end

function Base.showerror(io::IO, e::IncompatibleADTypeError)
return print(
io,
"Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)",
)
end

"""
ADTypeCheckContext{ADType,ChildContext}

A context for checking that the expected ADType is being used.

Evaluating a model with this context will check that the types of values in a `VarInfo` are
compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError`
is thrown.

For instance, evaluating a model with
`ADTypeCheckContext(AutoForwardDiff(), child_context)`
would throw an error if within the model a type associated with e.g. ReverseDiff was
encountered.

As a current short-coming, this context can not distinguish between ForwardDiff and Zygote.
"""
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
DynamicPPL.AbstractContext
child::ChildContext

function ADTypeCheckContext(adbackend, child)
adtype = adbackend isa Type ? adbackend : typeof(adbackend)
if !any(adtype .<: keys(eltypes_by_adtype))
throw(ArgumentError("Unsupported ADType: $adtype"))
end
return new{adtype,typeof(child)}(child)
end
end

adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType

DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child
function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child)
return ADTypeCheckContext(adtype(c), child)
end

"""
valid_eltypes(context::ADTypeCheckContext)

Return the element types that are valid for the ADType of `context` as a tuple.
"""
function valid_eltypes(context::ADTypeCheckContext)
context_at = adtype(context)
for at in keys(eltypes_by_adtype)
if context_at <: at
return (eltypes_by_adtype[at]..., always_valid_eltypes...)
end
end
# This should never be reached due to the check in the inner constructor.
throw(ArgumentError("Unsupported ADType: $(adtype(context))"))
end

"""
check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo)

Check that the element types in `vi` are compatible with the ADType of `context`.

Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
"""
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
valids = valid_eltypes(context)
for val in vi[:]
valtype = typeof(val)
if !any(valtype .<: valids)
throw(IncompatibleADTypeError(valtype, adtype(context)))
end
end
return nothing
end

# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
# context, and then call check_adtype on the result before returning the results from the
# child context.

function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi)
value, logp, vi = DynamicPPL.tilde_assume(
DynamicPPL.childcontext(context), right, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.tilde_assume(rng, context::ADTypeCheckContext, sampler, right, vn, vi)
value, logp, vi = DynamicPPL.tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi)
logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi)
check_adtype(context, vi)
return logp, vi
end

function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
logp, vi = DynamicPPL.tilde_observe(
DynamicPPL.childcontext(context), sampler, right, left, vi
)
check_adtype(context, vi)
return logp, vi
end

function DynamicPPL.dot_tilde_assume(context::ADTypeCheckContext, right, left, vn, vi)
value, logp, vi = DynamicPPL.dot_tilde_assume(
DynamicPPL.childcontext(context), right, left, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.dot_tilde_assume(
rng, context::ADTypeCheckContext, sampler, right, left, vn, vi
)
value, logp, vi = DynamicPPL.dot_tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, left, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, right, left, vi)
logp, vi = DynamicPPL.dot_tilde_observe(
DynamicPPL.childcontext(context), right, left, vi
)
check_adtype(context, vi)
return logp, vi
end

function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
logp, vi = DynamicPPL.dot_tilde_observe(
DynamicPPL.childcontext(context), sampler, right, left, vi
)
check_adtype(context, vi)
return logp, vi
end

# Check that the ADTypeCheckContext works as expected.
Test.@testset "ADTypeCheckContext" begin
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
tm = test_model()
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(Turing.AutoForwardDiff(), tm.context)
)
# This should not throw an error since we are using ForwardDiff as expected.
Turing.sample(contextualised_tm, Turing.NUTS(; adtype=Turing.AutoForwardDiff()), 100)
# Using ReverseDiff when ForwardDiff is expected should throw an error.
Test.@test_throws IncompatibleADTypeError Turing.sample(
contextualised_tm, Turing.NUTS(; adtype=Turing.AutoReverseDiff()), 100
)
end

end
Loading