Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added helpful failure massages for tests #288

Merged
merged 5 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ makedocs(;
"ChainRulesTestUtils" => "index.md",
"API" => "api.md",
],
strict=true,
checkdocs=:exports,
# doctest=:fix
)

const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git"
Expand Down
32 changes: 20 additions & 12 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
```@meta
DocTestFilters = [r"[0-9\.]+s",r"isapprox\(.*\)"]
```
# ChainRulesTestUtils

[![CI](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/workflows/CI/badge.svg?branch=main)](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/actions?query=workflow%3ACI)
Expand Down Expand Up @@ -38,12 +41,12 @@
# output

```
and `rrule`
and `rrule` which contains a mistake in the first cotangent
```jldoctest ex
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3])
return (NoTangent(), 2.1*Ȳ[2], 3.0*Ȳ[3])
end
return y, two2three_pullback
end
Expand All @@ -61,27 +64,32 @@
The call will test the `frule` for function `f` at the point `x` in the domain.
Keep this in mind when testing discontinuous rules for functions like [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally be tested at both `x` being above and below zero.

```jldoctest ex

Check failure on line 67 in docs/src/index.md

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in src/index.md:67-73 ```jldoctest ex julia> using ChainRulesTestUtils; julia> test_frule(two2three, 3.33, -7.77); Test Summary: | Pass Total Time test_frule: two2three on Float64,Float64 | 6 6 2.7s ``` Subexpression: test_frule(two2three, 3.33, -7.77); Evaluated output: Test Summary: | Pass Total test_frule: two2three on Float64,Float64 | 6 6 Expected output: Test Summary: | Pass Total Time test_frule: two2three on Float64,Float64 | 6 6 2.7s diff = Warning: Diff output requires color. Test Summary: | Pass Total Time Total test_frule: two2three on Float64,Float64 | 6 6 2.7s6
julia> using ChainRulesTestUtils;

julia> test_frule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_frule: two2three on Float64,Float64 | 6 6
Test Summary: | Pass Total Time
test_frule: two2three on Float64,Float64 | 6 6 2.7s

```

### Testing the `rrule`

[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`.
[`test_rrule`](@ref) takes in the function `f`, and primal inputs `x`.
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.

```jldoctest ex

Check failure on line 81 in docs/src/index.md

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in src/index.md:81-88 ```jldoctest ex julia> test_rrule(two2three, 3.33, -7.77); test_rrule: two2three on Float64,Float64: Test Failed at /home/lior/.julia/dev/ChainRulesTestUtils/src/check_result.jl:24 Expression: isapprox(actual, expected; kwargs...) Problem: cotangent for input 2, Float64 Evaluated: isapprox(-4.032, -3.840000000001641; rtol = 1.0e-9, atol = 1.0e-9) [...] ``` Subexpression: test_rrule(two2three, 3.33, -7.77); Evaluated output: test_rrule: two2three on Float64,Float64: Test Failed at /home/runner/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/check_result.jl:24 Expression: Problem: cotangent for input 2, Float64 Evaluated: Stacktrace: [1] record(ts::Test.DefaultTestSet, t::Union{Test.Error, Test.Fail}) @ Test /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:809 [2] do_test(result::Test.ExecutionResult, orig_expr::Any) @ Test /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:555 [3] test_approx(actual::Union{Number, AbstractArray{var"#s46", N} where {var"#s46"<:Number, N}}, expected::Union{Number, AbstractArray{var"#s45", N} where {var"#s45"<:Number, N}}, msg::Any; kwargs::Any) @ ChainRulesTestUtils ~/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/check_result.jl:24 [4] _test_cotangent(accum_cotangent::Any, ad_cotangent::Any, fd_cotangent::Any, msg::Any; check_inferred::Any, kwargs::Any) @ ChainRulesTestUtils ~/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/testers.jl:312 [5] (::ChainRulesTestUtils.var"#65#70"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any, N} where N) @ ChainRulesTestUtils ~/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/testers.jl:236 [6] foreach(::Function, ::Tuple{ChainRulesCore.NoTangent, Float64, Float64}, ::Tuple{ChainRulesCore.NoTangent, Float64, Float64}, ::Vararg{Any, N} where N) @ Base ./abstractarray.jl:2142 [7] macro expansion @ ~/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/testers.jl:235 [inlined] [8] macro expansion @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined] [9] test_rrule(::ChainRulesCore.RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, testset_name::Any, kwargs::Any) @ ChainRulesTestUtils ~/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/testers.jl:205 [10] test_rrule @ ~/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/testers.jl:197 [inlined] [11] #test_rrule#62 @ ~/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/testers.jl:178 [inlined] [12] test_rrule(::Any, ::Any, ::Any) @ ChainRulesTestUtils ~/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/testers.jl:178 [13] top-level scope @ none:1 [14] eval @ ./boot.jl:360 [inlined] [15] #5 @ ~/.julia/packages/Documenter/bYYzK/src/DocTests.jl:245 [inlined] [16] (::IOCapture.var"#3#5"{DataType, Documenter.DocTests.var"#5#6"{Module}, Task, IOContext{Base.PipeEndpoint}, IOContext{Base.PipeEndpoint}, IOContext{Base.PipeEndpoint}, IOContext{Base.PipeEndpoint}})() @ IOCapture ~/.julia/packages/IOCapture/8Uj7o/src/IOCapture.jl:119 [17] with_logstate(f::Function, logstate::Any) @ Base.CoreLogging ./logging.jl:491 [18] with_logger @ ./logging.jl:603 [inlined] [19] capture(f::Documenter.DocTests.var"#5#6"{Module}; rethrow::Type, color::Bool) @ IOCapture ~/.julia/packages/IOCapture/8Uj7o/src/IOCapture.jl:116 [20] eval_repl(block::Documenter.DocTests.MutableMD2CodeBlock, sandbox::Module, meta::Dict{Symbol, Any}, doc::Documenter.Documents.Document, page::String) @ Documenter.DocTests ~/.julia/packages/Documenter/bYYzK/src/DocTests.jl:244 [21] doctest(ctx::Documenter.DocTests.DocTestContext, block_immutable::Documenter.Utilities.Markdown2.CodeBlock) @ Documenter.DocTests ~/.julia/packages/Documenter/bYYzK/src/Do
julia> test_rrule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_rrule: two2three on Float64,Float64 | 9 9

test_rrule: two2three on Float64,Float64: Test Failed at /home/lior/.julia/dev/ChainRulesTestUtils/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Problem: cotangent for input 2, Float64
Evaluated: isapprox(-4.032, -3.840000000001641; rtol = 1.0e-9, atol = 1.0e-9)
[...]
```

The output of the test indicates to us the cause of the failure under "Problem:" with the expected (`rrule` derived) and actual finite difference results.
The Problem lies with the cotangent corresponding to input 2 of `rrule`, which is the first cotangent as expected.

## Scalar example

For functions with a single argument and a single output, such as e.g. ReLU,
Expand All @@ -103,15 +111,15 @@

`test_scalar` function is provided to test both the `frule` and the `rrule` with a single
call.
```jldoctest ex

Check failure on line 114 in docs/src/index.md

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in src/index.md:114-123 ```jldoctest ex julia> test_scalar(relu, 0.5); Test Summary: | Pass Total Time test_scalar: relu at 0.5 | 12 12 1.2s julia> test_scalar(relu, -0.5); Test Summary: | Pass Total Time test_scalar: relu at -0.5 | 12 12 0.0s ``` Subexpression: test_scalar(relu, 0.5); Evaluated output: Test Summary: | Pass Total test_scalar: relu at 0.5 | 11 11 Expected output: Test Summary: | Pass Total Time test_scalar: relu at 0.5 | 12 12 1.2s diff = Warning: Diff output requires color. Test Summary: | Pass Total Time Total test_scalar: relu at 0.5 | 12 12 1.2s11 11

