Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump DynamicPPL to 0.23 #2001

Merged
merged 31 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
976e621
bump dppl test versions
torfjelde Apr 18, 2023
4347bd9
also bump bijectors
torfjelde Apr 18, 2023
ee53d1c
bump AdvancedVI versions
torfjelde Apr 18, 2023
384158f
revert Bijectors bump
torfjelde Apr 18, 2023
a258d54
bumped vi and bijectors too
torfjelde Apr 20, 2023
75fe9c3
breaking change
torfjelde Apr 20, 2023
d02d2f5
Merge branch 'master' into torfjelde/dynamicppl-bump
torfjelde Apr 20, 2023
06a2e7d
removed refernce to Bijectors.setadbackend
torfjelde Apr 20, 2023
320af8c
Merge branch 'torfjelde/dynamicppl-bump' of github.com:TuringLang/Tur…
torfjelde Apr 20, 2023
d8f7086
Merge branch 'master' into torfjelde/dynamicppl-bump
torfjelde Jun 6, 2023
d1054ee
make use of DynamicPPL.make_evaluate_args_and_kwargs
torfjelde Jun 6, 2023
1785cf7
bump DPPL version
torfjelde Jun 6, 2023
6f4fecb
bump DPPL version for tests
torfjelde Jun 6, 2023
149c3d6
fixed bug in TracedModel
torfjelde Jun 7, 2023
c6b4d28
forgot to remove some lines
torfjelde Jun 7, 2023
2bafbc3
just drop the kwargs completely :(
torfjelde Jun 8, 2023
36ae046
Update container.jl
yebai Jun 10, 2023
55f8504
Update container.jl
yebai Jun 10, 2023
b007749
will now error if we're using a model with kwargs and SMC
torfjelde Jun 12, 2023
6bf0980
added reference to issue
torfjelde Jun 12, 2023
b3f13a1
added test for keyword models failing
torfjelde Jun 12, 2023
404882c
make this a breaking change
torfjelde Jun 12, 2023
3c036eb
made error message more informative
torfjelde Jun 13, 2023
56ce7f1
makde it slightly less informative
torfjelde Jun 13, 2023
52071b0
fixed typo in checking for TRaceModel
torfjelde Jun 13, 2023
0ed4f59
Merge branch 'master' into torfjelde/dynamicppl-bump
torfjelde Jun 13, 2023
6f580fb
finally fixed the if-statement..
torfjelde Jun 13, 2023
400915b
Fix test error
yebai Jun 14, 2023
394d38d
fixed tests maybe
torfjelde Jun 14, 2023
0eb9445
now fixed maybe
torfjelde Jun 14, 2023
30ad2ff
Update test/inference/Inference.jl
torfjelde Jun 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.25.2"
version = "0.26.0"


[deps]
Expand Down Expand Up @@ -48,7 +48,7 @@ DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicPPL = "0.21.5, 0.22"
DynamicPPL = "0.23"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand Down
45 changes: 16 additions & 29 deletions src/essential/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,21 @@ function TracedModel(
model::Model,
sampler::AbstractSampler,
varinfo::AbstractVarInfo,
rng::Random.AbstractRNG
)
rng::Random.AbstractRNG,
)
context = SamplingContext(rng, sampler, DefaultContext())
evaluator = _get_evaluator(model, varinfo, context)
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator)
end

# TODO: maybe move to DynamicPPL
@generated function _get_evaluator(
model::Model{_F,argnames}, varinfo, context
) where {_F,argnames}
unwrap_args = [
:($DynamicPPL.matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames
]
# We want to give `context` precedence over `model.context` while also
# preserving the leaf context of `context`. We can do this by
# 1. Set the leaf context of `model.context` to `leafcontext(context)`.
# 2. Set leaf context of `context` to the context resulting from (1).
# The result is:
# `context` -> `childcontext(context)` -> ... -> `model.context`
# -> `childcontext(model.context)` -> ... -> `leafcontext(context)`
return quote
context_new = DynamicPPL.setleafcontext(
context, DynamicPPL.setleafcontext(model.context, DynamicPPL.leafcontext(context))
)
(model.f, model, DynamicPPL.resetlogp!!(varinfo), context_new, $(unwrap_args...))
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
if kwargs !== nothing || !isempty(kwargs)
error("Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.")
end
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
model,
sampler,
varinfo,
(model.f, args...)
)
end


function Base.copy(model::AdvancedPS.GenericModel{<:TracedModel})
newtask = copy(model.ctask)
newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(deepcopy(model.f.model), deepcopy(model.f.sampler), deepcopy(model.f.varinfo), deepcopy(model.f.evaluator))
Expand Down Expand Up @@ -73,10 +58,12 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
return trace
end

function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
# Extract the `args`.
args = trace.model.ctask.args
_, _, container, = args
rng = container.rng
# From `args`, extract the `SamplingContext`, which contains the RNG.
sampling_context = args[3]
rng = sampling_context.rng
trace.rng = rng
return trace
end
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Clustering = "0.14, 0.15"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.21.5, 0.22"
DynamicPPL = "0.23"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
LogDensityProblems = "2"
Expand Down
6 changes: 6 additions & 0 deletions test/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ end
@test length(unique(c[:m])) == 1
@test length(unique(c[:s])) == 1
end

# https://github.com/TuringLang/Turing.jl/issues/2007
@turing_testset "keyword arguments not supported" begin
@model kwarg_demo(; x = 2) = return x
@test_throws ErrorException sample(kwarg_demo(), PG(1), 10)
end
end

# @testset "pmmh.jl" begin
Expand Down