Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize sample_topp by filtering out small value elements up front #276

Merged
merged 1 commit into from
Aug 14, 2023

Conversation

jrudolph
Copy link
Contributor

@jrudolph jrudolph commented Aug 12, 2023

Refs #246

This works because we know that in worst case only 1 element will be selected and therefore the remaining (n-1) elements have to split the remaining (1-topp) probability. Probabilities smaller than that cannot be selected and can be filtered out up front.

E.g. for p = 0.9 that means that usually only 100-1000 tokens remain, speeding up the remaining process considerably.

(In llama2.scala, I further improved on that by avoiding the sort in most cases, based on the observation that the distribution looks like power-law and only very few elements will be selected ultimately, so that iteratively scanning over the array to find the next best element (kind of selection sort) keeping track of cumulative p seems to be a slightly better solution yet).

This works because we know that in worst case only 1 element will be selected
and therefore the remaining (n-1) elements have to split the remaining (1-topp)
probability. Probabilities smaller than that cannot be selected and can
be filtered out up front.
@cgbur
Copy link
Contributor

cgbur commented Aug 12, 2023

238 -> 509 tokens/s when sampling with temperature 1 and top-p 0.9. Really nice work! Did you notice much improvement with the linear scanning? I have not tried that yet in my port. I found the average number of logits per step is pretty drastically reduced already.

image

@rdentato rdentato mentioned this pull request Aug 12, 2023
@jrudolph
Copy link
Contributor Author

Ah, didn't see #270 and #274 came first with the same idea :)

For fun, I ran a code generation model on llama2.scala for a while and asked it for suggestions: https://gist.github.com/jrudolph/fb7641ba2406de705c5499280783b55c

The suggested algorithms are of sometimes comically bad quality but some ideas seem interesting:

  • It suggested QuickSelect (which I think somewhat degenerates for the long-tail distributions we have here, but you could choose clever pivots, basically the filtering in the PR here is just the first step of a QuickSelect)
  • Gathering a histogram in a first run to quickly figure out an upper bound of elements to keep

@jrudolph
Copy link
Contributor Author

Did you notice much improvement with the linear scanning

I think after the filtering it doesn't matter much any more, after all the speed improvements are only needed anyway for small models (since sampling speed only depends on vocabulary size regardless of model size).

@jrudolph
Copy link
Contributor Author

I think after the filtering it doesn't matter much any more, after all the speed improvements are only needed anyway for small models (since sampling speed only depends on vocabulary size regardless of model size).

Ok, in Scala, the effect of scanning is improving just the top-p selection process by another 10x (but that's also because the naive idiomatic sorting involves a high abstraction overhead due to boxing).

[info] TopPBenchmark.topP  filterAndScan  thrpt    5  28793.391 ± 5723.854  ops/s
[info] TopPBenchmark.topP  filterAndSort  thrpt    5   2531.715 ±   50.438  ops/s
[info] TopPBenchmark.topP        sorting  thrpt    5    132.780 ±    5.378  ops/s

@karpathy karpathy merged commit 4a2c375 into karpathy:master Aug 14, 2023
6 checks passed
@karpathy
Copy link
Owner

Thank you for a nice PR!

@jrudolph
Copy link
Contributor Author

Here's a small report on my experiments of trying out different top-p algorithms: https://blog.virtual-void.net/2023/08/29/calculating-top-p/

vinhtran2611 pushed a commit to vinhtran2611/llama2.c that referenced this pull request Jan 20, 2024
optimize sample_topp by filtering out small value elements up front
@Majdoddin
Copy link
Contributor

please also consider #313, constant cut-off.
The code can safely fallback to other algorithms if constant cut-off doesn't work, but practically that doesn't happen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants