Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Where to find basic aritmetic operators derivatives and broadcasted versions #399

Open
Cvikli opened this issue Apr 18, 2021 · 5 comments
Open

Comments

@Cvikli
Copy link

Cvikli commented Apr 18, 2021

Hey,

I try to redefine @scalar_rules and some of the macros, to create the appropriate code for our symbolic derivation.

I tried to google and understand the source code it but didin't see where to fing the pullback for

:(+)(::AbstractArray, ::AbstractArray) ...
:(.+)(::AbstractArray, ::AbstractArray) ...
... -,*,/,^...

and so on.

I know it is easy to do but the library looks really nice and I don't understand where are the basic aritmetics. I found some *... in some case but the Array*Array also something I couldn't find.

What do I miss? How does Zygote do the chainrules without these arithmetics I couldn't get to know. Can you guys help me?

@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Apr 18, 2021

The frule and rrules for scalar-functions are mostly defined with the @scalar_rule macro.

And rules for scalar + and * are in the src/rulesets/base/fastmath_able.jl file.

See

@scalar_rule x + y (One(), One())
@scalar_rule x - y (One(), -1)
@scalar_rule x / y (one(x) / y, -/ y))

and
# product rule requires special care for arguments where `mul` is non-commutative
function frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
# Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more
# accurate on machines with FMA instructions, since there are only two
# rounding operations, one in `muladd/fma` and the other in `*`.
∂xy = muladd.(Δx, y, x .* Δy)
return x * y, ∂xy
end
function rrule(::typeof(*), x::Number, y::Number)
function times_pullback(ΔΩ)
return (NO_FIELDS, ΔΩ * y', x' * ΔΩ)
end
return x * y, times_pullback
end

Rules for Array functions are in rulesets/base/array.jl or rulesets/base/arraymath.jl (roughly trying to match the location of the functions in Julia Base). Some Array functionality is from the LinearAlgebra standard library, so defined in src/rulesets/LinearAlgebra/.

The rules for Array*Array are here

#####
##### `*`
#####
function rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
InplaceableThunk(
@thunk(Ȳ * B'),
-> mul!(X̄, Ȳ, B', true, true)
),
InplaceableThunk(
@thunk(A' * Ȳ),
-> mul!(X̄, A', Ȳ, true, true)
)
)
end
return A * B, times_pullback
end
function rrule(
::typeof(*),
A::AbstractVector{<:CommutativeMulNumber},
B::AbstractMatrix{<:CommutativeMulNumber},
)
function times_pullback(Ȳ)
@assert size(B, 1) === 1 # otherwise primal would have failed.
return (
NO_FIELDS,
InplaceableThunk(
@thunk(Ȳ * vec(B')),
-> mul!(X̄, Ȳ, vec(B'), true, true)
),
InplaceableThunk(
@thunk(A' * Ȳ),
-> mul!(X̄, A', Ȳ, true, true)
)
)
end
return A * B, times_pullback
end

@Cvikli
Copy link
Author

Cvikli commented Apr 18, 2021

Thank you for the detailed answer!

For me the unclear part is that how does * handle for bigger arrays 3D, 4D... But as I typed this again realised that there is no * opreation between bigger array, what I am looking for the ".*" and ".^" etc. which is interpreted between bigger arrays. What are the broadcasted function's version frule, rrule?

One more question just asking fast, I see * but only see @scalar_rule for + and -, from where does the Zygote get the info for the array cases?

@simeonschaub
Copy link
Member

Zygote doesn't use ChainRules for handling broadcasting at all, everything is defined here: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl. We'd first need to solve JuliaDiff/ChainRulesCore.jl#68 before we can define such rules in ChainRules.

@Cvikli
Copy link
Author

Cvikli commented Apr 18, 2021

It is really tricky code, I just can't understand how does this broadcasting called/overloaded in each . case. Also I don't see the cases for the N dimension Array types. But I will try to read more of their code maybe.

So the only question I have where are the array versions of the +, - in ChainRules, or am I missing something here?

@ToucheSir
Copy link
Contributor

ToucheSir commented Sep 4, 2021

So the only question I have where are the array versions of the +, - in ChainRules, or am I missing something here?

See

#####
##### Negation (Unary -)
#####
function rrule(::typeof(-), x::AbstractArray)
function negation_pullback(ȳ)
return NoTangent(), InplaceableThunk(ā ->.-= ȳ, @thunk(-ȳ))
end
return -x, negation_pullback
end
#####
##### Addition (Multiarg `+`)
#####
function rrule(::typeof(+), arrs::AbstractArray...)
y = +(arrs...)
arr_axs = map(axes, arrs)
function add_pullback(dy)
return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...)
end
return y, add_pullback
end
. I'm not sure if there's a dedicated rrule for -(A::AbstractArray, B::AbstractArray) or if it falls through to broadcasting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants