Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with AD and SVD #387

Closed
AlexRobson opened this issue Mar 22, 2021 · 5 comments
Closed

Issue with AD and SVD #387

AlexRobson opened this issue Mar 22, 2021 · 5 comments

Comments

@AlexRobson
Copy link
Member

AlexRobson commented Mar 22, 2021

Originally opened in ParameterHandling: JuliaGaussianProcesses/ParameterHandling.jl#18

In ParameterHandling there is a method called nearest_orthogonal_matrix that is essentially this:

function nearest_orthogonal_matrix(X::StridedMatrix{<:Union{Real, Complex}})
    U, _, V = svd(X)
    return U * V'
end

However, I'm encountering issues in using Zygote to AD through this. This is a MWE example that demonstrates this with FiniteDifferences. I do include a variant that should be identical in the forward pass that appears to be correct in the backward pass, that may help diagnose this.

using Zygote
using ParameterHandling
using ParameterHandling: value
using LinearAlgebra
using FiniteDifferences
using Test

function test_ad(test_function, Δoutput, inputs...; atol=1e-7, rtol=1e-7)

    # Verify that the forwards-pass produces the correct answer.
    output, pb = Zygote.pullback(test_function, inputs...)
    @test output  test_function(inputs...)

    # Compute the adjoints using AD and FiniteDifferences.
    dW_ad = pb(Δoutput)
    dW_fd = FiniteDifferences.j′vp(central_fdm(5, 1), test_function, Δoutput, inputs...)

    # Compare AD and FiniteDifferences results.
    @testset "$(typeof(test_function)) argument $n" for n in eachindex(inputs)
        @test dW_ad[n]  dW_fd[n] atol=atol rtol=rtol
    end
end

function nearest_orthogonal_matrix_variant(X::StridedMatrix{<:Union{Real, Complex}})
    svd(X).U * svd(X).V'
end

# Confirm forward pass is the same
r = rand(3,2)
@test ParameterHandling.nearest_orthogonal_matrix(r) 
    nearest_orthogonal_matrix_variant(r)

# Fails at Expression: ≈(dW_ad[n], dW_fd[n], atol = atol, rtol = rtol)
test_ad(
    ParameterHandling.nearest_orthogonal_matrix,
    randn(3,2),
    randn(3,2),
)

# Passes
test_ad(
    nearest_orthogonal_matrix_variant,
    randn(3,2),
    randn(3,2),
)
  [26cc04aa] FiniteDifferences v0.12.2
  [2412ca09] ParameterHandling v0.3.1
  [e88e6eb3] Zygote v0.6.4
  [37e2e46d] LinearAlgebra
@mzgubic
Copy link
Member

mzgubic commented Mar 22, 2021

Note that Ȳ

function svd_pullback::Composite)

here evaluates to

Composite{Any}(V = [-0.03019941519190618 -0.40107818732356243; 0.9178419801905596 1.4860412597607486],)

for the failing function, rather than including the U sensitivity as well. I suspect the issue might be with how Zygote accumulates the gradients to the factorisation?

@willtebbutt
Copy link
Member

Also note that this

function nearest_orthogonal_matrix_twostep(X::StridedMatrix{<:Union{Real, Complex}})
    # Inlining necessary for type inference for some reason.
    F = svd(X)
    return F.U * F.V'
end

also doesn't work properly (this was pointed out my Miha and Alex in the original thread)

@mzgubic
Copy link
Member

mzgubic commented Mar 22, 2021

Oh, I think I got it, right in the belly of the beast:
https://github.com/FluxML/Zygote.jl/blob/890b6f57303fb983a815aec5fea92620c93a4bd3/src/lib/lib.jl#L213-L226

This function is not defined for literal_getproperty which means the gradient does not get accumulated.

This kind of makes sense since there was a recent reorganisation of that part I believe.

I will work it out more properly and submit a PR

@mzgubic
Copy link
Member

mzgubic commented Mar 22, 2021

Btw, we really need JuliaDiff/ChainRulesTestUtils.jl#114, would be so much easier to debug these things easily

@mzgubic
Copy link
Member

mzgubic commented Mar 25, 2021

closed by FluxML/Zygote.jl#922

@mzgubic mzgubic closed this as completed Mar 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants