Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LogDensityProblems instead of gradient_logp #1877

Merged
merged 8 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.21.10"
version = "0.21.11"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -11,7 +11,6 @@ AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -20,6 +19,7 @@ EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -32,7 +32,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "4"
Expand All @@ -43,14 +42,14 @@ AdvancedVI = "0.1"
BangBang = "0.3"
Bijectors = "0.8, 0.9, 0.10"
DataStructures = "0.18"
DiffResults = "1"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicPPL = "0.20"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.6.7, 0.7"
LogDensityProblems = "0.12"
MCMCChains = "5"
NamedArrays = "0.9"
Reexport = "0.2, 1"
Expand All @@ -60,5 +59,4 @@ SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
ZygoteRules = "0.2"
julia = "1.6"
2 changes: 1 addition & 1 deletion docs/src/using-turing/autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ title: Automatic Differentiation

Turing supports four packages of automatic differentiation (AD) in the back end during sampling. The default AD backend is [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) for forward-mode AD. Three reverse-mode AD backends are also supported, namely [Tracker](https://github.com/FluxML/Tracker.jl), [Zygote](https://github.com/FluxML/Zygote.jl) and [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl). `Zygote` and `ReverseDiff` are supported optionally if explicitly loaded by the user with `using Zygote` or `using ReverseDiff` next to `using Turing`.

To switch between the different AD backends, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forwarddiff` (`ForwardDiff`), `:tracker` (`Tracker`), `:zygote` (`Zygote`) or `:reversediff` (`ReverseDiff.jl`). When using `ReverseDiff`, to compile the tape only once and cache it for later use, the user needs to load [Memoization.jl](https://github.com/marius311/Memoization.jl) first with `using Memoization` then call `Turing.setrdcache(true)`. However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine. To empty the cache, you can call `Turing.emptyrdcache()`.
To switch between the different AD backends, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forwarddiff` (`ForwardDiff`), `:tracker` (`Tracker`), `:zygote` (`Zygote`) or `:reversediff` (`ReverseDiff.jl`). When using `ReverseDiff`, to compile the tape only once and cache it for later use, the user has to call `Turing.setrdcache(true)`. However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine.


## Compositional Sampling with Differing AD Modes
Expand Down
9 changes: 9 additions & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Tracker: Tracker

import AdvancedVI
import DynamicPPL: getspace, NoDist, NamedDist
import LogDensityProblems
import Random

const PROGRESS = Ref(true)
Expand Down Expand Up @@ -37,12 +38,20 @@ function (f::LogDensityFunction)(θ::AbstractVector)
return getlogp(last(DynamicPPL.evaluate!!(f.model, VarInfo(f.varinfo, f.sampler, θ), f.sampler, f.context)))
end

# LogDensityProblems interface
LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) = f(θ)
LogDensityProblems.dimension(f::LogDensityFunction) = length(f.varinfo[f.sampler])
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
return LogDensityProblems.LogDensityOrder{0}()
end

# Standard tag: Improves stacktraces
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct TuringTag end

# Allow Turing tag in gradient etc. calls of the log density function
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::LogDensityFunction, ::AbstractArray{V}) where {V} = true
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:LogDensityFunction}, ::AbstractArray{V}) where {V} = true

# Random probability measures.
include("stdlib/distributions.jl")
Expand Down
41 changes: 13 additions & 28 deletions src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,6 @@ DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}()

DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space

# Only define traits for `DynamicNUTS` sampler to avoid type piracy and surprises
# TODO: Implement generally with `LogDensityProblems`
const DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}

function DynamicHMC.dimension(ℓ::DynamicHMCLogDensity)
return length(ℓ.varinfo[ℓ.sampler])
end

function DynamicHMC.capabilities(::Type{<:DynamicHMCLogDensity})
return DynamicHMC.LogDensityOrder{1}()
end

function DynamicHMC.logdensity_and_gradient(
ℓ::DynamicHMCLogDensity,
x::AbstractVector,
)
return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler, ℓ.context)
end

"""
DynamicNUTSState

Expand All @@ -46,9 +27,10 @@ State of the [`DynamicNUTS`](@ref) sampler.
# Fields
$(TYPEDFIELDS)
"""
struct DynamicNUTSState{V<:AbstractVarInfo,C,M,S}
struct DynamicNUTSState{L,V<:AbstractVarInfo,C,M,S}
logdensity::L
vi::V
"Cache of sample, log density, and gradient of log density."
"Cache of sample, log density, and gradient of log density evaluation."
cache::C
metric::M
stepsize::S
Expand All @@ -61,10 +43,10 @@ function gibbs_state(
state::DynamicNUTSState,
varinfo::AbstractVarInfo,
)
# Update the previous evaluation.
ℓ = Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext())
# Update the log density function and its cached evaluation.
ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext()))
Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl])
return DynamicNUTSState(varinfo, Q, state.metric, state.stepsize)
return DynamicNUTSState(ℓ, varinfo, Q, state.metric, state.stepsize)
end

DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform()
Expand All @@ -82,10 +64,13 @@ function DynamicPPL.initialstep(
model(rng, vi, spl)
end

# Define log-density function.
ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))

# Perform initial step.
results = DynamicHMC.mcmc_keep_warmup(
rng,
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()),
,
0;
initialization = (q = vi[spl],),
reporter = DynamicHMC.NoProgressReport(),
Expand All @@ -99,7 +84,7 @@ function DynamicPPL.initialstep(

# Create first sample and state.
sample = Transition(vi)
state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ)
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)

return sample, state
end
Expand All @@ -113,7 +98,7 @@ function AbstractMCMC.step(
)
# Compute next sample.
vi = state.vi
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
ℓ = state.logdensity
steps = DynamicHMC.mcmc_steps(
rng,
DynamicHMC.NUTS(),
Expand All @@ -129,7 +114,7 @@ function AbstractMCMC.step(

# Create next sample and state.
sample = Transition(vi)
newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize)
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)

return sample, newstate
end
22 changes: 14 additions & 8 deletions src/contrib/inference/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ function SGHMC{AD}(
return SGHMC{AD,space,typeof(_learning_rate)}(_learning_rate, _momentum_decay)
end

struct SGHMCState{V<:AbstractVarInfo, T<:AbstractVector{<:Real}}
struct SGHMCState{L,V<:AbstractVarInfo, T<:AbstractVector{<:Real}}
logdensity::L
vi::V
velocity::T
end
Expand All @@ -61,7 +62,8 @@ function DynamicPPL.initialstep(

# Compute initial sample and state.
sample = Transition(vi)
state = SGHMCState(vi, zero(vi[spl]))
ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
state = SGHMCState(ℓ, vi, zero(vi[spl]))

return sample, state
end
Expand All @@ -74,9 +76,10 @@ function AbstractMCMC.step(
kwargs...
)
# Compute gradient of log density.
ℓ = state.logdensity
vi = state.vi
θ = vi[spl]
_, grad = gradient_logp(θ, vi, model, spl)
grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ))

# Update latent variables and velocity according to
# equation (15) of Chen et al. (2014)
Expand All @@ -92,7 +95,7 @@ function AbstractMCMC.step(

# Compute next sample and state.
sample = Transition(vi)
newstate = SGHMCState(vi, newv)
newstate = SGHMCState(ℓ, vi, newv)

return sample, newstate
end
Expand Down Expand Up @@ -191,7 +194,8 @@ metadata(t::SGLDTransition) = (lp = t.lp, SGLD_stepsize = t.stepsize)

DynamicPPL.getlogp(t::SGLDTransition) = t.lp

struct SGLDState{V<:AbstractVarInfo}
struct SGLDState{L,V<:AbstractVarInfo}
logdensity::L
vi::V
step::Int
end
Expand All @@ -211,7 +215,8 @@ function DynamicPPL.initialstep(

# Create first sample and state.
sample = SGLDTransition(vi, zero(spl.alg.stepsize(0)))
state = SGLDState(vi, 1)
ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
state = SGLDState(ℓ, vi, 1)

return sample, state
end
Expand All @@ -224,9 +229,10 @@ function AbstractMCMC.step(
kwargs...
)
# Perform gradient step.
ℓ = state.logdensity
vi = state.vi
θ = vi[spl]
_, grad = gradient_logp(θ, vi, model, spl)
grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ))
step = state.step
stepsize = spl.alg.stepsize(step)
θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ))
Expand All @@ -237,7 +243,7 @@ function AbstractMCMC.step(

# Compute next sample and state.
sample = SGLDTransition(vi, stepsize)
newstate = SGLDState(vi, state.step + 1)
newstate = SGLDState(ℓ, vi, state.step + 1)

return sample, newstate
end
15 changes: 4 additions & 11 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,13 @@ import Bijectors: link, invlink
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using Requires

import AdvancedPS
import DiffResults
import ZygoteRules
import LogDensityProblems

include("container.jl")
include("ad.jl")

function __init__()
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("compat/reversediff.jl")
export ReverseDiffAD, getrdcache, setrdcache, emptyrdcache
end
end

export @model,
@varname,
generate_observe,
Expand All @@ -53,11 +44,13 @@ export @model,
ForwardDiffAD,
TrackerAD,
ZygoteAD,
ReverseDiffAD,
value,
gradient_logp,
CHUNKSIZE,
ADBACKEND,
setchunksize,
setrdcache,
getrdcache,
verifygrad,
@logprob_str,
@prob_str
Expand Down
Loading