diff --git a/Project.toml b/Project.toml index 7849880..38e9331 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "ExactOptimalTransport" uuid = "24df6009-d856-477c-ac5c-91f668376b31" authors = ["JuliaOptimalTransport"] -version = "0.1.1" +version = "0.1.2" [deps] Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" @@ -16,6 +17,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] Distances = "0.9.0, 0.10" Distributions = "0.24, 0.25" +FillArrays = "0.12" MathOptInterface = "0.9" PDMats = "0.10, 0.11" QuadGK = "2" diff --git a/src/ExactOptimalTransport.jl b/src/ExactOptimalTransport.jl index ea642e6..aba117e 100644 --- a/src/ExactOptimalTransport.jl +++ b/src/ExactOptimalTransport.jl @@ -3,6 +3,7 @@ module ExactOptimalTransport using Distances using MathOptInterface using Distributions +using FillArrays using PDMats using QuadGK using StatsBase: StatsBase diff --git a/src/exact.jl b/src/exact.jl index 82ea71d..12515a8 100644 --- a/src/exact.jl +++ b/src/exact.jl @@ -263,30 +263,38 @@ a sparse matrix. See also: [`ot_cost`](@ref), [`emd`](@ref) """ function ot_plan(_, μ::DiscreteNonParametric, ν::DiscreteNonParametric) - # unpack the probabilities of the two distributions + # Unpack the probabilities of the two distributions + # Note: support of `DiscreteNonParametric` is sorted μprobs = probs(μ) νprobs = probs(ν) - - # create the iterator - # note: support of `DiscreteNonParametric` is sorted - iter = Discrete1DOTIterator(μprobs, νprobs) - - # create arrays for the indices of the two histograms and the optimal flow between the - # corresponding points - n = length(iter) - I = Vector{Int}(undef, n) - J = Vector{Int}(undef, n) - W = Vector{Base.promote_eltype(μprobs, νprobs)}(undef, n) - - # compute the sparse optimal transport plan - @inbounds for (idx, (i, j, w)) in enumerate(iter) - I[idx] = i - J[idx] = j - W[idx] = w + T = Base.promote_eltype(μprobs, νprobs) + + return if μprobs isa FillArrays.AbstractFill && + νprobs isa FillArrays.AbstractFill && + length(μprobs) == length(νprobs) + # Special case: discrete uniform distributions of the same "size" + k = length(μprobs) + sparse(1:k, 1:k, T(first(μprobs)), k, k) + else + # Generic case + # Create the iterator + iter = Discrete1DOTIterator(μprobs, νprobs) + + # create arrays for the indices of the two histograms and the optimal flow between the + # corresponding points + n = length(iter) + I = Vector{Int}(undef, n) + J = Vector{Int}(undef, n) + W = Vector{T}(undef, n) + + # compute the sparse optimal transport plan + @inbounds for (idx, (i, j, w)) in enumerate(iter) + I[idx] = i + J[idx] = j + W[idx] = w + end + sparse(I, J, W, length(μprobs), length(νprobs)) end - γ = sparse(I, J, W, length(μprobs), length(νprobs)) - - return γ end """ @@ -305,45 +313,50 @@ A pre-computed optimal transport `plan` may be provided. See also: [`ot_plan`](@ref), [`emd2`](@ref) """ function ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric; plan=nothing) - return _ot_cost(c, μ, ν, plan) -end - -# compute cost from scratch if no plan is provided -function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, ::Nothing) - # unpack the probabilities of the two distributions + # Extract support and probabilities of discrete distributions + # Note: support of `DiscreteNonParametric` is sorted + μsupport = support(μ) + νsupport = support(ν) μprobs = probs(μ) νprobs = probs(ν) + return if μprobs isa FillArrays.AbstractFill && + νprobs isa FillArrays.AbstractFill && + length(μprobs) == length(νprobs) + # Special case: discrete uniform distributions of the same "size" + # In this case we always just compute `sum(c.(μsupport .- νsupport))` and scale it + # We use pairwise summation and avoid allocations + # (https://github.com/JuliaLang/julia/pull/31020) + T = Base.promote_eltype(μprobs, νprobs) + T(first(μprobs)) * + sum(Broadcast.instantiate(Broadcast.broadcasted(c, μsupport, νsupport))) + else + # Generic case + _ot_cost(c, μsupport, μprobs, νsupport, νprobs, plan) + end +end + +# compute cost from scratch if no plan is provided +function _ot_cost(c, μsupport, μprobs, νsupport, νprobs, ::Nothing) # create the iterator - # note: support of `DiscreteNonParametric` is sorted iter = Discrete1DOTIterator(μprobs, νprobs) # compute the cost - μsupport = support(μ) - νsupport = support(ν) - cost = sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in iter) - - return cost + return sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in iter) end # if a sparse plan is provided, we just iterate through the non-zero entries -function _ot_cost( - c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan::SparseMatrixCSC -) +function _ot_cost(c, μsupport, _, νsupport, _, plan::SparseMatrixCSC) # extract non-zero flows I, J, W = findnz(plan) # compute the cost - μsupport = support(μ) - νsupport = support(ν) - cost = sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in zip(I, J, W)) - - return cost + return sum(w * c(μsupport[i], νsupport[j]) for (i, j, w) in zip(I, J, W)) end # fallback: compute cost matrix (probably often faster to compute cost from scratch) -function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan) - return dot(plan, StatsBase.pairwise(c, support(μ), support(ν))) +function _ot_cost(c, μsupport, _, νsupport, _, plan) + return dot(plan, StatsBase.pairwise(c, μsupport, νsupport)) end ################ diff --git a/src/utils.jl b/src/utils.jl index 454265a..f7dc47f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,7 +12,7 @@ end """ discretemeasure( support::AbstractVector, - probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)), + probs::AbstractVector{<:Real}=FillArrays.Fill(inv(length(support)), length(support)), ) Construct a finite discrete probability measure with `support` and corresponding @@ -42,13 +42,13 @@ using KernelFunctions """ function discretemeasure( support::AbstractVector{<:Real}, - probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)), + probs::AbstractVector{<:Real}=Fill(inv(length(support)), length(support)), ) return DiscreteNonParametric(support, probs) end function discretemeasure( support::AbstractVector, - probs::AbstractVector{<:Real}=fill(inv(length(support)), length(support)), + probs::AbstractVector{<:Real}=Fill(inv(length(support)), length(support)), ) return FiniteDiscreteMeasure{typeof(support),typeof(probs)}(support, probs) end diff --git a/test/exact.jl b/test/exact.jl index 26826a8..e3e5ada 100644 --- a/test/exact.jl +++ b/test/exact.jl @@ -1,6 +1,7 @@ using ExactOptimalTransport using Distances +using FillArrays using PythonOT: PythonOT using Tulip using MathOptInterface @@ -110,56 +111,77 @@ Random.seed!(100) end @testset "discrete case" begin - # random source and target marginal - m = 30 - μprobs = normalize!(rand(m), 1) - μsupport = randn(m) - μ = DiscreteNonParametric(μsupport, μprobs) - - n = 50 - νprobs = normalize!(rand(n), 1) - νsupport = randn(n) - ν = DiscreteNonParametric(νsupport, νprobs) - - # compute OT plan - γ = @inferred(ot_plan(euclidean, μ, ν)) - @test γ isa SparseMatrixCSC - @test size(γ) == (m, n) - @test vec(sum(γ; dims=2)) ≈ μ.p - @test vec(sum(γ; dims=1)) ≈ ν.p - - # consistency checks - I, J, W = findnz(γ) - @test all(w > zero(w) for w in W) - @test sum(W) ≈ 1 - @test sort(unique(I)) == 1:m - @test sort(unique(J)) == 1:n - @test sort(I .+ J) == 2:(m + n) - - # compute OT cost - c = @inferred(ot_cost(euclidean, μ, ν)) - - # compare with computation with explicit cost matrix - # DiscreteNonParametric sorts the support automatically, here we have to sort - # manually - C = pairwise(Euclidean(), μsupport', νsupport'; dims=2) - c2 = emd2(μprobs, νprobs, C, Tulip.Optimizer()) - @test c2 ≈ c rtol = 1e-5 - - # compare with POT - # disabled currently since https://github.com/PythonOT/POT/issues/169 causes bounds - # error - # @test γ ≈ POT.emd_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean") - # @test c ≈ POT.emd2_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean") - - # do not use the probabilities of μ and ν to ensure that the provided plan is - # used - μ2 = DiscreteNonParametric(μsupport, reverse(μprobs)) - ν2 = DiscreteNonParametric(νsupport, reverse(νprobs)) - c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=γ)) - @test c2 ≈ c - c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ))) - @test c2 ≈ c + # different random sources and target marginals: + # non-uniform + different size, uniform + different size, uniform + equal size + for (μ, ν) in ( + ( + DiscreteNonParametric(randn(30), normalize!(rand(30), 1)), + DiscreteNonParametric(randn(50), normalize!(rand(50), 1)), + ), + ( + DiscreteNonParametric(randn(30), Fill(1 / 30, 30)), + DiscreteNonParametric(randn(50), Fill(1 / 50, 50)), + ), + ( + DiscreteNonParametric(randn(30), Fill(1 / 30, 30)), + DiscreteNonParametric(randn(30), Fill(1 / 30, 30)), + ), + ) + # extract support, probabilities, and "size" + μsupport = support(μ) + μprobs = probs(μ) + m = length(μprobs) + + νsupport = support(ν) + νprobs = probs(ν) + n = length(νprobs) + + # compute OT plan + γ = @inferred(ot_plan(euclidean, μ, ν)) + @test γ isa SparseMatrixCSC + @test size(γ) == (m, n) + @test vec(sum(γ; dims=2)) ≈ μ.p + @test vec(sum(γ; dims=1)) ≈ ν.p + + # consistency checks + I, J, W = findnz(γ) + @test all(w > zero(w) for w in W) + @test sum(W) ≈ 1 + @test sort(unique(I)) == 1:m + @test sort(unique(J)) == 1:n + @test sort(I .+ J) == if μprobs isa Fill && νprobs isa Fill && m == n + # Optimized version for special case (discrete uniform + equal size) + 2:2:(m + n) + else + # Generic case (not optimized) + 2:(m + n) + end + + # compute OT cost + c = @inferred(ot_cost(euclidean, μ, ν)) + + # compare with computation with explicit cost matrix + # DiscreteNonParametric sorts the support automatically, here we have to sort + # manually + C = pairwise(Euclidean(), μsupport', νsupport'; dims=2) + c2 = emd2(μprobs, νprobs, C, Tulip.Optimizer()) + @test c2 ≈ c rtol = 1e-5 + + # compare with POT + # disabled currently since https://github.com/PythonOT/POT/issues/169 causes bounds + # error + # @test γ ≈ POT.emd_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean") + # @test c ≈ POT.emd2_1d(μ.support, ν.support; a=μ.p, b=μ.p, metric="euclidean") + + # do not use the probabilities of μ and ν to ensure that the provided plan is + # used + μ2 = DiscreteNonParametric(μsupport, reverse(μprobs)) + ν2 = DiscreteNonParametric(νsupport, reverse(νprobs)) + c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=γ)) + @test c2 ≈ c + c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ))) + @test c2 ≈ c + end end end