Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul Rules #30

Merged
merged 23 commits into from
Sep 17, 2019
Merged

Overhaul Rules #30

merged 23 commits into from
Sep 17, 2019

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Aug 27, 2019

This is a very mighty PR.

julia> _, pushforward = frule(sin, 1)
(0.8414709848078965, ChainRules.var"##75#sin_pushforward#55"{Int64}(1))

Does not look quiet as nice for 1.0 but still useful

julia> _, pushforward = frule(sin, 1)
(0.8414709848078965, getfield(ChainRules, Symbol("##75#sin_pushforward#55")){Int64}(1))

It has a corresponding PR to ChainRules.jl

JuliaDiff/ChainRules.jl#91

This is the main blocker for FluxML/Zygote.jl#291

@oxinabox oxinabox changed the title WIP Derivative wrt function WIP: Derivative wrt function Aug 27, 2019
@oxinabox
Copy link
Member Author

All the accumulation stuff needs to be rewritten still.

@oxinabox oxinabox changed the title WIP: Derivative wrt function Overhaul Rules Sep 2, 2019
@oxinabox
Copy link
Member Author

oxinabox commented Sep 2, 2019

I have not delted the AbstractRules yet, as I am yet to workout the story for store!
and for 2 arg rules that know how to update!.

I guess that will block that PR, but I will work that through as I finish JuliaDiff/ChainRules.jl#91
and need them.
Which I also want done before merging this.

Good the review this now though,
it is a big PR and the vast majority of changes are in.

@simeonschaub
Copy link
Member

@oxinabox I've created a test case for mixed Wirtinger derivatives here. This should help making sure, that this works correctly.

@oxinabox
Copy link
Member Author

oxinabox commented Sep 2, 2019

Cool. I will add that to this PR tomorrow.
At least to check all the returned types are right.

Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some small typos in the docstrings, that caught my eye

src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/differentials.jl Outdated Show resolved Hide resolved
src/differentials.jl Outdated Show resolved Hide resolved
src/differentials.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/differentials.jl Outdated Show resolved Hide resolved
Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job! a bunch of pretty small comments:

look forward to reviewing once the other changes are in!

@@ -59,6 +137,9 @@ For examples, see ChainRulesCore' `rules` directory.
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
"""
macro scalar_rule(call, maybe_setup, partials...)
############################################################################
# Setup: normalizing input form etc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be broken up into functions? I'd love for this to not be 100 lines long...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i still feel the same

But if it is not easy / does not make sense to you, @oxinabox , then that's fine by me too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wahooo! Thanks 🎉

test/rules.jl Outdated Show resolved Hide resolved
test/rules.jl Outdated Show resolved Hide resolved
src/rule_types.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/differentials.jl Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved

####
"""
differential(𝒟::Type, der)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should take primal and conjugate as arguments, and depending on 𝒟 return either Wirtinger or their sum? I think that would make it more clear to rule authors, that when you create a Wirtinger, you usually also want this fall-through behavior.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could also be useful to check, whether conjugate isa Zero here, and unwrap Wirtinger if that's the case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should take primal and conjugate as arguments, and depending on 𝒟 return either Wirtinger or their sum?

I'm not sure I understand the advantag of that?

With this we have
differential(𝒟, Wirtinger(primal, conjugatge))
which seems fine.
What is the advantage of
differential(𝒟, primal, conjugatge)
?

It could also be useful to check, whether conjugate isa Zero here, and unwrap Wirtinger if that's the case.

Maybe.
Maybe even iszero(conjugate) to to get constants like 0

Maybe even for inputs that are scalar:
iszero(der) && return Zero()
and similar for One()

We should discuss that kind thing in an issue and make a follow up PR for it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the advantage of
differential(𝒟, primal, conjugatge)
?

I just don't know if differential is the best name for this function, since it takes a differential and returns a, sometimes different and more efficient, differential again. I would expect a function called differential to work more like a constructor. Or do we maybe want to call this wirtinger and have it take 𝒟, primal, and conjugate, since this probably corresponds better to what it does right now? But I also wouldn't feel too strongly just leaving this for now, since I'm also struggling to find a better name for this function.

Maybe.
Maybe even iszero(conjugate) to to get constants like 0

Maybe even for inputs that are scalar:
iszero(der) && return Zero()
and similar for One()

I'm not quite convinced we benefit from introducing dispatch based on value here, wouldn't this also cause problem on GPUs? But this is definitely an issue for another day.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we rename it to refine_differential ?
I think this actually also interacts with #8 since we will want to apply something recursively.

src/differentials.jl Outdated Show resolved Hide resolved
test/rules.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/differentials.jl Outdated Show resolved Hide resolved
src/differentials.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
@tkf
Copy link

tkf commented Sep 4, 2019

I have a concern about this decision. Does it mean that, for example, the reverse rule of * becomes something like (computationally equivalent to)

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
    return A * B, Rule(Ȳ -> (Ȳ * B', A' * Ȳ))
end

(which is like what Zygote.jl is doing at the moment) instead of the current definition

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
    return A * B, (Rule(Ȳ ->* B'), Rule(Ȳ -> A' * Ȳ))
end

where you can compute the derivative w.r.t different argument separately?

Wouldn't it be a huge performance loss when large constant arrays are participated in the computation of the intermediate variables that depend on the variables ("trainable parameters") with which the derivatives are taken? For example, in the Generative Adversarial Network (GAN) setting, I think it would be a big issue when taking derivative (d/dp) D(G(p)) w.r.t the parameter p of the generator G while treating the parameters of discriminator D as constant. I also noted the concerns in other similar situations here FluxML/Zygote.jl#323 (comment).

I'm by no means an AD or ML specialist so I may be missing something. It would be great if you can clarify that my concern is invalid.

@oxinabox
Copy link
Member Author

oxinabox commented Sep 4, 2019

Short answer: don't worry we solve this with Thunks.


Full answer:

@tkf a very reasonable concern. And one I used too have
until I understood why whe have the Thunk differential.
(#18)
We have an issue onpen about documenting that a bit more.
#46

... the current definition

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
    return A * B, (Rule(Ȳ ->* B'), Rule(Ȳ -> A' * Ȳ))
end

becomes: In partner PR, this is one of the ones I've already updated

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
    return A * B, Ȳ -> (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ))
end

@thunk(Ȳ * B') is just a shorthand for Thunk(()->Ȳ * B')
Thunk source: definition, and math

That it basically is is a differential that defers computation until it is used.
So if it is never used then the wrapped computation is never computed
For examole:

Y, pullback = rrule(*, A, B)
_, dA_diff, dB_diff = pullback(One()
dB = extern(dB_diff)

then since dA_diff was never externed the Ȳ * B' is never evaluated.

@tkf
Copy link

tkf commented Sep 4, 2019

@oxinabox Thanks a lot! I appreciate the full explanation. I should have checked the partner PR.

@oxinabox
Copy link
Member Author

oxinabox commented Sep 4, 2019

@oxinabox Thanks a lot! I appreciate the full explanation. I should have checked the partner PR.

It is a really important question

This was referenced Sep 5, 2019
oxinabox and others added 17 commits September 17, 2019 11:10
make real scalar rules work.

correct @scalarrule forward rule return

Wirtinger scalar working

work WirtingerRule test as a test of @scalar_rule

Fix spelling

Co-Authored-By: simeonschaub <simeondavidschaub99@gmail.com>
Oxford Comma

Co-Authored-By: simeonschaub <simeondavidschaub99@gmail.com>
spelling

Co-Authored-By: Nick Robinson <npr251@gmail.com>
docstring for propagator_name
spelling

Co-Authored-By: Nick Robinson <npr251@gmail.com>
error ratehr than Assert
cleanup

Update src/rule_definition_tools.jl

Co-Authored-By: Nick Robinson <npr251@gmail.com>
Add more complex Wirtinger Scalar Rule Test
update accumulate to work on differentials
Co-Authored-By: Curtis Vogt <curtis.vogt@gmail.com>
zero the storage inplace
@oxinabox
Copy link
Member Author

Rebased, and squashed some of them.
Going to try shuffling the commits to squash it down some more.

Normally I am hesitant to squash during PR review but this has had a lot of review so far,
so making each commit distinct in purpose seem apppropriate now

@oxinabox
Copy link
Member Author

oxinabox commented Sep 17, 2019

All tests (except inegration tests) are passing.

Shuffle rebasing is hard, not sure if worth it
Might squash thing into a single commit at the end?

Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM

Good work!

I have felt handful of tiny comment :)

- The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)`
- The expression being a constant
- The expression being itself a `thunk`
- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire section is great

Shall we move it to a page in the docs?

(Otherwise one can just use a normal `Thunk`).

Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
and destroy its inplacability.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and destroy its inplacability.
and destroy its ability to work inplace.

Base.conj(x::Thunk) = @thunk(conj(extern(x)))
# The real reason we have this:
accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ)
store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
store!(Δ, ∂::InplaceableThunk) =.add!((Δ.*=false)) # zero it, then add to it.
store!(Δ, ∂::InplaceableThunk) =.add!((Δ .*= false)) # zero it, then add to it.

Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place,
storing the result in `Δ`.

Note: this function may not actually store the result in `Δ` if `Δ` is immutable,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Note: this function may not actually store the result in `Δ` if `Δ` is immutable,
!!! note
this function may not actually store the result in `Δ` if `Δ` is immutable,

"""
store!(Δ, ∂)

Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Stores ``, in `Δ`, overwriting what ever was in `Δ` before.
Stores `` in `Δ` overwriting whatever was in `Δ` before.

Returns the expression for the propagation of
the input gradient `Δs` though the partials `∂s`.

𝒟 is an expression that when evaluated returns the type-of the input domain.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
𝒟 is an expression that when evaluated returns the type-of the input domain.
𝒟 is an expression that when evaluated returns the type of the input domain.

function standard_propagation_expr(Δs, ∂s)
# This is basically Δs ⋅ ∂s

# Notice: the thunking of `∂s[i] (potentially) saves us some computation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Notice: the thunking of `∂s[i] (potentially) saves us some computation
# Notice: the thunking of `∂s[i]` (potentially) saves us some computation

# Notice: the thunking of `∂s[i] (potentially) saves us some computation
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
# as the pullback is evaluated
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes :) Is it worth opening an issue (to stare hard at this and figure out if all is well)?

@@ -59,6 +137,9 @@ For examples, see ChainRulesCore' `rules` directory.
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
"""
macro scalar_rule(call, maybe_setup, partials...)
############################################################################
# Setup: normalizing input form etc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i still feel the same

But if it is not easy / does not make sense to you, @oxinabox , then that's fine by me too

src/rules.jl Outdated Show resolved Hide resolved
oxinabox and others added 2 commits September 17, 2019 16:55
Co-Authored-By: Nick Robinson <npr251@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants