Skip to content

Commit

Permalink
Merge pull request #2 from frankschae/mt/interface
Browse files Browse the repository at this point in the history
Fixes gradient, Jacobian, Hessian, and vjp tests
  • Loading branch information
mohamed82008 authored Aug 2, 2021
2 parents da5aece + a9f1538 commit 347d978
Show file tree
Hide file tree
Showing 3 changed files with 514 additions and 37 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ julia = "1"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "FiniteDifferences"]
test = ["Test", "FiniteDifferences", "Random"]
206 changes: 183 additions & 23 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,32 @@ primalvalue(x::Tuple) = map(primalvalue, x)
primalvalue(x) = x

function derivative(ab::AbstractBackend, f, xs::Number...)
return getindex.(jacobian(lowest(ab), f, xs...), 1)
der = getindex.(jacobian(lowest(ab), f, xs...), 1)
if der isa Tuple
return der
else
return (der,)
end
end

function gradient(ab::AbstractBackend, f, xs...)
return adjoint.(jacobian(lowest(ab), f, xs...))
return reshape.(adjoint.(jacobian(lowest(ab), f, xs...)),size.(xs))
end
function jacobian(ab::AbstractBackend, f, xs...) end
function hessian(ab::AbstractBackend, f, xs...)
return jacobian(secondlowest(ab), (xs...,) -> begin
gradient(lowest(ab), f, xs...)
end, xs...)
xss = collect((xs...,))
counter = 0
# gradient returns tuple of gradient values with respect to inputs x,y ∈ xs
# Hessian is the Jacobian of the individual gradients, i.e., the tuple of matrices
# defined by ∂x∂x f, ∂y∂y f, in the case of a scalar valued function `f`.
hess = map((xs...,)) do x
counter += 1
_f = _x->f(setindex!(deepcopy(xss),_x,counter)...)
return jacobian(secondlowest(ab),(x,)-> begin
return gradient(lowest(ab), _f, x)
end, x)[1]
end
return hess
end

function value_and_derivative(ab::AbstractBackend, f, xs::Number...)
Expand All @@ -49,11 +65,15 @@ function value_and_derivative(ab::AbstractBackend, f, xs::Number...)
end
function value_and_gradient(ab::AbstractBackend, f, xs...)
value, jacs = value_and_jacobian(lowest(ab), f, xs...)
return value, adjoint.(jacs)
return value, reshape.(adjoint.(jacs),size.(xs))
end
function value_and_jacobian(ab::AbstractBackend, f, xs...)
local value
primalcalled = false
if lowest(ab) isa AbstractFiniteDifference
value = primalvalue(ab, nothing, f, xs)
primalcalled = true
end
jacs = jacobian(lowest(ab), (_xs...,) -> begin
v = f(_xs...)
if !primalcalled
Expand All @@ -62,11 +82,16 @@ function value_and_jacobian(ab::AbstractBackend, f, xs...)
end
return v
end, xs...)

return value, jacs
end
function value_and_hessian(ab::AbstractBackend, f, xs...)
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primalvalue(ab, nothing, f, xs)
primalcalled = true
end
hess = jacobian(secondlowest(ab), (_xs...,) -> begin
v, g = value_and_gradient(lowest(ab), f, _xs...)
if !primalcalled
Expand Down Expand Up @@ -145,6 +170,10 @@ function value_and_pushforward_function(
@assert ds isa Tuple && length(ds) == length(xs)
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primalvalue(ab, nothing, f, xs)
primalcalled = true
end
pf = pushforward_function(lowest(ab), (_xs...,) -> begin
vs = f(_xs...)
if !primalcalled
Expand Down Expand Up @@ -199,6 +228,10 @@ function value_and_pullback_function(
return (ws) -> begin
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primalvalue(ab, nothing, f, xs)
primalcalled = true
end
if ws === nothing
vs = f(xs...)
if !primalcalled
Expand All @@ -224,63 +257,169 @@ struct LazyDerivative{B, F, X}
f::F
xs::X
end

function Base.:*(d::LazyDerivative, y)
return derivative(d.ab, d.f, d.xs...) * y
return derivative(d.backend, d.f, d.xs...) * y
end

function Base.:*(y, d::LazyDerivative)
return y * derivative(d.ab, d.f, d.xs...)
return y * derivative(d.backend, d.f, d.xs...)
end

function Base.:*(d::LazyDerivative, y::Union{Number,Tuple})
return derivative(d.backend, d.f, d.xs...) .* y
end

function Base.:*(y::Union{Number,Tuple}, d::LazyDerivative)
return y .* derivative(d.backend, d.f, d.xs...)
end

function Base.:*(d::LazyDerivative, y::AbstractArray)
return map((d)-> d*y, derivative(d.backend, d.f, d.xs...))
end

function Base.:*(y::AbstractArray, d::LazyDerivative)
return map((d)-> y*d, derivative(d.backend, d.f, d.xs...))
end


struct LazyGradient{B, F, X}
backend::B
f::F
xs::X
end
Base.:*(d::LazyGradient, y) = gradient(d.ab, d.f, d.xs...) * y
Base.:*(y, d::LazyGradient) = y * gradient(d.ab, d.f, d.xs...)
Base.:*(d::LazyGradient, y) = gradient(d.backend, d.f, d.xs...) * y
Base.:*(y, d::LazyGradient) = y * gradient(d.backend, d.f, d.xs...)

function Base.:*(d::LazyGradient, y::Union{Number,Tuple})
if d.xs isa Tuple
return gradient(d.backend, d.f, d.xs...) .* y
else
return gradient(d.backend, d.f, d.xs) .* y
end
end

function Base.:*(y::Union{Number,Tuple}, d::LazyGradient)
if d.xs isa Tuple
return y .* gradient(d.backend, d.f, d.xs...)
else
return y .* gradient(d.backend, d.f, d.xs)
end
end


struct LazyJacobian{B, F, X}
backend::B
f::F
xs::X
end

function Base.:*(d::LazyJacobian, ys)
return pushforward_function(d.ab, d.f, d.xs...)(ys)
if !(ys isa Tuple)
ys = (ys, )
end
if d.xs isa Tuple
vjp = pushforward_function(d.backend, d.f, d.xs...)(ys)
else
vjp = pushforward_function(d.backend, d.f, d.xs)(ys)
end
if vjp isa Tuple
return vjp
else
return (vjp,)
end
end

function Base.:*(ys, d::LazyJacobian)
if ys isa Tuple
ya = adjoint.(ys)
else
ya = adjoint(ys)
end
return pullback_function(d.ab, d.f, d.xs...)(ya)
if d.xs isa Tuple
return pullback_function(d.backend, d.f, d.xs...)(ya)
else
return pullback_function(d.backend, d.f, d.xs)(ya)
end
end

function Base.:*(d::LazyJacobian, ys::Number)
if d.xs isa Tuple
return jacobian(d.backend, d.f, d.xs...) .* ys
else
return jacobian(d.backend, d.f, d.xs) .* ys
end
end

function Base.:*(ys::Number, d::LazyJacobian)
if d.xs isa Tuple
return jacobian(d.backend, d.f, d.xs...) .* ys
else
return jacobian(d.backend, d.f, d.xs) .* ys
end
end


struct LazyHessian{B, F, X}
backend::B
f::F
xs::X
end

function Base.:*(d::LazyHessian, ys)
return pushforward_function(
secondlowest(d.ab),
(xs...,) -> gradient(lowest(d.ab), d.f, xs...),
d.xs...,
)(ys)
if !(ys isa Tuple)
ys = (ys, )
end

if d.xs isa Tuple
return pushforward_function(
secondlowest(d.backend),
(xs...,) -> gradient(lowest(d.backend), d.f, xs...), d.xs...,)(ys)
else
return pushforward_function(
secondlowest(d.backend),
(xs,) -> gradient(lowest(d.backend), d.f, xs),d.xs,)(ys)
end
end

function Base.:*(ys, d::LazyHessian)
if ys isa Tuple
ya = adjoint.(ys)
else
ya = adjoint(ys)
end
return pullback_function(
secondlowest(d.ab),
(xs...,) -> gradient(lowest(d.ab), d.f, xs...),
d.xs...,
)(ya)
if d.xs isa Tuple
return pullback_function(
secondlowest(d.backend),
(xs...,) -> gradient(lowest(d.backend), d.f, xs...),
d.xs...,
)(ya)
else
return pullback_function(
secondlowest(d.backend),
(xs,) -> gradient(lowest(d.backend), d.f, xs)[1],
d.xs,
)(ya)
end
end

function Base.:*(d::LazyHessian, ys::Number)
if d.xs isa Tuple
return hessian(d.backend, d.f, d.xs...).*ys
else
return hessian(d.backend, d.f, d.xs).*ys
end
end

function Base.:*(ys::Number, d::LazyHessian)
if d.xs isa Tuple
return ys.*hessian(d.backend, d.f, d.xs...)
else
return ys.*hessian(d.backend, d.f, d.xs)
end
end


function lazyderivative(ab::AbstractBackend, f, xs::Number...)
return LazyDerivative(ab, f, xs)
end
Expand Down Expand Up @@ -350,6 +489,20 @@ function define_pushforward_function_and_friends(fdef)
pff(cols)
end
end
elseif eltype(identity_like) <: AbstractMatrix
# needed for the computation of the Hessian and Jacobian
ret = hcat.(mapslices(identity_like[1], dims=1) do cols
# cols loop over basis states
pf = pff((cols,))
if typeof(pf) <: AbstractVector
# to make the hcat. work / get correct matrix-like, non-flat output dimension
return (pf, )
else
return pf
end
end ...)
return ret isa Tuple ? ret : (ret,)

else
return pff(identity_like)
end
Expand All @@ -373,6 +526,13 @@ function define_pullback_function_and_friends(fdef)
value_and_pbf(cols)[2]'
end
end
elseif eltype(identity_like) <: AbstractMatrix
# needed for Hessian computation:
# value is a (grad,). Then, identity_like is a (matrix,).
# cols loops over columns of the matrix
return vcat.(mapslices(identity_like[1], dims=1) do cols
adjoint.(value_and_pbf((cols,))[2])
end ...)
else
return adjoint.(value_and_pbf(identity_like)[2])
end
Expand Down Expand Up @@ -427,4 +587,4 @@ function zero_matrix_like(x)
throw("The function `zero_matrix_like` is not defined for the type $(typeof(x)).")
end

end
end
Loading

0 comments on commit 347d978

Please sign in to comment.