Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChainRules.jl definitions #148

Closed
ChrisRackauckas opened this issue Dec 24, 2019 · 15 comments
Closed

Add ChainRules.jl definitions #148

ChrisRackauckas opened this issue Dec 24, 2019 · 15 comments

Comments

@ChrisRackauckas
Copy link
Member

Just like the new Zygote @adjoint, we can define how to lower solve to different choices of forward mode and adjoint definitions. I might need help from @oxinabox here.

@ChrisRackauckas
Copy link
Member Author

@YingboMa since you have ChainRules experience, could you take a crack at this?

@YingboMa
Copy link
Member

Yeah, I can take a crack after the break.

@oxinabox
Copy link
Contributor

What does "different choices" mean?

As in multiple ways to compute the same pullback/pushforward ?
That has been something I have been wondering about for a while.
If it were needed, and if so how to support it.

@ChrisRackauckas
Copy link
Member Author

As in multiple ways to compute the same pullback/pushforward ?

Yes, with very different memory and performance characteristics. The way we handle it is that we have a keyword argument for what algorithm to use, then that drops down to a helper function to dispatch on it:

https://github.com/JuliaDiffEq/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/sensitivity_interface.jl#L1-L3

So as long as kwargs are supported I think we're good? Forward sensitivities are harder.

But I realized we cannot support all operations on the solution type, since sol.k holds intermediate variables that would be too costly to track in the adjoint sensitivities implementations but is required to do for the continuous interpolation, so my idea is to introduce concretesolve which is Array(solve and adapted forms. That would be where we put @adjoint on, and it would generalize diffeq_adjoint/diffeq_rd, etc. @YingboMa what do you think?

@oxinabox
Copy link
Contributor

Yeah, that will work.
I wonder if we should have a general pattern for this in ChainRules.
cc @willtebbutt

@willtebbutt
Copy link
Contributor

I'm not sure if we really need to formalise anything @oxinabox . I assume that the kwarg approach here is sufficient for your needs @ChrisRackauckas ?

Maybe it would be sufficient to provide examples of this in the ChainRulesCore docs? Perhaps specify some conventions?

@ChrisRackauckas
Copy link
Member Author

This is sufficient for me, as long as I can do the same kind of style with a ChainRules.jl definition.

@willtebbutt
Copy link
Contributor

This is sufficient for me, as long as I can do the same kind of style with a ChainRules.jl definition.

Yeah -- I don't know how we would really stop you haha.

@ChrisRackauckas
Copy link
Member Author

Oh, I just didn't know I could do kwargs and splatting now. This is gonna be easy then.

@YingboMa
Copy link
Member

.. so my idea is to introduce concretesolve which is Array(solve and adapted forms. That would be where we put @adjoint on, and it would generalize diffeq_adjoint/diffeq_rd, etc.

Yeah, I think that is a good idea.

@ChrisRackauckas
Copy link
Member Author

It should all be setup now. Can't really test it until an AD can use it though, but it's a start :)

@oxinabox
Copy link
Contributor

Can't really test it until an AD can use it though, but it's a start :)

Check out this branch of Zygote, and comment out your ZygoteRules
FluxML/Zygote.jl#366

@ChrisRackauckas
Copy link
Member Author

I'll wait until we can really add tests. But you know I'll be the first adopter of the final form :)

@oxinabox
Copy link
Contributor

oxinabox commented Dec 27, 2019

But you know I'll be the first adopter of the final form :)

image

TBH, this is pretty close to the final form as far as user-facing is concerned. (baring optional convienece macros)
Might be some more changes as far as AD facing.

@ChrisRackauckas
Copy link
Member Author

Yeah, that's why I felt the need to set it up. I'm 90% sure that when an AD turns ChainRules on, it'll just work. There's a 10% chance of a typo, and handling the frule likely needs a detail (since currently it cannot handle du0, so I require that to be nothing)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants