Skip to content

Commit

Permalink
Faster filldist() (#227)
Browse files Browse the repository at this point in the history
* fix testset_zygote_broken()

define vars used by error()

* logpdf(arraydist): use mapreduce

* logpdf(filldist): use mapreduce

* remove filldist(Zygote) from broken

* improve mapreduce

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* improve mapreduce() invocation

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* tests: exclude Chernoff from Zygote filldist tests

* simplify mapreduce -> sum

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* explicitly broadcast

since it looks like `mapreduce()` still allocates

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* require ChainRulesTestUtils >= 1.9.2

some graident tests require test_approx(::Array{<:Array}, ::Zero)

* _flat_logpdf(): explicit lazy broadcasting

* filldist tests: enable Skellam

* use product_distribution() to fix deprecation

* eliminate unnecessary intermediate var

* replace some anonymous funcs with Base.Fix1

* replace sum(lambda, zip(...)) with lazy broadcast

* Update src/arraydist.jl

* Update test/ad/distributions.jl

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 8, 2022
1 parent dc92604 commit 02ca329
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 22 deletions.
17 changes: 8 additions & 9 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
return MatrixOfUnivariate(dists)
end
function Distributions._logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
# return sum(((d, xi),) -> logpdf(d, xi), zip(dist.dists, x))
# Broadcasting here breaks Tracker for some reason
return sum(map(logpdf, dist.dists, x))
# Lazy broadcast to avoid allocations and use pairwise summation
return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, x)))
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(dist, x), x)
return map(Base.Fix1(logpdf, dist), x)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
return map(x -> logpdf(dist, x), x)
return map(Base.Fix1(logpdf, dist), x)
end

function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
Expand All @@ -52,16 +51,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution})
end

function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
return sum(((di, xi),) -> logpdf(di, xi), zip(dist.dists, eachcol(x)))
return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, eachcol(x))))
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(dist, x), x)
return map(Base.Fix1(logpdf, dist), x)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
return map(x -> logpdf(dist, x), x)
return map(Base.Fix1(logpdf, dist), x)
end

function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
init = reshape(rand(rng, dist.dists[1]), :, 1)
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
return mapreduce(Base.Fix1(rand, rng), hcat, view(dist.dists, 2:length(dist)); init = init)
end
12 changes: 5 additions & 7 deletions src/filldist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,19 @@ end
function _flat_logpdf(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return sum(f.(args..., x))
# Lazy broadcast to avoid allocations and use pairwise summation
return sum(Broadcast.instantiate(Broadcast.broadcasted(xi -> f(args..., xi), x)))
else
return sum(map(x) do x
logpdf(dist, x)
end)
return sum(Broadcast.instantiate(Broadcast.broadcasted(Base.Fix1(logpdf, dist), x)))
end
end

function _flat_logpdf_mat(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return vec(sum(f.(args..., x), dims = 1))
return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1))
else
temp = map(x -> logpdf(dist, x), x)
return vec(sum(temp, dims = 1))
return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1))
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
ChainRulesTestUtils = "1.9.2"
Combinatorics = "1.0.2"
Distributions = "0.25.15"
FiniteDifferences = "0.11.3, 0.12"
Expand Down
4 changes: 1 addition & 3 deletions test/ad/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,7 @@
# PoissonBinomial fails with Zygote
# Matrix case does not work with Skellam:
# https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493
filldist_broken = if D <: Skellam
((d.broken..., :Zygote, :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff))
elseif D <: PoissonBinomial
filldist_broken = if D <: PoissonBinomial
((d.broken..., :Zygote), (d.broken..., :Zygote))
elseif D <: Chernoff
# Zygote is not broken with `filldist`
Expand Down
8 changes: 6 additions & 2 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,16 @@ function testset_zygote(distspec, unpack_x_θ, args...; kwargs...)
end
end

function testset_zygote_broken(args...; kwargs...)
function testset_zygote_broken(distspec, args...; kwargs...)
# don't show test errors - tests are known to be broken :)
testset = suppress_stdout() do
testset_zygote(args...; kwargs...)
testset_zygote(distspec, args...; kwargs...)
end

f = distspec.f
θ = distspec.θ
x = distspec.x

# change errors and fails to broken results, and count number of errors and fails
efs = errors_to_broken!(testset)

Expand Down

0 comments on commit 02ca329

Please sign in to comment.