diff --git a/test/chainrules.jl b/test/chainrules.jl index 0ab7cdb..6023ac3 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -7,39 +7,38 @@ using ChainRulesTestUtils y = randn(n) X = randn(n, 3) Y = randn(n, 3) + Xrep = repeat(x, 1, 3) + Yrep = repeat(y, 1, 3) @testset for metric in (SqEuclidean(), Euclidean()) - @testset "different arguments" begin - # Single evaluation - test_rrule(metric ⊢ NoTangent(), x, y) + # Single evaluation + test_rrule(metric ⊢ NoTangent(), x, y) + test_rrule(metric ⊢ NoTangent(), x, x) + for A in (X, Xrep) # Column-wise distance - test_rrule(colwise, metric ⊢ NoTangent(), X, Y) - - # Pairwise distances - test_rrule(pairwise, metric ⊢ NoTangent(), X) - test_rrule(pairwise, metric ⊢ NoTangent(), X; fkwargs=(dims=1,)) - test_rrule(pairwise, metric ⊢ NoTangent(), X; fkwargs=(dims=2,)) - test_rrule(pairwise, metric ⊢ NoTangent(), X, Y) - test_rrule(pairwise, metric ⊢ NoTangent(), X, Y; fkwargs=(dims=1,)) - test_rrule(pairwise, metric ⊢ NoTangent(), X, Y; fkwargs=(dims=2,)) - end - - # check numerical issues if distances are zero - @testset "equal arguments" begin - # Single evaluation - test_rrule(metric ⊢ NoTangent(), x, x) - - # Column-wise distance - test_rrule(colwise, metric ⊢ NoTangent(), X, X) + test_rrule(colwise, metric ⊢ NoTangent(), A, A) # Pairwise distances # Finite differencing yields impressively inaccurate derivatives for `Euclidean`, # see https://github.com/FluxML/Zygote.jl/blob/45bf883491d2b52580d716d577e2fa8577a07230/test/gradcheck.jl#L1206 - kwargs = metric isa Euclidean ? (rtol = 1e-4,) : () - test_rrule(pairwise, metric ⊢ NoTangent(), X, X; kwargs...) - test_rrule(pairwise, metric ⊢ NoTangent(), X, X; fkwargs=(dims=1,), kwargs...) - test_rrule(pairwise, metric ⊢ NoTangent(), X, X; fkwargs=(dims=2,), kwargs...) + kwargs = metric isa Euclidean ? (rtol=1e-3, atol=1e-3) : () + test_rrule(pairwise, metric ⊢ NoTangent(), A; kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A; fkwargs=(dims=1,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A; fkwargs=(dims=2,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A, A; kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A, A; fkwargs=(dims=1,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A, A; fkwargs=(dims=2,), kwargs...) + + for B in (Y, Yrep) + # Column-wise distance + test_rrule(colwise, metric ⊢ NoTangent(), A, B) + + # Pairwise distances + test_rrule(pairwise, metric ⊢ NoTangent(), A, B) + test_rrule(pairwise, metric ⊢ NoTangent(), A, B; fkwargs=(dims=1,)) + test_rrule(pairwise, metric ⊢ NoTangent(), A, B; fkwargs=(dims=2,)) + end end end end