Skip to content

Commit

Permalink
Add tests with matrices of repeated columns
Browse files Browse the repository at this point in the history
  • Loading branch information
David Widmann committed Sep 8, 2023
1 parent 02ba67b commit 2e8563a
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,38 @@ using ChainRulesTestUtils
y = randn(n)
X = randn(n, 3)
Y = randn(n, 3)
Xrep = repeat(x, 1, 3)
Yrep = repeat(y, 1, 3)

@testset for metric in (SqEuclidean(), Euclidean())
@testset "different arguments" begin
# Single evaluation
test_rrule(metric NoTangent(), x, y)
# Single evaluation
test_rrule(metric NoTangent(), x, y)
test_rrule(metric NoTangent(), x, x)

for A in (X, Xrep)
# Column-wise distance
test_rrule(colwise, metric NoTangent(), X, Y)

# Pairwise distances
test_rrule(pairwise, metric NoTangent(), X)
test_rrule(pairwise, metric NoTangent(), X; fkwargs=(dims=1,))
test_rrule(pairwise, metric NoTangent(), X; fkwargs=(dims=2,))
test_rrule(pairwise, metric NoTangent(), X, Y)
test_rrule(pairwise, metric NoTangent(), X, Y; fkwargs=(dims=1,))
test_rrule(pairwise, metric NoTangent(), X, Y; fkwargs=(dims=2,))
end

# check numerical issues if distances are zero
@testset "equal arguments" begin
# Single evaluation
test_rrule(metric NoTangent(), x, x)

# Column-wise distance
test_rrule(colwise, metric NoTangent(), X, X)
test_rrule(colwise, metric NoTangent(), A, A)

# Pairwise distances
# Finite differencing yields impressively inaccurate derivatives for `Euclidean`,
# see https://github.com/FluxML/Zygote.jl/blob/45bf883491d2b52580d716d577e2fa8577a07230/test/gradcheck.jl#L1206
kwargs = metric isa Euclidean ? (rtol = 1e-4,) : ()
test_rrule(pairwise, metric NoTangent(), X, X; kwargs...)
test_rrule(pairwise, metric NoTangent(), X, X; fkwargs=(dims=1,), kwargs...)
test_rrule(pairwise, metric NoTangent(), X, X; fkwargs=(dims=2,), kwargs...)
kwargs = metric isa Euclidean ? (rtol=1e-3, atol=1e-3) : ()
test_rrule(pairwise, metric NoTangent(), A; kwargs...)
test_rrule(pairwise, metric NoTangent(), A; fkwargs=(dims=1,), kwargs...)
test_rrule(pairwise, metric NoTangent(), A; fkwargs=(dims=2,), kwargs...)
test_rrule(pairwise, metric NoTangent(), A, A; kwargs...)
test_rrule(pairwise, metric NoTangent(), A, A; fkwargs=(dims=1,), kwargs...)
test_rrule(pairwise, metric NoTangent(), A, A; fkwargs=(dims=2,), kwargs...)

for B in (Y, Yrep)
# Column-wise distance
test_rrule(colwise, metric NoTangent(), A, B)

# Pairwise distances
test_rrule(pairwise, metric NoTangent(), A, B)
test_rrule(pairwise, metric NoTangent(), A, B; fkwargs=(dims=1,))
test_rrule(pairwise, metric NoTangent(), A, B; fkwargs=(dims=2,))
end
end
end
end

0 comments on commit 2e8563a

Please sign in to comment.