diff --git a/Project.toml b/Project.toml index c2c895e46..599325e12 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.47" +version = "0.6.48" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/lib/distances.jl b/src/lib/distances.jl index b49a2f74c..ee39e9de6 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -79,11 +79,16 @@ end end @adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2) - D, back = pullback(X -> pairwise(SqEuclidean(), X; dims = dims), X) - D .= sqrt.(D) - return D, function(Δ) - Δ = Δ ./ (2 .* max.(D, eps(eltype(D)))) - Δ[diagind(Δ)] .= 0 - return (nothing, first(back(Δ))) + + _conditional(d, δ) = d > δ ? sqrt(d) : zero(d) + + function _pairwise_euclidean(X) + δ = eps(eltype(X))^2 + D2 = pairwise(SqEuclidean(), X; dims=dims) + return _conditional.(D2, δ) end + D, back = pullback(_pairwise_euclidean, X) + + _pairwise_pullback(Δ) = (nothing, back(Δ)...) + return D, _pairwise_pullback end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index e37e0ea15..3330d5927 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1192,6 +1192,19 @@ end # This is impressively inaccurate, but at least it doesn't produce a NaN. @test first(Δ_fd) ≈ first(pb(Δ)) atol=1e-3 rtol=1e-3 end + + @testset "repeated X" begin + Δ = randn(P, P) + X = repeat(randn(rng, D), 1, P) + + Δ_fd = FiniteDifferences.j′vp( + FiniteDifferences.central_fdm(5, 1), X -> pairwise(metric, X; dims=2), Δ, X + ) + _, pb = Zygote.pullback(X -> pairwise(metric, X; dims=2), X) + + # This is impressively inaccurate, but at least it doesn't produce a NaN. + @test first(Δ_fd) ≈ first(pb(Δ)) atol=1e-3 rtol=1e-3 + end end @testset "binary pairwise - X and Y close" begin