Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2) Optimize byte pair merge for small and big character sequences - 8.2s to 3.9s #76

Merged
merged 14 commits into from
Jan 2, 2024

Conversation

l0rinc
Copy link
Contributor

@l0rinc l0rinc commented Dec 22, 2023

Continuing #75 - note that the first few commits are repeated here, will be eliminated by rebase once the other one's merged.

The original byte pair merge algorithm diverges quickly for longer character sequences in a superlinear way - e.g. a 20000 character word (e.g. 2500 tokens) can take several seconds to tokenize.

For bigger character sequences we're switching to a linear(ithmic) algorithm at around 500 characters (below which the current one is faster):

The change also includes an optimization for just token counting - when the tokens themselves aren't important.

Before (i.e. assuming #75 was merged):

Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBase                        data    ss   10  8.273 ± 0.116   s/op
SingleThreadedBenchmark.benchmarkP50kBase                          data    ss   10  9.762 ± 0.101   s/op
SingleThreadedBenchmark.benchmarkP50kEdit                          data    ss   10  9.790 ± 0.056   s/op
SingleThreadedBenchmark.benchmarkR50kBase                          data    ss   10  8.750 ± 0.065   s/op

After:

Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBase                        data    ss   10  4.547 ± 0.056   s/op
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount              data    ss   10  3.944 ± 0.031   s/op
SingleThreadedBenchmark.benchmarkP50kBase                          data    ss   10  5.427 ± 0.065   s/op
SingleThreadedBenchmark.benchmarkP50kEdit                          data    ss   10  5.375 ± 0.062   s/op
SingleThreadedBenchmark.benchmarkR50kBase                          data    ss   10  5.073 ± 0.063   s/op

Please review commit-by-commit for the changes to make sense:
image

@l0rinc l0rinc requested a review from tox-p as a code owner December 22, 2023 15:53
@l0rinc l0rinc changed the title Optimize byte pair merge for small and big character lists Optimize byte pair merge for small and big character sequences Dec 22, 2023
private final Pattern pattern;
private final TokenEncoder encoder;
private final SpecialEncoder specialEncoder;
private final Map<Integer, byte[]> encodedToDecoded;
Copy link
Contributor Author

@l0rinc l0rinc Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TokenEncoder is only used for encoding now - so we can eliminate the types

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to keep (this explicitly typed) encodedToDecoded map inside the modified TokenEncoder to be consistent with the SpecialEncoder structure

As far as I can tell, the only performance implication is that we lose the benefit of caching decoding in a single map and therefore having to make 2 map lookups for special tokens which is negligible (especially since encoding special characters is unsupported anyways)

Copy link
Contributor Author

@l0rinc l0rinc Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my original over-optimized version there were multiple TokenEncoders - when the byte count was < Long.BYTES, we stored them in a primitive long and used a primitive map with a long key instead - that enabled squeezing the last few drops, since short tokens are the most common ones and they're a lot faster since no byte arrays are present anymore. But I haven't committed that yet, so probably I should merge it back with TokenEncoder for now and maybe split it out again if there's a PR #79

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}

public static int getMinRankIndex(List<Integer> ranks) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When searching for the next minimum value, we've unrolled it to favor SIMD optimizations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity: Have you benchmarked this unrolling seperately? I was under the impression that this kind of optimization (loop unrolling) is best left to the compiler

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is still left for the JIT compiler, since I'm not using the vector api, just making it simpler for it to group similar instructions. If you could run the benchmarks yourself, you can tell me yourself if it reproduces or not.
And yes, I have benchmarked everything separately (haven't committed every single benchmark though), this was the fastest.

Lőrinc added 10 commits December 29, 2023 21:23
* When searching for the next minimum value, we've unrolled it to favor SIMD optimizations.
* PieceIndexToRank is removed, we're only storing the rank, since we're not deleting next values anymore (which avoids copying every subsequent value)
* Since we've replaced minimums with sentinels, previous and next indexes are replaced by iteration
* The encoders map is split by input byte array size so that we're only querying small maps
* Iteration stops before the last minimum is MAX_RANK by keeping track of merge results - resulting in one less minimum search at the end

Before:
Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBase                        data    ss   10  8.947 ± 0.109   s/op
SingleThreadedBenchmark.benchmarkP50kBase                          data    ss   10  9.419 ± 0.082   s/op
SingleThreadedBenchmark.benchmarkP50kEdit                          data    ss   10  9.365 ± 0.073   s/op
SingleThreadedBenchmark.benchmarkR50kBase                          data    ss   10  8.403 ± 0.080   s/op

After:
Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBase                        data    ss   10  7.313 ± 0.031   s/op
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount              data    ss   10  7.242 ± 0.027   s/op
SingleThreadedBenchmark.benchmarkP50kBase                          data    ss   10  7.742 ± 0.054   s/op
SingleThreadedBenchmark.benchmarkP50kEdit                          data    ss   10  7.748 ± 0.121   s/op
SingleThreadedBenchmark.benchmarkR50kBase                          data    ss   10  7.017 ± 0.110   s/op
Before:
Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount              data    ss   10  7.242 ± 0.027   s/op

Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount              data    ss   10  6.885 ± 0.049   s/op
We're storing the ranks in a red-black tree of trees.
Getting the minimum rank is basically constant time (grouping by the rank itself since we can have multiple, popping the first (representing the first occurrence)).
Here we're removing the node after merge (also basically constant time operation).
We're also counting the remaining valid ranks for stopping condition.
To know the previous and next values here, we're storing all of it in a RankNode that we're updating after finding the minimum via the tree.

Before:
Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBase                        data    ss   10  7.372 ± 0.063   s/op
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount              data    ss   10  6.885 ± 0.049   s/op
SingleThreadedBenchmark.benchmarkP50kBase                          data    ss   10  7.846 ± 0.051   s/op
SingleThreadedBenchmark.benchmarkP50kEdit                          data    ss   10  7.850 ± 0.051   s/op
SingleThreadedBenchmark.benchmarkR50kBase                          data    ss   10  7.006 ± 0.066   s/op

After:
Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBase                        data    ss   10  4.592 ± 0.055   s/op
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount              data    ss   10  4.215 ± 0.036   s/op
SingleThreadedBenchmark.benchmarkP50kBase                          data    ss   10  5.598 ± 0.063   s/op
SingleThreadedBenchmark.benchmarkP50kEdit                          data    ss   10  5.569 ± 0.044   s/op
SingleThreadedBenchmark.benchmarkR50kBase                          data    ss   10  5.178 ± 0.128   s/op
It's simpler and in the current implementation it's basically just as fast.
@l0rinc l0rinc changed the title 2) Optimize byte pair merge for small and big character sequences - 8.9s to 3.8s 2) Optimize byte pair merge for small and big character sequences - 8.9s to 3.7s Dec 29, 2023
validRanks -= (newRank == MAX_RANK) ? 1 : -1;
TreeMap<Integer, RankNode> minNodes = rankMap.firstEntry().getValue();
for (int i = 0; i < minNodes.size(); i++) {
RankNode minNode = minNodes.firstEntry().getValue();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the minimum we've found will be the next minimum anyway, if there are multiple tokens with the same rank, so we can just use up the gathered ones

public static final int DUMMY_RANK = Integer.MAX_VALUE;
public static final int MAX_RANK = Integer.MAX_VALUE - 1;
public final class TokenEncoder {
public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public should be fine, the users should also be able to override this from the outside if they really need it - though I wouldn't expose it through the API

@@ -14,22 +14,26 @@ class Cl100kBaseTest {

private static final Encoding ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);

Encoding getEncoding() {
return ENCODING;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have to be able to override it from a child test to control the order of initializations


import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY;

class Cl100kLargeTokenizerTest extends Cl100kBaseTest {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checked it via test coverage, the first only runs the fast array one, this one only runs the map based one

@l0rinc l0rinc requested a review from tox-p December 29, 2023 19:28
validRanks--;
}
removeNode(rankMap, nextNode);
}
removeNode(minNodes, rankMap, minNode);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor optimization - we could probably do a few more tiny ones here, but this is a worst-case handler anyway

validRanks--;
}
removeNode(rankMap, nextNode);
Copy link
Contributor Author

@l0rinc l0rinc Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what happened with the formatting here before :/

@@ -104,7 +104,7 @@ Kateri je tvoj najljubši okus sladoleda?,"[42, 977, 72, 4864, 259, 3415, 73, 30
Quel est ton livre préféré?,"[2232, 301, 1826, 8941, 56984, 27389, 69, 68862, 30]","[2232, 301, 1826, 8941, 56984, 27389, 69, 68862, 30]"
Qual é a tua cor favorita?,"[32129, 4046, 264, 64984, 1867, 4799, 6388, 30]","[32129, 4046, 264, 64984, 1867, 4799, 6388, 30]"
Koja ti je omiljena boja?,"[42, 78, 5697, 9165, 4864, 8019, 321, 73, 7304, 712, 5697, 30]","[42, 78, 5697, 9165, 4864, 8019, 321, 73, 7304, 712]"
Melyik a kedvenc étel?,"[44, 989, 1609, 264, 80142, 85, 967, 14240, 301, 30]","[44, 989, 1609, 264, 80142, 85, 967, 14240, 301, 30]"
Melyik a kedvenc ételed?,"[44, 989, 1609, 264, 80142, 85, 967, 4046, 668, 839, 30]","[44, 989, 1609, 264, 80142, 85, 967, 4046, 668, 839]"
Copy link
Contributor Author

@l0rinc l0rinc Dec 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Hungarian sentence itself was incorrect:
image

…iteration

After:
Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBase                        data    ss   10  4.547 ± 0.056   s/op
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount              data    ss   10  3.944 ± 0.031   s/op
SingleThreadedBenchmark.benchmarkP50kBase                          data    ss   10  5.427 ± 0.065   s/op
SingleThreadedBenchmark.benchmarkP50kEdit                          data    ss   10  5.375 ± 0.062   s/op
SingleThreadedBenchmark.benchmarkR50kBase                          data    ss   10  5.073 ± 0.063   s/op
@l0rinc l0rinc changed the title 2) Optimize byte pair merge for small and big character sequences - 8.9s to 3.7s 2) Optimize byte pair merge for small and big character sequences - 8.2s to 3.9s Dec 30, 2023
@@ -150,7 +150,7 @@ public static void main(String[] args) throws Exception {
"'s", "'t", "'re", "'ve", "'m", "'ll", "'d", "'x",
"x",
"123",
"ő",
"a",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testing a slightly different setup here since aaaaaa is still a single token, so this makes it a more difficult case

@@ -177,7 +177,7 @@ public static void main(String[] args) throws Exception {
}

var totalSize = calculateTotalFileSize(rootFolder);
if (totalSize != 99_945_290) {
if (totalSize != 99_925_295) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reran the before/after with these data

assert minNode.rank != MAX_RANK;
TreeMap<Integer, RankNode> minNodes = rankMap.pollFirstEntry().getValue();
int firstIndex;
for (Entry<Integer, RankNode> entry = minNodes.firstEntry(); entry != null; entry = minNodes.ceilingEntry(firstIndex)) {
Copy link
Contributor Author

@l0rinc l0rinc Dec 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of the tree structure we're actually storing every instance of the same token, so once we find any of them, we can be sure that the next few minimums will also be the same token - so let's just iterate those instead, without removing them one-by-one (polling once and iterating until consumed).

@tox-p tox-p merged commit 14270de into knuddelsgmbh:main Jan 2, 2024
2 checks passed
@l0rinc l0rinc deleted the optimize-byte-pair-merge branch January 2, 2024 18:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants