From e33d8461b55a0b5e4c87afa5cfbc99e16e931e80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 16:52:01 +0100 Subject: [PATCH 01/14] Changing Distances adjoints to ChainRules syntax --- src/lib/distances.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index b49a2f74c..215513de0 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -1,6 +1,6 @@ using .Distances -@adjoint function (::SqEuclidean)(x::AbstractVector, y::AbstractVector) +function rrule(::SqEuclidean, x::AbstractVector, y::AbstractVector) δ = x .- y function sqeuclidean(Δ::Real) x̄ = (2 * Δ) .* δ @@ -9,14 +9,14 @@ using .Distances return sum(abs2, δ), sqeuclidean end -@adjoint function colwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix) +function rrule(::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix) return colwise(s, x, y), function (Δ::AbstractVector) x̄ = 2 .* Δ' .* (x .- y) return nothing, x̄, -x̄ end end -@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2) +function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2) if dims==1 return pairwise(s, x, y; dims=1), ∇pairwise(s, transpose(x), transpose(y), transpose) else @@ -31,7 +31,7 @@ end return (nothing, f(x̄), f(ȳ)) end -@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix; dims::Int=2) +function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix; dims::Int=2) if dims==1 return pairwise(s, x; dims=1), ∇pairwise(s, transpose(x), transpose) else @@ -46,7 +46,7 @@ end return (nothing, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f) end -@adjoint function (::Euclidean)(x::AbstractVector, y::AbstractVector) +function rrule(::Euclidean, x::AbstractVector, y::AbstractVector) D = x .- y δ = sqrt(sum(abs2, D)) function euclidean(Δ::Real) @@ -56,7 +56,7 @@ end return δ, euclidean end -@adjoint function colwise(s::Euclidean, x::AbstractMatrix, y::AbstractMatrix) +function rrule(::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMatrix) d = colwise(s, x, y) return d, function (Δ::AbstractVector) x̄ = (Δ ./ max.(d, eps(eltype(d))))' .* (x .- y) @@ -64,7 +64,7 @@ end end end -@adjoint function pairwise(::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2) +function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2) # Modify the forwards-pass slightly to ensure stability on the reverse. function _pairwise_euclidean(X, Y) @@ -78,7 +78,7 @@ end end end -@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2) +function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix; dims=2) D, back = pullback(X -> pairwise(SqEuclidean(), X; dims = dims), X) D .= sqrt.(D) return D, function(Δ) From 3827063ebc6404bc5c2284dd7453521f9bf9e9e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 18:13:18 +0100 Subject: [PATCH 02/14] Imported ChainRules and changed nothing to NO_FIELDS --- src/lib/distances.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index 215513de0..469b32381 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -1,4 +1,5 @@ using .Distances +using .ChainRules: NO_FIELDS, rrule function rrule(::SqEuclidean, x::AbstractVector, y::AbstractVector) δ = x .- y @@ -12,7 +13,7 @@ end function rrule(::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix) return colwise(s, x, y), function (Δ::AbstractVector) x̄ = 2 .* Δ' .* (x .- y) - return nothing, x̄, -x̄ + return NO_FIELDS, x̄, -x̄ end end @@ -28,7 +29,7 @@ end function(Δ) x̄ = 2 .* (x * Diagonal(vec(sum(Δ; dims=2))) .- y * transpose(Δ)) ȳ = 2 .* (y * Diagonal(vec(sum(Δ; dims=1))) .- x * Δ) - return (nothing, f(x̄), f(ȳ)) + return (NO_FIELDS, f(x̄), f(ȳ)) end function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix; dims::Int=2) @@ -43,7 +44,7 @@ end function(Δ) d1 = Diagonal(vec(sum(Δ; dims=1))) d2 = Diagonal(vec(sum(Δ; dims=2))) - return (nothing, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f) + return (NO_FIELDS, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f) end function rrule(::Euclidean, x::AbstractVector, y::AbstractVector) @@ -60,7 +61,7 @@ function rrule(::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMa d = colwise(s, x, y) return d, function (Δ::AbstractVector) x̄ = (Δ ./ max.(d, eps(eltype(d))))' .* (x .- y) - return nothing, x̄, -x̄ + return NO_FIELDS, x̄, -x̄ end end @@ -74,7 +75,7 @@ function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix, Y::AbstractMa D, back = pullback(_pairwise_euclidean, X, Y) return D, function(Δ) - return (nothing, back(Δ)...) + return (NO_FIELDS, back(Δ)...) end end @@ -84,6 +85,6 @@ function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix; dims=2) return D, function(Δ) Δ = Δ ./ (2 .* max.(D, eps(eltype(D)))) Δ[diagind(Δ)] .= 0 - return (nothing, first(back(Δ))) + return (NO_FIELDS, first(back(Δ))) end end From e9580921fbb57fbfb17dca94fb660014c3591475 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 18:35:21 +0100 Subject: [PATCH 03/14] Update distances.jl --- src/lib/distances.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index 469b32381..2a333adff 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -1,5 +1,5 @@ using .Distances -using .ChainRules: NO_FIELDS, rrule +import .ChainRules: NO_FIELDS, rrule function rrule(::SqEuclidean, x::AbstractVector, y::AbstractVector) δ = x .- y From 647b091b1f28855b77931d29c0ae8d23db86d3c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 22 Mar 2021 19:29:27 +0100 Subject: [PATCH 04/14] Missing NO_FIELDS --- src/lib/distances.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index 2a333adff..b572aa2d0 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -13,7 +13,7 @@ end function rrule(::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix) return colwise(s, x, y), function (Δ::AbstractVector) x̄ = 2 .* Δ' .* (x .- y) - return NO_FIELDS, x̄, -x̄ + return NO_FIELDS, NO_FIELDS, x̄, -x̄ end end @@ -29,7 +29,7 @@ end function(Δ) x̄ = 2 .* (x * Diagonal(vec(sum(Δ; dims=2))) .- y * transpose(Δ)) ȳ = 2 .* (y * Diagonal(vec(sum(Δ; dims=1))) .- x * Δ) - return (NO_FIELDS, f(x̄), f(ȳ)) + return NO_FIELDS, NO_FIELDS, f(x̄), f(ȳ) end function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix; dims::Int=2) @@ -44,7 +44,7 @@ end function(Δ) d1 = Diagonal(vec(sum(Δ; dims=1))) d2 = Diagonal(vec(sum(Δ; dims=2))) - return (NO_FIELDS, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f) + return NO_FIELDS, NO_FIELDS, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f end function rrule(::Euclidean, x::AbstractVector, y::AbstractVector) @@ -52,7 +52,7 @@ function rrule(::Euclidean, x::AbstractVector, y::AbstractVector) δ = sqrt(sum(abs2, D)) function euclidean(Δ::Real) x̄ = ifelse(iszero(δ), D, (Δ / δ) .* D) - return x̄, -x̄ + return NO_FIELDS, x̄, -x̄ end return δ, euclidean end @@ -61,7 +61,7 @@ function rrule(::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMa d = colwise(s, x, y) return d, function (Δ::AbstractVector) x̄ = (Δ ./ max.(d, eps(eltype(d))))' .* (x .- y) - return NO_FIELDS, x̄, -x̄ + return NO_FIELDS, NO_FIELDS, x̄, -x̄ end end @@ -75,7 +75,7 @@ function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix, Y::AbstractMa D, back = pullback(_pairwise_euclidean, X, Y) return D, function(Δ) - return (NO_FIELDS, back(Δ)...) + return (NO_FIELDS, NO_FIELDS, back(Δ)...) end end @@ -85,6 +85,6 @@ function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix; dims=2) return D, function(Δ) Δ = Δ ./ (2 .* max.(D, eps(eltype(D)))) Δ[diagind(Δ)] .= 0 - return (NO_FIELDS, first(back(Δ))) + return (NO_FIELDS, NO_FIELDS, first(back(Δ))) end end From 18f56377d34818bc97432b2ce80d38422793c487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Fri, 26 Mar 2021 15:42:37 +0100 Subject: [PATCH 05/14] Missing NO_FIELDS --- src/lib/distances.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index b572aa2d0..b6fd51fcb 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -5,7 +5,7 @@ function rrule(::SqEuclidean, x::AbstractVector, y::AbstractVector) δ = x .- y function sqeuclidean(Δ::Real) x̄ = (2 * Δ) .* δ - return x̄, -x̄ + return NO_FIELDS, x̄, -x̄ end return sum(abs2, δ), sqeuclidean end From 73dd9d95258425e322e5b8f293f2f8b8340a3884 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Fri, 26 Mar 2021 17:22:57 +0100 Subject: [PATCH 06/14] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 432cee9c1..5511fb278 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.4" +version = "0.6.7" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 06cf2849607b408310628a47485ef111ce4afb19 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 30 Apr 2021 09:24:30 +0200 Subject: [PATCH 07/14] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 57d032d12..68d7366a9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.7" +version = "0.6.10" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 248c1e43e501df26d82690f6dd3742de44b04845 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Wed, 20 Apr 2022 22:44:49 +0200 Subject: [PATCH 08/14] pullback -> rrule_via_ad --- src/lib/distances.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index b6fd51fcb..2d8ac05a9 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -65,22 +65,22 @@ function rrule(::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMa end end -function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2) +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(pairwise), ::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2) # Modify the forwards-pass slightly to ensure stability on the reverse. function _pairwise_euclidean(X, Y) δ = eps(promote_type(eltype(X), eltype(Y)))^2 return sqrt.(max.(pairwise(SqEuclidean(), X, Y; dims=dims), δ)) end - D, back = pullback(_pairwise_euclidean, X, Y) + D, back = rrule_via_ad(config, _pairwise_euclidean, X, Y) return D, function(Δ) return (NO_FIELDS, NO_FIELDS, back(Δ)...) end end -function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix; dims=2) - D, back = pullback(X -> pairwise(SqEuclidean(), X; dims = dims), X) +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(pairwise), ::Euclidean, X::AbstractMatrix; dims=2) + D, back = rrule_via_ad(config, X -> pairwise(SqEuclidean(), X; dims = dims), X) D .= sqrt.(D) return D, function(Δ) Δ = Δ ./ (2 .* max.(D, eps(eltype(D)))) From 6070bd797f67157400b39cc0ed9612ec978a4d62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 21 Apr 2022 13:07:33 +0200 Subject: [PATCH 09/14] Apply suggestions --- src/lib/distances.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index 2d8ac05a9..aff5e70d9 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -1,11 +1,11 @@ using .Distances -import .ChainRules: NO_FIELDS, rrule +import .ChainRules: NoTangent, rrule, rrule_via_ad function rrule(::SqEuclidean, x::AbstractVector, y::AbstractVector) δ = x .- y function sqeuclidean(Δ::Real) x̄ = (2 * Δ) .* δ - return NO_FIELDS, x̄, -x̄ + return NoTangent(), x̄, -x̄ end return sum(abs2, δ), sqeuclidean end @@ -13,7 +13,7 @@ end function rrule(::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix) return colwise(s, x, y), function (Δ::AbstractVector) x̄ = 2 .* Δ' .* (x .- y) - return NO_FIELDS, NO_FIELDS, x̄, -x̄ + return NoTangent(), NoTangent(), x̄, -x̄ end end @@ -29,7 +29,7 @@ end function(Δ) x̄ = 2 .* (x * Diagonal(vec(sum(Δ; dims=2))) .- y * transpose(Δ)) ȳ = 2 .* (y * Diagonal(vec(sum(Δ; dims=1))) .- x * Δ) - return NO_FIELDS, NO_FIELDS, f(x̄), f(ȳ) + return NoTangent(), NoTangent(), f(x̄), f(ȳ) end function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix; dims::Int=2) @@ -44,7 +44,7 @@ end function(Δ) d1 = Diagonal(vec(sum(Δ; dims=1))) d2 = Diagonal(vec(sum(Δ; dims=2))) - return NO_FIELDS, NO_FIELDS, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f + return NoTangent(), NoTangent(), x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f end function rrule(::Euclidean, x::AbstractVector, y::AbstractVector) @@ -52,7 +52,7 @@ function rrule(::Euclidean, x::AbstractVector, y::AbstractVector) δ = sqrt(sum(abs2, D)) function euclidean(Δ::Real) x̄ = ifelse(iszero(δ), D, (Δ / δ) .* D) - return NO_FIELDS, x̄, -x̄ + return NoTangent(), x̄, -x̄ end return δ, euclidean end @@ -61,11 +61,11 @@ function rrule(::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMa d = colwise(s, x, y) return d, function (Δ::AbstractVector) x̄ = (Δ ./ max.(d, eps(eltype(d))))' .* (x .- y) - return NO_FIELDS, NO_FIELDS, x̄, -x̄ + return NoTangent(), NoTangent(), x̄, -x̄ end end -function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(pairwise), ::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2) +function rrule(config::ZygoteRuleConfig, ::typeof(pairwise), ::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2) # Modify the forwards-pass slightly to ensure stability on the reverse. function _pairwise_euclidean(X, Y) @@ -75,16 +75,16 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(pairwise), ::Eucli D, back = rrule_via_ad(config, _pairwise_euclidean, X, Y) return D, function(Δ) - return (NO_FIELDS, NO_FIELDS, back(Δ)...) + return (NoTangent(), NoTangent(), back(Δ)...) end end -function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(pairwise), ::Euclidean, X::AbstractMatrix; dims=2) +function rrule(config::ZygoteRuleConfig, ::typeof(pairwise), ::Euclidean, X::AbstractMatrix; dims=2) D, back = rrule_via_ad(config, X -> pairwise(SqEuclidean(), X; dims = dims), X) D .= sqrt.(D) return D, function(Δ) Δ = Δ ./ (2 .* max.(D, eps(eltype(D)))) Δ[diagind(Δ)] .= 0 - return (NO_FIELDS, NO_FIELDS, first(back(Δ))) + return (NoTangent(), NoTangent(), first(back(Δ))) end end From c5226682cf3b208ab13b25f933c1333b2ee78d9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Fri, 22 Apr 2022 10:02:30 +0200 Subject: [PATCH 10/14] Apply suggestions from code review Co-authored-by: David Widmann --- src/lib/distances.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index aff5e70d9..dabc75f51 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -1,7 +1,7 @@ using .Distances import .ChainRules: NoTangent, rrule, rrule_via_ad -function rrule(::SqEuclidean, x::AbstractVector, y::AbstractVector) +function rrule(::ZygoteRuleConfig, ::SqEuclidean, x::AbstractVector, y::AbstractVector) δ = x .- y function sqeuclidean(Δ::Real) x̄ = (2 * Δ) .* δ @@ -10,14 +10,14 @@ function rrule(::SqEuclidean, x::AbstractVector, y::AbstractVector) return sum(abs2, δ), sqeuclidean end -function rrule(::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix) +function rrule(::ZygoteRuleConfig, ::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix) return colwise(s, x, y), function (Δ::AbstractVector) x̄ = 2 .* Δ' .* (x .- y) return NoTangent(), NoTangent(), x̄, -x̄ end end -function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2) +function rrule(::ZygoteRuleConfig, ::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2) if dims==1 return pairwise(s, x, y; dims=1), ∇pairwise(s, transpose(x), transpose(y), transpose) else @@ -32,7 +32,7 @@ end return NoTangent(), NoTangent(), f(x̄), f(ȳ) end -function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix; dims::Int=2) +function rrule(::ZygoteRuleConfig, ::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix; dims::Int=2) if dims==1 return pairwise(s, x; dims=1), ∇pairwise(s, transpose(x), transpose) else @@ -47,7 +47,7 @@ end return NoTangent(), NoTangent(), x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f end -function rrule(::Euclidean, x::AbstractVector, y::AbstractVector) +function rrule(::ZygoteRuleConfig, ::Euclidean, x::AbstractVector, y::AbstractVector) D = x .- y δ = sqrt(sum(abs2, D)) function euclidean(Δ::Real) @@ -57,7 +57,7 @@ function rrule(::Euclidean, x::AbstractVector, y::AbstractVector) return δ, euclidean end -function rrule(::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMatrix) +function rrule(::ZygoteRuleConfig, ::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMatrix) d = colwise(s, x, y) return d, function (Δ::AbstractVector) x̄ = (Δ ./ max.(d, eps(eltype(d))))' .* (x .- y) From 0ec0ee06e24b948d06aea856d5f4132b78d57f10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 24 Jan 2023 15:05:32 +0100 Subject: [PATCH 11/14] Try different approach --- src/lib/distances.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index a5c3f333a..3e0c556b0 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -75,7 +75,7 @@ function rrule(config::ZygoteRuleConfig, ::typeof(pairwise), dist::Euclidean, X: return _sqrt_if_positive.(D2, δ) end D, back = rrule_via_ad(config, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y) - pairwise_Euclidean_rrule(Δ) = (NoTangent(), back(Δ)...) + pairwise_Euclidean_rrule = back return D, pairwise_Euclidean_rrule end @@ -87,6 +87,6 @@ function rrule(config::ZygoteRuleConfig, ::typeof(pairwise), dist::Euclidean, X: return _sqrt_if_positive.(D2, δ) end D, back = rrule_via_ad(config, _pairwise_euclidean, SqEuclidean(dist.thresh), X) - pairwise_Euclidean_rrule(Δ) = (NoTangent(), back(Δ)...) + pairwise_Euclidean_rrule = back return D, pairwise_Euclidean_rrule end From 22c264cad76e43279bbbe6945d21d047c6c80871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 24 Jan 2023 17:59:05 +0100 Subject: [PATCH 12/14] Update distances.jl --- src/lib/distances.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index 3e0c556b0..e9d562660 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -74,9 +74,7 @@ function rrule(config::ZygoteRuleConfig, ::typeof(pairwise), dist::Euclidean, X: δ = eps(eltype(D2)) return _sqrt_if_positive.(D2, δ) end - D, back = rrule_via_ad(config, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y) - pairwise_Euclidean_rrule = back - return D, pairwise_Euclidean_rrule + return rrule_via_ad(config, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y) end function rrule(config::ZygoteRuleConfig, ::typeof(pairwise), dist::Euclidean, X::AbstractMatrix; dims=2) @@ -86,7 +84,5 @@ function rrule(config::ZygoteRuleConfig, ::typeof(pairwise), dist::Euclidean, X: δ = eps(eltype(D2)) return _sqrt_if_positive.(D2, δ) end - D, back = rrule_via_ad(config, _pairwise_euclidean, SqEuclidean(dist.thresh), X) - pairwise_Euclidean_rrule = back - return D, pairwise_Euclidean_rrule + return rrule_via_ad(config, _pairwise_euclidean, SqEuclidean(dist.thresh), X) end From 45ee7bf87207b67ad33f04fe4906258c1e2a5067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 25 Jan 2023 13:45:13 +0100 Subject: [PATCH 13/14] Better rrule names --- src/lib/distances.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index e9d562660..c40674ce7 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -3,18 +3,19 @@ import .ChainRules: NoTangent, rrule, rrule_via_ad function rrule(::ZygoteRuleConfig, ::SqEuclidean, x::AbstractVector, y::AbstractVector) δ = x .- y - function sqeuclidean(Δ::Real) + function sqeuclidean_rrule(Δ::Real) x̄ = (2 * Δ) .* δ return NoTangent(), x̄, -x̄ end - return sum(abs2, δ), sqeuclidean + return sum(abs2, δ), sqeuclidean_rrule end function rrule(::ZygoteRuleConfig, ::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix) - return colwise(s, x, y), function (Δ::AbstractVector) + function colwise_SqEuclidean_rrule(Δ::AbstractVector) x̄ = 2 .* Δ' .* (x .- y) return NoTangent(), NoTangent(), x̄, -x̄ end + return colwise(s, x, y), colwise_SqEuclidean_rrule end function rrule(::ZygoteRuleConfig, ::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2) @@ -26,7 +27,7 @@ function rrule(::ZygoteRuleConfig, ::typeof(pairwise), s::SqEuclidean, x::Abstra end ∇pairwise(s, x, y, f) = - function(Δ) + function pairwise_sqeuclidean_rrule(Δ) x̄ = 2 .* (x * Diagonal(vec(sum(Δ; dims=2))) .- y * transpose(Δ)) ȳ = 2 .* (y * Diagonal(vec(sum(Δ; dims=1))) .- x * Δ) return NoTangent(), NoTangent(), f(x̄), f(ȳ) @@ -41,7 +42,7 @@ function rrule(::ZygoteRuleConfig, ::typeof(pairwise), s::SqEuclidean, x::Abstra end ∇pairwise(s, x, f) = - function(Δ) + function_pairwise_sqeuclidean(Δ) d1 = Diagonal(vec(sum(Δ; dims=1))) d2 = Diagonal(vec(sum(Δ; dims=2))) return NoTangent(), NoTangent(), x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f @@ -50,19 +51,20 @@ end function rrule(::ZygoteRuleConfig, ::Euclidean, x::AbstractVector, y::AbstractVector) D = x .- y δ = sqrt(sum(abs2, D)) - function euclidean(Δ::Real) + function euclidean_rrule(Δ::Real) x̄ = ifelse(iszero(δ), D, (Δ / δ) .* D) return NoTangent(), x̄, -x̄ end - return δ, euclidean + return δ, euclidean_rrule end function rrule(::ZygoteRuleConfig, ::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMatrix) d = colwise(s, x, y) - return d, function (Δ::AbstractVector) + function colwise_Euclidean_rrule(Δ::AbstractVector) x̄ = (Δ ./ max.(d, eps(eltype(d))))' .* (x .- y) return NoTangent(), NoTangent(), x̄, -x̄ end + return d, colwise_Euclidean_rrule end _sqrt_if_positive(d, δ) = d > δ ? sqrt(d) : zero(d) From 4b9477ed8717c97f312f1bdc4b964ba87e191b21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Fri, 27 Jan 2023 00:23:44 +0100 Subject: [PATCH 14/14] Update distances.jl --- src/lib/distances.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index c40674ce7..e87b39659 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -42,7 +42,7 @@ function rrule(::ZygoteRuleConfig, ::typeof(pairwise), s::SqEuclidean, x::Abstra end ∇pairwise(s, x, f) = - function_pairwise_sqeuclidean(Δ) + function pairwise_sqeuclidean_rrule(Δ) d1 = Diagonal(vec(sum(Δ; dims=1))) d2 = Diagonal(vec(sum(Δ; dims=2))) return NoTangent(), NoTangent(), x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f