Skip to content

Commit

Permalink
specialize on input function on a few more wrapper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
KristofferC authored and jrevels committed Apr 6, 2018
1 parent 04c751f commit 4a62165
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/apiutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ end

@inline static_dual_eval(::Type{T}, f, x::SArray) where {T} = f(dualize(T, x))

function vector_mode_dual_eval(f, x, cfg::Union{JacobianConfig,GradientConfig})
function vector_mode_dual_eval(f::F, x, cfg::Union{JacobianConfig,GradientConfig}) where {F}
xdual = cfg.duals
seed!(xdual, x, cfg.seeds)
return f(xdual)
end

function vector_mode_dual_eval(f!, y, x, cfg::JacobianConfig)
function vector_mode_dual_eval(f!::F, y, x, cfg::JacobianConfig) where {F}
ydual, xdual = cfg.duals
seed!(xdual, x, cfg.seeds)
seed!(ydual, y)
Expand Down
6 changes: 3 additions & 3 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Compute `∇f` evaluated at `x` and store the result(s) in `result`, assuming `f
This method assumes that `isa(f(x), Real)`.
"""
function gradient!(result::Union{AbstractArray,DiffResult}, f, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK}
function gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK, F}
CHK && checktag(T, f, x)
if chunksize(cfg) == length(x)
vector_mode_gradient!(result, f, x, cfg)
Expand Down Expand Up @@ -92,13 +92,13 @@ end
# vector mode #
###############

function vector_mode_gradient(f, x, cfg::GradientConfig{T}) where {T}
function vector_mode_gradient(f::F, x, cfg::GradientConfig{T}) where {T, F}
ydual = vector_mode_dual_eval(f, x, cfg)
result = similar(x, valtype(ydual))
return extract_gradient!(T, result, ydual)
end

function vector_mode_gradient!(result, f, x, cfg::GradientConfig{T}) where {T}
function vector_mode_gradient!(result, f::F, x, cfg::GradientConfig{T}) where {T, F}
ydual = vector_mode_dual_eval(f, x, cfg)
result = extract_gradient!(T, result, ydual)
return result
Expand Down

2 comments on commit 4a62165

@axsk
Copy link

@axsk axsk commented on 4a62165 Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouln't these changes also be applied to the non-mutating gradient() and all jacobian calls as well? c.f. #516

@j-fu
Copy link
Contributor

@j-fu j-fu commented on 4a62165 Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.