-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: more efficient ∇getindex
#962
Conversation
in the longer term this will be resolved once we support ChainRules inplacable thunks.
We can define a custom rule for |
On my computer
here are the benchmarks
Notable differences from yours in the top post: This PR makes the collect example waaaay faster than what you saw: an order of magnitude faster. The current maseter timing fro the eval poly example is waaay worse for me. This is great. We should have been doing this for ages. |
Those times look great, thanks. I should try on 1.6 + rosetta too, maybe there are other bugs involved.
Good point, hadn't thought about nested AD at all. There appear to be no tests, so perhaps that can be looked at in another PR. Nested |
Is that correct in general? Aso, could we replicate this with |
Is what correct in general?
Which bit? |
Should we be concerned about race conditions? @DhairyaLGandhi and @MikeInnes know more about this than I do (I know basically nothing), but I would imagine that if you have parallel code somewhere, this modification could be problematic if Zygote winds up trying to accumulate to the same location at the same time (I'm not sure whether this can ever happen in practice)? My impression is that Zygote currently works fine with threads / parallelism generally -- would this change that? |
I don't think so. I think we can assume it is safe and wait for evidence to the contrary (though debugging race conditions sucks) |
Times with rosetta as promised:
Also some here: #905 (comment) Edit -- without mutation, same computer:
|
It turns out that JuliaLang/julia#905 is the way to produce sort of bug I was a bit worried about. Something is doing
|
Yeah fair. Good catch. Presumably this came out of if a Enforcing this rule against aliasing would add some copies but not as many as inplace accumulation would solve. I think there might be something in literature about this and a |
Yes that might be enough. I don't know the literature (obviously!) but it's hard to then picture it going wrong. Would there be cases where this hurts, as well as ones where it helps overall? This policy of no aliasing would probably be necessary for any of the in-place updates planned in ChainRules, too, right? Not just If that were to be the policy, you could surely hack in a test to check for aliasing, here, I mean pirate something during testing. ChainRules could easily do that too. Just making the update vararg like |
OK, I think I see how to ensure broadcasting takes preventative copies. For times on this page it only costs 10% off the fastest unsafe way, but in #905 (comment) it saves no memory, and is slower than master. |
I have updated my timings in #962 (comment)
Yeah it is hard to picture it going wrong, and putting differential memory addresses into acorrespondence in quanity with primal feels right. I don't know the literature that well either, not really.
Yes, cases where you don't do anything with the output that you just copied.
I think you mean not just
yeah. We could hack that into |
Re this PR as it stands, one comment is that it does a bit less to preserve the types of arrays, e.g.: julia> using StaticArrays
julia> gradient(x -> x[1], SA[1,2,3])[1]
3-element Zygote.OneElement{Int64, 1, Tuple{Int64}, Tuple{SOneTo{3}}} with indices SOneTo(3):
1
0
0
julia> gradient(x -> x[1] + x[2], SA[1,2,3])[1]
3-element SizedVector{3, Int64, Vector{Int64}} with indices SOneTo(3):
1
1
0 compared to master: julia> gradient(x -> x[1] + x[2], SA[1,2,3])[1] # better than a SizedVector on similar case
3-element MVector{3, Int64} with indices SOneTo(3):
1
1
0
julia> gradient(x -> sum(x .+ x'), SA[1,2,3])[1] # but many other cases don't preserve this
3×1 Matrix{Int64}:
6
6
6
julia> gradient(x -> sum(abs, x .+ x'), SA[1,2,3])[1] # and many simple things don't work at all
ERROR: MethodError: no method matching _mapfoldl(::typeof(identity), ::typeof(+), ::Tuple{Int64, Int64}, ::StaticArrays._InitialValue, ::Size{(3, 3)}, ::SMatrix{3, 3, Int64, 9}) I wondered a bit whether OneElement should capture the type of the array, but the method of similar I think you'd want to use when turning back into a dense array doesn't seem to exist, surprisingly? julia> similar(typeof([1,2]), axes([1,2]))
2-element Vector{Int64}:
12
57
julia> similar(typeof(SA[1,2]), axes(SA[1,2]))
2-element MVector{2, Int64} with indices SOneTo(2):
4585456624
3
julia> similar(typeof([1,2]), Float64, axes([1,2])) # with eltype
ERROR: MethodError: no method matching similar(::Type{Vector{Int64}}, ::Type{Float64}, ::Tuple{Base.OneTo{Int64}})
julia> similar([1,2], Float64, axes([1,2])) # with an instance it's fine
2-element Vector{Float64}:
0.0
2.2655161774e-314
julia> similar(typeof(SA[1,2]), Float64, Size(SA[1,2])) # this exists, but not for Base types
2-element MVector{2, Float64} with indices SOneTo(2):
2.162767274e-314
0.0 |
Are you ready for this to be merged? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's do this.
Failure on master ought to be fixed by FluxML/IRTools.jl#86
(Minor aesthetic changes in the review.)
RFC: more efficient `∇getindex`
Oh I see. I think we both hit the merge button at the same time, yours succeeded & mine gave a confusing error... which I tried to rebase to fix... but this is merged & can be closed. |
Most of the work for static arrays is pretty orthogonal to the concerns of accumulation. It seems odd to make future work harder. |
This is a small improvement to the gradient for
getindex
in the easiest case, by returning a very simple one-nonzero-element array.It also changes how[It is not, and it didn't, see below]accum
works: adding two of these will produce anArray
, but accumulating the next one may as well mutate that. Maybe it is safe to do that in general, I am not 100% sure? Maybe CI will tell us?Note that it will never produce a SparseArray.
My take is that if you call getindex once then this[This won't apply anymore]OneElement
is fine, and if you call it on every element in the array, then accumulating in an Array is optimal; in-between cases where you call it on 1% of the elements seem likely to be very rare, and sparse arrays add all sorts of complication.And, are there
CuArray
concerns?∇getindex
is careful about making a similar zero to write into, but this change is only for scalar indexing, which is (usually) disallowed anyway.This doesn't give a huge speedup, but it does save a lot of memory. I'm not too sure what the "10089 allocations" here actually are, maybe there is some further trick to remove them? (Might also be useful for someone to time this on a different computer too -- these are M1 mac, which has unusual memory, and Julia 1.7 native.)
On the example from #644 -- discussion there had other comparable proposals:
Xref also JuliaLang/julia#365 (getindex) and JuliaLang/julia#905 (accumulation).