From 4a621653fc717f8501df29ecba2854e657f43ba1 Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Wed, 4 Apr 2018 21:48:56 +0200 Subject: [PATCH] specialize on input function on a few more wrapper functions --- src/apiutils.jl | 4 ++-- src/gradient.jl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/apiutils.jl b/src/apiutils.jl index 222e7bcb..eec39653 100644 --- a/src/apiutils.jl +++ b/src/apiutils.jl @@ -29,13 +29,13 @@ end @inline static_dual_eval(::Type{T}, f, x::SArray) where {T} = f(dualize(T, x)) -function vector_mode_dual_eval(f, x, cfg::Union{JacobianConfig,GradientConfig}) +function vector_mode_dual_eval(f::F, x, cfg::Union{JacobianConfig,GradientConfig}) where {F} xdual = cfg.duals seed!(xdual, x, cfg.seeds) return f(xdual) end -function vector_mode_dual_eval(f!, y, x, cfg::JacobianConfig) +function vector_mode_dual_eval(f!::F, y, x, cfg::JacobianConfig) where {F} ydual, xdual = cfg.duals seed!(xdual, x, cfg.seeds) seed!(ydual, y) diff --git a/src/gradient.jl b/src/gradient.jl index e9a280c5..8721af6d 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -29,7 +29,7 @@ Compute `∇f` evaluated at `x` and store the result(s) in `result`, assuming `f This method assumes that `isa(f(x), Real)`. """ -function gradient!(result::Union{AbstractArray,DiffResult}, f, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK} +function gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK, F} CHK && checktag(T, f, x) if chunksize(cfg) == length(x) vector_mode_gradient!(result, f, x, cfg) @@ -92,13 +92,13 @@ end # vector mode # ############### -function vector_mode_gradient(f, x, cfg::GradientConfig{T}) where {T} +function vector_mode_gradient(f::F, x, cfg::GradientConfig{T}) where {T, F} ydual = vector_mode_dual_eval(f, x, cfg) result = similar(x, valtype(ydual)) return extract_gradient!(T, result, ydual) end -function vector_mode_gradient!(result, f, x, cfg::GradientConfig{T}) where {T} +function vector_mode_gradient!(result, f::F, x, cfg::GradientConfig{T}) where {T, F} ydual = vector_mode_dual_eval(f, x, cfg) result = extract_gradient!(T, result, ydual) return result