diff --git a/.github/workflows/fix_doctests.yml b/.github/workflows/fix_doctests.yml new file mode 100644 index 00000000..5ca114d1 --- /dev/null +++ b/.github/workflows/fix_doctests.yml @@ -0,0 +1,29 @@ +name: fix_doctests +on: + pull_request: +jobs: + doctests: + name: Fix doctests (Julia ${{ matrix.julia-version }} - ${{ github.event_name }}) + runs-on: ubuntu-latest + strategy: + matrix: + julia-version: [1.6] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + - uses: actions/checkout@v1 + - name: Fix doctests + shell: julia --project=docs/ {0} + run: | + using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate() + using Documenter + using ChainRulesTestUtils + doctest(ChainRulesTestUtils, fix=true) + # don't push changes to Manifest in suggestions, as it removes `path=..` + run(`git restore docs/Manifest.toml`) + - uses: reviewdog/action-suggester@v1 + if: github.event_name == 'pull_request' + with: + tool_name: Documenter (fix doctests) + fail_on_error: true diff --git a/Project.toml b/Project.toml index 0fd77dc5..af31d6b0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.7.11" +version = "0.7.12" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index ddc12e8c..0a79baf2 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba" +git-tree-sha1 = "dbc9aae1227cfddaa9d2552f3ecba5b641f6cce9" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.4" +version = "0.10.5" [[ChainRulesTestUtils]] deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"] path = ".." uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.7.9" +version = "0.7.12" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] @@ -47,9 +47,9 @@ version = "0.8.5" [[Documenter]] deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649" +git-tree-sha1 = "5acbebf1be22db43589bc5aa1bb5fcc378b17780" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.26.3" +version = "0.27.0" [[Downloads]] deps = ["ArgTools", "LibCURL", "NetworkOptions"] @@ -62,10 +62,10 @@ uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" version = "0.12.13" [[IOCapture]] -deps = ["Logging"] -git-tree-sha1 = "377252859f740c217b936cebcd918a44f9b53b59" +deps = ["Logging", "Random"] +git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.1.1" +version = "0.2.2" [[InteractiveUtils]] deps = ["Markdown"] diff --git a/docs/Project.toml b/docs/Project.toml index 6594306b..48d5ec85 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,5 +4,5 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" [compat] -Documenter = "0.26" +Documenter = "0.27" julia = "1" diff --git a/docs/make.jl b/docs/make.jl index 4bf53d76..06827ad1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -12,7 +12,7 @@ makedocs(; ], strict=true, checkdocs=:exports, -) + ) const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git" deploydocs(; repo=repo, push_preview=true) diff --git a/docs/src/index.md b/docs/src/index.md index 4c75dacd..398e375a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -12,7 +12,7 @@ For information about ChainRules, including how to write rules, refer to the gen ## Canonical example Let's suppose a custom transformation has been defined -```jldoctest ex; output = false +```jldoctest ex function two2three(x1::Float64, x2::Float64) return 1.0, 2.0*x1, 3.0*x2 end @@ -21,7 +21,7 @@ end two2three (generic function with 1 method) ``` along with the `frule` -```jldoctest ex; output = false +```jldoctest ex using ChainRulesCore function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2) @@ -33,7 +33,7 @@ end ``` and `rrule` -```jldoctest ex; output = false +```jldoctest ex function ChainRulesCore.rrule(::typeof(two2three), x1, x2) y = two2three(x1, x2) function two2three_pullback(Ȳ) @@ -55,12 +55,13 @@ They can be used for any type and number of inputs and outputs. The call will test the `frule` for function `f` at the point `x` in the domain. Keep this in mind when testing discontinuous rules for functions like [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally be tested at both `x` being above and below zero. -```jldoctest ex; output = false +```jldoctest ex julia> using ChainRulesTestUtils; julia> test_frule(two2three, 3.33, -7.77); Test Summary: | Pass Total test_frule: two2three on Float64,Float64 | 6 6 + ``` ### Testing the `rrule` @@ -68,16 +69,17 @@ test_frule: two2three on Float64,Float64 | 6 6 [`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`. The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain. -```jldoctest ex; output = false +```jldoctest ex julia> test_rrule(two2three, 3.33, -7.77); Test Summary: | Pass Total -test_rrule: two2three on Float64,Float64 | 7 7 +test_rrule: two2three on Float64,Float64 | 8 8 + ``` ## Scalar example For functions with a single argument and a single output, such as e.g. ReLU, -```jldoctest ex; output = false +```jldoctest ex function relu(x::Real) return max(0, x) end @@ -86,7 +88,7 @@ end relu (generic function with 1 method) ``` with the `frule` and `rrule` defined with the help of `@scalar_rule` macro -```jldoctest ex; output = false +```jldoctest ex @scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x) # output @@ -95,14 +97,16 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro `test_scalar` function is provided to test both the `frule` and the `rrule` with a single call. -```jldoctest ex; output = false +```jldoctest ex julia> test_scalar(relu, 0.5); Test Summary: | Pass Total -test_scalar: relu at 0.5 | 9 9 +test_scalar: relu at 0.5 | 10 10 + julia> test_scalar(relu, -0.5); Test Summary: | Pass Total -test_scalar: relu at -0.5 | 9 9 +test_scalar: relu at -0.5 | 10 10 + ``` ## Testing constructors and functors (callable objects) diff --git a/src/testers.jl b/src/testers.jl index b82ada71..da8391e3 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -244,9 +244,14 @@ function test_rrule( check_inferred && _test_inferred(pullback, ȳ) ad_cotangents = pullback(ȳ) - ad_cotangents isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.") - msg = "The pullback should return 1 cotangent for the primal and each primal input." - @test_msg msg length(ad_cotangents) == 1 + length(args) + @test_msg( + "The pullback must return a Tuple (∂self, ∂args...)", + ad_cotangents isa Tuple + ) + @test_msg( + "The pullback should return 1 cotangent for the primal and each primal input.", + length(ad_cotangents) == length(primals) + ) # Correctness testing via finite differencing. # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 diff --git a/test/testers.jl b/test/testers.jl index d82866bf..d1ca8017 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -548,6 +548,15 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end @test fails(() -> test_frule(foo, 2.1, 2.1)) @test fails(() -> test_rrule(foo, 21.0, 32.0)) end + + @testset "rrule not returning a tuple" begin + bar(x, y) = x + 3y + function ChainRulesCore.rrule(::typeof(bar), x, y) + bar_pullback(dy) = dy + return bar(x,y), bar_pullback + end + @test fails(() -> test_rrule(bar, 21.0, 32.0)) + end end @testset "structs" begin