Skip to content

Commit

Permalink
Use LogDensityProblems instead of gradient_logp (#1877)
Browse files Browse the repository at this point in the history
* Use LogDensityProblems instead of `gradient_logp`

* Update Project.toml

* Update Project.toml

* Import LogDensityProblems in submodule

* Fix Gibbs sampling with DynamicHMC

* Update ad.jl

* Add LogDensityProblems test dependency

* Qualify `LogDensityFunction`
  • Loading branch information
devmotion authored Aug 26, 2022
1 parent 3531d75 commit 5990fae
Show file tree
Hide file tree
Showing 16 changed files with 128 additions and 270 deletions.
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.21.10"
version = "0.21.11"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -11,7 +11,6 @@ AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -20,6 +19,7 @@ EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -32,7 +32,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "4"
Expand All @@ -43,14 +42,14 @@ AdvancedVI = "0.1"
BangBang = "0.3"
Bijectors = "0.8, 0.9, 0.10"
DataStructures = "0.18"
DiffResults = "1"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicPPL = "0.20"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.6.7, 0.7"
LogDensityProblems = "0.12"
MCMCChains = "5"
NamedArrays = "0.9"
Reexport = "0.2, 1"
Expand All @@ -60,5 +59,4 @@ SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
ZygoteRules = "0.2"
julia = "1.6"
2 changes: 1 addition & 1 deletion docs/src/using-turing/autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ title: Automatic Differentiation

Turing supports four packages of automatic differentiation (AD) in the back end during sampling. The default AD backend is [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) for forward-mode AD. Three reverse-mode AD backends are also supported, namely [Tracker](https://github.com/FluxML/Tracker.jl), [Zygote](https://github.com/FluxML/Zygote.jl) and [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl). `Zygote` and `ReverseDiff` are supported optionally if explicitly loaded by the user with `using Zygote` or `using ReverseDiff` next to `using Turing`.

To switch between the different AD backends, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forwarddiff` (`ForwardDiff`), `:tracker` (`Tracker`), `:zygote` (`Zygote`) or `:reversediff` (`ReverseDiff.jl`). When using `ReverseDiff`, to compile the tape only once and cache it for later use, the user needs to load [Memoization.jl](https://github.com/marius311/Memoization.jl) first with `using Memoization` then call `Turing.setrdcache(true)`. However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine. To empty the cache, you can call `Turing.emptyrdcache()`.
To switch between the different AD backends, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forwarddiff` (`ForwardDiff`), `:tracker` (`Tracker`), `:zygote` (`Zygote`) or `:reversediff` (`ReverseDiff.jl`). When using `ReverseDiff`, to compile the tape only once and cache it for later use, the user has to call `Turing.setrdcache(true)`. However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine.


## Compositional Sampling with Differing AD Modes
Expand Down
9 changes: 9 additions & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Tracker: Tracker

import AdvancedVI
import DynamicPPL: getspace, NoDist, NamedDist
import LogDensityProblems
import Random

const PROGRESS = Ref(true)
Expand Down Expand Up @@ -37,12 +38,20 @@ function (f::LogDensityFunction)(θ::AbstractVector)
return getlogp(last(DynamicPPL.evaluate!!(f.model, VarInfo(f.varinfo, f.sampler, θ), f.sampler, f.context)))
end

# LogDensityProblems interface
LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) = f(θ)
LogDensityProblems.dimension(f::LogDensityFunction) = length(f.varinfo[f.sampler])
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
return LogDensityProblems.LogDensityOrder{0}()
end

# Standard tag: Improves stacktraces
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct TuringTag end

# Allow Turing tag in gradient etc. calls of the log density function
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::LogDensityFunction, ::AbstractArray{V}) where {V} = true
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:LogDensityFunction}, ::AbstractArray{V}) where {V} = true

# Random probability measures.
include("stdlib/distributions.jl")
Expand Down
41 changes: 13 additions & 28 deletions src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,6 @@ DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}()

DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space

# Only define traits for `DynamicNUTS` sampler to avoid type piracy and surprises
# TODO: Implement generally with `LogDensityProblems`
const DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}

function DynamicHMC.dimension(ℓ::DynamicHMCLogDensity)
return length(ℓ.varinfo[ℓ.sampler])
end

function DynamicHMC.capabilities(::Type{<:DynamicHMCLogDensity})
return DynamicHMC.LogDensityOrder{1}()
end

function DynamicHMC.logdensity_and_gradient(
::DynamicHMCLogDensity,
x::AbstractVector,
)
return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler, ℓ.context)
end

"""
DynamicNUTSState
Expand All @@ -46,9 +27,10 @@ State of the [`DynamicNUTS`](@ref) sampler.
# Fields
$(TYPEDFIELDS)
"""
struct DynamicNUTSState{V<:AbstractVarInfo,C,M,S}
struct DynamicNUTSState{L,V<:AbstractVarInfo,C,M,S}
logdensity::L
vi::V
"Cache of sample, log density, and gradient of log density."
"Cache of sample, log density, and gradient of log density evaluation."
cache::C
metric::M
stepsize::S
Expand All @@ -61,10 +43,10 @@ function gibbs_state(
state::DynamicNUTSState,
varinfo::AbstractVarInfo,
)
# Update the previous evaluation.
= Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext())
# Update the log density function and its cached evaluation.
= LogDensityProblems.ADgradient(Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext()))
Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl])
return DynamicNUTSState(varinfo, Q, state.metric, state.stepsize)
return DynamicNUTSState(ℓ, varinfo, Q, state.metric, state.stepsize)
end

DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform()
Expand All @@ -82,10 +64,13 @@ function DynamicPPL.initialstep(
model(rng, vi, spl)
end

# Define log-density function.
= LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))

# Perform initial step.
results = DynamicHMC.mcmc_keep_warmup(
rng,
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()),
,
0;
initialization = (q = vi[spl],),
reporter = DynamicHMC.NoProgressReport(),
Expand All @@ -99,7 +84,7 @@ function DynamicPPL.initialstep(

# Create first sample and state.
sample = Transition(vi)
state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ)
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)

return sample, state
end
Expand All @@ -113,7 +98,7 @@ function AbstractMCMC.step(
)
# Compute next sample.
vi = state.vi
= Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
= state.logdensity
steps = DynamicHMC.mcmc_steps(
rng,
DynamicHMC.NUTS(),
Expand All @@ -129,7 +114,7 @@ function AbstractMCMC.step(

# Create next sample and state.
sample = Transition(vi)
newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize)
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)

return sample, newstate
end
22 changes: 14 additions & 8 deletions src/contrib/inference/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ function SGHMC{AD}(
return SGHMC{AD,space,typeof(_learning_rate)}(_learning_rate, _momentum_decay)
end

struct SGHMCState{V<:AbstractVarInfo, T<:AbstractVector{<:Real}}
struct SGHMCState{L,V<:AbstractVarInfo, T<:AbstractVector{<:Real}}
logdensity::L
vi::V
velocity::T
end
Expand All @@ -61,7 +62,8 @@ function DynamicPPL.initialstep(

# Compute initial sample and state.
sample = Transition(vi)
state = SGHMCState(vi, zero(vi[spl]))
= LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
state = SGHMCState(ℓ, vi, zero(vi[spl]))

return sample, state
end
Expand All @@ -74,9 +76,10 @@ function AbstractMCMC.step(
kwargs...
)
# Compute gradient of log density.
= state.logdensity
vi = state.vi
θ = vi[spl]
_, grad = gradient_logp(θ, vi, model, spl)
grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ))

# Update latent variables and velocity according to
# equation (15) of Chen et al. (2014)
Expand All @@ -92,7 +95,7 @@ function AbstractMCMC.step(

# Compute next sample and state.
sample = Transition(vi)
newstate = SGHMCState(vi, newv)
newstate = SGHMCState(ℓ, vi, newv)

return sample, newstate
end
Expand Down Expand Up @@ -191,7 +194,8 @@ metadata(t::SGLDTransition) = (lp = t.lp, SGLD_stepsize = t.stepsize)

DynamicPPL.getlogp(t::SGLDTransition) = t.lp

struct SGLDState{V<:AbstractVarInfo}
struct SGLDState{L,V<:AbstractVarInfo}
logdensity::L
vi::V
step::Int
end
Expand All @@ -211,7 +215,8 @@ function DynamicPPL.initialstep(

# Create first sample and state.
sample = SGLDTransition(vi, zero(spl.alg.stepsize(0)))
state = SGLDState(vi, 1)
= LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
state = SGLDState(ℓ, vi, 1)

return sample, state
end
Expand All @@ -224,9 +229,10 @@ function AbstractMCMC.step(
kwargs...
)
# Perform gradient step.
= state.logdensity
vi = state.vi
θ = vi[spl]
_, grad = gradient_logp(θ, vi, model, spl)
grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ))
step = state.step
stepsize = spl.alg.stepsize(step)
θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ))
Expand All @@ -237,7 +243,7 @@ function AbstractMCMC.step(

# Compute next sample and state.
sample = SGLDTransition(vi, stepsize)
newstate = SGLDState(vi, state.step + 1)
newstate = SGLDState(ℓ, vi, state.step + 1)

return sample, newstate
end
15 changes: 4 additions & 11 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,13 @@ import Bijectors: link, invlink
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using Requires

import AdvancedPS
import DiffResults
import ZygoteRules
import LogDensityProblems

include("container.jl")
include("ad.jl")

function __init__()
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("compat/reversediff.jl")
export ReverseDiffAD, getrdcache, setrdcache, emptyrdcache
end
end

export @model,
@varname,
generate_observe,
Expand All @@ -53,11 +44,13 @@ export @model,
ForwardDiffAD,
TrackerAD,
ZygoteAD,
ReverseDiffAD,
value,
gradient_logp,
CHUNKSIZE,
ADBACKEND,
setchunksize,
setrdcache,
getrdcache,
verifygrad,
@logprob_str,
@prob_str
Expand Down
Loading

2 comments on commit 5990fae

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/67136

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.21.11 -m "<description of version>" 5990fae9f176e84d83bd119fa4a6b0e68f028493
git push origin v0.21.11

Please sign in to comment.