Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with AD when using orthogonal #18

Closed
AlexRobson opened this issue Mar 20, 2021 · 8 comments
Closed

Issue with AD when using orthogonal #18

AlexRobson opened this issue Mar 20, 2021 · 8 comments

Comments

@AlexRobson
Copy link
Contributor

AlexRobson commented Mar 20, 2021

I'm encountering issues in using Zygote to AD through orthogonal. I note that in the present test, we do test that AD works, but as I understand it, we do not test for correctness. .

This is a MWE example that demonstrates this with FiniteDifferences. I do include a variant that should be identical in the forward pass that appears to be correct in the backward pass. I'm not too familiar with the AD ecosystem, but it is possible this should be an issue in ChainRules because that is where the rrules for SVD are defined. Perhaps the rules there are not quite catching the the algebra in nearest_orthogonal_matrix correctly?

using Zygote
using ParameterHandling
using ParameterHandling: value
using LinearAlgebra
using FiniteDifferences
using Test

function test_ad(test_function, Δoutput, inputs...; atol=1e-7, rtol=1e-7)

    # Verify that the forwards-pass produces the correct answer.
    output, pb = Zygote.pullback(test_function, inputs...)
    @test output  test_function(inputs...)

    # Compute the adjoints using AD and FiniteDifferences.
    dW_ad = pb(Δoutput)
    dW_fd = FiniteDifferences.j′vp(central_fdm(5, 1), test_function, Δoutput, inputs...)

    # Compare AD and FiniteDifferences results.
    @testset "$(typeof(test_function)) argument $n" for n in eachindex(inputs)
        @test dW_ad[n]  dW_fd[n] atol=atol rtol=rtol
    end
end

function nearest_orthogonal_matrix_variant(X::StridedMatrix{<:Union{Real, Complex}})
    svd(X).U * svd(X).V'
end

# Confirm forward pass is the same
r = rand(3,2)
@test ParameterHandling.nearest_orthogonal_matrix(r) 
    nearest_orthogonal_matrix_variant(r)

# Fails at Expression: ≈(dW_ad[n], dW_fd[n], atol = atol, rtol = rtol)
test_ad(
    ParameterHandling.nearest_orthogonal_matrix,
    randn(3,2),
    randn(3,2),
)

# Passes
test_ad(
    nearest_orthogonal_matrix_variant,
    randn(3,2),
    randn(3,2),
)
  [26cc04aa] FiniteDifferences v0.12.2
  [2412ca09] ParameterHandling v0.3.1
  [e88e6eb3] Zygote v0.6.4
  [37e2e46d] LinearAlgebra
@willtebbutt
Copy link
Member

Oooo this is interesting.

Looking at our svd implementation here, it doesn't look like we account for the de-structured version where you immediately pull out the parameters.

@oxinabox do you have any idea how the de-structuring mechanism works? I've just taken a look at LinearAlgebra, but it's not at all obvious to me how the U, s, V = svd(X) implementation is implemented...

@mzgubic
Copy link

mzgubic commented Mar 22, 2021

Could you elaborate on what you think the issue is?

function nearest_orthogonal_matrix_twostep(X::StridedMatrix{<:Union{Real, Complex}})
           # Inlining necessary for type inference for some reason.
           F = svd(X)
           return F.U * F.V'
       end

also fails, so I am not sure how de-structuring is an issue?

@willtebbutt
Copy link
Member

willtebbutt commented Mar 22, 2021

Oh, I hadn't realised that. I'm now thoroughly confused 😂

edit: so it sounds to me like some kind of an accumulation issue, or an issue when there are non-zero cotangents for multiple bits of the svd.

@mzgubic
Copy link

mzgubic commented Mar 22, 2021

Yeah, same, it looks like the accumulation of gradients of F is broken?

@willtebbutt
Copy link
Member

Yeah, that's the only thing I can come up with.

In any case, this issue definitely seems like it warrants a ChainRules issue. @AlexRobson could you please open a linked issue on ChainRules and copy over your MWE?

@AlexRobson
Copy link
Contributor Author

nearest_orthogonal_matrix_twostep(X::StridedMatrix{<:Union{Real, Complex}})

Yeah, I had meant to edit in that I had tried this one too. Also, breaking it up into

nearest_orthogonal_matrix(F::SVD) = F.U * F.V'
nearest_orthogonal_matrix(X::StridedMatrix{<:Union{Real, Complex}}) = nearest_orthogonal_matrix(svd(X))

@AlexRobson
Copy link
Contributor Author

In any case, this issue definitely seems like it warrants a ChainRules issue. @AlexRobson could you please open a linked issue on ChainRules and copy over your MWE?

Yep

@mzgubic
Copy link

mzgubic commented Mar 25, 2021

Solved by FluxML/Zygote.jl#922

@mzgubic mzgubic closed this as completed Mar 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants