Skip to content

Commit

Permalink
Tweak getindex implementation + test corner cases
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Mar 18, 2020
1 parent f16d856 commit b3d0ef7
Show file tree
Hide file tree
Showing 3 changed files with 1,315 additions and 1,296 deletions.
4 changes: 2 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ version = "0.2.0"

[[CompilerSupportLibraries_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "b57c5d019367c90f234a7bc7e24ff0a84971da5d"
git-tree-sha1 = "aa832564f930a7fc9290972526908d01a35aefac"
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.2.0+1"
version = "0.3.0+0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
Expand Down
23 changes: 17 additions & 6 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,34 @@ end
@adjoint (::Type{T})(sz) where {T<:Zeros} = Zeros(sz), Δ->(nothing,)
@adjoint (::Type{T})(sz) where {T<:Ones} = Ones(sz), Δ->(nothing,)

_zero(xs::AbstractArray{<:Number}, T=float(eltype(xs))) = fill!(similar(xs, T), false)
_zero(xs::AbstractArray, T=Any) = Union{Nothing, T}[nothing for x in xs]

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)

∇getindex(x::AbstractArray, inds) = dy -> begin
if inds isa NTuple{<:Any,Integer}
if inds isa NTuple{<:Any, Integer}
dx = _zero(x, typeof(dy))
dx[inds...] = dy
else
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .+= _droplike(dy, dxv)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
(dx, map(_->nothing, inds)...)
return (dx, map(_->nothing, inds)...)
end

function _zero(xs::AbstractArray{<:Number}, T::Type{Nothing})
return fill!(similar(xs), zero(eltype(xs)))
end

# function _zero(xs::AbstractArray{<:Number}, T=float(eltype(xs)))
function _zero(xs::AbstractArray{<:Number}, T)
return fill!(similar(xs, T), false)
end

# function _zero(xs::AbstractArray, T=Any)
function _zero(xs::AbstractArray, T)
return fill!(similar(xs, Union{Nothing, T}), nothing)
end

_droplike(dy, dxv) = dy
Expand Down
Loading

0 comments on commit b3d0ef7

Please sign in to comment.