Skip to content

Commit

Permalink
Merge #547
Browse files Browse the repository at this point in the history
547: Tweak getindex implementation + test corner cases r=willtebbutt a=willtebbutt

Resolves #546 

Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
  • Loading branch information
bors[bot] and willtebbutt authored Apr 14, 2020
2 parents 17a017f + ac52149 commit 359e586
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
13 changes: 7 additions & 6 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,26 @@ end
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)

_zero(xs::AbstractArray{<:Number}, T=float(eltype(xs))) = fill!(similar(xs, T), false)
_zero(xs::AbstractArray, T=Any) = Union{Nothing, T}[nothing for x in xs]

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)

∇getindex(x::AbstractArray, inds) = dy -> begin
if inds isa NTuple{<:Any,Integer}
if inds isa NTuple{<:Any, Integer}
dx = _zero(x, typeof(dy))
dx[inds...] = dy
else
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .+= _droplike(dy, dxv)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
(dx, map(_->nothing, inds)...)
return (dx, map(_->nothing, inds)...)
end

_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
_zero(xs::AbstractArray, T) = fill!(similar(xs, Union{Nothing, T}), nothing)

_droplike(dy, dxv) = dy
_droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::AbstractVector) =
dropdims(dy; dims=2)
Expand Down
34 changes: 21 additions & 13 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,33 +138,41 @@ end
@test gradient(logdet, 2.0)[1] == 0.5

@testset "getindex" begin
@test gradtest(x -> x[:,2,:], (3,4,5))
@test gradtest(x -> x[1:2,3:4], (3,4))
@test gradtest(x -> x[:, 2, :], (3, 4, 5))
@test gradtest(x -> x[1:2, 3:4], (3, 4))

imat = [1 2; 3 4]
@test gradtest(x -> x[:,imat], (3,4))
@test gradtest(x -> x[:,[1,2,2]], (3,4))
@test gradtest(x -> x[:, imat], (3, 4))
@test gradtest(x -> x[:, [1, 2, 2]], (3, 4))
irep = [1 2; 2 2]
@test gradtest(x -> x[1,irep], (3,4))
@test gradtest(x -> x[1, irep], (3, 4))

# https://github.com/invenia/Nabla.jl/issues/139
x = rand(3)
z = [1,2,3,3]
y(x,z) = dot(ones(4), x[z])
@test gradient(y, x,z) == ([1,1,2], nothing)
z = [1, 2, 3, 3]
y(x, z) = dot(ones(4), x[z])
@test gradient(y, x, z) == ([1, 1, 2], nothing)

# https://github.com/FluxML/Zygote.jl/issues/376
_, back = Zygote._pullback(x->x[1]*im, randn(2))
@test back(1.0)[2] == [-im, 0]

# _droplike
@test gradient(x -> sum(inv, x[1,:]'), ones(2,2)) == ([-1 -1; 0 0],)
@test gradient(x -> sum(inv, x[1:1,:]'), ones(2,2)) == ([-1 -1; 0 0],)
@test gradient(x -> sum(inv, transpose(view(x,1,:))), ones(2,2)) == ([-1 -1; 0 0],)
@test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],)
@test gradient(x -> sum(inv, x[1:1, :]'), ones(2, 2)) == ([-1 -1; 0 0],)
@test gradient(x -> sum(inv, transpose(view(x, 1, :))), ones(2, 2)) == ([-1 -1; 0 0],)

# https://github.com/FluxML/Zygote.jl/issues/513
@test gradient(p -> sum(Float32[1,0] - p), [2,3]) == ([-1,-1],)
@test gradient(x -> sum(Float32[1,x] .+ x), 4) == (3.0f0,)
@test gradient(p -> sum(Float32[1, 0] - p), [2, 3]) == ([-1, -1],)
@test gradient(x -> sum(Float32[1, x] .+ x), 4) == (3.0f0,)

# Ensure that nothings work with numeric types.
_, back = Zygote.pullback(getindex, randn(4), [1])
@test back([nothing]) == (zeros(4), nothing)

# Ensure that nothings work with non-numeric types.
_, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1])
@test back([nothing]) == ([nothing for _ in 1:3], nothing)
end

@testset "view" begin
Expand Down

0 comments on commit 359e586

Please sign in to comment.