Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump DynamicPPL to 0.23 #2001

Merged
merged 31 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
976e621
bump dppl test versions
torfjelde Apr 18, 2023
4347bd9
also bump bijectors
torfjelde Apr 18, 2023
ee53d1c
bump AdvancedVI versions
torfjelde Apr 18, 2023
384158f
revert Bijectors bump
torfjelde Apr 18, 2023
a258d54
bumped vi and bijectors too
torfjelde Apr 20, 2023
75fe9c3
breaking change
torfjelde Apr 20, 2023
d02d2f5
Merge branch 'master' into torfjelde/dynamicppl-bump
torfjelde Apr 20, 2023
06a2e7d
removed refernce to Bijectors.setadbackend
torfjelde Apr 20, 2023
320af8c
Merge branch 'torfjelde/dynamicppl-bump' of github.com:TuringLang/Tur…
torfjelde Apr 20, 2023
d8f7086
Merge branch 'master' into torfjelde/dynamicppl-bump
torfjelde Jun 6, 2023
d1054ee
make use of DynamicPPL.make_evaluate_args_and_kwargs
torfjelde Jun 6, 2023
1785cf7
bump DPPL version
torfjelde Jun 6, 2023
6f4fecb
bump DPPL version for tests
torfjelde Jun 6, 2023
149c3d6
fixed bug in TracedModel
torfjelde Jun 7, 2023
c6b4d28
forgot to remove some lines
torfjelde Jun 7, 2023
2bafbc3
just drop the kwargs completely :(
torfjelde Jun 8, 2023
36ae046
Update container.jl
yebai Jun 10, 2023
55f8504
Update container.jl
yebai Jun 10, 2023
b007749
will now error if we're using a model with kwargs and SMC
torfjelde Jun 12, 2023
6bf0980
added reference to issue
torfjelde Jun 12, 2023
b3f13a1
added test for keyword models failing
torfjelde Jun 12, 2023
404882c
make this a breaking change
torfjelde Jun 12, 2023
3c036eb
made error message more informative
torfjelde Jun 13, 2023
56ce7f1
makde it slightly less informative
torfjelde Jun 13, 2023
52071b0
fixed typo in checking for TRaceModel
torfjelde Jun 13, 2023
0ed4f59
Merge branch 'master' into torfjelde/dynamicppl-bump
torfjelde Jun 13, 2023
6f580fb
finally fixed the if-statement..
torfjelde Jun 13, 2023
400915b
Fix test error
yebai Jun 14, 2023
394d38d
fixed tests maybe
torfjelde Jun 14, 2023
0eb9445
now fixed maybe
torfjelde Jun 14, 2023
30ad2ff
Update test/inference/Inference.jl
torfjelde Jun 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicPPL = "0.21.5, 0.22"
DynamicPPL = "0.23"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand Down
44 changes: 15 additions & 29 deletions src/essential/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,20 @@ function TracedModel(
model::Model,
sampler::AbstractSampler,
varinfo::AbstractVarInfo,
rng::Random.AbstractRNG
)
rng::Random.AbstractRNG,
)
context = SamplingContext(rng, sampler, DefaultContext())
evaluator = _get_evaluator(model, varinfo, context)
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator)
# FIXME: We're just dropping the `kwargs` here. I'm guessing this can
# cause issues if the model ends up mutating any of the keyword arguments.
args, _ = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bad.

See other comments on alternatives: #2001 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I've figured out why it breaks, but I think it's going to be difficult to fix it 😕

Libtask only handles nested produce in the case where each instruction it sees contains at most one produce.

julia> using Libtask

julia> f(x) = (produce(x); produce(2x); produce(3x); return nothing)
f (generic function with 1 method)

julia> g(x) = f(x)
g (generic function with 1 method)

julia> task = Libtask.TapedTask(f, 1);

julia> consume(task), consume(task), consume(task)
(1, 2, 3)

julia> task = Libtask.TapedTask(g, 1);  # tracing of nested call

julia> consume(task)   # goes through all the `produce` calls before even calling the `callback` (which is `Libtask.producer`)
counter=1
tf=TapedFunction:
* .func => g
* .ir   =>
------------------
CodeInfo(
1%1 = Main.f(x)::Core.Const(nothing)
└──      return %1
)
------------------

ErrorException("There is a produced value which is not consumed.")Union{Ptr{Nothing}, Base.InterpreterIP}[Ptr{Nothing} @0x00007fa8d200eeeb, Ptr{Nothing} @0x00007fa8a0a30f29, Ptr{Nothing} @0x00007fa8a0a36844, Ptr{Nothing} @0x00007fa8a0a36865, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a366e3, Ptr{Nothing} @0x00007fa8a0a36802, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a35f25, Ptr{Nothing} @0x00007fa8a0a361dd, Ptr{Nothing} @0x00007fa8a0a36512, Ptr{Nothing} @0x00007fa8a0a3652f, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8e6b6656f]
ERROR: There is a produced value which is not consumed.
Stacktrace:
 [1] consume(ttask::TapedTask{typeof(g), Tuple{Int64}})
   @ Libtask ~/.julia/packages/Libtask/h7Kal/src/tapedtask.jl:153
 [2] top-level scope
   @ REPL[9]:1

This also makes me realize that Libtask.jl will also just silently do the wrong thing in certain cases where we use @submodel and the inner-model contains observations..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And all of this effectively means that, with the current Libtask, we cannot support kwargs in a model with out converting it all to positional arguments..

Copy link
Member

@yebai yebai Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Libtask only handles nested produce in the case where each instruction it sees contains at most one produce.

Yes, this was a design decision -- it becomes very complicated to trace every function in a nested way. In theory, this can be fixed by adding submodels to the list of functions we need to trace into.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And all of this effectively means that, with the current Libtask, we cannot support kwargs in a model with out converting it all to positional arguments..

That could be a temporary hacky solution before @KDr2 and I add submodels to the list of functions we need to trace recursively.

Copy link
Member

@yebai yebai Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wanted to make the dependency light so Libtask is more stable for new Julia releases. Also, recursively unrolling will create unnecessarily large tape in most use cases, i.e. for SMC/PG, it has significant performance penalties.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha. But yeah, it should be possible. I'm actually trying to implement exactly that right now.

It's ok to ignore that for now.
It is very time-consuming to make this feature right if I remember the details correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love to:) But this completely breaks SMC samplers for any model with kwargs 😕

Copy link
Member

@yebai yebai Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your previous commit fixed all tests except the new grammar model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it will run with all kwarg models but it will silently do the incorrect thing if the default kwargs are not the ones we're supposed to use. In this PR we currently just drop them completely, i.e. we end up using the default kwargs always.

I don't think this quite qualifies as good enough:/

return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
model,
sampler,
varinfo,
(model.f, args...),
)
end

# TODO: maybe move to DynamicPPL
@generated function _get_evaluator(
model::Model{_F,argnames}, varinfo, context
) where {_F,argnames}
unwrap_args = [
:($DynamicPPL.matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames
]
# We want to give `context` precedence over `model.context` while also
# preserving the leaf context of `context`. We can do this by
# 1. Set the leaf context of `model.context` to `leafcontext(context)`.
# 2. Set leaf context of `context` to the context resulting from (1).
# The result is:
# `context` -> `childcontext(context)` -> ... -> `model.context`
# -> `childcontext(model.context)` -> ... -> `leafcontext(context)`
return quote
context_new = DynamicPPL.setleafcontext(
context, DynamicPPL.setleafcontext(model.context, DynamicPPL.leafcontext(context))
)
(model.f, model, DynamicPPL.resetlogp!!(varinfo), context_new, $(unwrap_args...))
end
end


function Base.copy(model::AdvancedPS.GenericModel{<:TracedModel})
newtask = copy(model.ctask)
newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(deepcopy(model.f.model), deepcopy(model.f.sampler), deepcopy(model.f.varinfo), deepcopy(model.f.evaluator))
Expand Down Expand Up @@ -73,9 +57,11 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
return trace
end

function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
# Extract the `args`.
args = trace.model.ctask.args
_, _, container, = args
# From `args`, extract the RNG-container, i.e. `SamplingContext`.
container = args[3]
rng = container.rng
trace.rng = rng
return trace
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Clustering = "0.14, 0.15"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.21.5, 0.22"
DynamicPPL = "0.23"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
LogDensityProblems = "2"
Expand Down