Check failure on line 114 in docs/src/index.md

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in src/index.md:114-123 ```jldoctest ex julia> test_scalar(relu, 0.5); Test Summary: | Pass Total Time test_scalar: relu at 0.5 | 12 12 1.2s julia> test_scalar(relu, -0.5); Test Summary: | Pass Total Time test_scalar: relu at -0.5 | 12 12 0.0s ``` Subexpression: test_scalar(relu, -0.5); Evaluated output: Test Summary: | Pass Total test_scalar: relu at -0.5 | 11 11 Expected output: Test Summary: | Pass Total Time test_scalar: relu at -0.5 | 12 12 0.0s diff = Warning: Diff output requires color. Test Summary: | Pass Total Time Total test_scalar: relu at -0.5 | 12 12 0.0s11 11
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 11 11
Test Summary: | Pass Total Time
test_scalar: relu at 0.5 | 12 12 1.2s


julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 11 11
Test Summary: | Pass Total Time
test_scalar: relu at -0.5 | 12 12 0.0s

```

Expand Down
23 changes: 14 additions & 9 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@
res === nothing && throw(MethodError(rrule_f, Tuple{Core.Typeof.(primals)...}))
y_ad, pullback = res
y = call(primals...)
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
test_approx(y_ad, y, "Failed primal value check"; isapprox_kwargs...) # make sure primal is correct

ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent

Expand All @@ -231,7 +231,8 @@
# Correctness testing via finite differencing.
is_ignored = isa.(accum_cotangents, NoTangent)
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
foreach(accum_cotangents, ad_cotangents, fd_cotangents) do args...
msgs = ntuple(i->"cotangent for input $i, $(summary(fd_cotangents[i]))", length(fd_cotangents))
foreach(accum_cotangents, ad_cotangents, fd_cotangents, msgs) do args...
_test_cotangent(args...; check_inferred=check_inferred, isapprox_kwargs...)
end

Expand Down Expand Up @@ -282,14 +283,16 @@
end

"""
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent; kwargs...)
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent[, msg]; kwargs...)

Check if the cotangent `ad_cotangent` from `rrule` is consistent with `accum_tangent` and
approximately equal to the cotangent `fd_cotangent` obtained with finite differencing.

If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-differentiable,
`ad_cotangent` and `fd_cotangent` should be `NoTangent()` as well.

If a msg string is given, it is emmited on test failure.

# Keyword arguments
- If `check_inferred=true` (the default) and `ad_cotangent` is a thunk, then it is checked if
its content can be inferred.
Expand All @@ -298,22 +301,23 @@
function _test_cotangent(
accum_cotangent,
ad_cotangent,
fd_cotangent;
fd_cotangent,
msg="";
check_inferred=true,
kwargs...,
)
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)

# The main test of the actual derivative being correct:
test_approx(ad_cotangent, fd_cotangent; kwargs...)
test_approx(ad_cotangent, fd_cotangent, msg; kwargs...)
_test_add!!_behaviour(accum_cotangent, ad_cotangent; kwargs...)
end

# we marked the argument as non-differentiable
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent; kwargs...)
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent, msg=""; kwargs...)
@test ad_cotangent isa NoTangent
end
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent; kwargs...)
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent, msg=""; kwargs...)
error(
"The pullback in the rrule should use NoTangent()" *
" rather than ZeroTangent() for non-perturbable arguments."
Expand All @@ -322,7 +326,8 @@
function _test_cotangent(
::NoTangent,
ad_cotangent::ChainRulesCore.NotImplemented,
::NoTangent;
::NoTangent,
msg="";
kwargs...,
)
# this situation can occur if a cotangent is not implemented and
Expand All @@ -332,6 +337,6 @@
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217
@test_broken ad_cotangent isa NoTangent
end
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent; kwargs...)
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent, msg=""; kwargs...)

Check warning on line 340 in src/testers.jl

View check run for this annotation

Codecov / codecov/patch

src/testers.jl#L340

Added line #L340 was not covered by tests
error("cotangent obtained with finite differencing has to be NoTangent()")
end
Loading