Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarking expected_loglik #82

Open
theogf opened this issue Dec 1, 2021 · 11 comments
Open

Benchmarking expected_loglik #82

theogf opened this issue Dec 1, 2021 · 11 comments

Comments

@theogf
Copy link
Member

theogf commented Dec 1, 2021

So with 1.7 (maybe 1.6 too?), there is an issue with Zygote gradients and the expec_loglik function because of it hits a BLAS function.
I tried to rewrite it to make fewer allocations and here are the results:

using BenchmarkTools, Distributions, ApproximateGPs, IrrationalConstants, FastGaussQuadrature
# The new proposed version
function expected_loglik(
    gh::GaussHermite, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik
)
    xs, ws = gausshermite(gh.n_points)
    return mapreduce(+, q_f, y) do q, y
        μ = mean(q)
        σ = std(q)
        mapreduce(+, xs, ws) do x, w
            f = sqrt2 * σ * x + μ
            loglikelihood(lik(f), y) * w
        end
    end / sqrtπ
end
# The previous version
function expected_loglik_old(
    gh::GaussHermite, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik
)
    xs, ws = gausshermite(gh.n_points)
    fs = sqrt2 * std.(q_f) .* xs' .+ mean.(q_f)
    lls = loglikelihood.(lik.(fs), y)
    return sum(lls * ws) / π
end
function evaluate_speed(N)
  gh = GaussHermite(100)
  lik = BernoulliLikelihood()
  y = rand(0:1, N)
  q_f = Normal.(randn(N), rand(N))
  @btime expected_loglik($gh, $y, $q_f, $lik)
  @btime expected_loglik_old($gh, $y, $q_f, $lik)
end
for N in [10, 100, 500, 1000]
  @info N
  evaluate_speed(N)
end
[ Info: 10
  164.407 μs (192 allocations: 44.45 KiB)
  139.752 μs (87 allocations: 48.83 KiB)
[ Info: 100
  396.769 μs (1002 allocations: 146.42 KiB)
  312.959 μs (89 allocations: 191.52 KiB)
[ Info: 500
  1.433 ms (4602 allocations: 599.61 KiB)
  1.027 ms (89 allocations: 826.08 KiB)
[ Info: 1000
  2.736 ms (9102 allocations: 1.14 MiB)
  1.972 ms (89 allocations: 1.58 MiB)

So the old approach is faster but make bigger allocations, I actually don't know where all this allocations come from for the first approach, any clue?

@st-- st-- changed the title Benchmarking expec_loglik Benchmarking expected_loglik Dec 1, 2021
@st--
Copy link
Member

st-- commented Dec 1, 2021

Bigger allocations, but less of them (and only a constant number). Few large allocations is more efficient than lots of small ones, I suppose!

@st--
Copy link
Member

st-- commented Dec 1, 2021

It might be helpful to see how it scales with number of Gauss-Hermite points, 10 vs 100 vs 1000...

@devmotion
Copy link
Member

mapreduce with multiple arrays just computes the resulting array up-front: https://github.com/JuliaLang/julia/blob/d16f4806e9389dbc92c463efc5b96f67a7aebf22/base/reducedim.jl#L324-L325 (added in JuliaLang/julia#31532)

I think a better approach (that seems to break AD with basically all backends though according to my experiments in Distributions) is

sum(Broadcast.instantiate(Broadcast.broadcasted(op, args...)))

which also uses pairwise summation and is fast in recent Julia versions (JuliaLang/julia#31020).

@devmotion
Copy link
Member

So with 1.7 (maybe 1.6 too?), there is an issue with Zygote gradients and the expec_loglik function

I'm not familiar with this package but this sounds like it could (or maybe should) be fixed in Zygote or with a ChainRule definition?

@theogf
Copy link
Member Author

theogf commented Dec 1, 2021

The really weird thing is that it seems to happen on a Matrix/Vector product...

@devmotion
Copy link
Member

(Completely unrelated: You could use invsqrtπ instead of / √π or / sqrtπ)

@devmotion
Copy link
Member

SciML noticed the same Zygote issue it seems: SciML/SciMLSensitivity.jl#511 (comment)

@theogf
Copy link
Member Author

theogf commented Dec 2, 2021

Follow up:
I tried @devmotion proposition with broadcasted:

function expected_loglik_david(
    gh::GaussHermite, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik
)
    xs, ws = gausshermite(gh.n_points)
    return sum(Broadcast.instantiate(
            Broadcast.broadcasted(q_f, y) do q, y
                μ = mean(q)
                σ = std(q)
                sum(Broadcast.instantiate(
                    Broadcast.broadcasted(xs, ws) do x, w
                        f = sqrt2 * σ * x + μ
                        loglikelihood(lik(f), y) * w
                    end
                ))
            end
        )) * invsqrtπ
end

And it definitely improves a lot! Yet the linear algebra is always the fastest

# David solution is the middle one
[ Info: 10
  165.688 μs (188 allocations: 44.08 KiB)
  158.399 μs (88 allocations: 32.47 KiB)
  154.454 μs (87 allocations: 48.55 KiB)
[ Info: 100
  411.624 μs (998 allocations: 146.06 KiB)
  351.009 μs (88 allocations: 32.47 KiB)
  307.118 μs (89 allocations: 191.22 KiB)
[ Info: 500
  1.505 ms (4598 allocations: 599.25 KiB)
  1.206 ms (88 allocations: 32.47 KiB)
  1.234 ms (89 allocations: 825.78 KiB)
[ Info: 1000
  2.882 ms (9098 allocations: 1.14 MiB)
  2.277 ms (88 allocations: 32.47 KiB)
  2.132 ms (89 allocations: 1.58 MiB)

@theogf
Copy link
Member Author

theogf commented Dec 2, 2021

Now regarding AD, with Julia 1.7, Zygote 0.6.32, everything seems to pass

## Checking differentiation
using Zygote
N = 100
gh = GaussHermite(100)
lik = BernoulliLikelihood()
y = rand(0:1, N)
μ = randn(N)
σ = rand(N)
for f in [expected_loglik, expected_loglik_david, expected_loglik_old]
    g = only(Zygote.gradient(x->f(gh, y, Normal.(x, σ), lik), μ))
end

@devmotion
Copy link
Member

Yet the linear algebra is always the fastest

It seems broadcasted is the fastest for N = 500 and best allocation-wise in all examples?

@theogf
Copy link
Member Author

theogf commented Jan 11, 2022

I'll make a PR to use broadcasted instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants