Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump DynamicPPL to 0.23 #2001

Merged
merged 31 commits into from
Jun 15, 2023
Merged

Bump DynamicPPL to 0.23 #2001

merged 31 commits into from
Jun 15, 2023

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Jun 6, 2023

Closes #2004 and #2005

@torfjelde
Copy link
Member Author

Would be nice to get this through once tests pass:)

@torfjelde
Copy link
Member Author

@torfjelde
Copy link
Member Author

Nvm, it seems like things are broken. Though the resulting errors are quite cryptic to me, so if someone has seen these before, I'd appreciate some guidance 🙏

@torfjelde
Copy link
Member Author

If I drop the usage of ArgsAndKwargsF (which also means I drop support for kwargs in the model), everthing seems to work again.

I'm a bit uncertain what's going on here 😕

Comment on lines 22 to 27
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
model,
sampler,
varinfo,
(ArgsAndKwargsF(model.f), args, kwargs),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some digging, I thought the issue was a lack of deepcopy for the model when not explicitly present in args. So I tried making this into

(f::ArgsAndKwargsF)(model_arg::DynamicPPL.Model, args, kwargs) = f.f(model_arg, args...; kwargs...)

and

    return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
        model,
        sampler,
        varinfo,
        (ArgsAndKwargsF(model.f), args[1], args[2:end], kwargs),
    )

With that, the deepcopy is correctly triggered. But the error in the tests is still occurring 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KDr2 do you remember whether Libtask.TapedFunction supports keyword arguments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK we don't. E.g. TapedTask uses kwargs to pass the deepcopy_types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, just for the record, if I do:

(model.f, args...)

i.e. drop the kwargs + the diversion completely, then everything works again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moreover, it only seems to happen when combined with Gibbs. I haven't run into these issues when running PG on it's own.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, before I used to do:

struct ArgsAndKwargsF{F}
    f::F
end

(f::ArgsAndKwargsF)(args, kwargs) = f.f(args...; kwargs...)

and

    return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
        model,
        sampler,
        varinfo,
        (ArgsAndKwargsF(model.f), args, kwargs),
    )

in an attempt to also support kwargs, but this is causing strange behaviors when using something like PG in Gibbs. The above comments are referring to the code that was doing this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, now that we have Core.kwcall, it should be possible to support from Julia 1.9 on onwards in Libtask.jl

return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator)
# FIXME: We're just dropping the `kwargs` here. I'm guessing this can
# cause issues if the model ends up mutating any of the keyword arguments.
args, _ = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bad.

See other comments on alternatives: #2001 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I've figured out why it breaks, but I think it's going to be difficult to fix it 😕

Libtask only handles nested produce in the case where each instruction it sees contains at most one produce.

julia> using Libtask

julia> f(x) = (produce(x); produce(2x); produce(3x); return nothing)
f (generic function with 1 method)

julia> g(x) = f(x)
g (generic function with 1 method)

julia> task = Libtask.TapedTask(f, 1);

julia> consume(task), consume(task), consume(task)
(1, 2, 3)

julia> task = Libtask.TapedTask(g, 1);  # tracing of nested call

julia> consume(task)   # goes through all the `produce` calls before even calling the `callback` (which is `Libtask.producer`)
counter=1
tf=TapedFunction:
* .func => g
* .ir   =>
------------------
CodeInfo(
1%1 = Main.f(x)::Core.Const(nothing)
└──      return %1
)
------------------

ErrorException("There is a produced value which is not consumed.")Union{Ptr{Nothing}, Base.InterpreterIP}[Ptr{Nothing} @0x00007fa8d200eeeb, Ptr{Nothing} @0x00007fa8a0a30f29, Ptr{Nothing} @0x00007fa8a0a36844, Ptr{Nothing} @0x00007fa8a0a36865, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a366e3, Ptr{Nothing} @0x00007fa8a0a36802, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a35f25, Ptr{Nothing} @0x00007fa8a0a361dd, Ptr{Nothing} @0x00007fa8a0a36512, Ptr{Nothing} @0x00007fa8a0a3652f, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8e6b6656f]
ERROR: There is a produced value which is not consumed.
Stacktrace:
 [1] consume(ttask::TapedTask{typeof(g), Tuple{Int64}})
   @ Libtask ~/.julia/packages/Libtask/h7Kal/src/tapedtask.jl:153
 [2] top-level scope
   @ REPL[9]:1

This also makes me realize that Libtask.jl will also just silently do the wrong thing in certain cases where we use @submodel and the inner-model contains observations..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And all of this effectively means that, with the current Libtask, we cannot support kwargs in a model with out converting it all to positional arguments..

Copy link
Member

@yebai yebai Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Libtask only handles nested produce in the case where each instruction it sees contains at most one produce.

Yes, this was a design decision -- it becomes very complicated to trace every function in a nested way. In theory, this can be fixed by adding submodels to the list of functions we need to trace into.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And all of this effectively means that, with the current Libtask, we cannot support kwargs in a model with out converting it all to positional arguments..

That could be a temporary hacky solution before @KDr2 and I add submodels to the list of functions we need to trace recursively.

Copy link
Member

@yebai yebai Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wanted to make the dependency light so Libtask is more stable for new Julia releases. Also, recursively unrolling will create unnecessarily large tape in most use cases, i.e. for SMC/PG, it has significant performance penalties.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha. But yeah, it should be possible. I'm actually trying to implement exactly that right now.

It's ok to ignore that for now.
It is very time-consuming to make this feature right if I remember the details correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love to:) But this completely breaks SMC samplers for any model with kwargs 😕

Copy link
Member

@yebai yebai Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your previous commit fixed all tests except the new grammar model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it will run with all kwarg models but it will silently do the incorrect thing if the default kwargs are not the ones we're supposed to use. In this PR we currently just drop them completely, i.e. we end up using the default kwargs always.

I don't think this quite qualifies as good enough:/

@yebai
Copy link
Member

yebai commented Jun 9, 2023

Looking at the new grammar test again, I am not quite sure why keyword arguments worked before this PR? If we don't have support for keywords arguments from Libtask, shouldn't these tests always fail?

@torfjelde
Copy link
Member Author

Looking at the new grammar test again, I am not quite sure why keyword arguments worked before this PR? If we don't have support for keywords arguments from Libtask, shouldn't these tests always fail?

Kwargs were supported before DPPL before 0.23 to the extent that kwargs were supported by DPPL.

With that I mean that splatted kwargs, e.g. @model f(; kwargs...), was not supported, while explicit kwargs, e.g. @model f(; x,y=1), etc. were supported because they were converted into arguments in the evaluator. The evaluator (i.e. model.f) would never have kwargs (though the constructor did), only positional arguments.

After DPPL 0.23, the evaluator now supports kwargs properly, but this means that we've added a additional diversions before actually hitting the "body" of model.f, hence breaking interactions with Libtask.jl.

A few comments on the above:

  1. Usage of @submodel with observations inside completely break any usage of SMC samplers (because Libtask.jl does not recurse). This is on the current verison of Turing.jl, and it's effectively the same issue causing incompabilities with DPPL 0.23.
  2. The fact that DPPL 0.23 was merged into master and released is an indication of how lacking the team's knowledge is of the limitations of Libtask.jl (it was looked at by several of the members). And if we don't even know the limitations, how are the users supposed to know what's okay and not? 😕

There seem to be two ways to go about this:

  1. Make Libtask.jl support recursing into methods to construct tape.
  2. Implement what DPPL tried to achieve differently.

IMHO, we should really address (1). And in the process, we should document how Libtask.jl actually works so we ensure that this doesn't happen again. But I've tried doing this now, and boy it seems pretty darn difficult to do with the current approach. The reason is that we're currently relying on Base.code_typed to give us the IR, which we then Libask.translate! into a tape from which we can start running from any given point in the tape (which, IIUC, is crucial for the copy-able continuations). Hence, the properly addressing this issue will require recursively "inlining" methods in such a way that every instruction on the tape contains at most a single produce.
But recursive inlining requires type-inference (and type-stability), and so it doesn't seem like a viable approach either.

All in all, I don't quite see how the current approach taken in Libtask.jl can support these things. Are there any concrete plans regarding this @yebai?

We can implement (2), as follows.

Instead of converting

@model f(args...; kwargs...)

into the evaluator

f(__model__, __varinfo__, __context__, args...; kwargs...)

as is done in DPPL 0.23, we can convert it into

f(__model__, __varinfo__, __context__, kwargs::NamedTuple, args...)

taking inspiration from Base.kwcall.

But approach (2) will not change the "silently producing the incorret result for @submodel"-problem.

@yebai
Copy link
Member

yebai commented Jun 11, 2023

  1. Make Libtask.jl support recursing into methods to construct tape.
  2. Implement what DPPL tried to achieve differently.

Let's take the 2nd route since it doesn't block this PR, or subsequent PRs in DynamicPPL. @KDr2 and I can add support for (1) later -- we are aware of the limitation but decided it wasn't used very widely. But I agree to obtain consistent behavior, and we should raise the bar and properly support it.

But approach (2) will not change the "silently producing the incorret result for @submodel"-problem.

Yes, but let's consider that in a separate PR.

taking inspiration from Base.kwcall.

I'm not aware of this functionality. Can you elaborate?

@torfjelde
Copy link
Member Author

Let's take the 2nd route since it doesn't block this PR, or subsequent PRs in DynamicPPL

Already on it:)

I'm not aware of this functionality. Can you elaborate?

https://docs.julialang.org/en/v1/devdocs/functions/#Keyword-arguments

Just saying that the approach I suggest is inspired by that.

@torfjelde
Copy link
Member Author

Let's take the 2nd route since it doesn't block this PR, or subsequent PRs in DynamicPPL.

On 2nd thought, this won't solve our issues 😕 We would need to make the evaluator of the model a generated function to achieve what I was thinking 😕

@torfjelde
Copy link
Member Author

We would need to make the evaluator of the model a generated function to achieve what I was thinking confused

So you can actually do this: TuringLang/DynamicPPL.jl#483

Not certain if this is a good idea or not.

Comment on lines 249 to 250
@test_throws ErrorException chain = sample(gauss(x), PG(10), 10)
@test_throws ErrorException chain = sample(gauss(x), SMC(), 10)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be passing. guass2 should be the one causing issues.

@torfjelde
Copy link
Member Author

Needs TuringLang/Bijectors.jl#269

@codecov
Copy link

codecov bot commented Jun 15, 2023

Codecov Report

Patch and project coverage have no change.

Comparison is base (861ae37) 0.00% compared to head (30ad2ff) 0.00%.

Additional details and impacted files
@@          Coverage Diff           @@
##           master   #2001   +/-   ##
======================================
  Coverage    0.00%   0.00%           
======================================
  Files          21      21           
  Lines        1424    1427    +3     
======================================
- Misses       1424    1427    +3     
Impacted Files Coverage Δ
src/essential/container.jl 0.00% <0.00%> (ø)
src/inference/hmc.jl 0.00% <0.00%> (ø)

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@yebai yebai merged commit 6a307cd into master Jun 15, 2023
@yebai yebai deleted the torfjelde/dynamicppl-bump branch June 15, 2023 11:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants