Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve adjoint for product and zip #1489

Merged
merged 16 commits into from
Jan 19, 2024
73 changes: 48 additions & 25 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ end
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
(::StaticGetter{i})(::Nothing) where {i} = nothing
@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...)
function _unzip(tuples, ::Val{N}) where {N}
getters = ntuple(n -> StaticGetter{n}(), Val(N))
lxvm marked this conversation as resolved.
Show resolved Hide resolved
map(g -> map(g, tuples), getters)
end
function unzip(tuples)
N = length(first(tuples))
Expand Down Expand Up @@ -169,8 +170,11 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)
# So we keep axes(x) to restore gradient dx to its full length & correct shape.
_tryaxes(x) = axes(x)
_tryaxes(x::Tuple) = Val(length(x))
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax)
_tryaxes(::Number) = Val(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a Val, maybe nothing or missing would be a more appropriate sentinel value here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, or even the number itself

_restore(dx::AbstractArray{Nothing}, ax::Tuple) = similar(dx, ax)
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(map(length, ax)) - length(dx))), ax)
_restore(dx, ::Val{N}) where {N} = ntuple(i -> get(dx,i,nothing), N)
_restore(dx, ::Val{-1}) = only(dx)

# Sometimes a pullback doesn't return a Tuple, but rather returns only a
# single nothing to say "all arguments have zero cotangent". This function is needed to
Expand Down Expand Up @@ -268,32 +272,51 @@ end
_ndims(::Base.HasShape{d}) where {d} = d
_ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)) : 1

function productfunc(xs, dy)
@assert length(first(dy)) == length(xs)
ndim = map(Zygote._ndims, xs)
cdim = cumsum((1, ndim[begin:end-1]...))
getters = ntuple(n -> StaticGetter{n}(), Val(length(xs)))
lxvm marked this conversation as resolved.
Show resolved Hide resolved
map(first(dy), xs, cdim, getters) do dyn, x, cd, getter
dyn === nothing && return nothing
nd = _ndims(x)
dims = nd == 0 ? (:) : ntuple(i -> i<cd ? i : i+nd, Val(ndims(dy)-nd))
init = map(zero, dyn) # allows for tuples, which accum can add:
red = mapreduce(getter, accum, dy; dims, init)
return _project(x, nd == 0 ? red : reshape(red, axes(x)))
end
end

@adjoint function Iterators.product(xs...)
back(::AbstractArray{Nothing}) = nothing
back(dy::NamedTuple{(:iterators,)}) = dy.iterators
function back(dy::AbstractArray)
d = 1
ntuple(length(xs)) do n
nd = _ndims(xs[n])
dims = ntuple(i -> i<d ? i : i+nd, ndims(dy)-nd)
d += nd
first(dy)[n] === nothing && return nothing
init = zero.(first(dy)[n]) # allows for tuples, which accum can add:
red = mapreduce(StaticGetter{n}(), accum, dy; dims=dims, init=init)
return _project(xs[n], reshape(red, axes(xs[n])))
end
product_pullback(::AbstractArray{Nothing}) = nothing
product_pullback(dy::NamedTuple{(:iterators,)}) = dy.iterators
product_pullback(dy::AbstractArray) = productfunc(xs, dy)
Iterators.product(xs...), product_pullback
end

@adjoint function Base.collect(p::Base.Iterators.ProductIterator)
collect_product_pullback(dy) = ((iterators=productfunc(p.iterators, dy),),)
return collect(p), collect_product_pullback
end

function zipfunc(xs, dy)
getters = ntuple(n -> StaticGetter{n}(), Val(length(xs)))
lxvm marked this conversation as resolved.
Show resolved Hide resolved
map(xs, getters) do x, getter
dx = map(getter, dy)
_project(x, _restore(dx, _tryaxes(x)))
end
Iterators.product(xs...), back
end

@adjoint function Iterators.Zip(xs)
axs = map(_tryaxes, xs) # same function used for map
back(dy::NamedTuple{(:is,)}) = tuple(dy.is)
back(dy::AbstractArray) = ntuple(length(xs)) do d
dx = map(StaticGetter{d}(), dy)
_project(xs[d], _restore(dx, axs[d]))
end |> tuple
Iterators.Zip(xs), back
@adjoint function Iterators.zip(xs...)
zip_pullback(::AbstractArray{Nothing}) = nothing
zip_pullback(dy::NamedTuple{(:is,)}) = dy.is
zip_pullback(dy::AbstractArray) = zipfunc(xs, dy)
Iterators.zip(xs...), zip_pullback
end

@adjoint function Base.collect(z::Base.Iterators.Zip)
collect_zip_pullback(dy::AbstractArray) = ((is=zipfunc(z.is, dy),),)
collect(z), collect_zip_pullback
end

# Reductions
Expand Down
52 changes: 52 additions & 0 deletions test/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,36 @@ test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_
# This was wrong before https://github.com/FluxML/Zygote.jl/pull/1170
@test gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])[1] ≈ [320, 320, 320, 320]
@test gradient(x -> sum(y[2] * y[3] for y in Iterators.product(x, x, x, x)), [1,2,3,4])[1] ≈ [320, 320, 320, 320]

# Numbers failed before https://github.com/FluxML/Zygote.jl/pull/1489
for p in (1.0, fill(1.0), [1.0])
@test gradient(p -> sum([x*q for q in p, x in 1:3]), p) == (6p,)
@test gradient(p -> sum(x*q for (q, x) in Iterators.product(p, 1:3)), p) == (6p,)
end

# inference would also fail before #1489
y, back = _pullback(Iterators.product, 1:5, fill(1))
@test @inferred back(collect(y)) == (nothing, [1.0, 2.0, 3.0, 4.0, 5.0], fill(5.0))
end

@testset "adjoints of Iterators.zip" begin
y, back = _pullback(Iterators.zip, 1:5, 1:3, 1:2)
@test back(collect(y)) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [1.0, 2.0])
@test back([(nothing, j, k) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, nothing, [1.0, 2.0, 0.0], [1.0, 2.0])
@test back([(i, nothing, k) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], nothing, [1.0, 2.0])
@test back([(i, j, nothing) for (i,j,k) in zip(1:5, 1:3, 1:2)]) == (nothing, [1.0, 2.0, 0.0, 0.0, 0.0], [1.0, 2.0, 0.0], nothing)


@test gradient(x -> sum([y[2] * y[3] for y in Iterators.zip(x, x, x, x)]), [1,2,3,4])[1] ≈ [2, 4, 6, 8]
@test gradient(x -> sum(y[2] * y[3] for y in Iterators.zip(x, x, x, x)), [1,2,3,4])[1] ≈ [2, 4, 6, 8]

for p in (1.0, fill(1.0), [1.0])
@test gradient(p_ -> sum(map(prod, Iterators.zip(p_, p))), p) == (p,)
@test gradient(p_ -> sum(x*q for (q, x) in Iterators.zip(p_, p)), p) == (p,)
end

y, back = _pullback(Iterators.zip, 1:5, fill(1))
@test @inferred back(collect(y)) == (nothing, [1.0, 0.0, 0.0, 0.0, 0.0], fill(1.0))
end

@testset "collect" begin
Expand Down Expand Up @@ -45,6 +75,28 @@ end
g = gradient(d -> sum(x^2 for x in collect(d)), t)[1]
@test g === (2.0, 4.0)
end

@testset "Iterators.ProductIterator" begin
p = Iterators.product(1:3, 1:2)
g = gradient(p -> sum(prod, collect(p)), p)[1]
@test g == (iterators=(3ones(3), 6ones(2)),)

@test gradient(x -> sum(broadcast(prod, Iterators.product(x,x))), ones(4)) == (2*4ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2, x))), ones(4)) == (3*4ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.product(x, x .^ 2))), ones(4)) == (3*4ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2, x .^ 2))), ones(4)) == (4*4ones(4),)
end

@testset "Iterators.Zip" begin
z = Iterators.zip(1:3, 1:2)
g = gradient(z -> sum(prod, collect(z)), z)[1]
@test g == (is=([1.0, 2.0, 0.0], [1.0, 2.0]),)

@test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x))), ones(4)) == (2ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x))), ones(4)) == (3ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x.^2))), ones(4)) == (3ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x.^2))), ones(4)) == (4ones(4),)
end
end

@testset "dictionary comprehension" begin
Expand Down
Loading