Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressing performance issues with broadcasting #230

Closed
wants to merge 1 commit into from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Jan 15, 2023

This PR introduces a "lazy" version of arraydist which allows us to hack around performance issues with broadcasting over constructors in ReverseDiff and Zygote. Based on "insights" from TuringLang/Turing.jl#1934.

More specifically, in broadcasting:

Here's an example of the result of the "lazy" array dist:

julia> using Distributions, DistributionsAD, LogDensityProblems, LogDensityProblemsAD, ReverseDiff, BenchmarkTools

julia> # Examples.
       n = 1000;

julia> logitp = randn(n); x = rand(Bool, n);

julia> struct Slow{A}
           x::A
       end

julia> LogDensityProblems.logdensity(f::Slow, logitp) = sum(logpdf.(BernoulliLogit.(logitp), f.x))

julia> LogDensityProblems.dimension(f::Slow) = length(f.x)

julia> LogDensityProblems.capabilities(::Type{<:Slow}) = LogDensityProblems.LogDensityOrder{0}()

julia> struct Fast{A}
           x::A
       end

julia> LogDensityProblems.logdensity(f::Fast, logitp) = logpdf(arraydist(BernoulliLogit, logitp), f.x)

julia> LogDensityProblems.dimension(f::Fast) = length(f.x)

julia> LogDensityProblems.capabilities(::Type{<:Fast}) = LogDensityProblems.LogDensityOrder{0}()

julia> adbackend = :ReverseDiff
:ReverseDiff

julia> kwargs = if adbackend === :ReverseDiff
           (compile=Val(true),)
       else
           NamedTuple()
       end
(compile = Val{true}(),)

julia> ∂slow = ADgradient(adbackend, Slow(x); kwargs...)
ReverseDiff AD wrapper for Slow{Vector{Bool}}(Bool[0, 1, 1, 0, 0, 1, 1, 1, 1, 0    0, 0, 0, 0, 1, 1, 0, 0, 1, 1]) (compiled tape)

julia> ∂fast = ADgradient(adbackend, Fast(x); kwargs...)
ReverseDiff AD wrapper for Fast{Vector{Bool}}(Bool[0, 1, 1, 0, 0, 1, 1, 1, 1, 0    0, 0, 0, 0, 1, 1, 0, 0, 1, 1]) (compiled tape)

julia> @benchmark $(LogDensityProblems.logdensity_and_gradient)($∂slow, $logitp)
BenchmarkTools.Trial: 5650 samples with 1 evaluation.
 Range (min  max):  838.543 μs    4.950 ms  ┊ GC (min  max): 0.00%  80.79%
 Time  (median):     847.915 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   883.237 μs ± 306.428 μs  ┊ GC (mean ± σ):  2.82% ±  6.62%

  ▄▇█▆▅▄▃▃▂▂▂   ▁▂▂▁▁                                           ▁
  ████████████▇████████▇▆▆▅▅▆▃▆▅▅▆███▇▆▄▄▆▅▅▄▄▄▅▄▆▅▃▃▁▅▅▃▅▄▃▃▄▃ █
  839 μs        Histogram: log(frequency) by time       1.05 ms <

 Memory estimate: 406.23 KiB, allocs estimate: 13491.

julia> @benchmark $(LogDensityProblems.logdensity_and_gradient)($∂fast, $logitp)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  32.306 μs   2.524 ms  ┊ GC (min  max): 0.00%  98.21%
 Time  (median):     34.354 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   34.874 μs ± 25.074 μs  ┊ GC (mean ± σ):  0.71% ±  0.98%

  ▄█▆▂   ▃    ▁██▆▄▃▂▅▆▄▁▁ ▁▂▁     ▂      ▂▃       ▂▁▁        ▂
  █████▄▆██▆▅▅████████████████▇▇▆▅▆██▆▆▅▃▃███▇▆▆▅▆▆████▇▇▆▆▅▆ █
  32.3 μs      Histogram: log(frequency) by time      40.8 μs <

 Memory estimate: 7.97 KiB, allocs estimate: 2.

