Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes gradient, Jacobian, Hessian, and vjp tests #2

Merged
merged 15 commits into from
Aug 2, 2021
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ julia = "1"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "FiniteDifferences"]
test = ["Test", "FiniteDifferences", "Random"]
36 changes: 35 additions & 1 deletion src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ end
function value_and_jacobian(ab::AbstractBackend, f, xs...)
local value
primalcalled = false
if ab isa AbstractFiniteDifference
frankschae marked this conversation as resolved.
Show resolved Hide resolved
value = primalvalue(ab, nothing, f, xs)
primalcalled = true
end
jacs = jacobian(lowest(ab), (_xs...,) -> begin
v = f(_xs...)
if !primalcalled
Expand All @@ -62,11 +66,16 @@ function value_and_jacobian(ab::AbstractBackend, f, xs...)
end
return v
end, xs...)

return value, jacs
end
function value_and_hessian(ab::AbstractBackend, f, xs...)
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primalvalue(ab, nothing, f, xs)
primalcalled = true
end
hess = jacobian(secondlowest(ab), (_xs...,) -> begin
v, g = value_and_gradient(lowest(ab), f, _xs...)
if !primalcalled
Expand Down Expand Up @@ -145,6 +154,10 @@ function value_and_pushforward_function(
@assert ds isa Tuple && length(ds) == length(xs)
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primalvalue(ab, nothing, f, xs)
primalcalled = true
end
pf = pushforward_function(lowest(ab), (_xs...,) -> begin
vs = f(_xs...)
if !primalcalled
Expand Down Expand Up @@ -199,6 +212,10 @@ function value_and_pullback_function(
return (ws) -> begin
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primalvalue(ab, nothing, f, xs)
primalcalled = true
end
if ws === nothing
vs = f(xs...)
if !primalcalled
Expand Down Expand Up @@ -350,6 +367,19 @@ function define_pushforward_function_and_friends(fdef)
pff(cols)
end
end
elseif eltype(identity_like) <: AbstractMatrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please comment this part? It's not clear to me what's happening here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the fix I added for the computation of the Hessian. identity_like and cols look as follows in that case:

identity_like = ([1.0 0.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0; 0.0 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 0.0 1.0],)
identity_like[1] = [1.0 0.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0; 0.0 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 0.0 1.0]
cols = [1.0, 0.0, 0.0, 0.0, 0.0]
cols = [0.0, 1.0, 0.0, 0.0, 0.0]
cols = [0.0, 0.0, 1.0, 0.0, 0.0]
cols = [0.0, 0.0, 0.0, 1.0, 0.0]
cols = [0.0, 0.0, 0.0, 0.0, 1.0]

so it was mainly to fix the input/output for the additional pushforward that is used. I'll need to check in a bit more detail if one can simplify that function a bit more. IIRC without that function I got dimension errors in the jvp of FiniteDifferences.jl because it would have pushed forward matrices like identity_like[1] instead of the columns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I think you are right that this needs some additional care.. I think the current version will break down once we'd like to compute a tuple of Hessians.

ret = hcat.(mapslices(identity_like[1], dims=1) do cols
pf = pff((cols,))
if typeof(pf) <: AbstractVector
return (pf, )
elseif typeof(pf) <: AbstractMatrix
return (transpose(pf), )
else
return pf
end
end ...)
return ret isa Tuple ? ret : (ret,)

else
return pff(identity_like)
end
Expand All @@ -373,6 +403,10 @@ function define_pullback_function_and_friends(fdef)
value_and_pbf(cols)[2]'
end
end
elseif eltype(identity_like) <: AbstractMatrix
return vcat.(mapslices(identity_like[1], dims=1) do cols
adjoint.(value_and_pbf((cols,))[2])
end ...)
else
return adjoint.(value_and_pbf(identity_like)[2])
end
Expand Down Expand Up @@ -427,4 +461,4 @@ function zero_matrix_like(x)
throw("The function `zero_matrix_like` is not defined for the type $(typeof(x)).")
end

end
end
56 changes: 49 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using AbstractDifferentiation
using Test, FiniteDifferences, LinearAlgebra
using Random
Random.seed!(1234)

const FDM = FiniteDifferences

Expand Down Expand Up @@ -30,23 +32,49 @@ const fdm_backend3 = FDMBackend3()
AD.@primitive function pullback_function(ab::FDMBackend3, f, xs...)
return function (vs)
# Supports only single output
@assert length(vs) == 1
return j′vp(ab.alg, f, vs[1], xs...)
if vs isa AbstractVector
return j′vp(ab.alg, f, vs, xs...)
else
@assert length(vs) == 1
return j′vp(ab.alg, f, vs[1], xs...)

end
end
end

fder(x, y) = exp(y) * x + y * log(x)
dfderdx(x, y) = exp(y) + y * 1/x
dfderdy(x, y) = exp(y) * x + log(x)

fgrad(x, y) = prod(x) + sum(y ./ (1:length(y)))
dfgraddx(x, y) = prod(x)./x
dfgraddy(x, y) = one(eltype(y)) ./ (1:length(y))
dfgraddxdx(x, y) = prod(x)./(x*x') - Diagonal(diag(prod(x)./(x*x')))

function fjac(x, y)
x + Bidiagonal(-ones(length(y)) * 3, ones(length(y) - 1) / 2, :U) * y
end
dfjacdx(x, y) = I(length(x))
dfjacdy(x, y) = Bidiagonal(-ones(length(y)) * 3, ones(length(y) - 1) / 2, :U)

# Jvp
jxvp(x,y,v) = dfjacdx(x,y)*v
jyvp(x,y,v) = dfjacdy(x,y)*v

# vJp
vJxp(x,y,v) = dfjacdx(x,y)'*v
vJyp(x,y,v) = dfjacdy(x,y)'*v

const xscalar = rand()
const yscalar = rand()

const xvec = rand(5)
const yvec = rand(5)

# to check if vectors get mutated
xvec2 = deepcopy(xvec)
yvec2 = deepcopy(yvec)

function test_fdm_derivatives(fdm_backend)
der1 = AD.derivative(fdm_backend, fder, xscalar, yscalar)
der2 = (
Expand All @@ -57,6 +85,8 @@ function test_fdm_derivatives(fdm_backend)
valscalar, der3 = AD.value_and_derivative(fdm_backend, fder, xscalar, yscalar)
@test valscalar == fder(xscalar, yscalar)
@test der3 .- der1 == (0, 0)
der_exact = (dfderdx(xscalar,yscalar), dfderdy(xscalar,yscalar))
@test minimum(isapprox.(der_exact, der1, rtol=1e-10))
end

function test_fdm_gradients(fdm_backend)
Expand All @@ -66,6 +96,10 @@ function test_fdm_gradients(fdm_backend)
valscalar, grad3 = AD.value_and_gradient(fdm_backend, fgrad, xvec, yvec)
@test valscalar == fgrad(xvec, yvec)
@test norm.(grad3 .- grad1) == (0, 0)
grad_exact = (dfgraddx(xvec,yvec), dfgraddy(xvec,yvec))
@test minimum(isapprox.(grad_exact, grad1, rtol=1e-10))
@test xvec == xvec2
@test yvec == yvec2
end

function test_fdm_jacobians(fdm_backend)
Expand All @@ -75,6 +109,10 @@ function test_fdm_jacobians(fdm_backend)
valvec, jac3 = AD.value_and_jacobian(fdm_backend, fjac, xvec, yvec)
@test valvec == fjac(xvec, yvec)
@test norm.(jac3 .- jac1) == (0, 0)
grad_exact = (dfjacdx(xvec, yvec), dfjacdy(xvec, yvec))
@test minimum(isapprox.(grad_exact, jac1, rtol=1e-10))
@test xvec == xvec2
@test yvec == yvec2
end

function test_fdm_hessians(fdm_backend)
Expand All @@ -99,6 +137,12 @@ function test_fdm_hessians(fdm_backend)
@test valscalar == fgrad(xvec, yvec)
@test norm.(grad .- AD.gradient(fdm_backend, fhess, xvec)) == (0,)
@test norm.(hess4 .- hess1) == (0,)
@test dfgraddxdx(xvec,yvec) ≈ hess1[1] atol=1e-10
@test xvec == xvec2
@test yvec == yvec2
fhess2 = x-> dfgraddx(x, yvec)
hess5 = AD.jacobian(fdm_backend, fhess2, xvec)
@test minimum(isapprox.(hess5, hess1, atol=1e-10))
end

function test_fdm_jvp(fdm_backend)
Expand All @@ -122,6 +166,9 @@ function test_fdm_j′vp(fdm_backend)
valvec, pb3 = AD.value_and_pullback_function(fdm_backend, fjac, xvec, yvec)(w)
@test valvec == fjac(xvec, yvec)
@test norm.(pb3 .- pb1) == (0, 0)
@test minimum(isapprox.(pb1, (vJxp(xvec,yvec,w), vJyp(xvec,yvec,w)), atol=1e-10))
@test xvec == xvec2
@test yvec == yvec2
end

@testset "AbstractDifferentiation.jl" begin
Expand All @@ -139,28 +186,23 @@ end
@testset "Jacobian" begin
test_fdm_jacobians(fdm_backend1)
test_fdm_jacobians(fdm_backend2)
# Errors
test_fdm_jacobians(fdm_backend3)
end
@testset "Hessian" begin
# Works but super slow
test_fdm_hessians(fdm_backend1)
# Errors
test_fdm_hessians(fdm_backend2)
# Errors
test_fdm_hessians(fdm_backend3)
end
@testset "jvp" begin
test_fdm_jvp(fdm_backend1)
# Errors
test_fdm_jvp(fdm_backend2)
# Errors
test_fdm_jvp(fdm_backend3)
end
@testset "j′vp" begin
test_fdm_j′vp(fdm_backend1)
test_fdm_j′vp(fdm_backend2)
# Errors
test_fdm_j′vp(fdm_backend3)
end
end
Expand Down