Skip to content

Commit

Permalink
simple stack adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Feb 25, 2019
1 parent 1d886ad commit b560693
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

@adjoint Base.vect(xs...) = Base.vect(xs...), Δ ->...,)

@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,)

Base.zero(xs::AbstractArray{Any}) = fill!(similar(xs), nothing)

@adjoint function getindex(xs::Array, i...)
Expand Down
21 changes: 20 additions & 1 deletion src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,26 @@ using Base: @get!

@nograd readline

@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,)
# Gradient of AD stacks

grad_mut(::AbstractVector) = []

@adjoint! function _push!(a::Vector, x)
_push!(a, x), function (y)
dstk = grad_mut(__context__, a)
return (nothing, pop!(dstk))
end
end

@adjoint! function pop!(stk::Stack)
pop!(stk), function (Δ)
dstk = grad_mut(__context__, stk.data)
push!(dstk, Δ)
return
end
end

# Dictionaries

grad_mut(d::AbstractDict) = Dict()

Expand Down
8 changes: 8 additions & 0 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ D(f, x) = grad(f, x)[1]
if VERSION > v"1.2-"
@test D(x -> x*D(y -> x+y, 1), 1) == 1
@test D(x -> x*D(y -> x*y, 1), 4) == 8
@test_broken sin'''(1.0) == -cos(1.0)
end

f(x) = throw(DimensionMismatch("fubar"))
Expand Down Expand Up @@ -247,3 +248,10 @@ end
push!([], x)
return x
end

@test gradient(1) do x
stk = []
Zygote._push!(stk, x)
stk = Zygote.Stack(stk)
pop!(stk)
end == (1,)

0 comments on commit b560693

Please sign in to comment.