Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to AD backend usage in externalsampler #2223

Merged
merged 9 commits into from
May 18, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.31.4"
version = "0.31.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
6 changes: 4 additions & 2 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
return Transition(f.model, varinfo, transition)
end

state_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, state) = state_to_turing(parent(f), state)
state_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, state) = TuringState(state, f)
function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transition)
return transition_to_turing(parent(f), transition)
end
Expand All @@ -29,7 +29,9 @@ getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))

setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Accessors.@set f.varinfo = varinfo
setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo) = setvarinfo(parent(f), varinfo)
function setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo)
return Accessors.@set f.ℓ = setvarinfo(f.ℓ, varinfo)
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
Expand Down
13 changes: 13 additions & 0 deletions test/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ end
end
end
end

@testset "don't drop `ADgradient` (PR: #2223)" begin
rng = Random.default_rng()
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
sampler = initialize_nuts(model)
sampler_ext = externalsampler(sampler; unconstrained=true, adtype=AutoForwardDiff())
# Initial step.
state = last(AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler_ext)); n_adapts=0)
@test state.logdensity isa LogDensityProblemsAD.ADGradientWrapper
# Subsequent step.
state = last(AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler_ext), state); n_adapts=0)
@test state.logdensity isa LogDensityProblemsAD.ADGradientWrapper
end
end

@turing_testset "AdvancedMH.jl" begin
Expand Down
Loading