-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ChainRules #366
Add ChainRules #366
Conversation
@MikeInnes any idea why this segfaults? |
It could be a Zygote compiler issue, but given that you haven't touched that, the second most likely option is that it's just a stackoverflow. |
What makes you think stackoverflow? |
IME, Julia tends to segfault due to stack overflows. Doesn't happen every time, but often. |
It works! 1 of the blackslists can be removed immediately with JuliaDiff/ChainRules.jl#124 |
e47440a
to
74396b4
Compare
310ae70
to
bc8ac2c
Compare
I'm concerned about the performance implications of this, in its current form. In particular we have to be pretty careful to manage the call stack; recursing many times into Also, presumably many of Zygote's current adjoints are now redundant, given this PR. It'd be nice to remove those to show the overall impact on Zygote's codebase (hopefully a significant net improvement). |
Which bits do you think are adding levels? How every you count it, adding deeper integation and usiing ChainRules types (rather than
Easiest way to do would be to add convience functions to ChainRulesCore I'ld rather not be doing that alone, and I'ld rather not have that blocking this PR. |
If the user-defined call tree is Previously, it would have been |
Where can I learn about this type inference hack? |
I don't think it's documented, but you just set a flag on the codeinfo, as here. I think manual inlining / having one uber-function for |
Re inference rather than using the intended if 2 problems:
So I guess I have to put the whole thing including the decision of if to use a ChainRule or source2source But if I do that, don't I end up not having the compile know about what functions are called, Is there a way to set Maybe I could overload https://github.com/JuliaLang/julia/blob/a63f2e9b26751ef2d1522fa2634ee9d56db8528d/base/compiler/utilities.jl#L138| |
Maybe could go the otherway and make |
3e95942
to
951dac6
Compare
@MikeInnes @dhairyagandhi96 how are we going to resolve this? |
951dac6
to
7087cbd
Compare
This is much nicer now. |
Update docs/src/adjoints.md Co-Authored-By: Nick Robinson <npr251@gmail.com> fix typo in docs delete debug printing
linkl top chainrules issue about fastmath
fix typo Pin IRTools to 0.3.2 because FluxML/IRTools.jl#58
Co-authored-by: Pietro Vertechi <pietro.vertechi@protonmail.com> Update docs/src/adjoints.md
Co-authored-by: AzamatB <aberdysh@gmail.com>
make kwargs work Update src/compiler/chainrules.jl Update src/compiler/chainrules.jl and chainrules kwarg tests
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Decide if is kwfunc at compile time.
…es (#1) * ChainRules pullbacks always have 1 input JuliaDiff/ChainRulesCore.jl#152 * swap to version of chainrules that don't use multiarg pullbacks * update tests * make so don't need custom rule anymore * add comment * Update src/compiler/chainrules.jl Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
25b9aa6
to
06ad874
Compare
06ad874
to
41f4c17
Compare
bors r+ |
Build succeeded: |
This replaces #291
The bits from that OP that still matter
Step 1) Change Zygote to check for chainrules before doing its normal stuff,
and adapt the stuff it gets back from chainrules to play nice with Zygote's expectations
Step 2) adapt Zygote more deeply, so it can take full advantage of thunks etc.
This PR is Step 1.
TODO: workout why this seems to segfault for me.