Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse-mode AD extremely slow for large number of observations #1642

Closed
anhi opened this issue Jun 16, 2021 · 14 comments
Closed

Reverse-mode AD extremely slow for large number of observations #1642

anhi opened this issue Jun 16, 2021 · 14 comments

Comments

@anhi
Copy link

anhi commented Jun 16, 2021

When trying to optimize our Turing code, we experimented with the different AD engines. It seems as if the reverse-mode AD engines are extremely slow for large numbers of observations. Our original model has several hundred dimensions, but the effect can be demonstrated on this simple example:

setadbackend(:forwarddiff)

@model benchmark_model(x) = begin
    μ ~ TruncatedNormal(1, 2, 0.1, 10)
    σ ~ TruncatedNormal(1, 2, 0.1, 10)
    
    x .~ LogNormal(μ, σ)   
end

samples = rand(LogNormal(1.5, 0.5), 100000);

@time chains = Turing.sample(benchmark_model(samples), NUTS(0.65), 2000)

On my machine, this takes about 42 seconds: 41.868502 seconds (21.57 M allocations: 1.340 GiB, 1.49% gc time, 18.19% compilation time)

I understand that for such a simple model, forward-mode should be more efficient. But when switching to reverse-mode

setadbackend(:zygote)

@time chains = Turing.sample(benchmark_model(samples), NUTS(0.65), 2000)

it takes several hours on my machine to even arrive at an ETA, which starts out at several days. This seems a little excessive.

Are we doing anything wrong, or is reverse-mode just not useable for large numbers of observations?

@torfjelde
Copy link
Member

Try filldist instead of .~. That should spead things up significantly.

Essentially, filldist means that the entire vector x is treated as one multivariate random variable while .~ makes it so that you get length(x) univariate random variables. The former is going to be signficantly faster.

@yebai
Copy link
Member

yebai commented Jun 16, 2021

@torfjelde @mohamed82008 A side note, any reason that we shouldn't translate dot observe and dot assume into filldist automatically in DynamicPPL?

@anhi
Copy link
Author

anhi commented Jun 16, 2021

@torfjelde thank you for the hint. I'm still a little confused: did you mean

x = filldist(LogNormal(μ, σ), 100000)

which is blazingly fast but yields wrong values for \mu and \sigma, or

x ~ filldist(LogNormal(μ, σ), 100000)

which seems to have a very similar runtime to the original (still waiting for the ETA)?

Or is there another way to use filldist? I was a little confused by the documentation here.

@mohamed82008
Copy link
Member

@yebai dot broadcasting should be fast on observations, even GPU compatible most of the time. That is unless something changed recently.

@mohamed82008
Copy link
Member

@anhi try Zygote and ReverseDiff. You might have better luck with Zygote here because ReverseDiff's performance is a bit brittle depending on whether we use an array of tracked reals (slow) or a tracked array (fast).

@torfjelde
Copy link
Member

torfjelde commented Jun 16, 2021

I meant the latter, i.e. the one with ~. I'm surprised it's not faster though 😕

Also, as @mohamed82008 pointed out this is for observations so my argument above doesn't actually hold.

try Zygote and ReverseDiff

It seems like he's using Zygote already?

@mohamed82008
Copy link
Member

It seems like he's using Zygote already?

Yes my bad, didn't see this. So try ReverseDiff then :)

@torfjelde
Copy link
Member

One thing I've noticed in the past: if the function being maped or broadcasted contains if-statements (which LogNormal does), you get a pretty significant slowdown when using Zygote (broadcasting also often leads to type-instability in this case).

@anhi
Copy link
Author

anhi commented Jun 16, 2021

It seems like he's using Zygote already?

Yes my bad, didn't see this. So try ReverseDiff then :)

:) ok, I'll try... ETA for zygote was ~22 hours, btw, while forwarddiff with filldist took 61 seconds (a little longer than without filldist, which was ~42 seconds)

@torfjelde
Copy link
Member

torfjelde commented Jun 16, 2021

Btw, one thing you can do if you want to go really fast, is to use @addlogprob! and just compute the logpdf of LogNormal "by hand". Then you can also remove stuff like the if statements to check if it's inside the domain or nor. That is, you can replace it with

logx = log.(x)
zval = @. (logx - μ) / σ # `StatsFuns.normlogpdf(μ, σ, x)` has an if-statement in it, so we circumvent this by computing the `zval` ourselves.
@addlogprob! sum(StatsFuns.normlogpdf.(zval)) - sum(logx)

which should be the same (maybe check this though), assuming you've done import StatsFuns somewhere.

This should be muuuch faster using Zygote.

@anhi
Copy link
Author

anhi commented Jun 16, 2021

@torfjelde

ok, this is getting close...



@model benchmark_model_2(x) = begin
    μ ~ Normal(0.1, 10)
    σ ~ Normal(0.1, 10)   

    logx = log.(x)
    zval = @. (logx - μ) / σ
    @Turing.addlogprob! sum(StatsFuns.normlogpdf.(zval)) - sum(logx)
end

@time chains = Turing.sample(benchmark_model_2(samples), NUTS(0.65), 2000)

I've also changed the TruncatedNormals to Normals because I was not sure if they contain ifs as well...

308.334925 seconds (8.34 G allocations: 381.370 GiB, 19.94% gc time, 11.96% compilation time)

so this is indeed much faster than the standard LogNormal implementation.

I'm running the same experiment with the TruncatedNormals again, and it seems similarly fast.

ReverseDiff seems to have similar problems as Zygote, and similar timings. But I'll try that again later.

@anhi
Copy link
Author

anhi commented Jun 16, 2021

Using TruncatedNormals, the sampling took 268 seconds, which indeed looks much better than the several days I started out with. However, I just noticed that the sampling returns wrong values for \sigma. The data was generated with \mu = 1.5 and \sigma = 0.5, but sampling with non-truncated normals returned a \sigma of 41, the TruncatedNormal version a \sigma of 10. \mu was close to 1.5 in both cases...

@torfjelde
Copy link
Member

I've also changed the TruncatedNormals to Normals because I was not sure if they contain ifs as well...

Probably unnecessary. This is specifically a problem when you're doing map or broadcasting, i.e. f.(x), over something, not so much an if-statement somewhere in the code:)

However, I just noticed that the sampling returns wrong values for \sigma. The data was generated with \mu = 1.5 and \sigma = 0.5, but sampling with non-truncated normals returned a \sigma of 41, the TruncatedNormal version a \sigma of 10. \mu was close to 1.5 in both cases...

Yeah sorry, this is because I made a mistake in the above (told yah it needed some checking 😅). It should be

sum(StatsFuns.normlogpdf.(zval)) - sum(logx) - sum(log.(σ))

I forgot the log-abs-det-jacobian term from computing the zval 👍

@yebai
Copy link
Member

yebai commented Jan 27, 2023

Closed in favour of #1934

@yebai yebai closed this as completed Jan 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants