Skip to content

Commit

Permalink
rrule for stack (#681)
Browse files Browse the repository at this point in the history
* rrule for stack

* bump version

* extend rrule to muldim containers

* hope you don't mind me committing these

* import stack in tests

* cleanup

* Apply3 suggestions

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
  • Loading branch information
CarloLucibello and mcabbott authored Nov 11, 2022
1 parent a9a84ba commit 1597bcc
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.44.7"
version = "1.45.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -20,7 +20,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Adapt = "3.4.0"
ChainRulesCore = "1.15.3"
ChainRulesTestUtils = "1.5"
Compat = "3.42.0, 4"
Compat = "3.46, 4.2"
FiniteDifferences = "0.12.20"
GPUArraysCore = "0.1.0"
IrrationalConstants = "0.1.1"
Expand Down
2 changes: 2 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import ChainRulesCore: rrule, frule
# Experimental:
using ChainRulesCore: derivatives_given_output

using Compat: stack

# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}

Expand Down
27 changes: 27 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,30 @@ function _extrema_dims(x, dims)
end
return y, extrema_pullback_dims
end

#####
##### `stack`
#####

function frule((_, ẋ), ::typeof(stack), x; dims::Union{Integer, Colon} = :)
return stack(x; dims), stack(ẋ; dims)
end

# Other iterable X also allowed, maybe this should be wider?
function rrule(::typeof(stack), X::AbstractArray; dims::Union{Integer, Colon} = :)
Y = stack(X; dims)
sdims = if dims isa Colon
N = ndims(Y) - ndims(X)
X isa AbstractVector ? ndims(Y) : ntuple(i -> i + N, ndims(X))
else
dims
end
project = ProjectTo(X)
function stack_pullback(Δ)
dY = unthunk(Δ)
dY isa AbstractZero && return (NoTangent(), dY)
dX = collect(eachslice(dY; dims = sdims))
return (NoTangent(), project(reshape(dX, project.axes)))
end
return Y, stack_pullback
end
30 changes: 30 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,33 @@ end
B = hcat(A[:,:,1], A[:,:,1])
@test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1]
end

@testset "stack" begin
# vector container
xs = [rand(3, 4), rand(3, 4)]
test_frule(stack, xs)
test_frule(stack, xs; fkwargs=(dims=1,))

test_rrule(stack, xs, check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=3,), check_inferred=false)

# multidimensional container
ms = [rand(2,3) for _ in 1:4, _ in 1:5];

if VERSION > v"1.9-" # this needs new eachslice, not yet in Compat
test_rrule(stack, ms, check_inferred=false)
end
test_rrule(stack, ms, fkwargs=(dims=1,), check_inferred=false)
test_rrule(stack, ms, fkwargs=(dims=3,), check_inferred=false)

# non-array inner objects
ts = [Tuple(rand(3)) for _ in 1:4, _ in 1:2];

if VERSION > v"1.9-"
test_rrule(stack, ts, check_inferred=false)
end
test_rrule(stack, ts, fkwargs=(dims=1,), check_inferred=false)
test_rrule(stack, ts, fkwargs=(dims=2,), check_inferred=false)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils
using Adapt
using Base.Broadcast: broadcastable
using ChainRules
using ChainRules: stack
using ChainRulesCore
using ChainRulesTestUtils
using ChainRulesTestUtils: rand_tangent, _fdm
Expand Down

2 comments on commit 1597bcc

@mcabbott
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/72038

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.45.0 -m "<description of version>" 1597bcc5fcb9af7fab26d7505f680d2b7fe4d5d4
git push origin v1.45.0

Please sign in to comment.