-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Check that the correct AD backend is being used #2291
Conversation
c2561b7
to
c0cab10
Compare
@torfjelde pointed out in https://github.com/TuringLang/Turing.jl/pull/2221/files#r1608405934 that there's a similarity between this and I think this is a great idea, but I chose not to do it in this PR because
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2291 +/- ##
==========================================
- Coverage 85.82% 85.77% -0.06%
==========================================
Files 24 24
Lines 1623 1617 -6
==========================================
- Hits 1393 1387 -6
Misses 230 230 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks good to me.
I don't think comparing the types of values in varinfo is particularly elegant, but also I don't see other ways to do it with Context
s.
Maybe @yebai, @torfjelde, or @devmotion can give a quick scan? |
Pull Request Test Coverage Report for Build 10268798219Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
Agree on lack of elegance, better proposals welcome. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea very much:) Seems like something we should definitively test. And agree that we shouldn't do the CallbackContext
in this PR.
I left a comment on a potentially way we can check Zygote.jl, Tapir.jl, and Enzyme.jl too.
), | ||
# TODO(mhauru) Zygote.Dual is actually the same as ForwardDiff.Dual, so can't | ||
# distinguish between the two. | ||
Turing.AutoZygote => (Zygote.Dual,), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah there's no way you can do this here unfortunately 😕
Possibly crazy idea, but it might make sense to specifically overload the adjoint computation for a given backend which doesn't use types, e.g.
Zygote.@adjoint function check_adtype(...)
# Should only be hit if we're using Zygote.jl.
...
end
? Could do the same with Enzyme.jl and Tapir.jl, both of which would suffer from the same issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this, and implemented it. Turns out there's Zygote.hook
for exactly these sorts of things.
I thought a bit about switching to doing all checks like this, and giving up on the element type approach, but that seemed like it would get more complicated than is worth at this point. We now catch all cases except if one uses ForwardDiff when Zygote is expected, which seems sufficient to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very, very nice @mhauru :)
The only CI failure is a tiny, trivial loss of code coverage. @torfjelde, are you happy to have this merged as is? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very, very nice @mhauru :)
Introduces tests that check that the ADType we specify is actually the one being used. This alleviates my worry that we might mess up passing the ADType around and accidentally fall back on the default.
I've added tests for calling
sample
andmaximum_likelihood/maximum_a_posteriori
. Any other user-facing functions we should check to make sure they, too, use theiradtype
arguments correctly?Closes #2235