Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mistake in non-tuple output message #177

Merged
merged 9 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/fix_doctests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: fix_doctests
on:
pull_request:
jobs:
doctests:
name: Fix doctests (Julia ${{ matrix.julia-version }} - ${{ github.event_name }})
runs-on: ubuntu-latest
strategy:
matrix:
julia-version: [1.6]
steps:
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- uses: actions/checkout@v1
- name: Fix doctests
shell: julia --project=docs/ {0}
run: |
using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()
using Documenter
using ChainRulesTestUtils
doctest(ChainRulesTestUtils, fix=true)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
# don't push changes to Manifest in suggestions, as it removes `path=..`
run(`git restore docs/Manifest.toml`)
- uses: reviewdog/action-suggester@v1
if: github.event_name == 'pull_request'
with:
tool_name: Documenter (fix doctests)
fail_on_error: true
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.11"
version = "0.7.12"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
16 changes: 8 additions & 8 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba"
git-tree-sha1 = "dbc9aae1227cfddaa9d2552f3ecba5b641f6cce9"
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.4"
version = "0.10.5"
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

[[ChainRulesTestUtils]]
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
path = ".."
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.9"
version = "0.7.12"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand Down Expand Up @@ -47,9 +47,9 @@ version = "0.8.5"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649"
git-tree-sha1 = "5acbebf1be22db43589bc5aa1bb5fcc378b17780"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.26.3"
version = "0.27.0"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
Expand All @@ -62,10 +62,10 @@ uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.13"

[[IOCapture]]
deps = ["Logging"]
git-tree-sha1 = "377252859f740c217b936cebcd918a44f9b53b59"
deps = ["Logging", "Random"]
git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a"
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
version = "0.1.1"
version = "0.2.2"

[[InteractiveUtils]]
deps = ["Markdown"]
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[compat]
Documenter = "0.26"
Documenter = "0.27"
julia = "1"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ makedocs(;
],
strict=true,
checkdocs=:exports,
)
)

const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git"
deploydocs(; repo=repo, push_preview=true)
26 changes: 15 additions & 11 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ For information about ChainRules, including how to write rules, refer to the gen
## Canonical example

Let's suppose a custom transformation has been defined
```jldoctest ex; output = false
```jldoctest ex
function two2three(x1::Float64, x2::Float64)
return 1.0, 2.0*x1, 3.0*x2
end
Expand All @@ -21,7 +21,7 @@ end
two2three (generic function with 1 method)
```
along with the `frule`
```jldoctest ex; output = false
```jldoctest ex
using ChainRulesCore

function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
Expand All @@ -33,7 +33,7 @@ end

```
and `rrule`
```jldoctest ex; output = false
```jldoctest ex
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
Expand All @@ -55,29 +55,31 @@ They can be used for any type and number of inputs and outputs.
The call will test the `frule` for function `f` at the point `x` in the domain.
Keep this in mind when testing discontinuous rules for functions like [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally be tested at both `x` being above and below zero.

```jldoctest ex; output = false
```jldoctest ex
julia> using ChainRulesTestUtils;

julia> test_frule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_frule: two2three on Float64,Float64 | 6 6
Comment on lines 62 to 63

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Documenter (fix doctests)] reported by reviewdog 🐶

Suggested change
Test Summary: | Pass Total
test_frule: two2three on Float64,Float64 | 6 6


```

### Testing the `rrule`

[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`.
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.

```jldoctest ex; output = false
```jldoctest ex
julia> test_rrule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_rrule: two2three on Float64,Float64 | 7 7
test_rrule: two2three on Float64,Float64 | 8 8
Comment on lines 74 to +75

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Documenter (fix doctests)] reported by reviewdog 🐶

Suggested change
Test Summary: | Pass Total
test_rrule: two2three on Float64,Float64 | 7 7
test_rrule: two2three on Float64,Float64 | 8 8


```

## Scalar example

For functions with a single argument and a single output, such as e.g. ReLU,
```jldoctest ex; output = false
```jldoctest ex
function relu(x::Real)
return max(0, x)
end
Expand All @@ -86,7 +88,7 @@ end
relu (generic function with 1 method)
```
with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
```jldoctest ex; output = false
```jldoctest ex
@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)

# output
Expand All @@ -95,14 +97,16 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro

`test_scalar` function is provided to test both the `frule` and the `rrule` with a single
call.
```jldoctest ex; output = false
```jldoctest ex
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 9 9
test_scalar: relu at 0.5 | 10 10
Comment on lines 102 to +103

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Documenter (fix doctests)] reported by reviewdog 🐶

Suggested change
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 9 9
test_scalar: relu at 0.5 | 10 10



julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 9 9
test_scalar: relu at -0.5 | 10 10
Comment on lines 107 to +108

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Documenter (fix doctests)] reported by reviewdog 🐶

Suggested change
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 9 9
test_scalar: relu at -0.5 | 10 10


```

## Testing constructors and functors (callable objects)
Expand Down
11 changes: 8 additions & 3 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,14 @@ function test_rrule(

check_inferred && _test_inferred(pullback, ȳ)
ad_cotangents = pullback(ȳ)
ad_cotangents isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
msg = "The pullback should return 1 cotangent for the primal and each primal input."
@test_msg msg length(ad_cotangents) == 1 + length(args)
@test_msg(
"The pullback must return a Tuple (∂self, ∂args...)",
ad_cotangents isa Tuple
)
@test_msg(
"The pullback should return 1 cotangent for the primal and each primal input.",
length(ad_cotangents) == length(primals)
)

# Correctness testing via finite differencing.
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
Expand Down
9 changes: 9 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,15 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
@test fails(() -> test_frule(foo, 2.1, 2.1))
@test fails(() -> test_rrule(foo, 21.0, 32.0))
end

@testset "rrule not returning a tuple" begin
bar(x, y) = x + 3y
function ChainRulesCore.rrule(::typeof(bar), x, y)
bar_pullback(dy) = dy
return bar(x,y), bar_pullback
end
@test fails(() -> test_rrule(bar, 21.0, 32.0))
end
end

@testset "structs" begin
Expand Down