diff --git a/src/lib/array.jl b/src/lib/array.jl index d38751331..049a49f3d 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -30,25 +30,26 @@ end @adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,) @adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,) -_zero(xs::AbstractArray{<:Number}, T=float(eltype(xs))) = fill!(similar(xs, T), false) -_zero(xs::AbstractArray, T=Any) = Union{Nothing, T}[nothing for x in xs] - @adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds) @adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds) ∇getindex(x::AbstractArray, inds) = dy -> begin - if inds isa NTuple{<:Any,Integer} + if inds isa NTuple{<:Any, Integer} dx = _zero(x, typeof(dy)) dx[inds...] = dy else dx = _zero(x, eltype(dy)) dxv = view(dx, inds...) - dxv .+= _droplike(dy, dxv) + dxv .= accum.(dxv, _droplike(dy, dxv)) end - (dx, map(_->nothing, inds)...) + return (dx, map(_->nothing, inds)...) end +_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs))) +_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false) +_zero(xs::AbstractArray, T) = fill!(similar(xs, Union{Nothing, T}), nothing) + _droplike(dy, dxv) = dy _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::AbstractVector) = dropdims(dy; dims=2) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 3d95f7f8d..8994f974d 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -138,33 +138,41 @@ end @test gradient(logdet, 2.0)[1] == 0.5 @testset "getindex" begin - @test gradtest(x -> x[:,2,:], (3,4,5)) - @test gradtest(x -> x[1:2,3:4], (3,4)) + @test gradtest(x -> x[:, 2, :], (3, 4, 5)) + @test gradtest(x -> x[1:2, 3:4], (3, 4)) imat = [1 2; 3 4] - @test gradtest(x -> x[:,imat], (3,4)) - @test gradtest(x -> x[:,[1,2,2]], (3,4)) + @test gradtest(x -> x[:, imat], (3, 4)) + @test gradtest(x -> x[:, [1, 2, 2]], (3, 4)) irep = [1 2; 2 2] - @test gradtest(x -> x[1,irep], (3,4)) + @test gradtest(x -> x[1, irep], (3, 4)) # https://github.com/invenia/Nabla.jl/issues/139 x = rand(3) - z = [1,2,3,3] - y(x,z) = dot(ones(4), x[z]) - @test gradient(y, x,z) == ([1,1,2], nothing) + z = [1, 2, 3, 3] + y(x, z) = dot(ones(4), x[z]) + @test gradient(y, x, z) == ([1, 1, 2], nothing) # https://github.com/FluxML/Zygote.jl/issues/376 _, back = Zygote._pullback(x->x[1]*im, randn(2)) @test back(1.0)[2] == [-im, 0] # _droplike - @test gradient(x -> sum(inv, x[1,:]'), ones(2,2)) == ([-1 -1; 0 0],) - @test gradient(x -> sum(inv, x[1:1,:]'), ones(2,2)) == ([-1 -1; 0 0],) - @test gradient(x -> sum(inv, transpose(view(x,1,:))), ones(2,2)) == ([-1 -1; 0 0],) + @test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],) + @test gradient(x -> sum(inv, x[1:1, :]'), ones(2, 2)) == ([-1 -1; 0 0],) + @test gradient(x -> sum(inv, transpose(view(x, 1, :))), ones(2, 2)) == ([-1 -1; 0 0],) # https://github.com/FluxML/Zygote.jl/issues/513 - @test gradient(p -> sum(Float32[1,0] - p), [2,3]) == ([-1,-1],) - @test gradient(x -> sum(Float32[1,x] .+ x), 4) == (3.0f0,) + @test gradient(p -> sum(Float32[1, 0] - p), [2, 3]) == ([-1, -1],) + @test gradient(x -> sum(Float32[1, x] .+ x), 4) == (3.0f0,) + + # Ensure that nothings work with numeric types. + _, back = Zygote.pullback(getindex, randn(4), [1]) + @test back([nothing]) == (zeros(4), nothing) + + # Ensure that nothings work with non-numeric types. + _, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1]) + @test back([nothing]) == ([nothing for _ in 1:3], nothing) end @testset "view" begin