-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve adjoint for product and zip #1489
Conversation
Update: since we have adjoints for
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mostly LGTM. Are you able to add tests which check that one or more of the gradient computation, pullback or productfunc
are type stable? With the effort put into this, we should make sure inference keeps working!
Sure, I'll try and get a test written soon, although testing inference can be tricky. Ideally Additionally, I can update the |
I've finished adding inference and correctness tests for Update: In general, julia> Zygote.gradient(x -> sum(prod, x), 1:5)
([1.0, 1.0, 1.0, 1.0, 1.0],)
julia> Zygote.gradient(x -> sum(prod, zip(x)), 1:5)
(nothing,) Any thoughts on how to improve handling this? Correction: |
Although it does use ChainRules' projection machinery sometimes, Zygote overall doesn't do projection quite the same way because it predates ChainRules. The legacy projection machinery we do have can be rather inconsistent. In this particular case however, it looks like the input type is causing null gradients?
It makes sense that an integer range would be considered non-differentiable, but it would be good to confirm Zygote is doing this for the right reason and not because of some bug. Either way, if you can't figure out |
Thank you for the context about projection in Zygote. I'm happy with keeping it as is to have this pr be as non-breaking as possible. Otherwise the work on As for the observation of null gradients for integer ranges, it appears to have nothing to do with
I'd have to understand where the decision is being made, but I think it's safe to leave it to a follow-up. |
It turns out the answer is easier than I thought: Line 252 in 54f1e80
This may have been for supporting for x in 1:N ... .
|
Zygote often throws away gradients of a UnitRange, here: Line 252 in 54f1e80
It's not enforced by projection, so things that hit other rules such as
I have no memory of why, but when initially writing these rules, attaching them to the uppercase constructor not the lowercase function somehow made more cases work. There are tests here but fewer than I thought. |
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
I've rebased this branch onto master and resolved the last issue I was concerned about. Looks like the CI is mostly good, although I'm not sure if the DynamicPPL failure is related |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, a couple of things to touch up before merging.
src/lib/array.jl
Outdated
@@ -169,8 +170,11 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) | |||
# So we keep axes(x) to restore gradient dx to its full length & correct shape. | |||
_tryaxes(x) = axes(x) | |||
_tryaxes(x::Tuple) = Val(length(x)) | |||
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax) | |||
_tryaxes(::Number) = Val(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of a Val
, maybe nothing
or missing
would be a more appropriate sentinel value here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, or even the number itself
Remove `Val` from `ntuple`s where constant propagation occurs Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
This is a great contribution for a tricky set of rules, thanks @lxvm ! |
Thanks to everyone for the helpful support! |
Hi,
I've returned to my first contribution in #1170 since I noticed I couldn't differentiate w.r.t
Iterator.product
s that have a number as an iterator. This pr adds a test and fixes the issue while also improving the inferrability of the adjoint.PR Checklist