Skip to content

Commit

Permalink
Add optimization for discrete uniform distributions of equal size (#17)
Browse files Browse the repository at this point in the history
* Add optimization for discrete uniform distributions of equal size

* Update test/exact.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Add optimization for `ot_plan`

* Fix test

* Add compat entry for FillArrays

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
devmotion and github-actions[bot] committed Dec 21, 2021
1 parent dcda744 commit ada5933
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 97 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "ExactOptimalTransport"
uuid = "24df6009-d856-477c-ac5c-91f668376b31"
authors = ["JuliaOptimalTransport"]
version = "0.1.1"
version = "0.1.2"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -16,6 +17,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[compat]
Distances = "0.9.0, 0.10"
Distributions = "0.24, 0.25"
FillArrays = "0.12"
MathOptInterface = "0.9"
PDMats = "0.10, 0.11"
QuadGK = "2"
Expand Down
1 change: 1 addition & 0 deletions src/ExactOptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ExactOptimalTransport
using Distances
using MathOptInterface
using Distributions
using FillArrays
using PDMats
using QuadGK
using StatsBase: StatsBase
Expand Down
99 changes: 56 additions & 43 deletions src/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,30 +263,38 @@ a sparse matrix.
See also: [`ot_cost`](@ref), [`emd`](@ref)
"""
function ot_plan(_, μ::DiscreteNonParametric, ν::DiscreteNonParametric)
# unpack the probabilities of the two distributions
# Unpack the probabilities of the two distributions
# Note: support of `DiscreteNonParametric` is sorted
μprobs = probs(μ)
νprobs = probs(ν)

# create the iterator
# note: support of `DiscreteNonParametric` is sorted
iter = Discrete1DOTIterator(μprobs, νprobs)

# create arrays for the indices of the two histograms and the optimal flow between the
# corresponding points
n = length(iter)
I = Vector{Int}(undef, n)
J = Vector{Int}(undef, n)
W = Vector{Base.promote_eltype(μprobs, νprobs)}(undef, n)

# compute the sparse optimal transport plan
@inbounds for (idx, (i, j, w)) in enumerate(iter)
I[idx] = i
J[idx] = j
W[idx] = w
T = Base.promote_eltype(μprobs, νprobs)

return if μprobs isa FillArrays.AbstractFill &&
νprobs isa FillArrays.AbstractFill &&
length(μprobs) == length(νprobs)
# Special case: discrete uniform distributions of the same "size"
k = length(μprobs)
sparse(1:k, 1:k, T(first(μprobs)), k, k)
else
# Generic case
# Create the iterator
iter = Discrete1DOTIterator(μprobs, νprobs)

# create arrays for the indices of the two histograms and the optimal flow between the
# corresponding points
n = length(iter)
I = Vector{Int}(undef, n)
J = Vector{Int}(undef, n)
W = Vector{T}(undef, n)

# compute the sparse optimal transport plan
@inbounds for (idx, (i, j, w)) in enumerate(iter)
I[idx] = i
J[idx] = j
W[idx] = w
end
sparse(I, J, W, length(μprobs), length(νprobs))
end
γ = sparse(I, J, W, length(μprobs), length(νprobs))

return γ
end

"""
Expand All @@ -305,45 +313,50 @@ A pre-computed optimal transport `plan` may be provided.
See also: [`ot_plan`](@ref), [`emd2`](@ref)
"""
function ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric; plan=nothing)
return _ot_cost(c, μ, ν, plan)
end

# compute cost from scratch if no plan is provided
function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, ::Nothing)
# unpack the probabilities of the two distributions
# Extract support and probabilities of discrete distributions
# Note: support of `DiscreteNonParametric` is sorted
μsupport = support(μ)
νsupport = support(ν)
μprobs = probs(μ)
νprobs = probs(ν)

return if μprobs isa FillArrays.AbstractFill &&
νprobs isa FillArrays.AbstractFill &&
length(μprobs) == length(νprobs)
# Special case: discrete uniform distributions of the same "size"
# In this case we always just compute `sum(c.(μsupport .- νsupport))` and scale it
# We use pairwise summation and avoid allocations
# (https://github.com/JuliaLang/julia/pull/31020)
T = Base.promote_eltype(μprobs, νprobs)
T(first(μprobs)) *
sum(Broadcast.instantiate(Broadcast.broadcasted(c, μsupport, νsupport)))
else
# Generic case
_ot_cost(c, μsupport, μprobs, νsupport, νprobs, plan)
end
end

# compute cost from scratch if no plan is provided
function _ot_cost(c, μsupport, μprobs, νsupport, νprobs, ::Nothing)
# create the iterator
# note: support of `DiscreteNonParametric` is sorted
iter = Discrete1DOTIterator(μprobs, νprobs)

# compute the cost
μsupport = support(μ)
νsupport = support(ν)
cost = sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in iter)

return cost
return sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in iter)
end

# if a sparse plan is provided, we just iterate through the non-zero entries
function _ot_cost(
c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan::SparseMatrixCSC
)
function _ot_cost(c, μsupport, _, νsupport, _, plan::SparseMatrixCSC)
# extract non-zero flows
I, J, W = findnz(plan)

# compute the cost
μsupport = support(μ)
νsupport = support(ν)
cost = sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in zip(I, J, W))

return cost
return sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in zip(I, J, W))
end

# fallback: compute cost matrix (probably often faster to compute cost from scratch)
function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan)
return dot(plan, StatsBase.pairwise(c, support(μ), support(ν)))
function _ot_cost(c, μsupport, _, νsupport, _, plan)
return dot(plan, StatsBase.pairwise(c, μsupport, νsupport))
end

################
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end
"""
discretemeasure(
support::AbstractVector,
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
probs::AbstractVector{<:Real}=FillArrays.Fill(inv(length(support)), length(support)),
)
Construct a finite discrete probability measure with `support` and corresponding
Expand Down Expand Up @@ -42,13 +42,13 @@ using KernelFunctions
"""
function discretemeasure(
support::AbstractVector{<:Real},
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
probs::AbstractVector{<:Real}=Fill(inv(length(support)), length(support)),
)
return DiscreteNonParametric(support, probs)
end
function discretemeasure(
support::AbstractVector,
probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)),
probs::AbstractVector{<:Real}=Fill(inv(length(support)), length(support)),
)
return FiniteDiscreteMeasure{typeof(support),typeof(probs)}(support, probs)
end
Expand Down
122 changes: 72 additions & 50 deletions test/exact.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ExactOptimalTransport

using Distances
using FillArrays
using PythonOT: PythonOT
using Tulip
using MathOptInterface
Expand Down Expand Up @@ -110,56 +111,77 @@ Random.seed!(100)
end

@testset "discrete case" begin
# random source and target marginal
m = 30
μprobs = normalize!(rand(m), 1)
μsupport = randn(m)
μ = DiscreteNonParametric(μsupport, μprobs)

n = 50
νprobs = normalize!(rand(n), 1)
νsupport = randn(n)
ν = DiscreteNonParametric(νsupport, νprobs)

# compute OT plan
γ = @inferred(ot_plan(euclidean, μ, ν))
@test γ isa SparseMatrixCSC
@test size(γ) == (m, n)
@test vec(sum(γ; dims=2)) μ.p
@test vec(sum(γ; dims=1)) ν.p

# consistency checks
I, J, W = findnz(γ)
@test all(w > zero(w) for w in W)
@test sum(W) 1
@test sort(unique(I)) == 1:m
@test sort(unique(J)) == 1:n
@test sort(I .+ J) == 2:(m + n)

# compute OT cost
c = @inferred(ot_cost(euclidean, μ, ν))

# compare with computation with explicit cost matrix
# DiscreteNonParametric sorts the support automatically, here we have to sort
# manually
C = pairwise(Euclidean(), μsupport', νsupport'; dims=2)
c2 = emd2(μprobs, νprobs, C, Tulip.Optimizer())
@test c2 c rtol = 1e-5

# compare with POT
# disabled currently since https://github.com/PythonOT/POT/issues/169 causes bounds
# error
# @test γ ≈ POT.emd_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")
# @test c ≈ POT.emd2_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")

# do not use the probabilities of μ and ν to ensure that the provided plan is
# used
μ2 = DiscreteNonParametric(μsupport, reverse(μprobs))
ν2 = DiscreteNonParametric(νsupport, reverse(νprobs))
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=γ))
@test c2 c
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ)))
@test c2 c
# different random sources and target marginals:
# non-uniform + different size, uniform + different size, uniform + equal size
for (μ, ν) in (
(
DiscreteNonParametric(randn(30), normalize!(rand(30), 1)),
DiscreteNonParametric(randn(50), normalize!(rand(50), 1)),
),
(
DiscreteNonParametric(randn(30), Fill(1 / 30, 30)),
DiscreteNonParametric(randn(50), Fill(1 / 50, 50)),
),
(
DiscreteNonParametric(randn(30), Fill(1 / 30, 30)),
DiscreteNonParametric(randn(30), Fill(1 / 30, 30)),
),
)
# extract support, probabilities, and "size"
μsupport = support(μ)
μprobs = probs(μ)
m = length(μprobs)

νsupport = support(ν)
νprobs = probs(ν)
n = length(νprobs)

# compute OT plan
γ = @inferred(ot_plan(euclidean, μ, ν))
@test γ isa SparseMatrixCSC
@test size(γ) == (m, n)
@test vec(sum(γ; dims=2)) μ.p
@test vec(sum(γ; dims=1)) ν.p

# consistency checks
I, J, W = findnz(γ)
@test all(w > zero(w) for w in W)
@test sum(W) 1
@test sort(unique(I)) == 1:m
@test sort(unique(J)) == 1:n
@test sort(I .+ J) == if μprobs isa Fill && νprobs isa Fill && m == n
# Optimized version for special case (discrete uniform + equal size)
2:2:(m + n)
else
# Generic case (not optimized)
2:(m + n)
end

# compute OT cost
c = @inferred(ot_cost(euclidean, μ, ν))

# compare with computation with explicit cost matrix
# DiscreteNonParametric sorts the support automatically, here we have to sort
# manually
C = pairwise(Euclidean(), μsupport', νsupport'; dims=2)
c2 = emd2(μprobs, νprobs, C, Tulip.Optimizer())
@test c2 c rtol = 1e-5

# compare with POT
# disabled currently since https://github.com/PythonOT/POT/issues/169 causes bounds
# error
# @test γ ≈ POT.emd_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")
# @test c ≈ POT.emd2_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean")

# do not use the probabilities of μ and ν to ensure that the provided plan is
# used
μ2 = DiscreteNonParametric(μsupport, reverse(μprobs))
ν2 = DiscreteNonParametric(νsupport, reverse(νprobs))
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=γ))
@test c2 c
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ)))
@test c2 c
end
end
end

Expand Down

2 comments on commit ada5933

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/50953

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.2 -m "<description of version>" ada5933503a1806fd4279bea8c258d01c0a20094
git push origin v0.1.2

Please sign in to comment.