Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix adjoint Iterators.product behavior with nothing #1170

Merged
merged 4 commits into from
Feb 23, 2022

Conversation

lxvm
Copy link
Contributor

@lxvm lxvm commented Feb 18, 2022

Hi,
This PR fixes the issue illustrated below where a DimesionMismatch occurs depending on the position of nothing arguments to the pullback of Iterators.product.

using Zygote

function aget()
    y, back = Zygote._pullback(Iterators.product, 1:5, 1:3, 1:2)
end

function atest0()
    y, back = aget()
    t0 = [(i, j, k) for i in 1:5, j in 1:3, k in 1:2]
    back(t0)
    # OK
end
function atest1()
    y, back = aget()
    t1 = [(nothing, j, k) for i in 1:5, j in 1:3, k in 1:2]
    back(t1)
    # ERROR: DimensionMismatch("new dimensions (3,) must be consistent with array size 5")
end
function atest2()
    y, back = aget()
    t2 = [(i, nothing, k) for i in 1:5, j in 1:3, k in 1:2]
    back(t2)
    # ERROR: DimensionMismatch("new dimensions (2,) must be consistent with array size 3")
end
function atest3()
    y, back = aget()
    t3 = [(i, j, nothing) for i in 1:5, j in 1:3, k in 1:2]
    back(t3)
    # OK
end
function atest4()
    y, back = aget()
    t4 = [(nothing, nothing, k) for i in 1:5, j in 1:3, k in 1:2]
    back(t4)
    # ERROR: DimensionMismatch("new dimensions (2,) must be consistent with array size 5")
end
function atest5()
    y, back = aget()
    t5 = [(nothing, j, nothing) for i in 1:5, j in 1:3, k in 1:2]
    back(t5)
    # ERROR: DimensionMismatch("new dimensions (3,) must be consistent with array size 5")
end
function atest6()
    y, back = aget()
    t6 = [(i, nothing, nothing) for i in 1:5, j in 1:3, k in 1:2]
    back(t6)
    # OK
end
function atest7()
    y, back = aget()
    t7 = [(nothing, nothing, nothing) for i in 1:5, j in 1:3, k in 1:2]
    back(t7)
    # OK
end

I hope this PR enforces the desired behavior of the pullback in this edge case, but I don't know enough about Zygote to understand where the nothing arguments are originating from (I noticed this bug while trying to differentiate some other code).
I can reference:

With this PR, the results of the tests defined above are:

julia> atest0()
(nothing, [6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], [15.0, 30.0])

julia> atest1()
(nothing, nothing, [10.0, 20.0, 30.0], [15.0, 30.0])

julia> atest2()
(nothing, [6.0, 12.0, 18.0, 24.0, 30.0], nothing, [15.0, 30.0])

julia> atest3()
(nothing, [6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], nothing)

julia> atest4()
(nothing, nothing, nothing, [15.0, 30.0])

julia> atest5()
(nothing, nothing, [10.0, 20.0, 30.0], nothing)

julia> atest6()
(nothing, [6.0, 12.0, 18.0, 24.0, 30.0], nothing, nothing)

julia> atest7()
(nothing, nothing, nothing, nothing)

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks!

Could you add (at least some of) these tests to the tests?

If you can invent a function whose gradient produces this error, that would also be nice to have in the tests.

src/lib/array.jl Outdated Show resolved Hide resolved
lxvm and others added 2 commits February 21, 2022 15:01
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@lxvm
Copy link
Contributor Author

lxvm commented Feb 22, 2022

Thank you for reviewing the change!
I've added some tests which should check for the expected behavior.

I can't seem to come up with a mwe of a gradient that leads to this error.
When I find some more time I will try to create one, because I'm not sure where the nothing came from and I believe it is causing my gradients to vanish.
Is this pull request adequate?

The context in which this error arose is in computing a gradient for an electromagnetic simulation (specifically this one: https://github.com/lxvm/DeltaRCWA.jl/blob/1c445ab8d72129350de275ab782283d7ebac221e/scripts/zygote.jl#L41) where I used a product iterator to construct a matrix (the stacktrace from the gradient pointed to here: https://github.com/lxvm/DeltaRCWA.jl/blob/1c445ab8d72129350de275ab782283d7ebac221e/src/sheets.jl#L312).
I am confused about the reason for the error.

@mcabbott
Copy link
Member

I see, the nothing is from here:

julia> gradient(k⃗ -> k⃗[2], (1,2,3))
((nothing, 1.0, nothing),)

and so a way to get the wrong answer is:

julia> gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])
([224.0, 288.0, 352.0, 416.0],)

julia> ForwardDiff.gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])
4-element Vector{Int64}:
 320
 320
 320
 320

julia> gradient(x -> sum(y[2] * y[3] for y in Iterators.product(x, x, x, x)), [1,2,3,4])  # different path, misses the rule?
([320.0, 320.0, 320.0, 320.0],)

(@v1.8) pkg> st Zygote
Status `~/.julia/environments/v1.8/Project.toml`
  [e88e6eb3] Zygote v0.6.34

Thanks for tracking this down!

@lxvm
Copy link
Contributor Author

lxvm commented Feb 22, 2022

Thank you for noticing those things!
I'll see if I can find out why my code calls the rule by using the debugger.

test/lib/array.jl Outdated Show resolved Hide resolved
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@lxvm
Copy link
Contributor Author

lxvm commented Feb 22, 2022

The last test looks good!

@mcabbott mcabbott merged commit 8b9d67d into FluxML:master Feb 23, 2022
@mcabbott
Copy link
Member

Thanks!

@lxvm lxvm mentioned this pull request Jan 2, 2024
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants