Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance regression for BernoulliLogit #1934

Open
torfjelde opened this issue Jan 10, 2023 · 51 comments
Open

Performance regression for BernoulliLogit #1934

torfjelde opened this issue Jan 10, 2023 · 51 comments

Comments

@torfjelde
Copy link
Member

I was just playing around a bit with https://github.com/torfjelde/TuringBenchmarking.jl and noticed a sudden change in the runtime described in the README (the example model is suddenly 16x slower for gradient evaluation for ReverseDiff with compiled mode).

I eventually narrowed it down to #1892 being the cause, i.e. the performance of the following model:

@model function irt(y, i, p; I = maximum(i), P = maximum(p))
    theta ~ filldist(Normal(), P)
    beta ~ filldist(Normal(), I)
    Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(theta[p] - beta[i]), y))

    return (; theta, beta)
end

absolutely tanks for ReverseDiff when we use the implementation of BernoulliLogit from Distributions.jl 😕

On Turing@0.21.12:

┌ Info: Turing.jl
│   run(suite) =2-element BenchmarkTools.BenchmarkGroup:
│      tags: []
│      "linked" => 3-element BenchmarkTools.BenchmarkGroup:
│         tags: []
│         "evaluation" => Trial(1.333 ms)
│         "Turing.Essential.ReverseDiffAD{true}()" => Trial(1.752 ms)
│         "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(174.759 ms)
│      "not_linked" => 3-element BenchmarkTools.BenchmarkGroup:
│         tags: []
│         "evaluation" => Trial(1.339 ms)
│         "Turing.Essential.ReverseDiffAD{true}()" => Trial(1.796 ms)
└         "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(169.376 ms)

while on Turing@0.21.13

┌ Info: Turing.jl
│   run(suite) =2-element BenchmarkTools.BenchmarkGroup:
│      tags: []
│      "linked" => 3-element BenchmarkTools.BenchmarkGroup:
│         tags: []
│         "evaluation" => Trial(554.568 μs)
│         "Turing.Essential.ReverseDiffAD{true}()" => Trial(16.418 ms)
│         "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(140.508 ms)
│      "not_linked" => 3-element BenchmarkTools.BenchmarkGroup:
│         tags: []
│         "evaluation" => Trial(554.415 μs)
│         "Turing.Essential.ReverseDiffAD{true}()" => Trial(16.445 ms)
└         "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(139.849 ms)

Given that evaluation and ForwardDiff is faster in the latter case, it's clearly an "issue" with ReverseDiff, but at the same time this is such a significant perf hit that it makes me a bit uncomfortable to just "leave it in" there 😕

Thoughts? @devmotion

@devmotion
Copy link
Member

I don't have an immediate answer but generally the implementation in Distributions is much more specialized and hence more efficient than the previous BernoulliLogit in Turing (which fell back to BinomialLogit). So it seems there's a more ReverseDiff-specific issue here... Is there some type instability somewhere? Some type inference issue?

@torfjelde
Copy link
Member Author

Do you know of a good way to check this for ReverseDiff?

@torfjelde
Copy link
Member Author

It seems strange to me since AFAIK ForwardDiff is also used for broadcasting in ReverseDiff, no? So it's weird that ForwardDiff perf improves but ReverseDiff doesn't.

@devmotion
Copy link
Member

Maybe the branches in the new logpdf code kill performance with ReverseDiff?

@torfjelde
Copy link
Member Author

That's what I was thinking too, so I tried the folllowing impl to no avail:

function Distributions.logpdf(d::BernoulliLogit, x::Real)
    return (1 - x) * Distributions.logfailprob(d) + x * Distributions.logsuccprob(d)
end

perf is still bad.

@devmotion
Copy link
Member

No, I meant without two calls of log1pexp. Ie. something like

logpdf(d::BernoulliLogit, x::Bool) = -log1pexp(x ? -d.logitp : d.logitp)
function logpdf(d::BernoulliLogit, x::Real)
    logitp = d.logitp
    z = -log1pexpx(x == 0 ? logitp : -logitp)
    return insupport(d, x) ? z : oftype(z, -Inf)
end

@torfjelde
Copy link
Member Author

Unfortunately doesn't help 😕

@devmotion
Copy link
Member

😥 What happens if you implement the gradient of the logpdf function for ReverseDiff?

@devmotion
Copy link
Member

