-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use LogDensityProblems instead of gradient_logp
#1877
Conversation
Though I really appreciate the compatibility with Enzyme, I'm a bit uncertain about adoping LogDensityProblems 😕 It's somewhat rigid in what it expects/wants, and we're taking on a fair bit of features that we don't need (e.g. TransformVariables.jl). Is this really the best way we can get support for Enzyme? I do agree with you about generalizing/separating out the AD-stuff though. I was hoping that we at some point could just be using AbstractDifferentiation.jl, but AFAIK this won't fix the issue of having to use Requires.jl. |
I think I disagree quite a bit 😄 Maybe the title of the PR and the Enzyme were both a bit misleading. Of course, we could implement support for Enzyme by basically the example above and copying the code from LogDensityProblems. But Enzyme support is not the motivation here (apart from that I think it might be better to improve our code before adding support for another AD) as currently it does not work anyway on the simple examples I tried. The change is purely internal and does not require any major rewrite of Turing. It's just replacing the current Generally, the LogDensityProblems interface is very simple and general: it's just For quite some time I thought as well that we might want to switch to AbstractDifferentiation. However, I became more and more convinced that - apart from stability and implementation issues - it is the wrong level to target from our side. We don't just want to differentiate a function but rather we have to build a function from the log density function of the model for which we can evaluate the primal and the gradient efficiently. In particular, when we want to compile tapes or cache stuff it's mandatory to wrap the user/model-provided log density function when initializing the sampling procedure. But that's exactly what So the short summary of my longer reply is probably: I think the design of LogDensityProblems, and in particular of |
I like it because it simplifies current code and adds features. I tend to agree with @torfjelde and am concerned that we don't influence future maintenance of @tpapp Would it make sense chance that Cc @oschulz @phipsgabler who are the developers of |
I am happy to transfer LogDensityProblems.jl. |
I like the idea and trust David's judgement that it LogDensityProblems fits our needs. The only concern I have is that currently it has an empty intersection with DensityInterface, which I preferred as the requirement for AbstractPPL, but was considered not mature enough to serve as base for LogDensityInterface when it was begun (cf. tpapp/LogDensityProblems.jl#78). We would have two different kinds of interface, especially So my question is, how to best reconcile this? |
My initial guess would be: LogDensityProblems could implement the DensityInterface interface by deprecating |
I agree with @devmotion - DensityInterface.jl was essentially designed for this use case (with input from some Turing devs like @phipsgabler ). (Also, Distributions and MeasureBase/MeasureInterface already support DensityInterface.) |
CC @cscherrer |
We also have AbstractDifferentiation.jl now - maybe combining that with DensityInterface.jl will be already be sufficient? |
In Turing? Or LogDensityProblems? I think that it is not sufficient for us in Turing as we need the wrappers of logdensity function + gradient that are provided by LogDensityProblems. Quoted from above:
|
Shouldn't something like this work?
For example (with an
|
Surely you can use AbstractDifferentiation to get primal and gradient of a function for some backend (even though personally I don't think it's mature and efficient enough for general adoption yet), but it's not sufficient for our use cases. The best example is tape compilation with ReverseDiff: we want to perform and store optimizations when we initialize the sampling procedure. Hence just calling Hence to me it seems AbstractDifferentiation is too low-level for our purposes. But I assume for at least some backends LogDensityProblems could probably use the AbstractDifferentiation API at some point. Currently, I don't see a clear benefit though since it would not remove the Requires dependency and the implementation in LogDensityProblems should be as efficient as possible whereas I'm not fully convinced that's the case with AbstractDifferentiation yet. |
@tpapp could LogDensityProblems support DensityInterface? |
Hm, that's a useful thing in general. Maybe we should lobby for AbstractDifferentiation to add a caching mechanism like this? |
There's already an issue, so I think people are aware of it: JuliaDiff/AbstractDifferentiation.jl#41 It's not clear though how it could be done. |
Ah, thanks! |
Regarding this PR here: The Tracker test errors are caused by Julia 1.8, it broke some stuff in Tracker (I've seen also test failures in e.g. AbstractDifferentiation and DiffRules). FluxML/Tracker.jl#125 |
Possibly, but I do not fully understand what the question is here and what that would involve. Can you please provide an example, eg what a user or implementor would need?
Yes, that's the intention, once the latter stabilizes. Currently LogDensityProblems comprises two things:
The reason I haven't split this into two packages is that I am really hoping that AbstractDifferentiation will take care of the AD glue in the long run, at which point it will be removed from the package. In any case, I am happy to extend LogDensityProblems, transfer it and/or add people as maintainers, and PRs are welcome as always. |
We built DensityInterface (with input from devs from Distributions, Turing, and other packages) so that people can define density-like objects using a super-lightweight dependency. Code that uses density-like objects (algorithm code) will typically need more deps, of course, like AD and so on. So I think it would be nice if LogDensityProblems (which is closer to the algorithmic side since it also handles AD and specific transformations) would support DensityInterface densities. I have to admit I'm not quite sure how, but maybe we could figure something out together? In principle default implementation of |
I think our "problem" isn't technically difficult -- I have created a PoC port of LogDensityProblems to use DensityInterface here (diff view). Rather, it is of "social" nature: it requires messing with the interface of an existing package, deviating from the original considerations of the package author and surprising the existing user base, both of which make me hesitate. If anything, we all need to have a discussion about how to reconcise the interfaces -- not everything is immediately clear (especially |
I'd have some comments but I guess it would be better to discuss these things in an issue over at LogDensityProblems? (IIRC there was already an issue which is closed now.) There's nothing user-facing in this PR and no new API, so I think in principal it does not require any changes to or discussions of interfaces. |
I agree with @devmotion, but want to reiterate that I am happy to add anything to LogDensityProblems that helps with this PR. |
Appreciate the write-up @devmotion! And I'm also quite happy with the "interface" that is provided from LogDensityProblems.jl. One of my main concerns is that we'd be depending on TransformVariables.jl but wouldn't be using it, but, I as you pointed out, it doesn't seem like a particularly heavy dependency. If we're happy with this, I'm happy with taking it on as a dep:) |
Can this be replace by depending on ChangesOfVariables.jl (much lighter), since TransformVariables.jl supports it now? |
It does not seem completely trivial, I faced some problems: tpapp/LogDensityProblems.jl#88 |
Pull Request Test Coverage Report for Build 2920168242
💛 - Coveralls |
Codecov Report
@@ Coverage Diff @@
## master #1877 +/- ##
==========================================
+ Coverage 81.03% 82.21% +1.18%
==========================================
Files 24 21 -3
Lines 1466 1406 -60
==========================================
- Hits 1188 1156 -32
+ Misses 278 250 -28
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
In response to the discussion above, I have excised TransformVariables.jl as a dependency from LogDensityProblems.jl. See https://github.com/tpapp/TransformedLogDensities.jl/ (where it ended up) and tpapp/LogDensityProblems.jl#89 (which will be merged once the first package is registered and stuff is cleaned up). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this PR brings some excellent improvement. Many thanks, @phipsgabler, @devmotion and @tpapp!
The PR requires tpapp/LogDensityProblems.jl#86
The PR removes
gradient_logp
and instead implements the LogDensityProblems.jl interface forTuring.LogDensityFunction
. It seemed to work on simple examples (haven't run the full test suite yet) and allows us to re-use existing implementations and functionalities in LogDensityProblems, such as compilation of tapes for ReverseDiff and support of Enzyme (see comment below). That simplifies the code significantly and e.g. let's us remove the Requires block for ReverseDiff (taken care of in LogDensityProblems) and Memoization.Probably in the future we might want to move
LogDensityFunction
and its implementation of the LogDensityProblems interface to DynamicPPL. And we might want to move the AD backend types (TrackerAD
,ForwardDiffAD
etc.) and their (to be improved?) API to a lightweight package (TuringADCore?) such that AdvancedVI, AdvancedHMC, Bijectors etc. could also use it instead of duplicating the implementation.Enzyme support can be added by something like
Unfortunately, it seems it does not even work on simple examples such as
Intentionally the error messages and stacktraces are not included here since it is not part of this PR.