From 59e621b89a3ad409940dcd2aea5485bfa55fb2c7 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 31 Jan 2023 01:13:41 +0100 Subject: [PATCH 1/7] Add weak dependency on ChainRulesCore --- Project.toml | 13 +++- ext/ChainRulesCoreExt.jl | 145 +++++++++++++++++++++++++++++++++++++++ test/chainrules.jl | 45 ++++++++++++ test/runtests.jl | 5 ++ 4 files changed, 206 insertions(+), 2 deletions(-) create mode 100644 ext/ChainRulesCoreExt.jl create mode 100644 test/chainrules.jl diff --git a/Project.toml b/Project.toml index df3bee8b..45571e40 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Distances" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.7" +version = "0.10.8" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -8,15 +8,24 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[extensions] +ChainRulesCoreExt = "ChainRulesCore" + [compat] +ChainRulesCore = "1" StatsAPI = "1" julia = "1" [extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["OffsetArrays", "Random", "Test", "Unitful"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "OffsetArrays", "Random", "Test", "Unitful"] diff --git a/ext/ChainRulesCoreExt.jl b/ext/ChainRulesCoreExt.jl new file mode 100644 index 00000000..e064f894 --- /dev/null +++ b/ext/ChainRulesCoreExt.jl @@ -0,0 +1,145 @@ +module ChainRulesCoreExt + +using Distances + +import ChainRulesCore + +const CRC = ChainRulesCore + +## SqEuclidean + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, dist::SqEuclidean, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) + Ω = dist(x, y) + + function SqEuclidean_pullback(ΔΩ) + x̄ = (2 * CRC.unthunk(ΔΩ)) .* (x .- y) + return CRC.NoTangent(), x̄, -x̄ + end + + return Ω, SqEuclidean_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(colwise), dist::SqEuclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}) + Ω = colwise(dist, X, Y) + + function colwise_SqEuclidean_pullback(ΔΩ) + X̄ = 2 .* CRC.unthunk(ΔΩ)' .* (X .- Y) + return CRC.NoTangent(), CRC.NoTangent(), X̄, -X̄ + end + + return Ω, colwise_SqEuclidean_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::SqEuclidean, X::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) + dims = Distances.deprecated_dims(dims) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + Ω = pairwise(dist, X; dims=dims) + + function pairwise_SqEuclidean_X_pullback(ΔΩ) + Δ = CRC.unthunk(ΔΩ) + A = Δ .+ transpose(Δ) + X̄ = if dims == 1 + 2 .* (sum(A; dims=2) .* X .- A * X) + else + 2 .* (X .* sum(A; dims=1) .- X * A) + end + return CRC.NoTangent(), CRC.NoTangent(), X̄ + end + + return Ω, pairwise_SqEuclidean_X_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::SqEuclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) + dims = Distances.deprecated_dims(dims) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + Ω = pairwise(dist, X, Y; dims=dims) + + function pairwise_SqEuclidean_X_Y_pullback(ΔΩ) + Δ = CRC.unthunk(ΔΩ) + Δt = transpose(Δ) + X̄ = if dims == 1 + 2 .* (sum(Δ; dims=2) .* X .- Δ * Y) + else + 2 .* (X .* sum(Δt; dims=1) .- Y * Δt) + end + Ȳ = if dims == 1 + 2 .* (sum(Δt; dims=2) .* Y .- Δt * X) + else + 2 .* (Y .* sum(Δ; dims=1) .- X * Δ) + end + return CRC.NoTangent(), CRC.NoTangent(), X̄, Ȳ + end + + return Ω, pairwise_SqEuclidean_X_Y_pullback +end + +## Euclidean + +_normalize(x::Real, nrm::Real) = iszero(nrm) && !isnan(x) ? one(x / nrm) : x / nrm + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, dist::Euclidean, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) + Ω = dist(x, y) + + function Euclidean_pullback(ΔΩ) + x̄ = _normalize(CRC.unthunk(ΔΩ), Ω) .* (x .- y) + return CRC.NoTangent(), x̄, -x̄ + end + + return Ω, Euclidean_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(colwise), dist::Euclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}) + Ω = colwise(dist, X, Y) + + function colwise_Euclidean_pullback(ΔΩ) + X̄ = _normalize.(CRC.unthunk(ΔΩ)', Ω') .* (X .- Y) + return CRC.NoTangent(), CRC.NoTangent(), X̄, -X̄ + end + + return Ω, colwise_Euclidean_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::Euclidean, X::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) + dims = Distances.deprecated_dims(dims) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + Ω = pairwise(dist, X; dims=dims) + + function pairwise_Euclidean_X_pullback(ΔΩ) + Δ = CRC.unthunk(ΔΩ) + A = _normalize.(Δ .+ transpose(Δ), Ω) + X̄ = if dims == 1 + sum(A; dims=2) .* X .- A * X + else + X .* sum(A; dims=1) .- X * A + end + return CRC.NoTangent(), CRC.NoTangent(), X̄ + end + + return Ω, pairwise_Euclidean_X_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::Euclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) + dims = Distances.deprecated_dims(dims) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + Ω = pairwise(dist, X, Y; dims=dims) + + function pairwise_Euclidean_X_Y_pullback(ΔΩ) + Δ = _normalize.(CRC.unthunk(ΔΩ), Ω) + Δt = transpose(Δ) + X̄ = if dims == 1 + sum(Δ; dims=2) .* X .- Δ * Y + else + X .* sum(Δt; dims=1) .- Y * Δt + end + Ȳ = if dims == 1 + sum(Δt; dims=2) .* Y .- Δt * X + else + Y .* sum(Δ; dims=1) .- X * Δ + end + return CRC.NoTangent(), CRC.NoTangent(), X̄, Ȳ + end + + return Ω, pairwise_Euclidean_X_Y_pullback +end + +end # module \ No newline at end of file diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 00000000..14ff3b30 --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,45 @@ +using ChainRulesCore +using ChainRulesTestUtils + +@testset "ChainRulesCore extension" begin + n = 4 + x = randn(n) + y = randn(n) + X = randn(n, 3) + Y = randn(n, 3) + + @testset for metric in (SqEuclidean(), Euclidean()) + @testset "different arguments" begin + # Single evaluation + test_rrule(metric ⊢ NoTangent(), x, y) + + # Column-wise distance + test_rrule(colwise, metric ⊢ NoTangent(), X, Y) + + # Pairwise distances + test_rrule(pairwise, metric ⊢ NoTangent(), X) + test_rrule(pairwise, metric ⊢ NoTangent(), X; fkwargs=(dims=1,)) + test_rrule(pairwise, metric ⊢ NoTangent(), X; fkwargs=(dims=2,)) + test_rrule(pairwise, metric ⊢ NoTangent(), X, Y) + test_rrule(pairwise, metric ⊢ NoTangent(), X, Y; fkwargs=(dims=1,)) + test_rrule(pairwise, metric ⊢ NoTangent(), X, Y; fkwargs=(dims=2,)) + end + + # check numerical issues if distances are zero + @testset "equal arguments" begin + # Single evaluation + test_rrule(metric ⊢ NoTangent(), x, x) + + # Column-wise distance + test_rrule(colwise, metric ⊢ NoTangent(), X, X) + + # Pairwise distances + # Finite differencing yields impressively inaccurate derivatives for `Euclidean`, + # see https://github.com/FluxML/Zygote.jl/blob/45bf883491d2b52580d716d577e2fa8577a07230/test/gradcheck.jl#L1206 + kwargs = metric isa Euclidean ? (rtol = 1e-5,) : () + test_rrule(pairwise, metric ⊢ NoTangent(), X, X; kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), X, X; fkwargs=(dims=1,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), X, X; fkwargs=(dims=2,), kwargs...) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8cbeff8f..28049a67 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,3 +11,8 @@ using Unitful.DefaultSymbols include("F64.jl") include("test_dists.jl") + +# Test ChainRules definitions on Julia versions that support weak dependencies +if isdefined(Base, :get_extension) + include("chainrules.jl") +end \ No newline at end of file From 0f35fb6fa25c3449752cb4a6536ef53391a4522f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 31 Jan 2023 02:21:51 +0100 Subject: [PATCH 2/7] Adjust tolerance --- test/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 14ff3b30..0ab7cdbe 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -36,7 +36,7 @@ using ChainRulesTestUtils # Pairwise distances # Finite differencing yields impressively inaccurate derivatives for `Euclidean`, # see https://github.com/FluxML/Zygote.jl/blob/45bf883491d2b52580d716d577e2fa8577a07230/test/gradcheck.jl#L1206 - kwargs = metric isa Euclidean ? (rtol = 1e-5,) : () + kwargs = metric isa Euclidean ? (rtol = 1e-4,) : () test_rrule(pairwise, metric ⊢ NoTangent(), X, X; kwargs...) test_rrule(pairwise, metric ⊢ NoTangent(), X, X; fkwargs=(dims=1,), kwargs...) test_rrule(pairwise, metric ⊢ NoTangent(), X, X; fkwargs=(dims=2,), kwargs...) From 06929e0b8bc8d42c930ac38c2b82c1eb44b8e070 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 16 Feb 2023 20:37:34 +0100 Subject: [PATCH 3/7] Rename extension --- Project.toml | 2 +- ext/{ChainRulesCoreExt.jl => DistancesChainRulesCoreExt.jl} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename ext/{ChainRulesCoreExt.jl => DistancesChainRulesCoreExt.jl} (99%) diff --git a/Project.toml b/Project.toml index 45571e40..0d6695a7 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [extensions] -ChainRulesCoreExt = "ChainRulesCore" +DistancesChainRulesCoreExt = "ChainRulesCore" [compat] ChainRulesCore = "1" diff --git a/ext/ChainRulesCoreExt.jl b/ext/DistancesChainRulesCoreExt.jl similarity index 99% rename from ext/ChainRulesCoreExt.jl rename to ext/DistancesChainRulesCoreExt.jl index e064f894..1cce8497 100644 --- a/ext/ChainRulesCoreExt.jl +++ b/ext/DistancesChainRulesCoreExt.jl @@ -1,4 +1,4 @@ -module ChainRulesCoreExt +module DistancesChainRulesCoreExt using Distances From 98ec25a34f69e65536d5c524a7374ed506a3bb1b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 1 Jul 2023 11:58:59 +0200 Subject: [PATCH 4/7] Update ext/DistancesChainRulesCoreExt.jl --- ext/DistancesChainRulesCoreExt.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ext/DistancesChainRulesCoreExt.jl b/ext/DistancesChainRulesCoreExt.jl index 1cce8497..10d89bb2 100644 --- a/ext/DistancesChainRulesCoreExt.jl +++ b/ext/DistancesChainRulesCoreExt.jl @@ -8,7 +8,12 @@ const CRC = ChainRulesCore ## SqEuclidean -function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, dist::SqEuclidean, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) +function CRC.rrule( + ::CRC.RuleConfig{>:CRC.HasReverseMode}, + dist::SqEuclidean, + x::AbstractVector{<:Real}, + y::AbstractVector{<:Real} +) Ω = dist(x, y) function SqEuclidean_pullback(ΔΩ) From 2e8563aa30dbb0568abc40ea4013a6f56cf4e49a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 8 Sep 2023 07:33:18 +0200 Subject: [PATCH 5/7] Add tests with matrices of repeated columns --- test/chainrules.jl | 49 +++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 0ab7cdbe..6023ac3e 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -7,39 +7,38 @@ using ChainRulesTestUtils y = randn(n) X = randn(n, 3) Y = randn(n, 3) + Xrep = repeat(x, 1, 3) + Yrep = repeat(y, 1, 3) @testset for metric in (SqEuclidean(), Euclidean()) - @testset "different arguments" begin - # Single evaluation - test_rrule(metric ⊢ NoTangent(), x, y) + # Single evaluation + test_rrule(metric ⊢ NoTangent(), x, y) + test_rrule(metric ⊢ NoTangent(), x, x) + for A in (X, Xrep) # Column-wise distance - test_rrule(colwise, metric ⊢ NoTangent(), X, Y) - - # Pairwise distances - test_rrule(pairwise, metric ⊢ NoTangent(), X) - test_rrule(pairwise, metric ⊢ NoTangent(), X; fkwargs=(dims=1,)) - test_rrule(pairwise, metric ⊢ NoTangent(), X; fkwargs=(dims=2,)) - test_rrule(pairwise, metric ⊢ NoTangent(), X, Y) - test_rrule(pairwise, metric ⊢ NoTangent(), X, Y; fkwargs=(dims=1,)) - test_rrule(pairwise, metric ⊢ NoTangent(), X, Y; fkwargs=(dims=2,)) - end - - # check numerical issues if distances are zero - @testset "equal arguments" begin - # Single evaluation - test_rrule(metric ⊢ NoTangent(), x, x) - - # Column-wise distance - test_rrule(colwise, metric ⊢ NoTangent(), X, X) + test_rrule(colwise, metric ⊢ NoTangent(), A, A) # Pairwise distances # Finite differencing yields impressively inaccurate derivatives for `Euclidean`, # see https://github.com/FluxML/Zygote.jl/blob/45bf883491d2b52580d716d577e2fa8577a07230/test/gradcheck.jl#L1206 - kwargs = metric isa Euclidean ? (rtol = 1e-4,) : () - test_rrule(pairwise, metric ⊢ NoTangent(), X, X; kwargs...) - test_rrule(pairwise, metric ⊢ NoTangent(), X, X; fkwargs=(dims=1,), kwargs...) - test_rrule(pairwise, metric ⊢ NoTangent(), X, X; fkwargs=(dims=2,), kwargs...) + kwargs = metric isa Euclidean ? (rtol=1e-3, atol=1e-3) : () + test_rrule(pairwise, metric ⊢ NoTangent(), A; kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A; fkwargs=(dims=1,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A; fkwargs=(dims=2,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A, A; kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A, A; fkwargs=(dims=1,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A, A; fkwargs=(dims=2,), kwargs...) + + for B in (Y, Yrep) + # Column-wise distance + test_rrule(colwise, metric ⊢ NoTangent(), A, B) + + # Pairwise distances + test_rrule(pairwise, metric ⊢ NoTangent(), A, B) + test_rrule(pairwise, metric ⊢ NoTangent(), A, B; fkwargs=(dims=1,)) + test_rrule(pairwise, metric ⊢ NoTangent(), A, B; fkwargs=(dims=2,)) + end end end end From 1845d013c4e941ccc650d76338786fb24cbeb3e6 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 8 Sep 2023 11:36:01 +0200 Subject: [PATCH 6/7] Use StableRNGs to fix spurious test failures --- Project.toml | 3 ++- test/chainrules.jl | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 653f67f5..507b366f 100644 --- a/Project.toml +++ b/Project.toml @@ -27,8 +27,9 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["ChainRulesCore", "ChainRulesTestUtils", "OffsetArrays", "Random", "SparseArrays", "Test", "Unitful"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "OffsetArrays", "Random", "SparseArrays", "StableRNGs", "Test", "Unitful"] diff --git a/test/chainrules.jl b/test/chainrules.jl index 6023ac3e..2bc44ea4 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,12 +1,14 @@ using ChainRulesCore using ChainRulesTestUtils +using StableRNGs @testset "ChainRulesCore extension" begin n = 4 - x = randn(n) - y = randn(n) - X = randn(n, 3) - Y = randn(n, 3) + rng = StableRNG(100) + x = randn(rng, n) + y = randn(rng, n) + X = randn(rng, n, 3) + Y = randn(rng, n, 3) Xrep = repeat(x, 1, 3) Yrep = repeat(y, 1, 3) From 559e6d132863aa057e662b81ae6248daa220b7f4 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 6 Oct 2023 00:08:00 +0200 Subject: [PATCH 7/7] Update test/runtests.jl --- test/runtests.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 28049a67..c7b60cea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,9 @@ include("F64.jl") include("test_dists.jl") # Test ChainRules definitions on Julia versions that support weak dependencies -if isdefined(Base, :get_extension) +# Support for extensions was added in +# https://github.com/JuliaLang/julia/commit/93587d7c1015efcd4c5184e9c42684382f1f9ab2 +# https://github.com/JuliaLang/julia/pull/47695 +if VERSION >= v"1.9.0-alpha1.18" include("chainrules.jl") end \ No newline at end of file