-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot deduce type for QuadGK #1599
Comments
Offhand I don't know what the correct reverse rule is but here's some infra for this to be added: using EnzymeCore
function EnzymeCore.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f::Const, segs::Annotation{T}...; kws...) where {RT, T}
prims = map(x->x.val, segs)
res = ofunc.val(f.val, prims; kws...)
retres = if EnzymeRules.needs_primal(config)
res
else
nothing
end
dres = if EnzymeRules.width(config) == 1
zero(res)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zero(res)
end
end
cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
dres
else
nothing
end
return EnzymeCore.EnzymeRules.AugmentedReturn{
EnzymeCore.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
EnzymeCore.EnzymeRules.needs_shadow(config) ? (EnzymeCore.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeCore.EnzymeRules.width(config), eltype(RT)}) : Nothing,
typeof(cache)
}(retres, dres, cache)
end
function EnzymeCore.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Const, segs::Annotation{T}...; kws...) where {RT, T}
TODO
end @stevengj since it looks like you're a code owner, do you have any insights and/or would want to work through the rule together? |
Is the differentiate-then-discretize approximation acceptable (i.e. neglecting the error in the quadrature rule)? Then the Jacobian with respect to a parameter of an integrand is just the integral of the Jacobian of the integrand. Or, in the case of a reverse rule, the integral of vJp of the integrand. I think this is the approach used by Integrals.jl. My student @lxvm worked on this, so cc'ing him. If you want the exact "discretize-then-differentiate" Jacobian (including quadrature/discretization error), i.e. the exact derivative of the approximate integral up to roundoff error, then the simple approach won't work because One slick approach to "discretize-then-differentiate" would be to call |
I think if we write a rule (which makes sense here numerically for the reasons of different quadrature points specified above anyways), we might as well go for the discretize-then-differentiate solution. So one different notion here -- at least in Enzyme reverse mode, is that the original result is required to be computed in the forward pass [and optionally not ever computed if the compiler tells Enzyme it isn't needed and sets needs_primal to false], so I'm not sure how to fuse into one quadgk call. Moreover wouldn't it be more stable to have two distinct quadgk calls to pick the different points, or am I misunderstanding you and/or the magnitude of the performance implications |
(Similarly, for the derivative with respect to an endpoint of the integration domain, there is a simple differentiate-then-discretize rule using the fundamental theorem of calculus. The discretize-then-differentiate rule is more complicated, but can also be implemented by augmenting the integrand.) Another question is whether you want something specific to QuadGK, or if you want a more generic method for Integrals.jl. Unfortunately, to support multiple backends in Integrals.jl it may be harder, since not all solvers support specifying a custom error |
It can't pick the points until it is actually computing the integral, since it adaptively looks at error estimations from the estimated integral so far. |
Unless the quadgk rule would be obnoxious, I think it makes sense to add quadgk, and possibly also integrals later down the line. Reason being is that we've already seen various packages which use quadgk not via integrals, hitting issues as a result. We have something similar where we internally support the various julia solvers in Enzyme directly, and also have a rule within sciml solver packages |
If you have to have two separate calls for the original result and the vJp, then you either have to pay the price for estimating the integrand twice, or add a new API to QuadGK that allows it to cache all of the integrand points and weights, or accept the approximation of the differentiate-then-discretize approach. Actually, it's not crazy to cache all of the integrand points and weights, since effectively QuadGK already saves this information (it builds up a heap of subintervals and knows the quadrature rule for each subinterval). In fact, QuadGK's |
Well we can indeed save arbitrary data so this should be doable if quadgk had the API |
I implemented a new API for this in JuliaMath/QuadGK.jl#108 Now, if you have a function call I, E, segbuf = quadgk_segbuf(...) (There is generally no extra cost to this, since QuadGK computes the quadgk(vJp, ...; ..., eval_segbuf=segbuf, maxevals=0) and it will evaluate the new integrand using exactly the same quadrature rule (the same subintervals). This should make it fairly easy and efficient to implement the exact derivative (discretize-then-differentiate) of the integral estimate (Is there an optional dependency that we can add to QuadGK.jl, analogous to ChainRulesCore.jl, in order to add the AD rules directly to the package?) |
Yeah we can just add a dependency to EnzymeCore
…On Sun, Jul 21, 2024 at 3:30 PM Steven G. Johnson ***@***.***> wrote:
I implemented a new API for this in JuliaMath/QuadGK.jl#108
<JuliaMath/QuadGK.jl#108>
Now, if you have a function call (I, E) = quadgk(...), you can replace it
with:
I, E, segbuf = quadgk(...)
(There is generally no extra cost to this, since QuadGK computes the
segbuf internally anyway.)
Then, on a subsequent call to quadgk, even with a different integrand vJp,
you can pass:
quadgk(vJp, ...; ..., eval_segbuf=segbuf, maxevals=0)
and it will evaluate the new integrand using exactly the same quadrature
rule (the same subintervals).
This should make it fairly easy and efficient to implement the exact
derivative (discretize-then-differentiate) of the integral estimate I
with respect to parameters of the integrand, or with respect to the
endpoints.
(Is there an optional dependency that we can add to QuadGK.jl, analogous
to ChainRulesCore.jl, in order to add the AD rules directly to the package?)
—
Reply to this email directly, view it on GitHub
<#1599 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXGK47YYLGKCMG3KALTZNQD37AVCNFSM6AAAAABKHJYLOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENBRG42DSNRQG4>
.
You are receiving this because you commented.Message ID: <EnzymeAD/Enzyme.
***@***.***>
|
I think @stevengj has the right idea with adding a segbuf API to QuadGK.jl because with a fixed quadrature rule (in this case a composite and adaptive quadrature rule) the discretize-then-differentiate rules are very simple. Although calculus offers very convenient differentiate-then-discretize rules, it is not obvious how to automatically select absolute error tolerances for the adaptive integration of the vJp or Jvp, which as functions may behave differently from the original integrand, and if you think of them as dimensionful quantities they also have different units (this point is related to Steven's earlier point about choosing a So is the following the current proposed scheme?:
|
My understanding is that we won't be using the original function on the second call, since we don't want to pay the price of integrating it again. The proposal instead is to implement a custom reverse rule for |
In particular, suppose that the original integral is Instead, if we evaluate For a reverse rule, we want the vJp in which, hopefully assuming I haven't made any algebra errors, you will get the exact result if you plug the new integrand into the old quadrature rule. ( |
Sorry thinking about this more, I think we actually instead should probably elect the differentiation-then-discretize rather than the other way round. This is mostly a matter of convention -- are people calling the package expecting it to sum up with assumed point boundaries [which would imply the discretize is the right solution], or do they assume that it will generically integrate, with some possible error of integration [at which point the differentiate then discretize is correct]. Of course both of these have their relevant use cases, but thinking on it more differentiate first feels like the more reasonable default for users [who don't say intend to AD the relative error of the discretization itself]. Maybe we come up with an option for being able to specify the behavior and pick as a result |
Although this sounds like an optimization, especially in reverse mode where you may have to cache temporary values of the integrand at all points of evaluation, it could be worthwhile. If I had to implement this in Zygote I would write something like this function ChainRulesCore.rrule(::typeof(quadgk), f, a, b; norm=norm, order=7, kws...)
I, E, segbuf = quadgk_segbuf(x -> f((b-a)*x/2+(a+b)/2), -1, 1; norm, order, kws...)
x, w, wg = QuadGK.cachedrule(typeof((a+b)/2), order)
I_ab, back = Zygote.pullback(f, a, b) do f, a, b
s, c = (b-a)/2, (a+b)/2
sum(QuadGK.evalrule(f, seg.a*s+c, seg.b*s+c, x, w, wg, norm).I for seg in segbuf)
end
return (I_ab, E*(b-a)/2), (dI, _) -> (NoTangent(), back(dI)...)
end
Hopefully the rule I wrote above lets the AD framework perform the calculations you detailed
In the example above I already have. Also, if someone specifies breakpoints, isn't it usually because the integrand is non-differentiable at that point? |
I think that would be good to build in from the beginning, since both behaviors may be desirable. |
If the quadrature is sufficiently converged that you can use differentiate-then-discretize, then you should also be able to use discretize-then-differentiate. The discretize-then-differentiate approach requires more support from QuadGK, but it should be more efficient. Not only do you not have to worry about tolerances, as @lxvm points out, but it also requires fewer function evaluations (usually about half as many) since it skips the adaptive subdivision steps. |
This rule won’t work for infinite limits, in-place integrands, and batched integrands, for example. |
Even if the user imagines that they are solving the problem exactly and is not thinking about the derivative of the error terms, it is often better to AD the discretization error too if it is practical to do so, e.g. if you use this inside an optimization algorithm (which will expect the derivative to predict the first-order change in the computed function). The caveat is that, for adaptive quadrature, if they are updating the integrand parameters then the adaptive quadrature mesh will often change. This effectively introduces small discontinuities into the function that will screw up optimization if they become too large, whether you use differentiate-then-discretize or vice versa. So for adaptive algorithms, in practice, the user will have to ensure that it is sufficiently converged. (But at this point you can use either AD scheme, as I mentioned above, and discretize-then-differentiate can actually be more efficient in principle.) |
It might be worth sticking with the differentiate-then-discretize approach for differentiating with respect to the integration endpoints, however, as in that case the analytical rule is vastly cheaper and easier to compute, and also eliminates a lot of the complications that arise with the discretization-then-discretize approach for multi-point intervals. So, maybe a hybrid scheme: differentiate-then-discretize for the endpoints, but discretize-then-differentiate for parameters of the integrand. This way we get the best of both worlds. (And the user shouldn't care if their integral is sufficiently converged.) |
using EnzymeCore
function EnzymeCore.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f::Const, segs::Annotation{T}...; kws...) where {RT, T}
prims = map(x->x.val, segs)
res = ofunc.val(f.val, prims; kws...)
retres = if EnzymeRules.needs_primal(config)
res
else
nothing
end
dres = if EnzymeRules.width(config) == 1
zero(res)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zero(res)
end
end
cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
dres
else
nothing
end
return EnzymeCore.EnzymeRules.AugmentedReturn{
EnzymeCore.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
EnzymeCore.EnzymeRules.needs_shadow(config) ? (EnzymeCore.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeCore.EnzymeRules.width(config), eltype(RT)}) : Nothing,
typeof(cache)
}(retres, dres, cache)
end
function EnzymeCore.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Const, segs::Annotation{T}...; kws...) where {RT, T}
res = ofunc.val(EnzymeCore.autodiff(Reverse, f.val, segs...); kws...)
ntuple(Val(length(segs))) do i
Base.@_inline_meta
if segs[i] isa Const
nothing
elseif EnzymeCore.EnzymeRules.width(config) == 1
dres * res[i]
else
ntuple(Val(EnzymeCore.EnzymeRules.width(config))) do j
Base.@_inline_meta
dres * res[i][j]
end
end
end
end |
Draft 2, missing using Revise, Enzyme, QuadGK, EnzymeCore, LinearAlgebra
function EnzymeCore.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T<:Real}
prims = map(x->x.val, segs)
I, E, segbuf = quadgk_segbuf(f.val, prims...; kws...)
retres = if EnzymeRules.needs_primal(config)
res = I, E
else
nothing
end
dres = if !EnzymeCore.EnzymeRules.needs_shadow(config)
nothing
elseif EnzymeRules.width(config) == 1
zero.(res...)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zero.(res...)
end
end
cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
dres
else
nothing
end
cache2 = segbuf, cache
return EnzymeCore.EnzymeRules.AugmentedReturn{
EnzymeCore.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
EnzymeCore.EnzymeRules.needs_shadow(config) ? (EnzymeCore.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeCore.EnzymeRules.width(config), eltype(RT)}) : Nothing,
typeof(cache2)
}(retres, dres, cache2)
end
function call(f, x)
f(x)
end
struct ClosureVector{F}
f::F
end
function Base.:+(a::ClosureVector, b::ClosureVector)
return a
# throw(AssertionError("todo +"))
end
function Base.:-(a::ClosureVector, b::ClosureVector)
return a+(-1*b)
end
function Base.:*(a::Number, b::ClosureVector)
return b
# throw(AssertionError("todo +"))
end
function Base.:*(a::ClosureVector, b::Number)
return b*a
end
function EnzymeCore.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f, segs::Annotation{T}...; kws...) where {T<:Real}
# res = ofunc.val(EnzymeCore.autodiff(Reverse, f.val, segs...); kws...)
df = if f isa Const
nothing
else
segbuf = cache[1]
fwd, rev = EnzymeCore.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
tape, prim, shad = fwd(Const(call), f, Const(x))
drev = rev(Const(call), f, Const(x), dres.val[1], tape)
return ClosureVector(drev[1][1])
end
_df.f
end
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1])
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1])
return (df, # f
dsegs1,
ntuple(i -> nothing, Val(length(segs)-2))...,
dsegsn)
end |
MWE:
Output: https://gist.githubusercontent.com/mhauru/25e6fa41671b94cb392b9df01dd8f821/raw/3d175b8d8a08216df8b2d6696b0635b430f3ecc0/QuadGK_cannot_deduce_type
The text was updated successfully, but these errors were encountered: