-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overhaul Rules #30
Overhaul Rules #30
Conversation
061197f
to
e11c001
Compare
All the accumulation stuff needs to be rewritten still. |
e238a8d
to
e061404
Compare
I have not delted the AbstractRules yet, as I am yet to workout the story for I guess that will block that PR, but I will work that through as I finish JuliaDiff/ChainRules.jl#91 Good the review this now though, |
Cool. I will add that to this PR tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some small typos in the docstrings, that caught my eye
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job! a bunch of pretty small comments:
look forward to reviewing once the other changes are in!
@@ -59,6 +137,9 @@ For examples, see ChainRulesCore' `rules` directory. | |||
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref) | |||
""" | |||
macro scalar_rule(call, maybe_setup, partials...) | |||
############################################################################ | |||
# Setup: normalizing input form etc | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be broken up into functions? I'd love for this to not be 100 lines long...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i still feel the same
But if it is not easy / does not make sense to you, @oxinabox , then that's fine by me too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wahooo! Thanks 🎉
src/differentials.jl
Outdated
|
||
#### | ||
""" | ||
differential(𝒟::Type, der) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should take primal
and conjugate
as arguments, and depending on 𝒟
return either Wirtinger
or their sum? I think that would make it more clear to rule authors, that when you create a Wirtinger
, you usually also want this fall-through behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could also be useful to check, whether conjugate isa Zero
here, and unwrap Wirtinger
if that's the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should take primal and conjugate as arguments, and depending on 𝒟 return either Wirtinger or their sum?
I'm not sure I understand the advantag of that?
With this we have
differential(𝒟, Wirtinger(primal, conjugatge))
which seems fine.
What is the advantage of
differential(𝒟, primal, conjugatge)
?
It could also be useful to check, whether conjugate isa Zero here, and unwrap Wirtinger if that's the case.
Maybe.
Maybe even iszero(conjugate)
to to get constants like 0
Maybe even for inputs that are scalar:
iszero(der) && return Zero()
and similar for One()
We should discuss that kind thing in an issue and make a follow up PR for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the advantage of
differential(𝒟, primal, conjugatge)
?
I just don't know if differential
is the best name for this function, since it takes a differential and returns a, sometimes different and more efficient, differential again. I would expect a function called differential
to work more like a constructor. Or do we maybe want to call this wirtinger
and have it take 𝒟
, primal
, and conjugate
, since this probably corresponds better to what it does right now? But I also wouldn't feel too strongly just leaving this for now, since I'm also struggling to find a better name for this function.
Maybe.
Maybe eveniszero(conjugate)
to to get constants like0
Maybe even for inputs that are scalar:
iszero(der) && return Zero()
and similar forOne()
I'm not quite convinced we benefit from introducing dispatch based on value here, wouldn't this also cause problem on GPUs? But this is definitely an issue for another day.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we rename it to refine_differential
?
I think this actually also interacts with #8 since we will want to apply something recursively.
I have a concern about this decision. Does it mean that, for example, the reverse rule of function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
return A * B, Rule(Ȳ -> (Ȳ * B', A' * Ȳ))
end (which is like what Zygote.jl is doing at the moment) instead of the current definition function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
return A * B, (Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ))
end where you can compute the derivative w.r.t different argument separately? Wouldn't it be a huge performance loss when large constant arrays are participated in the computation of the intermediate variables that depend on the variables ("trainable parameters") with which the derivatives are taken? For example, in the Generative Adversarial Network (GAN) setting, I think it would be a big issue when taking derivative I'm by no means an AD or ML specialist so I may be missing something. It would be great if you can clarify that my concern is invalid. |
Short answer: don't worry we solve this with Full answer:@tkf a very reasonable concern. And one I used too have
becomes: In partner PR, this is one of the ones I've already updated
That it basically is is a differential that defers computation until it is used.
then since |
@oxinabox Thanks a lot! I appreciate the full explanation. I should have checked the partner PR. |
It is a really important question |
make real scalar rules work. correct @scalarrule forward rule return Wirtinger scalar working work WirtingerRule test as a test of @scalar_rule Fix spelling Co-Authored-By: simeonschaub <simeondavidschaub99@gmail.com> Oxford Comma Co-Authored-By: simeonschaub <simeondavidschaub99@gmail.com> spelling Co-Authored-By: Nick Robinson <npr251@gmail.com> docstring for propagator_name spelling Co-Authored-By: Nick Robinson <npr251@gmail.com>
error ratehr than Assert cleanup Update src/rule_definition_tools.jl Co-Authored-By: Nick Robinson <npr251@gmail.com> Add more complex Wirtinger Scalar Rule Test
update accumulate to work on differentials
Co-Authored-By: Curtis Vogt <curtis.vogt@gmail.com>
spelling is hard
zero the storage inplace
This reverts commit 85b5bf9.
5995297
to
3656389
Compare
Rebased, and squashed some of them. Normally I am hesitant to squash during PR review but this has had a lot of review so far, |
All tests (except inegration tests) are passing. Shuffle rebasing is hard, not sure if worth it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM
Good work!
I have felt handful of tiny comment :)
- The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)` | ||
- The expression being a constant | ||
- The expression being itself a `thunk` | ||
- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire section is great
Shall we move it to a page in the docs?
(Otherwise one can just use a normal `Thunk`). | ||
|
||
Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`; | ||
and destroy its inplacability. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and destroy its inplacability. | |
and destroy its ability to work inplace. |
Base.conj(x::Thunk) = @thunk(conj(extern(x))) | ||
# The real reason we have this: | ||
accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ) | ||
store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. | |
store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ .*= false)) # zero it, then add to it. |
Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place, | ||
storing the result in `Δ`. | ||
|
||
Note: this function may not actually store the result in `Δ` if `Δ` is immutable, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this function may not actually store the result in `Δ` if `Δ` is immutable, | |
!!! note | |
this function may not actually store the result in `Δ` if `Δ` is immutable, |
""" | ||
store!(Δ, ∂) | ||
|
||
Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before. | |
Stores `∂` in `Δ` overwriting whatever was in `Δ` before. |
src/rule_definition_tools.jl
Outdated
Returns the expression for the propagation of | ||
the input gradient `Δs` though the partials `∂s`. | ||
|
||
𝒟 is an expression that when evaluated returns the type-of the input domain. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
𝒟 is an expression that when evaluated returns the type-of the input domain. | |
𝒟 is an expression that when evaluated returns the type of the input domain. |
src/rule_definition_tools.jl
Outdated
function standard_propagation_expr(Δs, ∂s) | ||
# This is basically Δs ⋅ ∂s | ||
|
||
# Notice: the thunking of `∂s[i] (potentially) saves us some computation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Notice: the thunking of `∂s[i] (potentially) saves us some computation | |
# Notice: the thunking of `∂s[i]` (potentially) saves us some computation |
src/rule_definition_tools.jl
Outdated
# Notice: the thunking of `∂s[i] (potentially) saves us some computation | ||
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon | ||
# as the pullback is evaluated | ||
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes :) Is it worth opening an issue (to stare hard at this and figure out if all is well)?
@@ -59,6 +137,9 @@ For examples, see ChainRulesCore' `rules` directory. | |||
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref) | |||
""" | |||
macro scalar_rule(call, maybe_setup, partials...) | |||
############################################################################ | |||
# Setup: normalizing input form etc | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i still feel the same
But if it is not easy / does not make sense to you, @oxinabox , then that's fine by me too
Co-Authored-By: Nick Robinson <npr251@gmail.com>
2bd9b6c
to
e51ff80
Compare
This is a very mighty PR.
DNE()
as not functors) return value for allrrules
, to represent the derviative w.r.t internals of closures/functors, and similar demands an extra input argument at the start of a call tofrule
(ignored for all current cases as not functors) Differentiating with respect to a function #22frule
/rrule
now return a 1 propagator (pushforward/pullback) that returms a tuple of partials, rather than 1 propagator per partial 1 AbstractRule Per Partial, vs 1 AbstractRule returning a tuple of Differentials (one per partial) #38AbstractRule
subtypes are no longer used anywhere Remove Rule (or maybe all AbstractRules) and treat functions as Rules #39@scalar_rule
automatically names pullbacks/pushforwards. Below is what that looks like in Julia Master (with new improved display for gensymed names)Does not look quiet as nice for 1.0 but still useful
It has a corresponding PR to ChainRules.jl
JuliaDiff/ChainRules.jl#91
This is the main blocker for FluxML/Zygote.jl#291