diff --git a/docs/make.jl b/docs/make.jl index 06827ad..0bc7152 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -10,8 +10,8 @@ makedocs(; "ChainRulesTestUtils" => "index.md", "API" => "api.md", ], - strict=true, checkdocs=:exports, + # doctest=:fix ) const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git" diff --git a/docs/src/index.md b/docs/src/index.md index 0d1355c..77f6c8e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,3 +1,6 @@ +```@meta +DocTestFilters = [r"[0-9\.]+s",r"isapprox\(.*\)"] +``` # ChainRulesTestUtils [![CI](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/workflows/CI/badge.svg?branch=main)](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/actions?query=workflow%3ACI) @@ -38,12 +41,12 @@ end # output ``` -and `rrule` +and `rrule` which contains a mistake in the first cotangent ```jldoctest ex function ChainRulesCore.rrule(::typeof(two2three), x1, x2) y = two2three(x1, x2) function two2three_pullback(Ȳ) - return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3]) + return (NoTangent(), 2.1*Ȳ[2], 3.0*Ȳ[3]) end return y, two2three_pullback end @@ -65,23 +68,28 @@ Keep this in mind when testing discontinuous rules for functions like [ReLU](htt julia> using ChainRulesTestUtils; julia> test_frule(two2three, 3.33, -7.77); -Test Summary: | Pass Total -test_frule: two2three on Float64,Float64 | 6 6 +Test Summary: | Pass Total Time +test_frule: two2three on Float64,Float64 | 6 6 2.7s ``` ### Testing the `rrule` -[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`. +[`test_rrule`](@ref) takes in the function `f`, and primal inputs `x`. The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain. ```jldoctest ex julia> test_rrule(two2three, 3.33, -7.77); -Test Summary: | Pass Total -test_rrule: two2three on Float64,Float64 | 9 9 - +test_rrule: two2three on Float64,Float64: Test Failed at /home/lior/.julia/dev/ChainRulesTestUtils/src/check_result.jl:24 + Expression: isapprox(actual, expected; kwargs...) + Problem: cotangent for input 2, Float64 + Evaluated: isapprox(-4.032, -3.840000000001641; rtol = 1.0e-9, atol = 1.0e-9) +[...] ``` +The output of the test indicates to us the cause of the failure under "Problem:" with the expected (`rrule` derived) and actual finite difference results. +The Problem lies with the cotangent corresponding to input 2 of `rrule`, which is the first cotangent as expected. + ## Scalar example For functions with a single argument and a single output, such as e.g. ReLU, @@ -105,13 +113,13 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro call. ```jldoctest ex julia> test_scalar(relu, 0.5); -Test Summary: | Pass Total -test_scalar: relu at 0.5 | 11 11 +Test Summary: | Pass Total Time +test_scalar: relu at 0.5 | 12 12 1.2s julia> test_scalar(relu, -0.5); -Test Summary: | Pass Total -test_scalar: relu at -0.5 | 11 11 +Test Summary: | Pass Total Time +test_scalar: relu at -0.5 | 12 12 0.0s ``` diff --git a/src/testers.jl b/src/testers.jl index 7725a67..67dbbf3 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -213,7 +213,7 @@ function test_rrule( res === nothing && throw(MethodError(rrule_f, Tuple{Core.Typeof.(primals)...})) y_ad, pullback = res y = call(primals...) - test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct + test_approx(y_ad, y, "Failed primal value check"; isapprox_kwargs...) # make sure primal is correct ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent @@ -231,7 +231,8 @@ function test_rrule( # Correctness testing via finite differencing. is_ignored = isa.(accum_cotangents, NoTangent) fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored) - foreach(accum_cotangents, ad_cotangents, fd_cotangents) do args... + msgs = ntuple(i->"cotangent for input $i, $(summary(fd_cotangents[i]))", length(fd_cotangents)) + foreach(accum_cotangents, ad_cotangents, fd_cotangents, msgs) do args... _test_cotangent(args...; check_inferred=check_inferred, isapprox_kwargs...) end @@ -282,7 +283,7 @@ function _is_inferrable(f, args...; kwargs...) end """ - _test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent; kwargs...) + _test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent[, msg]; kwargs...) Check if the cotangent `ad_cotangent` from `rrule` is consistent with `accum_tangent` and approximately equal to the cotangent `fd_cotangent` obtained with finite differencing. @@ -290,6 +291,8 @@ approximately equal to the cotangent `fd_cotangent` obtained with finite differe If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-differentiable, `ad_cotangent` and `fd_cotangent` should be `NoTangent()` as well. +If a msg string is given, it is emmited on test failure. + # Keyword arguments - If `check_inferred=true` (the default) and `ad_cotangent` is a thunk, then it is checked if its content can be inferred. @@ -298,22 +301,23 @@ If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-diff function _test_cotangent( accum_cotangent, ad_cotangent, - fd_cotangent; + fd_cotangent, + msg=""; check_inferred=true, kwargs..., ) ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent) # The main test of the actual derivative being correct: - test_approx(ad_cotangent, fd_cotangent; kwargs...) + test_approx(ad_cotangent, fd_cotangent, msg; kwargs...) _test_add!!_behaviour(accum_cotangent, ad_cotangent; kwargs...) end # we marked the argument as non-differentiable -function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent; kwargs...) +function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent, msg=""; kwargs...) @test ad_cotangent isa NoTangent end -function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent; kwargs...) +function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent, msg=""; kwargs...) error( "The pullback in the rrule should use NoTangent()" * " rather than ZeroTangent() for non-perturbable arguments." @@ -322,7 +326,8 @@ end function _test_cotangent( ::NoTangent, ad_cotangent::ChainRulesCore.NotImplemented, - ::NoTangent; + ::NoTangent, + msg=""; kwargs..., ) # this situation can occur if a cotangent is not implemented and @@ -332,6 +337,6 @@ function _test_cotangent( # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217 @test_broken ad_cotangent isa NoTangent end -function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent; kwargs...) +function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent, msg=""; kwargs...) error("cotangent obtained with finite differencing has to be NoTangent()") end