From 976e621a6a67663471794220d9d85a05a3433b0f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 18 Apr 2023 19:41:36 +0100 Subject: [PATCH 01/27] bump dppl test versions --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 6433bd3c8..dd5b5fd07 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" -DynamicPPL = "0.21.5" +DynamicPPL = "0.21.5, 0.22" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" diff --git a/test/Project.toml b/test/Project.toml index 8af7826b5..9eaab9066 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -40,7 +40,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.21.5" +DynamicPPL = "0.21.5, 0.22" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" LogDensityProblems = "2" From 4347bd9ca65bd65ec068556f86662de0151b4b5e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 18 Apr 2023 19:45:02 +0100 Subject: [PATCH 02/27] also bump bijectors --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dd5b5fd07..da9dbb2e7 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" AdvancedVI = "0.1" BangBang = "0.3" -Bijectors = "0.8, 0.9, 0.10" +Bijectors = "0.8, 0.9, 0.10, 0.11, 0.12" DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" From ee53d1c02ba06f68804ea83a5d4bd68768c27fc1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 18 Apr 2023 19:46:21 +0100 Subject: [PATCH 03/27] bump AdvancedVI versions --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index da9dbb2e7..b63b53bbf 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ AbstractMCMC = "4" AdvancedHMC = "0.3.0, 0.4" AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" -AdvancedVI = "0.1" +AdvancedVI = "0.1, 0.2" BangBang = "0.3" Bijectors = "0.8, 0.9, 0.10, 0.11, 0.12" DataStructures = "0.18" diff --git a/test/Project.toml b/test/Project.toml index 9eaab9066..93cf7d437 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -35,7 +35,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "4" AdvancedMH = "0.6, 0.7" AdvancedPS = "0.4" -AdvancedVI = "0.1" +AdvancedVI = "0.1, 0.2" Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" From 384158fd0259fedac6ee902dc2270ae58d0cf0f1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 18 Apr 2023 19:55:07 +0100 Subject: [PATCH 04/27] revert Bijectors bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b63b53bbf..47afd9057 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" AdvancedVI = "0.1, 0.2" BangBang = "0.3" -Bijectors = "0.8, 0.9, 0.10, 0.11, 0.12" +Bijectors = "0.8, 0.9, 0.10" DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" From a258d54e8e80c8d911ad439d90eb713674654f42 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 Apr 2023 13:39:00 +0100 Subject: [PATCH 05/27] bumped vi and bijectors too --- Project.toml | 4 ++-- src/variational/advi.jl | 2 +- test/Project.toml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 47afd9057..b00816fa3 100644 --- a/Project.toml +++ b/Project.toml @@ -40,9 +40,9 @@ AbstractMCMC = "4" AdvancedHMC = "0.3.0, 0.4" AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" -AdvancedVI = "0.1, 0.2" +AdvancedVI = "0.2" BangBang = "0.3" -Bijectors = "0.8, 0.9, 0.10" +Bijectors = "0.12" DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" diff --git a/src/variational/advi.jl b/src/variational/advi.jl index 3169b9461..b9074515c 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -1,5 +1,5 @@ # TODO(torfjelde): Find a better solution. -struct Vec{N, B<:Bijectors.Bijector{N}} <: Bijectors.Bijector{1} +struct Vec{N,B} <: Bijectors.Bijector b::B size::NTuple{N, Int} end diff --git a/test/Project.toml b/test/Project.toml index 93cf7d437..d3f7747e0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -35,7 +35,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "4" AdvancedMH = "0.6, 0.7" AdvancedPS = "0.4" -AdvancedVI = "0.1, 0.2" +AdvancedVI = "0.2" Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" From 75fe9c3178ca558f33d6f14eebbd9407b7f76df5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 Apr 2023 13:39:23 +0100 Subject: [PATCH 06/27] breaking change --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b00816fa3..a3a879d1f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.24.3" +version = "0.25.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 06a2e7d93d8f8da4a3037403a7a247fd23f821ab Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 Apr 2023 14:26:43 +0100 Subject: [PATCH 07/27] removed refernce to Bijectors.setadbackend --- src/essential/ad.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 3000f46c9..61dee85e1 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -6,7 +6,6 @@ setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) function setadbackend(backend::Val) _setadbackend(backend) AdvancedVI.setadbackend(backend) - Bijectors.setadbackend(backend) end function _setadbackend(::Val{:forwarddiff}) From d1054eeb6208f8a000f91619269d7c54ce9033e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Jun 2023 15:13:54 +0100 Subject: [PATCH 08/27] make use of DynamicPPL.make_evaluate_args_and_kwargs --- src/essential/container.jl | 42 ++++++++++++++------------------------ 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index 19abca423..27a176c2a 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -1,3 +1,9 @@ +struct ArgsAndKwargsF{F} + f::F +end + +(f::ArgsAndKwargsF)(args, kwargs) = f.f(args...; kwargs...) + struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} model::M sampler::S @@ -9,36 +15,18 @@ function TracedModel( model::Model, sampler::AbstractSampler, varinfo::AbstractVarInfo, - rng::Random.AbstractRNG -) + rng::Random.AbstractRNG, +) context = SamplingContext(rng, sampler, DefaultContext()) - evaluator = _get_evaluator(model, varinfo, context) - return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( + model, + sampler, + varinfo, + (ArgsAndKwargsF(model.f), args, kwargs), + ) end -# TODO: maybe move to DynamicPPL -@generated function _get_evaluator( - model::Model{_F,argnames}, varinfo, context -) where {_F,argnames} - unwrap_args = [ - :($DynamicPPL.matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames - ] - # We want to give `context` precedence over `model.context` while also - # preserving the leaf context of `context`. We can do this by - # 1. Set the leaf context of `model.context` to `leafcontext(context)`. - # 2. Set leaf context of `context` to the context resulting from (1). - # The result is: - # `context` -> `childcontext(context)` -> ... -> `model.context` - # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` - return quote - context_new = DynamicPPL.setleafcontext( - context, DynamicPPL.setleafcontext(model.context, DynamicPPL.leafcontext(context)) - ) - (model.f, model, DynamicPPL.resetlogp!!(varinfo), context_new, $(unwrap_args...)) - end -end - - function Base.copy(model::AdvancedPS.GenericModel{<:TracedModel}) newtask = copy(model.ctask) newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(deepcopy(model.f.model), deepcopy(model.f.sampler), deepcopy(model.f.varinfo), deepcopy(model.f.evaluator)) From 1785cf7203698546fd6acb84fb0831534c23cd3f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Jun 2023 15:14:43 +0100 Subject: [PATCH 09/27] bump DPPL version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2c5730c4f..124c3d53f 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" -DynamicPPL = "0.21.5, 0.22" +DynamicPPL = "0.23" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" From 6f4fecbb124477dee2e01e7c6414694252f53467 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Jun 2023 15:15:08 +0100 Subject: [PATCH 10/27] bump DPPL version for tests --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 790eb4f33..adf489a6c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -40,7 +40,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.21.5, 0.22" +DynamicPPL = "0.23" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" LogDensityProblems = "2" From 149c3d677f22fe03204bc1aa29e2dca0730368fe Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 01:25:12 +0100 Subject: [PATCH 11/27] fixed bug in TracedModel --- src/essential/container.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/essential/container.jl b/src/essential/container.jl index 27a176c2a..9be82cf26 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -64,6 +64,10 @@ end function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R} args = trace.model.ctask.args _, _, container, = args + # Extract the `args`. + args, _ = trace.model.ctask.args + # From `args`, extract the RNG-container, i.e. `SamplingContext`. + container = args[3] rng = container.rng trace.rng = rng return trace From c6b4d28c15b1f5e27c7d832ed9df104bcb482c93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 01:28:51 +0100 Subject: [PATCH 12/27] forgot to remove some lines --- src/essential/container.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index 9be82cf26..20902b2d9 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -62,8 +62,6 @@ function AdvancedPS.reset_logprob!(trace::TracedModel) end function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R} - args = trace.model.ctask.args - _, _, container, = args # Extract the `args`. args, _ = trace.model.ctask.args # From `args`, extract the RNG-container, i.e. `SamplingContext`. From 2bafbc3dafb70f7ca822f7ca7bdd775bade34952 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 8 Jun 2023 11:38:30 +0100 Subject: [PATCH 13/27] just drop the kwargs completely :( --- src/essential/container.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index 20902b2d9..69cf9dd5b 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -1,9 +1,3 @@ -struct ArgsAndKwargsF{F} - f::F -end - -(f::ArgsAndKwargsF)(args, kwargs) = f.f(args...; kwargs...) - struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} model::M sampler::S @@ -18,12 +12,14 @@ function TracedModel( rng::Random.AbstractRNG, ) context = SamplingContext(rng, sampler, DefaultContext()) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + # FIXME: We're just dropping the `kwargs` here. I'm guessing this can + # cause issues if the model ends up mutating any of the keyword arguments. + args, _ = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( model, sampler, varinfo, - (ArgsAndKwargsF(model.f), args, kwargs), + (model.f, args...), ) end @@ -61,9 +57,9 @@ function AdvancedPS.reset_logprob!(trace::TracedModel) return trace end -function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R} +function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R} # Extract the `args`. - args, _ = trace.model.ctask.args + args = trace.model.ctask.args # From `args`, extract the RNG-container, i.e. `SamplingContext`. container = args[3] rng = container.rng From 36ae04629d7b87868ac015b2d48195d5f0d84609 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Sat, 10 Jun 2023 13:50:02 +0100 Subject: [PATCH 14/27] Update container.jl --- src/essential/container.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index 69cf9dd5b..34cb503c9 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -14,12 +14,15 @@ function TracedModel( context = SamplingContext(rng, sampler, DefaultContext()) # FIXME: We're just dropping the `kwargs` here. I'm guessing this can # cause issues if the model ends up mutating any of the keyword arguments. - args, _ = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + args_and_kwwargs = [args..., kwwargs...] return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( model, sampler, varinfo, - (model.f, args...), + # FIXME: this is a hacky way of preserving a previous DynamicPPL behavior prior + # https://github.com/TuringLang/DynamicPPL.jl/pull/477 + (model.f, args_and_kwwargs...), ) end From 55f85041d0925c495e153f068132fa384c410ce4 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Sat, 10 Jun 2023 14:48:51 +0100 Subject: [PATCH 15/27] Update container.jl --- src/essential/container.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index 34cb503c9..2e5d9b99a 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -15,14 +15,14 @@ function TracedModel( # FIXME: We're just dropping the `kwargs` here. I'm guessing this can # cause issues if the model ends up mutating any of the keyword arguments. args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) - args_and_kwwargs = [args..., kwwargs...] + args_and_kwargs = [args..., kwargs...] return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( model, sampler, varinfo, # FIXME: this is a hacky way of preserving a previous DynamicPPL behavior prior # https://github.com/TuringLang/DynamicPPL.jl/pull/477 - (model.f, args_and_kwwargs...), + (model.f, args_and_kwargs...), ) end From b0077490904aff58c78968232de925da2adc8b0c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 12 Jun 2023 13:45:50 +0100 Subject: [PATCH 16/27] will now error if we're using a model with kwargs and SMC --- src/essential/container.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index 2e5d9b99a..7dbe90612 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -12,17 +12,15 @@ function TracedModel( rng::Random.AbstractRNG, ) context = SamplingContext(rng, sampler, DefaultContext()) - # FIXME: We're just dropping the `kwargs` here. I'm guessing this can - # cause issues if the model ends up mutating any of the keyword arguments. args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) - args_and_kwargs = [args..., kwargs...] + if kwargs !== nothing || isempty(kwargs) + error("Sampling with `$(sampler.alg)` does not support keyword arguments.") + end return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( model, sampler, varinfo, - # FIXME: this is a hacky way of preserving a previous DynamicPPL behavior prior - # https://github.com/TuringLang/DynamicPPL.jl/pull/477 - (model.f, args_and_kwargs...), + (model.f, args...) ) end @@ -63,9 +61,9 @@ end function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R} # Extract the `args`. args = trace.model.ctask.args - # From `args`, extract the RNG-container, i.e. `SamplingContext`. - container = args[3] - rng = container.rng + # From `args`, extract the `SamplingContext`, which contains the RNG. + sampling_context = args[3] + rng = sampling_context.rng trace.rng = rng return trace end From 6bf0980b18ae4b5df5c73fe631ed8e1109e19755 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 12 Jun 2023 13:48:14 +0100 Subject: [PATCH 17/27] added reference to issue --- src/essential/container.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index 7dbe90612..bc01df3ed 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -14,7 +14,7 @@ function TracedModel( context = SamplingContext(rng, sampler, DefaultContext()) args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) if kwargs !== nothing || isempty(kwargs) - error("Sampling with `$(sampler.alg)` does not support keyword arguments.") + error("Sampling with `$(sampler.alg)` does not support keyword arguments. See issue #2007 for more details.") end return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( model, From b3f13a1f779f8000d359fc0655ac36c4e1869706 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 12 Jun 2023 13:51:52 +0100 Subject: [PATCH 18/27] added test for keyword models failing --- test/inference/AdvancedSMC.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/inference/AdvancedSMC.jl b/test/inference/AdvancedSMC.jl index c0bf738df..de6e65e40 100644 --- a/test/inference/AdvancedSMC.jl +++ b/test/inference/AdvancedSMC.jl @@ -173,6 +173,12 @@ end @test length(unique(c[:m])) == 1 @test length(unique(c[:s])) == 1 end + + # https://github.com/TuringLang/Turing.jl/issues/2007 + @turing_testset "keyword arguments not supported" begin + @model kwarg_demo(; x = 2) = return x + @test_throws ErrorException sample(kwarg_demo(), PG(1), 10) + end end # @testset "pmmh.jl" begin From 404882cf6b4e2ffea631194d922557a37e0f56ef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 12 Jun 2023 13:52:40 +0100 Subject: [PATCH 19/27] make this a breaking change --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 124c3d53f..11ee56215 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.25.2" +version = "0.26.0" [deps] From 3c036ebd640c9142e223e23c933f443582b473c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jun 2023 18:06:24 +0100 Subject: [PATCH 20/27] made error message more informative --- src/essential/container.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index bc01df3ed..c015040d1 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -14,7 +14,7 @@ function TracedModel( context = SamplingContext(rng, sampler, DefaultContext()) args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) if kwargs !== nothing || isempty(kwargs) - error("Sampling with `$(sampler.alg)` does not support keyword arguments. See issue #2007 for more details.") + error("Sampling with `$(sampler.alg)` does not support models with keyword arguments (in this case $(model.f)). See issue #2007 for more details.") end return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( model, From 56ce7f1f23adca5c62f828642826420a786ab7fb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jun 2023 18:06:49 +0100 Subject: [PATCH 21/27] makde it slightly less informative --- src/essential/container.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index c015040d1..b401b4389 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -14,7 +14,7 @@ function TracedModel( context = SamplingContext(rng, sampler, DefaultContext()) args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) if kwargs !== nothing || isempty(kwargs) - error("Sampling with `$(sampler.alg)` does not support models with keyword arguments (in this case $(model.f)). See issue #2007 for more details.") + error("Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.") end return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( model, From 52071b09a20504a50b8cf3d99629a5b9165f1168 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jun 2023 18:08:16 +0100 Subject: [PATCH 22/27] fixed typo in checking for TRaceModel --- src/essential/container.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index b401b4389..0c8f68e55 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -13,7 +13,7 @@ function TracedModel( ) context = SamplingContext(rng, sampler, DefaultContext()) args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) - if kwargs !== nothing || isempty(kwargs) + if kwargs !== nothing || !isempty(kwargs) error("Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.") end return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( From 6f580fb0195ddc89d1a85b0a3dd5772b8acab81d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jun 2023 23:49:58 +0100 Subject: [PATCH 23/27] finally fixed the if-statement.. --- src/essential/container.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index 0c8f68e55..b68cb5b09 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -13,7 +13,7 @@ function TracedModel( ) context = SamplingContext(rng, sampler, DefaultContext()) args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) - if kwargs !== nothing || !isempty(kwargs) + if kwargs !== nothing && !isempty(kwargs) error("Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.") end return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( From 400915b3330698b7a9638e002aae6aadaf905bf5 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 14 Jun 2023 01:00:40 +0100 Subject: [PATCH 24/27] Fix test error --- test/inference/Inference.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index bafe119bd..92ef92ab5 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -246,8 +246,8 @@ return priors end - chain = sample(gauss(x), PG(10), 10) - chain = sample(gauss(x), SMC(), 10) + @test_throws ErrorException chain = sample(gauss(x), PG(10), 10) + @test_throws ErrorException chain = sample(gauss(x), SMC(), 10) @model function gauss2(::Type{TV}=Vector{Float64}; x) where {TV} priors = TV(undef, 2) From 394d38dd7c54480b34f78b23d45a8f1f7023c9ca Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 14 Jun 2023 01:06:49 +0100 Subject: [PATCH 25/27] fixed tests maybe --- test/inference/Inference.jl | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index 92ef92ab5..f8b468138 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -246,8 +246,8 @@ return priors end - @test_throws ErrorException chain = sample(gauss(x), PG(10), 10) - @test_throws ErrorException chain = sample(gauss(x), SMC(), 10) + chain = sample(gauss(x), PG(10), 10) + chain = sample(gauss(x), SMC(), 10) @model function gauss2(::Type{TV}=Vector{Float64}; x) where {TV} priors = TV(undef, 2) @@ -259,11 +259,28 @@ return priors end - chain = sample(gauss2(; x=x), PG(10), 10) - chain = sample(gauss2(; x=x), SMC(), 10) + @test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10) + @test_throws ErrorExceptionchain = sample(gauss2(; x=x), SMC(), 10) + + @test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10) + @test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10) + + @model function gauss3(x, ::Type{TV}=Vector{Float64}) where {TV} + priors = TV(undef, 2) + priors[1] ~ InverseGamma(2, 3) # s + priors[2] ~ Normal(0, sqrt(priors[1])) # m + for i in 1:length(x) + x[i] ~ Normal(priors[2], sqrt(priors[1])) + end + return priors + end + + chain = sample(gauss3(; x=x), PG(10), 10) + @test_throws ErrorExceptionchain = sample(gauss3(; x=x), SMC(), 10) + + chain = sample(gauss3(Vector{Float64}; x=x), PG(10), 10) + chain = sample(gauss3(Vector{Float64}; x=x), SMC(), 10) - chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10) - chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10) end @testset "new interface" begin obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] From 0eb94454a49d56f593bccd6ed9ea911758d90cac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 14 Jun 2023 01:07:57 +0100 Subject: [PATCH 26/27] now fixed maybe --- test/inference/Inference.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index f8b468138..253418400 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -275,12 +275,11 @@ return priors end - chain = sample(gauss3(; x=x), PG(10), 10) - @test_throws ErrorExceptionchain = sample(gauss3(; x=x), SMC(), 10) - - chain = sample(gauss3(Vector{Float64}; x=x), PG(10), 10) - chain = sample(gauss3(Vector{Float64}; x=x), SMC(), 10) + chain = sample(gauss3(x), PG(10), 10) + chain = sample(gauss3(x), SMC(), 10) + chain = sample(gauss3(x, Vector{Real}), PG(10), 10) + chain = sample(gauss3(x, Vector{Real}), SMC(), 10) end @testset "new interface" begin obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] From 30ad2ffee08e5ee1fb41b53bb66974c7b9096a00 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jun 2023 09:34:52 +0100 Subject: [PATCH 27/27] Update test/inference/Inference.jl --- test/inference/Inference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inference/Inference.jl b/test/inference/Inference.jl index 253418400..13658c3a4 100644 --- a/test/inference/Inference.jl +++ b/test/inference/Inference.jl @@ -260,7 +260,7 @@ end @test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10) - @test_throws ErrorExceptionchain = sample(gauss2(; x=x), SMC(), 10) + @test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10) @test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10) @test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10)