Or simpler: If you do not go through Distributions but define the logpdf directly as a separate function and broadcast that one?

@devmotion
Copy link
Member

I guess, for debugging it could also be useful to inspect the tape that ReverseDiff creates with the different implementations.

@torfjelde
Copy link
Member Author

Just commenting to let you know I've seen the comments and I'm planning on having a go at it at some point, but right now I have some more pressing TODOs so need to put this on the backlog for a bit 😕

@torfjelde
Copy link
Member Author

torfjelde commented Jan 12, 2023

I guess, for debugging it could also be useful to inspect the tape that ReverseDiff creates with the different implementations.

I wrote a small package to check this (https://github.com/torfjelde/ReverseDiffDebugUtils.jl) and AFAIT, they're the same 😕

Distributions@0.25.76 (which runs in ~1.4ms):
graph dot

Distributions@0.25.80(which runs in ~18ms):
graph2 dot

So seems like it has to be something in the reverse pass?

EDIT: Well, if they're the same or not is of course dependent on whether the broadcast instructions are actually broadcasting the same functions, which they of course aren't 🤦

EDIT 2: Added hacky capability of inferring the broadcasted functions, and they're indeed the same still.

@torfjelde
Copy link
Member Author

Think I found a clue: with Distributions@0.25.80 ReverseDiffAD{false} is just as fast as ReverseDiffAD{true}, i.e. compilation doesn't help for some reason!

@torfjelde
Copy link
Member Author

While on Distributions@0.25.76 there's a bit of a slow-down from ~1.4ms to ~1.9ms

@torfjelde
Copy link
Member Author

Probably is a type-instability somewhere then?

@torfjelde
Copy link
Member Author

Profiling it, it becomes clear that ReverseDiff.special_forward_exec!(inst) is the issue where inst is the ∇broadcast that is different between the two implementations.
The reverse pass (ReverseDiff.ReverseExecutor) takes up almost noting of the 17ms runtime.

@torfjelde
Copy link
Member Author

Changing the model to:

@model function irt(y, i, p; I = maximum(i), P = maximum(p))
    theta ~ filldist(Normal(), P)
    beta ~ filldist(Normal(), I)
    tmp = BernoulliLogit.(theta[p] - beta[i])
    Turing.@addlogprob! sum(logpdf.(tmp, y))

    return (; theta, beta)
end

to avoid nested broadcasting, we get

julia> @benchmark $(LogDensityProblems.logdensity_and_gradient)($∂ℓ, $θ)
BenchmarkTools.Trial: 1323 samples with 1 evaluation.
 Range (min  max):  3.588 ms    5.117 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     3.741 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.773 ms ± 166.545 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

when compiling, which is muuuuch better than the 17ms from before.

Now, without compilation:

julia> @benchmark $(LogDensityProblems.logdensity_and_gradient)($∂ℓ, $θ)
BenchmarkTools.Trial: 453 samples with 1 evaluation.
 Range (min  max):   8.556 ms  39.326 ms  ┊ GC (min  max):  0.00%  71.77%
 Time  (median):      9.103 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   11.044 ms ±  6.140 ms  ┊ GC (mean ± σ):  13.86% ± 16.86%

which is also better.

@torfjelde
Copy link
Member Author

torfjelde commented Jan 13, 2023

Finally, using

logpdf_bernoulli_logit(logitp, x) = x == 0 ? StatsFuns.logistic(-logitp) : StatsFuns.logistic(logitp)

logpdf_bernoulli_logit(logitp, x::Bool) = x ? StatsFuns.logistic(logitp) : StatsFuns.logistic(-logitp)


@model function irt(y, i, p; I = maximum(i), P = maximum(p))
    theta ~ filldist(Normal(), P)
    beta ~ filldist(Normal(), I)
    Turing.@addlogprob! sum(logpdf_bernoulli_logit.(theta[p] - beta[i], y))

    return (; theta, beta)
end

we get

julia> suite = TuringBenchmarking.make_turing_suite(
           model,
           adbackends = [TuringBenchmarking.ForwardDiffAD{40}(), TuringBenchmarking.ReverseDiffAD{true}()]
       );

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(256.576 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(457.752 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(126.160 ms)
  "not_linked" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(256.936 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(457.365 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(126.680 ms)

which is significantly better (it's also 3X the speed of stan).

It is really annoying that logpdf broadcasting is costing this much though 😕

@devmotion
Copy link
Member

devmotion commented Jan 13, 2023

The logpdf in your last comment is wrong though, isn't it? At least it doesn't match the one discussed above.

@torfjelde
Copy link
Member Author

Uhm yes 🤦 I mixed the logpdf and pdf impls. With the correct one we're only at:

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(549.488 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(3.246 ms)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(133.501 ms)
  "not_linked" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(550.383 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(3.834 ms)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(134.578 ms)

@torfjelde
Copy link
Member Author

torfjelde commented Jan 13, 2023

Replacing LogExpFunctions.log1pexp without all the conditional statements, i.e. just log1p(exp(x)), results in

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(492.951 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(2.993 ms)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(136.426 ms)
  "not_linked" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(496.187 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(4.443 ms)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(135.160 ms)

so a slightl improvement but not sufficient.

EDIT: Seems to have been a fluke; doesn't seem to actually matter.

@devmotion
Copy link
Member

#1934 (comment) doesn't matter either?

@torfjelde
Copy link
Member Author

Nah 😕

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(375.343 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(4.449 ms)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(125.817 ms)
          "Turing.Essential.ZygoteAD()" => Trial(33.744 ms)
  "not_linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(378.344 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(3.554 ms)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(120.604 ms)
          "Turing.Essential.ZygoteAD()" => Trial(34.406 ms)

@torfjelde
Copy link
Member Author

Okay, I think I've found the issue, or at least the explanation.

In the one with Distributions@0.25.76 the broadcast results in a simdloop, which makes it so that:

  1. Forward pass is much faster.
  2. Reverse pass is dirt cheap.

So I decided the only logical thing to do is to add more broadcasting in the hopes that this, for some reason, would trigger similar things after the LogitBernoulli change, i.e. I did this:

julia> logpdf_bernoulli_logit(logitp, x) = -log1pexp(x == 0 ? -logitp : logitp)
logpdf_bernoulli_logit (generic function with 2 methods)

julia> logpdf_bernoulli_logit(logitp, x::Bool) = -log1pexp(x ? logitp : -logitp)
logpdf_bernoulli_logit (generic function with 2 methods)

julia> # performant model
       @model function irt(y, i, p; I = maximum(i), P = maximum(p))
           theta ~ filldist(Normal(), P)
           beta ~ filldist(Normal(), I)
           Turing.@addlogprob! sum(logpdf_bernoulli_logit.(theta[p] .- beta[i], y))

           return (; theta, beta)
       end
irt (generic function with 2 methods)

julia> # Instantiate
       model = irt(y, i, p);

julia> suite = TuringBenchmarking.make_turing_suite(
           model,
           adbackends = [TuringBenchmarking.ForwardDiffAD{40}(), TuringBenchmarking.ReverseDiffAD{true}(), TuringBenchmarking.ZygoteAD()]
       );

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(380.079 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(840.097 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(100.826 ms)
          "Turing.Essential.ZygoteAD()" => Trial(33.692 ms)
  "not_linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(379.018 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(833.076 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(100.951 ms)
          "Turing.Essential.ZygoteAD()" => Trial(33.791 ms)

julia> # performant model
       @model function irt(y, i, p; I = maximum(i), P = maximum(p))
           theta ~ filldist(Normal(), P)
           beta ~ filldist(Normal(), I)
           Turing.@addlogprob! sum(logpdf_bernoulli_logit.(theta[p] - beta[i], y))  # dont' broadcast `-`

           return (; theta, beta)
       end
irt (generic function with 2 methods)

julia> # Instantiate
       model = irt(y, i, p);

julia> suite = TuringBenchmarking.make_turing_suite(
           model,
           adbackends = [TuringBenchmarking.ForwardDiffAD{40}(), TuringBenchmarking.ReverseDiffAD{true}(), TuringBenchmarking.ZygoteAD()]
       );

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(380.220 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(3.054 ms)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(124.792 ms)
          "Turing.Essential.ZygoteAD()" => Trial(33.673 ms)
  "not_linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(380.720 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(3.575 ms)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(124.956 ms)
          "Turing.Essential.ZygoteAD()" => Trial(33.465 ms)

i.e. not broadcasting over the Distribution but replace it with a simpler method and broadcasting over - results in an implementation even faster than the original one on Distributions@0.25.76.

Which is even more annoying! So just the right amount of broadcasting leads to great performance, but if you do too little or too much, you're screwed.

@devmotion
Copy link
Member

Yeah, broadcasting performance issues and gotchas are about the worst... Countless hours that went into these things in the SciML ecosystem as well.

@torfjelde
Copy link
Member Author

But are there any "guidelines" or just general advice on how to:

  1. Identify these.
  2. Fix them.

? 😕

@devmotion
Copy link
Member

The usual workflow I experienced was that someone notices performance issues and then one starts debugging and finally notices that's broadcasting related (e.g., there was (is?) also a limit after how many broadcasting operations performance completely degrades, I'll try to dig up the relevant issues).

place it with a simpler method

BTW that's what DistributionsAD does for many univariate distributions with flatten: https://github.com/TuringLang/DistributionsAD.jl/blob/master/src/flatten.jl It's used mainly/only in filldist but maybe it would be useful more generally. Even though in principle ideally it would not be needed.

@torfjelde
Copy link
Member Author

BTW that's what DistributionsAD does for many univariate distributions with flatten

Woah, I was completely unaware of this! And yeah this might be very helpful.

@torfjelde
Copy link
Member Author

That indeed does wonders:

julia> # Using `DistributionsAD.flatten` to address performance.
       using Distributions, DistributionsAD

julia> """
           get_logpdf_expr(Tdist)

       Return a flattened method for computing the logpdf of `Tdist`.
       """
       function get_logpdf_expr(Tdist)
           x = gensym()
           fnames = fieldnames(Tdist)
           func = Expr(:->, 
                       Expr(:tuple, fnames..., x), 
                       Expr(:block,
                            Expr(:call, :logpdf,
                                 Expr(:call, :($Tdist), fnames...),
                                 x,
                                 )
                            )
                       )
           return :(flatten(::Type{<:$Tdist}) = $func)
       end
get_logpdf_expr

julia> # 1. Use `flatten` to extract a, well, flattened `logpdf`.
       eval(get_logpdf_expr(BernoulliLogit))
flatten (generic function with 1 method)

julia> # 2. [OPTIONAL] Use `StructArrays.jl` to avoid the initial call to the constructor entirely.

       # 3. Define a "fast" logpdf method.
       @generated function fast_logpdf(
           dist::Product{V,D,<:StructVector{<:Any,<:NamedTuple{names}}},
           x::AbstractArray
       ) where {V,D<:UnivariateDistribution,names}
           # Get the flatten expression.
           f = flatten(D)

           args = [:(dist.v.$n) for n in names]
           return :(sum($f.($(args...), x)))
       end
fast_logpdf (generic function with 2 methods)

julia> # 4. Convenience method for constructing `StructArray` without 
       function DistributionsAD.arraydist(::Type{D}, args...) where {D<:Distribution}
           return DistributionsAD.arraydist(D, NamedTuple{fieldnames(D)}(args))
       end

julia> DistributionsAD.arraydist(::Type{D}; args...) where {D<:Distribution} = DistributionsAD.arraydists(D, NamedTuple(args))

julia> function DistributionsAD.arraydist(::Type{D}, args::NamedTuple) where {D<:Distribution}
           # TODO: Use `purename`?
           return DistributionsAD.arraydist(StructArray{D}(args))
       end

julia> # 5. Type-piracy so we can make use of `~`.
       function Distributions.logpdf(dist::Product{<:Any,<:UnivariateDistribution,<:StructVector}, x::AbstractVector{<:Real})
           return fast_logpdf(dist, x)
       end

julia> @model function irt_vroom(y, i, p; I = maximum(i), P = maximum(p))
           theta ~ filldist(Normal(), P)
           beta ~ filldist(Normal(), I)
           y ~ arraydist(BernoulliLogit, theta[p] - beta[i])

           return (; theta, beta)
       end
irt_vroom (generic function with 2 methods)

julia> model = irt_vroom(y, i, p);

julia> suite = TuringBenchmarking.make_turing_suite(
           model,
           adbackends = [TuringBenchmarking.ForwardDiffAD{40}(), TuringBenchmarking.ReverseDiffAD{true}()]
       );

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(389.573 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(747.912 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(127.035 ms)
  "not_linked" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(391.116 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(745.925 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(126.951 ms)

@torfjelde
Copy link
Member Author

Note that the usage of StructArray to not have to go through the initial constructor is necessary (unless there's another way of avoiding the constructor), otherwise we degrade back to ~3ms as before.

@torfjelde
Copy link
Member Author

This also improves Zygote-perf 20-fold.

@devmotion
Copy link
Member

But why is the constructor so slow? It's the most simple struct one can come up with: https://github.com/JuliaStats/Distributions.jl/blob/d21c5a3d2386910b586cd9da188721f313073570/src/univariate/discrete/bernoullilogit.jl#L19-L21 Is it just that ReverseDiff is inherently bad with handling anything else than scalars or arrays?

@torfjelde
Copy link
Member Author

Is it just that ReverseDiff is inherently bad with handling anything else than scalars or arrays?

Well, I have no idea but from this it does seem that if you have a constructor in a broadcasting statement, then you want to hide this from ReverseDiff.jl. That is, make it broadcast (param, x) -> logpdf(dist(param), x) instead of logpdf.(dist.(param), x). I'm assuming these two are recorded differently on the tape, e.g. the former record a broadcast of the closure while the latter records a fused broadcast of all the methods involved?

@torfjelde
Copy link
Member Author

Or rather than fully blaming ReverseDiff, maybe it's more the type-inference failing in these two scenarios once you involve TrackedArray and/or TrackedReal?

@torfjelde
Copy link
Member Author

Because, as I said before, the two recorded tapes are the same with the exception of the one broadcast instruction.

I tried using both Cthulhu and JET and couldn't properly identify this failure of type-inference though (but maybe I should check again now that I'm a bit more knowledgeable about ReverseDiff's internals).

@torfjelde
Copy link
Member Author

Nah, still no luck with Cthulhu and JET

@torfjelde
Copy link
Member Author

So I used Infiltrator.jl to break at the end of ∇broadcast and I observe the following:

infil> @code_warntype broadcast(df, ReverseDiff.value.(targs)...)
MethodInstance for broadcast(::Df{Base.Broadcast.var"#10#12"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#9#11"}, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}, UnionAll}, typeof(logpdf)}, DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}, ::Vector{Float64}, ::Vector{Int64})
  from broadcast(f::Tf, As...) where Tf in Base.Broadcast at broadcast.jl:798
Static Parameters
  Tf = Df{Base.Broadcast.var"#10#12"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#9#11"}, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}, UnionAll}, typeof(logpdf)}, DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}
Arguments
  #self#::Core.Const(broadcast)
  f::Df{Base.Broadcast.var"#10#12"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#9#11"}, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}, UnionAll}, typeof(logpdf)}, DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}
  As::Tuple{Vector{Float64}, Vector{Int64}}
Body::Union{Vector, BitVector}
1%1 = Core.tuple(f)::Tuple{Df{Base.Broadcast.var"#10#12"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#9#11"}, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}, UnionAll}, typeof(logpdf)}, DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}}
│   %2 = Core._apply_iterate(Base.iterate, Base.Broadcast.broadcasted, %1, As)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Df{Base.Broadcast.var"#10#12"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#9#11"}, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}, UnionAll}, typeof(logpdf)}, DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}, Tuple{Vector{Float64}, Vector{Int64}}}
│   %3 = Base.Broadcast.materialize(%2)::Union{Vector, BitVector}
└──      return %3

for the "slow" version, and for the "fast" version:

infil> @code_warntype broadcast(df, ReverseDiff.value.(targs)...)
MethodInstance for broadcast(::Df{var"#53#54", DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}, ::Vector{Float64}, ::Vector{Int64})
  from broadcast(f::Tf, As...) where Tf in Base.Broadcast at broadcast.jl:798
Static Parameters
  Tf = Df{var"#53#54", DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}
Arguments
  #self#::Core.Const(broadcast)
  f::Df{var"#53#54", DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}
  As::Tuple{Vector{Float64}, Vector{Int64}}
Body::Vector{DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}}
1%1 = Core.tuple(f)::Tuple{Df{var"#53#54", DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}}
│   %2 = Core._apply_iterate(Base.iterate, Base.Broadcast.broadcasted, %1, As)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Df{var"#53#54", DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}, Tuple{Vector{Float64}, Vector{Int64}}}
│   %3 = Base.Broadcast.materialize(%2)::Vector{DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}}
└──      return %3

i.e. type-instability when broadcasting df for "slow" version and type-stability for df for "fast" version.

Tried removing the closure, i.e. replacing df with a wrapper-struct, but it doesn't help.

Seems like the function f itself is type-unstable?

infil> @descend ReverseDiff.splatcall(f, ReverseDiff.SVector(ReverseDiff.value.(map(first, targs))), untracked, inds)
splatcall(f, x::StaticArraysCore.SVector{N}, utargs::T, ::Val{tinds}) where {N, T<:Tuple, tinds} in ReverseDiff at /home/tor/.julia/packages/ReverseDiff/YkVxM/src/derivatives/broadcast.jl:111
    %0 = invoke splatcall(::#10#12{…},::SArray{…},::Tuple,::Val{…})::Any
111 1%1 = StaticArrays.getfield(x, :data)::Tuple{Float64, Float64}│╻╷  macro expansion
    │   %2 = Base.getfield(%1, 1, true)::Float64                     ││┃│  getindex
    │   %3 = StaticArrays.getfield(x, :data)::Tuple{Float64, Float64}││╻   getindex
    │   %4 = Base.getfield(%3, 2, true)::Float64                     │││╻   getindex
    │   %5 = Core.getfield(f, :makeargs)::Base.Broadcast.var"#16#18"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#9#11"}, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}, UnionAll}
    │   %6 = Core.getfield(%5, :f)::UnionAll                         │││╻   #16%7 = (%6)(%2)::Any                                           ││││
    │   %8 = (Distributions.logpdf)(%7, %4)::Any                     │││ 
    └──      return %8                                               ││  

Seems like the fact that BernoulliLogit is a UnionAll causes issues. It's worth pointing out that in the "old" version BernoulliLogit is a function rather than a UnionAll.

@torfjelde
Copy link
Member Author

Even "worse", just hiding BernoulliLogit behind a constructor resolves the type-inference issue (for ReverseDiff.jl; Zygote.jl still struggles):

julia> BernoulliLogitF(x) = BernoulliLogit(x)
BernoulliLogitF (generic function with 1 method)

julia> @model function irt(y, i, p; I = maximum(i), P = maximum(p))
           theta ~ filldist(Normal(), P)
           beta ~ filldist(Normal(), I)
           Turing.@addlogprob! sum(logpdf.(BernoulliLogitF.(theta[p] - beta[i]), y))

           return (; theta, beta)
       end
irt (generic function with 2 methods)

julia> model = irt(y, i, p);

julia> suite = TuringBenchmarking.make_turing_suite(
           model,
           adbackends = [TuringBenchmarking.ForwardDiffAD{40}(), TuringBenchmarking.ReverseDiffAD{true}(), TuringBenchmarking.ZygoteAD()]
       );

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(379.554 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(746.761 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(122.954 ms)
          "Turing.Essential.ZygoteAD()" => Trial(78.837 ms)
  "not_linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(379.219 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(749.739 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(126.728 ms)
          "Turing.Essential.ZygoteAD()" => Trial(79.091 ms)

Zygote.jl still benefits from the arraydist(BernoulliLogit, ...) approach though:

julia> suite = TuringBenchmarking.make_turing_suite(
           model_vroom,
           adbackends = [TuringBenchmarking.ForwardDiffAD{40}(), TuringBenchmarking.ReverseDiffAD{true}(), TuringBenchmarking.ZygoteAD()]
       );

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(388.298 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(748.069 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(122.363 ms)
          "Turing.Essential.ZygoteAD()" => Trial(1.495 ms)
  "not_linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(388.359 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(748.285 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(121.801 ms)
          "Turing.Essential.ZygoteAD()" => Trial(1.493 ms)

@devmotion
Copy link
Member

A problem with the generated functions (https://github.com/JuliaDiff/ReverseDiff.jl/blob/d522508aa6fea16e9716607cdd27d63453bb61e6/src/derivatives/broadcast.jl#L111)? A missing type parameter, leading to non-specialization of Function or Type somewhere?

@torfjelde
Copy link
Member Author

Tried adding type-parameter and it doesn't do anything 😕

@torfjelde
Copy link
Member Author

But can it specialize when BernolliLogit is a UnionAll? If we had something like BernoulliLogit{Float64} then it would probably be okay

@torfjelde
Copy link
Member Author

Just to make things even "clearer":

infil> @code_warntype f.makeargs.f
MethodInstance for getproperty(::Base.Broadcast.var"#16#18"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#9#11"}, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}, UnionAll}, ::Symbol)
  from getproperty(x, f::Symbol) in Base at Base.jl:38
Arguments
  #self#::Core.Const(getproperty)
  x::Base.Broadcast.var"#16#18"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#9#11"}, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}, UnionAll}
  f::Symbol
Body::Any
1nothing%2 = Base.getfield(x, f)::Any
└──      return %2

@torfjelde
Copy link
Member Author

torfjelde commented Jan 13, 2023

Brian Chen pointed out in Slack that the reason why Zygote fails this check: https://github.com/FluxML/Zygote.jl/blob/c2f1794ca9da3088a2f3bfb0144c8bfc4dd89d9a/src/lib/broadcast.jl#L198-L199

which causes the slow down.

In particular, the _dual_purefun fails when BernoulliLogit is used, and so we don't end up hitting broadcast_forward.

@devmotion
Copy link
Member

Regarding Zygote, I wonder if it's problematic for the compiler that the pullback is closed over the variable len of type Val in https://github.com/FluxML/Zygote.jl/blob/c2f1794ca9da3088a2f3bfb0144c8bfc4dd89d9a/src/lib/broadcast.jl#L206. One could check if making the pullback a callable struct with type parameter inclen(args) improves anything.

@torfjelde
Copy link
Member Author

I think the fast Zygote version just hits broadcast_forward though. And it was pointed out that BernoulliLogitF doesn't help Zygote because the combine_eltypes will not be a subtype of Union{Real,Complex} and so the check https://github.com/FluxML/Zygote.jl/blob/c2f1794ca9da3088a2f3bfb0144c8bfc4dd89d9a/src/lib/broadcast.jl#L198 is false.

@ToucheSir
Copy link

Coincidentally, FluxML/Zygote.jl#1359 was filed today about type instability in ∇broadcasted. The culprit is not len, but the compiler failing to const prop i into an ntuple callback. As Tor said though, ideally you'd want to avoid hitting this fallback path in the first place

@devmotion
Copy link
Member

Sure, the other path would be ideal - but if the fallback path could be improved, it might make it less of an issue.

@ToucheSir
Copy link

Hence FluxML/Zygote.jl#1360. I do think the fundamental issue is that we don't have a forward mode AD that understands complex structures like Distributions. Thus we're left with a half-dozen flattening/unflattening libraries and multiple reverse-mode ADs with the same performance cliffs in broadcasting.

@tansongchen
Copy link

tansongchen commented Jan 14, 2023

forward mode AD that understands complex structures like Distributions

@ToucheSir I think the idea that forward-mode AD should support generic "inner" type is exciting. Could you elaborate on this point?

ForwardDiff requires the inner type to <: Real. ForwardDiff2 extends this by allowing the inner type to be arrays. TaylorDiff currently require Numbers but will probably implement arrays in the future. So there are two questions:

  1. For generic inner type, I don't have a clear understanding of how to verify that it is "dualable" or "taylorable".
  2. Even if it is, since Julia concrete types don't subtype each other, we need to either
    • play with Cassette.jl, extend Julia call semantics and therefore "simulates" operator-overloading AD
    • go to source-code-transform AD

They are tricky but worth exploring, since in principle forward-mode AD should be as generic as reverse-mode AD. I'm pretty interested in exploring generic forward-mode AD in TaylorDiff.

@ToucheSir
Copy link

It's nothing too exotic, just that one should be able to differentiate through code like Normal.(means, stds) efficiently. I see there was some attempt at making this work for ForwardDiff at JuliaDiff/ForwardDiff.jl#307, but that PR hasn't been touched for almost 5 years :(

  1. For generic inner type, I don't have a clear understanding of how to verify that it is "dualable" or "taylorable".

It's a good question and I don't have an answer. The only Julia AD I know of which can (potentially) do this right now is Enzyme, and that's pure SCT. Perhaps there's something you could do with type-level programming and internals like return_type to determine whether a particular struct is dualable/taylorable.

@tansongchen
Copy link

Thanks for providing this PR and suggestions. It seems that handling generic inner type for forward mode AD (and similarly for reverse mode) more or less involves some SCT (at least some tweaks with Cassette). I will probably first do something with arrays before getting more general...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants