diff --git a/src/lib/array.jl b/src/lib/array.jl index 329676648..8c32faeaf 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -5,6 +5,8 @@ @adjoint Base.vect(xs...) = Base.vect(xs...), Δ -> (Δ...,) +@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,) + Base.zero(xs::AbstractArray{Any}) = fill!(similar(xs), nothing) @adjoint function getindex(xs::Array, i...) diff --git a/src/lib/base.jl b/src/lib/base.jl index fd582f5b7..61e82a477 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -2,7 +2,26 @@ using Base: @get! @nograd readline -@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,) +# Gradient of AD stacks + +grad_mut(::AbstractVector) = [] + +@adjoint! function _push!(a::Vector, x) + _push!(a, x), function (y) + dstk = grad_mut(__context__, a) + return (nothing, pop!(dstk)) + end +end + +@adjoint! function pop!(stk::Stack) + pop!(stk), function (Δ) + dstk = grad_mut(__context__, stk.data) + push!(dstk, Δ) + return + end +end + +# Dictionaries grad_mut(d::AbstractDict) = Dict() diff --git a/test/features.jl b/test/features.jl index 6a054a2ef..1f815f692 100644 --- a/test/features.jl +++ b/test/features.jl @@ -159,6 +159,7 @@ D(f, x) = grad(f, x)[1] if VERSION > v"1.2-" @test D(x -> x*D(y -> x+y, 1), 1) == 1 @test D(x -> x*D(y -> x*y, 1), 4) == 8 + @test_broken sin'''(1.0) == -cos(1.0) end f(x) = throw(DimensionMismatch("fubar")) @@ -247,3 +248,10 @@ end push!([], x) return x end + +@test gradient(1) do x + stk = [] + Zygote._push!(stk, x) + stk = Zygote.Stack(stk) + pop!(stk) +end == (1,)