diff --git a/Project.toml b/Project.toml index db31493..507b366 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Distances" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.9" +version = "0.10.10" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -9,21 +9,27 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [extensions] +DistancesChainRulesCoreExt = "ChainRulesCore" DistancesSparseArraysExt = "SparseArrays" [compat] +ChainRulesCore = "1" StatsAPI = "1" julia = "1" [extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["OffsetArrays", "Random", "SparseArrays", "Test", "Unitful"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "OffsetArrays", "Random", "SparseArrays", "StableRNGs", "Test", "Unitful"] diff --git a/ext/DistancesChainRulesCoreExt.jl b/ext/DistancesChainRulesCoreExt.jl new file mode 100644 index 0000000..10d89bb --- /dev/null +++ b/ext/DistancesChainRulesCoreExt.jl @@ -0,0 +1,150 @@ +module DistancesChainRulesCoreExt + +using Distances + +import ChainRulesCore + +const CRC = ChainRulesCore + +## SqEuclidean + +function CRC.rrule( + ::CRC.RuleConfig{>:CRC.HasReverseMode}, + dist::SqEuclidean, + x::AbstractVector{<:Real}, + y::AbstractVector{<:Real} +) + Ω = dist(x, y) + + function SqEuclidean_pullback(ΔΩ) + x̄ = (2 * CRC.unthunk(ΔΩ)) .* (x .- y) + return CRC.NoTangent(), x̄, -x̄ + end + + return Ω, SqEuclidean_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(colwise), dist::SqEuclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}) + Ω = colwise(dist, X, Y) + + function colwise_SqEuclidean_pullback(ΔΩ) + X̄ = 2 .* CRC.unthunk(ΔΩ)' .* (X .- Y) + return CRC.NoTangent(), CRC.NoTangent(), X̄, -X̄ + end + + return Ω, colwise_SqEuclidean_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::SqEuclidean, X::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) + dims = Distances.deprecated_dims(dims) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + Ω = pairwise(dist, X; dims=dims) + + function pairwise_SqEuclidean_X_pullback(ΔΩ) + Δ = CRC.unthunk(ΔΩ) + A = Δ .+ transpose(Δ) + X̄ = if dims == 1 + 2 .* (sum(A; dims=2) .* X .- A * X) + else + 2 .* (X .* sum(A; dims=1) .- X * A) + end + return CRC.NoTangent(), CRC.NoTangent(), X̄ + end + + return Ω, pairwise_SqEuclidean_X_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::SqEuclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) + dims = Distances.deprecated_dims(dims) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + Ω = pairwise(dist, X, Y; dims=dims) + + function pairwise_SqEuclidean_X_Y_pullback(ΔΩ) + Δ = CRC.unthunk(ΔΩ) + Δt = transpose(Δ) + X̄ = if dims == 1 + 2 .* (sum(Δ; dims=2) .* X .- Δ * Y) + else + 2 .* (X .* sum(Δt; dims=1) .- Y * Δt) + end + Ȳ = if dims == 1 + 2 .* (sum(Δt; dims=2) .* Y .- Δt * X) + else + 2 .* (Y .* sum(Δ; dims=1) .- X * Δ) + end + return CRC.NoTangent(), CRC.NoTangent(), X̄, Ȳ + end + + return Ω, pairwise_SqEuclidean_X_Y_pullback +end + +## Euclidean + +_normalize(x::Real, nrm::Real) = iszero(nrm) && !isnan(x) ? one(x / nrm) : x / nrm + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, dist::Euclidean, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) + Ω = dist(x, y) + + function Euclidean_pullback(ΔΩ) + x̄ = _normalize(CRC.unthunk(ΔΩ), Ω) .* (x .- y) + return CRC.NoTangent(), x̄, -x̄ + end + + return Ω, Euclidean_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(colwise), dist::Euclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}) + Ω = colwise(dist, X, Y) + + function colwise_Euclidean_pullback(ΔΩ) + X̄ = _normalize.(CRC.unthunk(ΔΩ)', Ω') .* (X .- Y) + return CRC.NoTangent(), CRC.NoTangent(), X̄, -X̄ + end + + return Ω, colwise_Euclidean_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::Euclidean, X::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) + dims = Distances.deprecated_dims(dims) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + Ω = pairwise(dist, X; dims=dims) + + function pairwise_Euclidean_X_pullback(ΔΩ) + Δ = CRC.unthunk(ΔΩ) + A = _normalize.(Δ .+ transpose(Δ), Ω) + X̄ = if dims == 1 + sum(A; dims=2) .* X .- A * X + else + X .* sum(A; dims=1) .- X * A + end + return CRC.NoTangent(), CRC.NoTangent(), X̄ + end + + return Ω, pairwise_Euclidean_X_pullback +end + +function CRC.rrule(::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(pairwise), dist::Euclidean, X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; dims::Union{Nothing,Integer}=nothing) + dims = Distances.deprecated_dims(dims) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + Ω = pairwise(dist, X, Y; dims=dims) + + function pairwise_Euclidean_X_Y_pullback(ΔΩ) + Δ = _normalize.(CRC.unthunk(ΔΩ), Ω) + Δt = transpose(Δ) + X̄ = if dims == 1 + sum(Δ; dims=2) .* X .- Δ * Y + else + X .* sum(Δt; dims=1) .- Y * Δt + end + Ȳ = if dims == 1 + sum(Δt; dims=2) .* Y .- Δt * X + else + Y .* sum(Δ; dims=1) .- X * Δ + end + return CRC.NoTangent(), CRC.NoTangent(), X̄, Ȳ + end + + return Ω, pairwise_Euclidean_X_Y_pullback +end + +end # module \ No newline at end of file diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 0000000..2bc44ea --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,46 @@ +using ChainRulesCore +using ChainRulesTestUtils +using StableRNGs + +@testset "ChainRulesCore extension" begin + n = 4 + rng = StableRNG(100) + x = randn(rng, n) + y = randn(rng, n) + X = randn(rng, n, 3) + Y = randn(rng, n, 3) + Xrep = repeat(x, 1, 3) + Yrep = repeat(y, 1, 3) + + @testset for metric in (SqEuclidean(), Euclidean()) + # Single evaluation + test_rrule(metric ⊢ NoTangent(), x, y) + test_rrule(metric ⊢ NoTangent(), x, x) + + for A in (X, Xrep) + # Column-wise distance + test_rrule(colwise, metric ⊢ NoTangent(), A, A) + + # Pairwise distances + # Finite differencing yields impressively inaccurate derivatives for `Euclidean`, + # see https://github.com/FluxML/Zygote.jl/blob/45bf883491d2b52580d716d577e2fa8577a07230/test/gradcheck.jl#L1206 + kwargs = metric isa Euclidean ? (rtol=1e-3, atol=1e-3) : () + test_rrule(pairwise, metric ⊢ NoTangent(), A; kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A; fkwargs=(dims=1,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A; fkwargs=(dims=2,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A, A; kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A, A; fkwargs=(dims=1,), kwargs...) + test_rrule(pairwise, metric ⊢ NoTangent(), A, A; fkwargs=(dims=2,), kwargs...) + + for B in (Y, Yrep) + # Column-wise distance + test_rrule(colwise, metric ⊢ NoTangent(), A, B) + + # Pairwise distances + test_rrule(pairwise, metric ⊢ NoTangent(), A, B) + test_rrule(pairwise, metric ⊢ NoTangent(), A, B; fkwargs=(dims=1,)) + test_rrule(pairwise, metric ⊢ NoTangent(), A, B; fkwargs=(dims=2,)) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8cbeff8..c7b60ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,3 +11,11 @@ using Unitful.DefaultSymbols include("F64.jl") include("test_dists.jl") + +# Test ChainRules definitions on Julia versions that support weak dependencies +# Support for extensions was added in +# https://github.com/JuliaLang/julia/commit/93587d7c1015efcd4c5184e9c42684382f1f9ab2 +# https://github.com/JuliaLang/julia/pull/47695 +if VERSION >= v"1.9.0-alpha1.18" + include("chainrules.jl") +end \ No newline at end of file