Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing Distances adjoints to ChainRules syntax #923

Closed
wants to merge 17 commits into from
Closed
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.6"
version = "0.6.7"
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
33 changes: 17 additions & 16 deletions src/lib/distances.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
using .Distances
import .ChainRules: NO_FIELDS, rrule
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

@adjoint function (::SqEuclidean)(x::AbstractVector, y::AbstractVector)
function rrule(::SqEuclidean, x::AbstractVector, y::AbstractVector)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
theogf marked this conversation as resolved.
Show resolved Hide resolved
δ = x .- y
function sqeuclidean(Δ::Real)
x̄ = (2 * Δ) .* δ
return x̄, -x̄
return NO_FIELDS, x̄, -x̄
end
return sum(abs2, δ), sqeuclidean
end

@adjoint function colwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix)
function rrule(::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix)
theogf marked this conversation as resolved.
Show resolved Hide resolved
return colwise(s, x, y), function (Δ::AbstractVector)
x̄ = 2 .* Δ' .* (x .- y)
return nothing, x̄, -x̄
return NO_FIELDS, NO_FIELDS, x̄, -x̄
end
end

@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2)
function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2)
theogf marked this conversation as resolved.
Show resolved Hide resolved
if dims==1
return pairwise(s, x, y; dims=1), ∇pairwise(s, transpose(x), transpose(y), transpose)
else
Expand All @@ -28,10 +29,10 @@ end
function(Δ)
x̄ = 2 .* (x * Diagonal(vec(sum(Δ; dims=2))) .- y * transpose(Δ))
ȳ = 2 .* (y * Diagonal(vec(sum(Δ; dims=1))) .- x * Δ)
return (nothing, f(x̄), f(ȳ))
return NO_FIELDS, NO_FIELDS, f(x̄), f(ȳ)
end

@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix; dims::Int=2)
function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix; dims::Int=2)
theogf marked this conversation as resolved.
Show resolved Hide resolved
if dims==1
return pairwise(s, x; dims=1), ∇pairwise(s, transpose(x), transpose)
else
Expand All @@ -43,28 +44,28 @@ end
function(Δ)
d1 = Diagonal(vec(sum(Δ; dims=1)))
d2 = Diagonal(vec(sum(Δ; dims=2)))
return (nothing, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f)
return NO_FIELDS, NO_FIELDS, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f
end

@adjoint function (::Euclidean)(x::AbstractVector, y::AbstractVector)
function rrule(::Euclidean, x::AbstractVector, y::AbstractVector)
theogf marked this conversation as resolved.
Show resolved Hide resolved
D = x .- y
δ = sqrt(sum(abs2, D))
function euclidean(Δ::Real)
x̄ = ifelse(iszero(δ), D, (Δ / δ) .* D)
return x̄, -x̄
return NO_FIELDS, x̄, -x̄
end
return δ, euclidean
end

@adjoint function colwise(s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
function rrule(::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
theogf marked this conversation as resolved.
Show resolved Hide resolved
d = colwise(s, x, y)
return d, function (Δ::AbstractVector)
x̄ = (Δ ./ max.(d, eps(eltype(d))))' .* (x .- y)
return nothing, x̄, -x̄
return NO_FIELDS, NO_FIELDS, x̄, -x̄
end
end

@adjoint function pairwise(::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
theogf marked this conversation as resolved.
Show resolved Hide resolved

# Modify the forwards-pass slightly to ensure stability on the reverse.
function _pairwise_euclidean(X, Y)
Expand All @@ -74,16 +75,16 @@ end
D, back = pullback(_pairwise_euclidean, X, Y)

return D, function(Δ)
return (nothing, back(Δ)...)
return (NO_FIELDS, NO_FIELDS, back(Δ)...)
end
end

@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2)
function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix; dims=2)
D, back = pullback(X -> pairwise(SqEuclidean(), X; dims = dims), X)
D .= sqrt.(D)
return D, function(Δ)
Δ = Δ ./ (2 .* max.(D, eps(eltype(D))))
Δ[diagind(Δ)] .= 0
return (nothing, first(back(Δ)))
return (NO_FIELDS, NO_FIELDS, first(back(Δ)))
end
end