-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The chunk
function is not differentiable on GPU
#170
Comments
I found that this might be caused by Zygote.jl 0.6.67 because the problem goes away when I downgrade Zygote.jl to 0.6.66. |
This looks like a bug in the CR rrule being used here, after Zygote deleted its rule. Any chance you can isolate it further, e.g. to a single |
I'm not sure what you expect. Can you elaborate on this? Then, I'll test it soon.
|
Sorry, what I mean is that This is the relevant bit of the stacktrace: [4] materialize!
@ ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:46 [inlined]
[5] materialize!
@ ./broadcast.jl:881 [inlined]
[6] ∇getindex!(dx::Vector{Union{ChainRulesCore.ZeroTangent, CuMatrix{Float32, CUDA.Mem.DeviceBuffer}, DenseCuMatrix{Float32, CUDA.Mem.DeviceBuffer}}}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, inds::Int64)
@ ChainRules ~/.julia/packages/ChainRules/Tvwnx/src/rulesets/Base/indexing.jl:147
[7] ∇getindex(x::Vector{SubArray{Float32, 2, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, inds::Int64)
@ ChainRules ~/.julia/packages/ChainRules/Tvwnx/src/rulesets/Base/indexing.jl:89 Here |
Thanks. I've test two other cases that do more direct indexing without calling function (model::Model)(x)
y = model.layers.dense(x)
# ERROR: LoadError: MethodError: no method matching parent(...
#a, b = Flux.chunk(y, size = [4, 4], dims = 1)
# this works
#a, b = y[1:4,:], y[5:8,:]
# this works
a, b = view(y, 1:4, :), view(y, 5:8, :)
sum(a + b)
end |
I discovered that the following pattern doesn't work. I guess the lowering implicitly inserts some a, b = [y[1:4,:], y[5:8,:]] |
So, I reduced the code to the following. As in the case above, the error disappears if I use Zygote.jl 0.6.66 instead of 0.6.67. using CUDA, Zygote
function f(x)
a, b = [x[1:4], x[5:8]]
sum(a + b)
end
x = cu(randn(8))
@show f(x)
@show Zygote.gradient(f, x) |
Thanks, that's helpful! |
Thanks for the MWE, will follow-up on the Zygote issue. |
The example in the OP works fine on the latest version of the packages |
I found that operations involving the
chunk
function are not differentiable on GPU.When I try to run this, I see the following error:
Full error message: log.txt
My environment is:
Manifest.toml and Project.toml are the followings (the file name extensions are replaced for uploading).
Manifest.txt
Project.txt
The text was updated successfully, but these errors were encountered: