Skip to content

Commit

Permalink
include functor and move code style to the top
Browse files Browse the repository at this point in the history
  • Loading branch information
Miha Zgubic committed Jun 7, 2021
1 parent cfc1bb2 commit c9f9168
Showing 1 changed file with 91 additions and 75 deletions.
166 changes: 91 additions & 75 deletions docs/src/writing_good_rules.md
Original file line number Diff line number Diff line change
@@ -1,58 +1,31 @@
# On writing good `rrule` / `frule` methods

## Use `Type{T}`, not `typeof(T)`, to define rules for constructors
## Code Style

To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`.
For example, the `rrule` signature would be like:
Use named local functions for the `pullback` in an `rrule`.

```julia
function rrule(::typeof(foo), args...; kwargs...)
...
return y, foo_pullback
# good:
function rrule(::typeof(foo), x)
Y = foo(x)
function foo_pullback(Ȳ)
return NoTangent(), bar(Ȳ)
end
return Y, foo_pullback
end
```

But to define an `rrule` for a constructor for a _type_ `T` we need to be careful to dispatch only on `Type{T}`.

For example, the `rrule` signature for a constructor would be like:
#== output
julia> rrule(foo, 2)
(4, var"#foo_pullback#11"())
==#

```julia
function rrule(::Type{T}, args...; kwargs...)
...
return y, T_pullback
# bad:
function rrule(::typeof(foo), x)
return foo(x), x̄ -> (NoTangent(), bar(x̄))
end
```

In particular, be careful not to use `typeof(T)` here.
Because `typeof(T)` is `DataType`, using this to define an `rrule`/`frule` will define an `rrule`/`frule` for all constructors.

You can check which to use with `Core.Typeof`:

```julia
julia> function foob end
foob (generic function with 0 methods)

julia> typeof(foob)
typeof(foob)

julia> Core.Typeof(foob)
typeof(foob)

julia> abstract type AbstractT end

julia> struct ExampleT <: AbstractT end

julia> typeof(AbstractT)
DataType

julia> typeof(ExampleT)
DataType

julia> Core.Typeof(AbstractT)
Type{AbstractT}

julia> Core.Typeof(ExampleT)
Type{ExampleT}
#== output:
julia> rrule(foo, 2)
(4, var"##9#10"())
==#
```

## Use `ZeroTangent()` as the return value
Expand Down Expand Up @@ -90,6 +63,77 @@ Examples being:
- There is only one derivative being returned, so from the fact that the user called
`frule`/`rrule` they clearly will want to use that one.

## Structs: constructors and functors

To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`.
For example, the `rrule` signature would be like:

```julia
function rrule(::typeof(foo), args...; kwargs...)
...
return y, foo_pullback
end
```

For a struct `Bar`,
```julia
struct Bar
a::Float64
end

(bar::Bar)(x, y) = return bar.a + x + y # functor
```
we can define an `frule`/`rrule` for the `Bar` constructor(s), as well as any `Bar` [functors](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects).

To define an `rrule` for a constructor for a _type_ `Bar` we need to be careful to dispatch only on `Type{Bar}`.
For example, the `rrule` signature for a `Bar` constructor would be like:
```julia
function ChainRulesCore.rrule(::Type{Bar}, a)
...
return Bar(a), Bar_pullback
end
```

In particular, be careful not to use `typeof(Bar)` here.
Because `typeof(Bar)` is `DataType`, using this to define an `rrule`/`frule` will define an `rrule`/`frule` for all constructors.

You can check which to use with `Core.Typeof`:

```julia
julia> function foo end
foo (generic function with 0 methods)

julia> typeof(foo)
typeof(foo)

julia> Core.Typeof(foob)
typeof(foo)

julia> typeof(Bar)
DataType

julia> Core.Typeof(Bar)
Type{Bar}

julia> abstract type AbstractT end

julia> typeof(AbstractT)
DataType

julia> Core.Typeof(AbstractT)
Type{AbstractT}
```

For the functor, use `bar::Bar`, i.e.

```julia
function ChainRulesCore.rrule(bar::Bar, x, y)
...
return bar(x, y), Bar_pullback
end
```


## Use `@not_implemented` appropriately

One can use [`@not_implemented`](@ref) to mark missing differentials.
Expand All @@ -107,34 +151,6 @@ https://github.com/JuliaMath/SpecialFunctions.jl/issues/160

Do not use `@not_implemented` if the differential does not exist mathematically (use `NoTangent()` instead).

## Code Style

Use named local functions for the `pullback` in an `rrule`.

```julia
# good:
function rrule(::typeof(foo), x)
Y = foo(x)
function foo_pullback(Ȳ)
return NoTangent(), bar(Ȳ)
end
return Y, foo_pullback
end
#== output
julia> rrule(foo, 2)
(4, var"#foo_pullback#11"())
==#

# bad:
function rrule(::typeof(foo), x)
return foo(x), x̄ -> (NoTangent(), bar(x̄))
end
#== output:
julia> rrule(foo, 2)
(4, var"##9#10"())
==#
```

While this is more verbose, it ensures that if an error is thrown during the `pullback` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it.
This makes it a lot simpler to debug from the stacktrace.

Expand Down

0 comments on commit c9f9168

Please sign in to comment.