julia> using Zygote

julia> adbackend = :Zygote
:Zygote

julia> kwargs = if adbackend === :ReverseDiff
           (compile=Val(true),)
       else
           NamedTuple()
       end
NamedTuple()

julia> ∂slow = ADgradient(adbackend, Slow(x); kwargs...)
Zygote AD wrapper for Slow{Vector{Bool}}(Bool[0, 1, 1, 0, 0, 1, 1, 1, 1, 0    0, 0, 0, 0, 1, 1, 0, 0, 1, 1])

julia> ∂fast = ADgradient(adbackend, Fast(x); kwargs...)
Zygote AD wrapper for Fast{Vector{Bool}}(Bool[0, 1, 1, 0, 0, 1, 1, 1, 1, 0    0, 0, 0, 0, 1, 1, 0, 0, 1, 1])

julia> @benchmark $(LogDensityProblems.logdensity_and_gradient)($∂slow, $logitp)
BenchmarkTools.Trial: 4265 samples with 1 evaluation.
 Range (min  max):  1.076 ms   10.724 ms  ┊ GC (min  max): 0.00%  88.85%
 Time  (median):     1.106 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.170 ms ± 652.314 μs  ┊ GC (mean ± σ):  4.76% ±  7.43%

   █                                                           
  ▅██▇▆▄▅▆▆▆▆▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▁▁▂▂▂▂ ▃
  1.08 ms         Histogram: frequency by time        1.37 ms <

 Memory estimate: 536.72 KiB, allocs estimate: 14601.

julia> @benchmark $(LogDensityProblems.logdensity_and_gradient)($∂fast, $logitp)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  189.582 μs    8.207 ms  ┊ GC (min  max): 0.00%  96.49%
 Time  (median):     195.637 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   206.669 μs ± 194.132 μs  ┊ GC (mean ± σ):  2.77% ±  2.88%

   ▄▇█▇▄▄▆▆▃▂▁▂▃▅▅▄▃▂▂▂▂▁▁    ▁▁                                ▂
  ▆████████████████████████▇▇████▇▇███▇▇▇▇▇▇█▆▆▇▇▇▆▆▇▆▆▆▅▅▅▅▅▅▅ █
  190 μs        Histogram: log(frequency) by time        257 μs <

 Memory estimate: 68.69 KiB, allocs estimate: 473.

I also thought maybe this is where LazyArrays.jl could be useful, but preliminary attempts didn't end up being fruitful so uncertain.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy about the performance improvements but at the same time very unhappy about introducing additional type piracy in DistributionsAD. We have been removing more and more of the type piracy in DistributionsAD (with the ultimate goal of making the package completely obsolete at some point) and transferring fixes and distributions to other AD packages and Distributions. For instance, it would also be good to start adapting product_distribution and ProductDistribution and improving its AD compatibility and performance in Distributions instead of the custom filldist, arraydist, and all the custom structs in DistributionsAD. The logpdf type piracies in this package have also already caused problems (IIRC there is also at least one open issue).


make_logpdf_closure(::Type{D}) where {D} = (x, args...) -> logpdf(D(args...), x)

function Distributions.logpdf(dist::Product{<:Any,D,<:StructArrays.StructArray}, x::AbstractVector{<:Real}) where {D}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Product is deprecated.

@torfjelde
Copy link
Member Author

I am of course in complete agreement with you in everything you say here, but do you think such a "hacky" approach would be accepted into Distributions?
It's seems to be an issue with either AD implantation or the Julia type inference that should really be addressed, but that is unfortunately beyond my own capability and will probably require a lot of work.

That is, do you foresee another solution in the near- or even medium-term future?:/

@devmotion
Copy link
Member

Could we add performance and AD improvements to ProductDistributions with FillArrays and broadcast objects resulting from Broadcast.instantiate(Broadcast.broadcasted(....)) to Distributions? The main problem with the latter might be to infer the type of the distribution correctly without materializing everything.

@torfjelde
Copy link
Member Author

Could we add performance and AD improvements to ProductDistributions with FillArrays and broadcast objects resulting from Broadcast.instantiate(Broadcast.broadcasted(....)) to Distributions?

I'm not completlely understanding what you mean here 😕 Do you mean we should make something like

product_distribution(Broadcast.broadcasted(BernoulliLogit, logitp))

work and have a similar implementation as in this PR?

@torfjelde
Copy link
Member Author

Maybe something like (I know Product is deprecated but it's a bit simpler than ProductDistribution):

using Distributions: UnivariateDistribution, MultivariateDistribution, ValueSupport

struct LazyProduct{
    S<:ValueSupport,
    T<:UnivariateDistribution{S},
    V,
} <: MultivariateDistribution{S}
    v::V
    function LazyProduct{S,T,V}(v::V) where {S<:ValueSupport,T<:UnivariateDistribution{S},V}
        return new{S,T,V}(v)
    end
end

function LazyProduct(
    v::V
) where {S<:ValueSupport,T<:UnivariateDistribution{S},V<:Broadcast.Broadcasted{<:Any,<:Any,Type{T}}}
    return LazyProduct{S, T, V}(v)
end

Base.length(d::LazyProduct) = length(d.v)
function Base.eltype(::Type{<:LazyProduct{S,T}}) where {S<:ValueSupport,
                                                        T<:UnivariateDistribution{S}}
    return eltype(Broadcast.combine_eltypes(bc.f, bc.args))
end

Distributions._rand!(rng::Distributions.AbstractRNG, d::LazyProduct, x::AbstractVector{<:Real}) =
    map!(Base.Fix1(rand, rng), x, d.v)

function Distributions._logpdf(d::LazyProduct, x::AbstractVector{<:Real})
    bc = d.v
    f = make_logpdf_closure(bc.f)
    return sum(f.(x, bc.args...))
end

This does the trick for ReverseDiff but Zygote complains (it tries to call LazyDistribution(::AbstractVector) at some point, so might just require a custom rule).

@devmotion
Copy link
Member

Basically yes but for Broadcast.instantiate(Broadcast.broadcasted(...)). Similar to how base supports it in eg mapreduce, sum etc. (often eg sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, y))) is a very performant alternative to sum(f(xi, yi) for (xi, yi) in zip(x, y)) and sum(f.(x, y)) or sum(map(f, x, y)) that uses pairwise summation (in contrast to zip) and avoids materializing the array (in contrast to broadcast and map) - therefore it's used e.g. in Distributions and StatsBase IIRC).

@torfjelde
Copy link
Member Author

torfjelde commented Jan 15, 2023

sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, y)))

This is much slower for ReverseDiff vs. actually allocating, i.e. sum(f.(x, y)) 😕 Maybe this could be fixed in ReverseDiff though..

EDIT: In the particular example above, we're talking 10X slower without compilation and 4X slower with compilation.

EDIT: Zygote handles instantiate(broadcasted(...)) the same as f.(...) it seems 👍

@devmotion
Copy link
Member

IMO Broadcast.instantiate(...) is so common (e.g. it's also used by LazyArrays: https://github.com/JuliaArrays/LazyArrays.jl/blob/09f382f8d4828d37eb11474963b671dfa95ea43d/src/lazymacro.jl#L114) that it's a bug if an AD system does not support it and these bugs should be fixed.

@torfjelde
Copy link
Member Author

Fair 😕

Speaking of LazyArrays; I was just looking at maybe using this for the product distribution. Then it just becomes a matter of doing product_distribution(BroadcastArray(BernoulliLogit, logitp)), etc.

@torfjelde
Copy link
Member Author

torfjelde commented Jan 15, 2023

Broadcast.instantiate(...)

Just for the record, ReverseDiff supports this and all. It's just that sum(instantiate(broadcasted(...))) seems, at a first glance, to be slower than sum(f.(...)) directly.

@torfjelde
Copy link
Member Author

That is, we can just change

function Distributions._logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return sum(copy(logpdf.(dist.v, x)))
end

to something like

        make_logpdf_closure(::Type{<:BroadcastVector{<:Any,Type{F}}}) where {F} = make_logpdf_closure(F)

        function Distributions._logpdf(
            dist::LazyVectorOfUnivariate,
            x::AbstractVector{<:Real},
        )
            f = DistributionsAD.make_logpdf_closure(typeof(dist.v))
            # TODO: Fix ReverseDiff performance on this.
            return sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, dist.v.args...)))
        end

with

make_logpdf_closure(::Type{D}) where {D} = (x, args...) -> logpdf(D(args...), x)

@torfjelde
Copy link
Member Author

Regarding ReverseDiff's issue with sum(instantiate(broadcasted(f, ...))) vs. sum(f.(...)) is that the former traces through the entire computation using TrackedReal (resulting in a long tape) while the latter results in a single ∇broadcast and a sum, i.e. the tape is only of length 2 (and we end up using ForwardDiff for ∇broadcast)

@torfjelde
Copy link
Member Author

Regarding ReverseDiff's issue with sum(instantiate(broadcasted(f, ...))) vs. sum(f.(...)) is that the former traces through the entire computation using TrackedReal (resulting in a long tape) while the latter results in a single ∇broadcast and a sum, i.e. the tape is only of length 2 (and we end up using ForwardDiff for ∇broadcast)

This seems like it's somewhat non-trivial to address 😕 The issue is that we want to allocate in the case we're using ReverseDiff since it allows us to vectorize using ForwardDiff..

@torfjelde
Copy link
Member Author

torfjelde commented Jan 15, 2023

Also, is instantiate(broadcasted(...)) actually the way to go? Saw this thread https://discourse.julialang.org/t/sum-mapreduce-and-broadcasted/15300 and tried the same benchmarks:

julia> using BenchmarkTools

julia> v = rand(100);

julia> f(v) = sum(v .* v')
f (generic function with 1 method)

julia> g(v) = mapreduce(identity, +, Broadcast.broadcasted(*, v, v'))
g (generic function with 1 method)

julia> h(v) = mapreduce(identity, +, Base.Broadcast.materialize(Broadcast.broadcasted(*, v, v')))
h (generic function with 1 method)

julia> @benchmark f($v)
BenchmarkTools.Trial: 10000 samples with 5 evaluations.
 Range (min  max):  5.846 μs  401.951 μs  ┊ GC (min  max):  0.00%  95.71%
 Time  (median):     6.519 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   8.044 μs ±  15.636 μs  ┊ GC (mean ± σ):  12.52% ±  6.39%

  ▃██▆▄▂▁   ▁▁▂                                               ▂
  ███████▇█▇██████▇▇▆▆▆▆▆▇▆▆▆▆▅▅▅▄▄▆▄▃▅▃▄▄▁▄▃▁▃▁▁▁▁▃▁▁▃▄▃▅▇▇▇ █
  5.85 μs      Histogram: log(frequency) by time      24.5 μs <

 Memory estimate: 78.17 KiB, allocs estimate: 2.

julia> @benchmark g($v)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  25.326 μs  105.242 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     25.986 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   27.306 μs ±   4.657 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▅█▄▂▂ ▁▃▁▂ ▃ ▃ ▂                                             ▁
  █████▇████▅█▁█▃██▆█▁▄▁▁▃▁▄▁▃▃▁▁▁▁▁▁▁▁▄▄▄▄▆▅▅▄▆▅▅▅▄▆▄▆▅▅▄▄▄▄▄ █
  25.3 μs       Histogram: log(frequency) by time      52.6 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark h($v)
BenchmarkTools.Trial: 10000 samples with 5 evaluations.
 Range (min  max):  5.058 μs  309.694 μs  ┊ GC (min  max):  0.00%  91.66%
 Time  (median):     6.459 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   7.756 μs ±  14.452 μs  ┊ GC (mean ± σ):  12.16% ±  6.39%

     ▇█▆▄▁                                                    ▂
  ▅▄███████▆▄▄▃▆█▇▆▆▅▆▆▅▅▅▆▄▆▅▅▁▅▄▅▃▄▁▃▁▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▇▆▇ █
  5.06 μs      Histogram: log(frequency) by time      24.4 μs <

 Memory estimate: 78.17 KiB, allocs estimate: 2.

I.e. materializing is faster (at least in this scenario, also involving sum).

So seems to indicate if we want speed on sum then f.(...) is preferable?

@devmotion
Copy link
Member

IMO performance issues are bugs as well since I think usually you should be able to be as fast (or slow) as regular broadcasting by materializing.

It would be better to support Broadcast.instantiate... in Distributions if possible than adding a dependency on LazyArrays there, I think.

In the example it seems you did not instantiate the broadcasting object. That's necessary for good performance in general.

@torfjelde
Copy link
Member Author

IMO performance issues are bugs as well since I think usually you should be able to be as fast (or slow) as regular broadcasting by materializing.

But isn't this an issue of how mapreduce access the lazy object? As in, of course we can make mapreduce on Broadcasted just as fast by simply overloading mapreduce for Broadcasted to materialize before doing the actual computation. But if you don't materialize, then mapreduce has to run the computation for each element sequentially, no? Which is bound to be slower than performing the map first (which in this case is materialize taking advantage of stuff like SIMD)?

@torfjelde
Copy link
Member Author

In the example it seems you did not instantiate the broadcasting object. That's necessary for good performance in general.

Doesn't matter here:

julia> using BenchmarkTools

julia> v = rand(100);

julia> f(v) = sum(v .* v')
f (generic function with 1 method)

julia> g(v) = mapreduce(identity, +, Broadcast.instantiate(Broadcast.broadcasted(*, v, v')))
g (generic function with 1 method)

julia> h(v) = mapreduce(identity, +, Base.Broadcast.materialize(Broadcast.instantiate(Broadcast.broadcasted(*, v, v'))))
h (generic function with 1 method)

julia> @benchmark f($v)
BenchmarkTools.Trial: 10000 samples with 7 evaluations.
 Range (min  max):  4.927 μs  64.068 μs  ┊ GC (min  max): 0.00%  65.85%
 Time  (median):     7.004 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.705 μs ±  3.979 μs  ┊ GC (mean ± σ):  5.13% ±  8.93%

    ▄█▇▅▂▂▂▁                                                 ▂
  ▅▄██████████▅▃▄▄▃▁▁▃▁▁▁▁▃▁▃▁▁▁▁▁▁▅▇▇▆▆▁▃▄▃▁▁▁▁▁▃▁▁▁▁▁▁▃▄▅▆ █
  4.93 μs      Histogram: log(frequency) by time     36.3 μs <

 Memory estimate: 78.17 KiB, allocs estimate: 2.

julia> @benchmark g($v)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  13.625 μs  60.730 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     14.411 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.894 μs ±  1.880 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▃      ▄▂     █▇     ▂▆▂      ▆▃      ▄▄       ▁▄▂       ▁▂ ▂
  █▇▁▁▁▁▇██▁▁▁▁▁██▄▁▁▁▁███▃▁▁▁▁▆██▁▁▃▃▁▁██▆▁▁▄▃▄▅███▅▃▄▁▃▁▁██ █
  13.6 μs      Histogram: log(frequency) by time      16.7 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark h($v)
BenchmarkTools.Trial: 10000 samples with 6 evaluations.
 Range (min  max):  4.328 μs  85.797 μs  ┊ GC (min  max): 0.00%  90.22%
 Time  (median):     6.613 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.486 μs ±  5.520 μs  ┊ GC (mean ± σ):  7.09% ±  8.95%

   ▁█▆▂▂▂▁                                                   ▁
  ▆███████▇▆▅▄▃▄▃▃▁▁▁▁▁▁▅▇█▄▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅ █
  4.33 μs      Histogram: log(frequency) by time     52.2 μs <

 Memory estimate: 78.17 KiB, allocs estimate: 2.

@devmotion
Copy link
Member

But if you don't materialize, then mapreduce has to run the computation for each element sequentially, no? Which is bound to be slower than performing the map first (which in this case is materialize taking advantage of stuff like SIMD)?

In both cases the same pairwise algorithm is used: https://github.com/JuliaLang/julia/blob/0371bf44bf6bfd6ee9fbfc32d478c2ff4c97b08b/base/reduce.jl#L426-L449 and https://github.com/JuliaLang/julia/blob/0371bf44bf6bfd6ee9fbfc32d478c2ff4c97b08b/base/reduce.jl#L251-L275 It just computes the required entries of the array or the broadcasting object when they are required. If there's an optimized way for computing the array (e.g. by hitting BLAS calls, I assume), then the allocations might be acceptable and still result in better performance - but the summation itself uses the same code in both cases.


I just ran the following on Julia 1.8.5:

julia> using BenchmarkTools

julia> A = vcat(1f0, fill(1f-8, 10^8));

julia> B = ones(Float32, length(A));

julia> f1(A, B) = reduce(+, A .* B);

julia> f2(A, B) = sum(A .* B);

julia> g1(A, B) = reduce(+, Broadcast.instantiate(Broadcast.broadcasted(*, A, B)));

julia> g2(A, B) = sum(Broadcast.instantiate(Broadcast.broadcasted(*, A, B)));

julia> h1(A, B) = sum(a * b for (a, b) in zip(A, B));

julia> h2(A, B) = sum(zip(A, B)) do (a, b)
           return a * b
       end;

julia> f1(A, B)
1.9999989f0

julia> f2(A, B)
1.9999989f0

julia> g1(A, B)
1.9999989f0

julia> g2(A, B)
1.9999989f0

julia> h1(A, B)
1.0f0

julia> h2(A, B)
1.0f0

julia> @btime f1($A, $B);
  185.125 ms (2 allocations: 381.47 MiB)

julia> @btime f2($A, $B);
  190.276 ms (2 allocations: 381.47 MiB)

julia> @btime g1($A, $B);
  46.049 ms (0 allocations: 0 bytes)

julia> @btime g2($A, $B);
  45.674 ms (0 allocations: 0 bytes)

julia> @btime h1($A, $B);
  143.151 ms (0 allocations: 0 bytes)

julia> @btime h2($A, $B);
  141.105 ms (0 allocations: 0 bytes)

Generally, it does not matter if we use sum (which ends up calling mapreduce(identity, Base.add_sum, ...)) or reduce(+, ...) (ends up calling mapreduce(identity, +, ...)).
It shows that in this example summing the non-materialized broadcast object is indeed faster than regular broadcast with intermediate allocation and faster than using zip. Moreover, with zip a completely wrong result is returned whereas both with regular arrays and lazy broadcasting the pairwise summation leads to a much more accurate result.


Regarding your benchmark:
I think it is mainly a question of the size of v - for larger arrays the lazy broadcasting is faster. I assume that with increasing array size potential benefits and optimizations of computing v .* v' directly are outweighed by the allocation cost:

julia> using BenchmarkTools

julia> v = rand(100);

julia> g1(v) = sum(Broadcast.instantiate(Broadcast.broadcasted(*, v, v')));

julia> g2(v) = reduce(+, Broadcast.instantiate(Broadcast.broadcasted(*, v, v')));

julia> g3(v) = mapreduce(identity, +, Broadcast.instantiate(Broadcast.broadcasted(*, v, v')));

julia> h1(v) = sum(v .* v');

julia> h2(v) = reduce(+, v .* v');

julia> h3(v) = mapreduce(identity, +, v .* v');

julia> @btime g1($v);
  16.961 μs (0 allocations: 0 bytes)

julia> @btime g2($v);
  15.575 μs (0 allocations: 0 bytes)

julia> @btime g3($v);
  15.670 μs (0 allocations: 0 bytes)

julia> @btime h1($v);
  6.872 μs (2 allocations: 78.17 KiB)

julia> @btime h2($v);
  7.173 μs (2 allocations: 78.17 KiB)

julia> @btime h3($v);
  6.849 μs (2 allocations: 78.17 KiB)

julia> v = rand(1_000);

julia> @btime g1($v);
  1.626 ms (0 allocations: 0 bytes)

julia> @btime g2($v);
  1.469 ms (0 allocations: 0 bytes)

julia> @btime g3($v);
^[[  1.469 ms (0 allocations: 0 bytes)

julia> @btime h1($v);
  1.030 ms (2 allocations: 7.63 MiB)

julia> @btime h2($v);
  1.031 ms (2 allocations: 7.63 MiB)

julia> @btime h3($v);
  1.035 ms (2 allocations: 7.63 MiB)

julia> v = rand(10_000);

julia> @btime g1($v);
  162.012 ms (0 allocations: 0 bytes)

julia> @btime g2($v);
  160.191 ms (0 allocations: 0 bytes)

julia> @btime g3($v);
  160.172 ms (0 allocations: 0 bytes)

julia> @btime h1($v);
  303.014 ms (2 allocations: 762.94 MiB)

julia> @btime h2($v);
  309.235 ms (2 allocations: 762.94 MiB)

julia> @btime h3($v);
  316.549 ms (2 allocations: 762.94 MiB)

@torfjelde
Copy link
Member Author

Ah, interesting! But what do we do about the AD issues? In particular, how do we tell ReverseDiff to hit the faster path, i.e. not trace through the entire thing using TrackedReal?

@torfjelde
Copy link
Member Author

Btw, related to all of this: JuliaArrays/LazyArrays.jl#232

@devmotion
Copy link
Member

Tracing should be fine if everything is lazy, shouldn't it? There's not much happening there in the forward pass. I would assume the problem is rather that for most functions (such as sum) efficient executions in the reverse pass are only implemented for TrackedArray but not for Broadcasted{TrackedStyle}. Similar to how efficient mapreduce/sum/prod etc. for lazy Broadcasted objects in base required JuliaLang/julia#31020.

@torfjelde
Copy link
Member Author

Tracing should be fine if everything is lazy, shouldn't it?

But tracing using TrackedReal will generally be slower here than if we use Dual, no? In particular since Dual makes use of rules from DiffRules.

@torfjelde
Copy link
Member Author

In particular here, we're looking at a full trace through everything vs. a tape containing only two statements (∇broadcast and sum), so traversing the tape will much faster (and forward-mdoe should be faster since it accumulates both the value and the gradient in the same pass rather than requiring one forward and one backward).

@devmotion
Copy link
Member

@devmotion
Copy link
Member

But tracing using TrackedReal will generally be slower here than if we use Dual, no? In particular since Dual makes use of rules from DiffRules.

In broadcasting using ForwardDiff is probably unavoidable (and Zygote and Tracker do the same) but IMO in general ReverseDiff uses Dual too often in places where it shouldn't. For instance, why would you want to work with Duals for handling the rules in DiffRules if you can just implement them directly?

@torfjelde
Copy link
Member Author

But that's another issue in ReverseDiff, isn't it? Similar to Zygote it should define adjoints mainly for broadcasted, not broadcast

Yep! I'm just saying that I'm personally not familiar enough with ReverseDiff to address this issue (and I'm worried it's going to take too much time before someone gets around to it 😕), and given the performance difference I'm not a big fan of just "leaving it" as is. We can always do something hacky like

_inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D

function Distributions._logpdf(
    dist::LazyVectorOfUnivariate,
    x::AbstractVector{<:Real},
)
    # TODO: Implement chain rule for `LazyArray` constructor to support Zygote.
    f = DistributionsAD.make_closure(logpdf, _inner_constructor(typeof(dist.v)))
    args = dist.v.args
    return if ReverseDiff.istracked(args) || ReverseDiff.istracked(x)
        sum(f.(x, args...))
    else
        sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))
    end
end

but that ain't particularly nice..

@torfjelde
Copy link
Member Author

Closing in favour of #231

@torfjelde torfjelde closed this Jan 16, 2023
@yebai yebai deleted the torfjelde/lazy-arraydist branch January 22, 2023 21:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants