From 8b244b6dfbc9c8ee6817ace21a78796cc8b88376 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Fri, 22 Dec 2023 10:49:28 +0100 Subject: [PATCH 01/14] Split out encodedToDecoded and SpecialEncoder from TokenEncoder --- .../knuddels/jtokkit/GptBytePairEncoding.java | 610 +++++++++--------- .../com/knuddels/jtokkit/SpecialEncoder.java | 39 ++ .../com/knuddels/jtokkit/TokenEncoder.java | 149 ++--- .../{Cl100kBaseTest.java => Cl100kTest.java} | 2 +- 4 files changed, 402 insertions(+), 398 deletions(-) create mode 100644 lib/src/main/java/com/knuddels/jtokkit/SpecialEncoder.java rename lib/src/test/java/com/knuddels/jtokkit/{Cl100kBaseTest.java => Cl100kTest.java} (99%) diff --git a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java index 9bc2bd73..84e3f72a 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java @@ -5,321 +5,309 @@ import com.knuddels.jtokkit.api.GptBytePairEncodingParams; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNull; + /** * Implementation of the byte pair encoding algorithm as used by the OpenAI tiktoken tokenizer. */ -final class GptBytePairEncoding implements Encoding { - - private final String name; - private final Pattern pattern; - private final TokenEncoder encoder; - private final TokenEncoder specialTokensEncoder; - - /** - * Creates a new instance of {@link GptBytePairEncoding}. - * - * @param params the parameters to use for the encoding - */ - GptBytePairEncoding(final GptBytePairEncodingParams params) { - this.name = params.getName(); - this.pattern = params.getPattern(); - this.encoder = new TokenEncoder<>(params.getEncoder(), ImmutableByteArray::from); - this.specialTokensEncoder = new TokenEncoder<>(params.getSpecialTokensEncoder()); - } - - @Override - public List encode(final String text) { - return encodeInternal(text, null).getTokens(); - } - - @Override - public EncodingResult encode(final String text, final int maxTokens) { - return encodeInternal(text, maxTokens); - } - - private EncodingResult encodeInternal(final String text, final Integer maxTokens) { - if (text == null) { - return new EncodingResult(Collections.emptyList(), false); - } - - for (final String specialToken : specialTokensEncoder.getDecodedTokens()) { - if (text.contains(specialToken)) { - throw new UnsupportedOperationException("Encoding special tokens is not supported yet."); - } - } - - return encodeOrdinaryInternal(text, maxTokens); - } - - @Override - public List encodeOrdinary(final String text) { - return encodeOrdinaryInternal(text, null).getTokens(); - } - - @Override - public EncodingResult encodeOrdinary(final String text, final int maxTokens) { - return encodeOrdinaryInternal(text, maxTokens); - } - - private EncodingResult encodeOrdinaryInternal(final String text, final Integer maxTokens) { - if (text == null) { - return new EncodingResult(Collections.emptyList(), false); - } - - final List out = new ArrayList<>(); - final Matcher matcher = pattern.matcher(text); - int tokenCount = 0; - while (matcher.find() && maxTokenCountNotReached(maxTokens, tokenCount)) { - final ImmutableByteArray match = ImmutableByteArray.from(matcher.group()); - if (encoder.containsDecodedToken(match)) { - out.add(encoder.encode(match)); - tokenCount++; - } else { - final List tokensToAdd = bytePairMerge(match); - tokenCount += addTokens(out, tokensToAdd, maxTokens); - } - } - - if (maxTokens != null) { - // Make sure we didn't break the multibyte character - for (int tokensToRemove = 0; tokensToRemove <= out.size(); tokensToRemove++) { - final List tokens = out.subList(0, out.size() - tokensToRemove); - final String decoded = decode(tokens); - if (text.startsWith(decoded)) { - // If decoded text is equal to the head of the original text, we can safely return the tokens - return new EncodingResult(tokens, text.length() > decoded.length()); - } - } - } - - return new EncodingResult(out, false); - } - - /** - * Adds tokens from 'tokensToAdd' to 'out' until either 'maxTokens' is reached or 'tokensToAdd' is exhausted. - * - * @return the number of tokens added to 'out' - */ - private int addTokens(final List out, final List tokensToAdd, final Integer maxTokens) { - if (maxTokens != null) { - final List sublist = tokensToAdd.subList(0, Math.min(maxTokens - out.size(), tokensToAdd.size())); - out.addAll(sublist); - return sublist.size(); - } - - out.addAll(tokensToAdd); - return tokensToAdd.size(); - } - - @Override - public int countTokens(final String text) { - return encode(text).size(); - } - - @Override - public int countTokensOrdinary(final String text) { - return encodeOrdinary(text).size(); - } - - @Override - public String decode(final List tokens) { - return new String(decodeBytes(tokens), StandardCharsets.UTF_8); - } - - @Override - public byte[] decodeBytes(final List tokens) { - final List out = new ArrayList<>(); - for (final int token : tokens) { - final byte[] decodedToken = decodeToken(token); - for (final byte b : decodedToken) { - out.add(b); - } - } - - final byte[] outArray = new byte[out.size()]; - for (int i = 0; i < out.size(); i++) { - outArray[i] = out.get(i); - } - return outArray; - } - - @Override - public String getName() { - return name; - } - - /* - * We use a custom implementation of the byte pair encoding algorithm as used by the OpenAI tokenizer. The - * piece is merged according to the merging rules provided by OpenAI. An example of the algorithm: - * - * piece: v e c t o r - * index: 0 1 2 3 4 5 6 - * ranks: 4 3 7 2 13 inf inf - * - * We don't modify piece directly. We instead create a list of tuples (index, rank) where index is the start index - * of a byte pair and rank is it's merge rank. We call this list of tuples parts. The lowest rank is the byte pair - * that will be merged next. In the example above, the lowest rank is 2, so we merge the byte pair at index 3. - * To merge a byte pair at index i, we first update the ranks of the byte pairs that are affected by the merge, in this - * case the byte pair at index 2 and the byte pair at index 3. Then we remove the byte pair at index i + 1 from the list. - * In this case, this is the byte pair at index 4. - * - * piece: v e c to r - * index: 0 1 2 3 5 6 - * ranks: 4 3 5 9 inf inf - * - * We then repeat the process until there are no more byte pairs to merge, either because we have merged all byte pairs - * and parts.size() is 1, or because there are no more merging rules that apply to our tokens. Let's assume there are merging - * rules for "e + c", "to + r" and "v + ec": - * - * piece: v ec to r - * index: 0 1 3 5 6 - * ranks: 4 11 12 inf inf - * ^ - * - * piece: vec to r - * index: 0 3 5 6 - * ranks: inf 12 inf inf - * ^ - * - * piece: vec tor - * index: 0 3 6 - * ranks: inf inf inf - * - * We can extract the final tokens by simply taking piece.get(parts[0].index) until piece.get(parts[1].index - 1) - * and piece.get(parts[1].index) until piece.get(parts[2].index - 1). Analogously for more than two parts. - * Note that we do not actually modify the piece, but only the parts list. The above visualization is just for - * illustration purposes. - */ - private List bytePairMerge(final ImmutableByteArray piece) { - /* - * piece: v e c t o r - * index: 0 1 2 3 4 5 6 - * ranks: inf inf inf inf inf inf inf - */ - final List parts = new ArrayList<>(); - for (int i = 0; i < piece.length() + 1; i++) { - parts.add(new PieceIndexToRank(i, Integer.MAX_VALUE)); - } - - /* - * piece: v e c t o r - * index: 0 1 2 3 4 5 6 - * ranks: 4 3 7 2 13 inf inf - */ - for (int i = 0; i < parts.size() - 2; i++) { - final Optional rank = getRank(piece, parts, i, 0); - if (rank.isPresent()) { - parts.get(i).rank = rank.get(); - } - } - - while (parts.size() > 1) { - /* - * piece: v e c t o r - * index: 0 1 2 3 4 5 6 - * ranks: 4 3 7 2 13 inf inf - * - * minRankIndex = 3 - * minRank = 2 - */ - int minRankIndex = 0; - int minRank = Integer.MAX_VALUE; - for (int i = 0; i < parts.size() - 1; i++) { - final int rank = parts.get(i).rank; - if (rank < minRank) { - minRank = rank; - minRankIndex = i; - } - } - - /* - * piece: v e c to r - * index: 0 1 2 3 5 6 - * ranks: 4 3 5 9 inf inf - */ - if (minRank != Integer.MAX_VALUE) { - // Note that we calculate the rank of the byte pairs at minRankIndex and minRankIndex - 1 before removing - // the merged byte pair. We use the skip parameter of the getRank function to calculate the rank of, in our - // example, "t" + "o" + "r" and "c" + "t" + "o". The assumption made in the OpenAI implementation is that - // removing first thrashes the cache, so it's better to calculate the rank of the byte pairs that are - // affected by the merge before removing the merged byte pair. I did not verify, if this is actually the - // case in java. - parts.get(minRankIndex).rank = getRank(piece, parts, minRankIndex, 1).orElse(Integer.MAX_VALUE); - if (minRankIndex > 0) { - parts.get(minRankIndex - 1).rank = getRank(piece, parts, minRankIndex - 1, 1).orElse(Integer.MAX_VALUE); - } - - parts.remove(minRankIndex + 1); - } else { - break; - } - } - - /* - * piece: vec tor - * index: 0 3 6 - * ranks: inf inf inf - */ - final List out = new ArrayList<>(); - for (int i = 0; i < parts.size() - 1; i++) { - out.add(encoder.encode(piece.getBytesBetween(parts.get(i).index, parts.get(i + 1).index))); - } - return out; - } - - private boolean maxTokenCountReached(final Integer maxTokenCount, final int tokenCount) { - return maxTokenCount != null && maxTokenCount.compareTo(tokenCount) <= 0; - } - - private boolean maxTokenCountNotReached(final Integer maxTokenCount, final int tokenCount) { - return !maxTokenCountReached(maxTokenCount, tokenCount); - } - - private Optional getRank( - final ImmutableByteArray piece, - final List parts, - final int startIndex, - final int skip - ) { - if (startIndex + skip + 2 >= parts.size()) { - return Optional.empty(); - } - - final int pieceStartIndex = parts.get(startIndex).index; - final int pieceEndIndex = parts.get(startIndex + skip + 2).index; - final ImmutableByteArray encoderIndex = piece.getBytesBetween(pieceStartIndex, pieceEndIndex); - - return encoder.encodeIfPresent(encoderIndex); - } - - private byte[] decodeToken(final int token) { - final Optional decodedToken = encoder.decodeIfPresent(token); - if (decodedToken.isPresent()) { - return decodedToken.get().getRawArray(); - } - - final Optional decodedSpecialToken = specialTokensEncoder.decodeIfPresent(token); - if (decodedSpecialToken.isPresent()) { - return decodedSpecialToken.get().getBytes(StandardCharsets.UTF_8); - } - - throw new IllegalArgumentException("Unknown token for decoding: " + token); - } - - private static class PieceIndexToRank { - private final int index; - private int rank; - - public PieceIndexToRank(final int index, final int rank) { - this.index = index; - this.rank = rank; - } - } +class GptBytePairEncoding implements Encoding { + private final String name; + private final Pattern pattern; + private final TokenEncoder encoder; + private final SpecialEncoder specialEncoder; + private final Map encodedToDecoded; + + /** + * Creates a new instance of {@link GptBytePairEncoding}. + * + * @param params the parameters to use for the encoding + */ + GptBytePairEncoding(GptBytePairEncodingParams params) { + this.name = params.getName(); + this.pattern = params.getPattern(); + this.encoder = new TokenEncoder(params.getEncoder(), ImmutableByteArray::from); + this.specialEncoder = new SpecialEncoder(params.getSpecialTokensEncoder()); + this.encodedToDecoded = new HashMap<>(params.getEncoder().size()); + params.getEncoder().forEach((k, v) -> encodedToDecoded.put(v, k)); + } + + @Override + public List encode(final String text) { + return encodeInternal(text, null).getTokens(); + } + + @Override + public EncodingResult encode(final String text, final int maxTokens) { + return encodeInternal(text, maxTokens); + } + + private EncodingResult encodeInternal(String text, Integer maxTokenCount) { + if (text == null) { + return new EncodingResult(emptyList(), false); + } + + specialEncoder.checkForSpecialTokens(text); + + return encodeOrdinaryInternal(text, maxTokenCount); + } + + @Override + public List encodeOrdinary(String text) { + return encodeOrdinaryInternal(text, null).getTokens(); + } + + @Override + public EncodingResult encodeOrdinary(String text, int maxTokens) { + return encodeOrdinaryInternal(text, maxTokens); + } + + private EncodingResult encodeOrdinaryInternal(String text, Integer maxTokens) { + if (text == null) { + return new EncodingResult(Collections.emptyList(), false); + } + + List out = new ArrayList<>(); + Matcher matcher = pattern.matcher(text); + int tokenCount = 0; + while (matcher.find() && maxTokenCountNotReached(maxTokens, tokenCount)) { + ImmutableByteArray match = ImmutableByteArray.from(matcher.group()); + if (encoder.containsDecodedToken(match)) { + out.add(encoder.encode(match)); + tokenCount++; + } else { + List tokensToAdd = bytePairMerge(match); + tokenCount += addTokens(out, tokensToAdd, maxTokens); + } + } + + if (maxTokens != null) { + // Make sure we didn't break the multibyte character + for (int tokensToRemove = 0; tokensToRemove <= out.size(); tokensToRemove++) { + List tokens = out.subList(0, out.size() - tokensToRemove); + String decoded = decode(tokens); + if (text.startsWith(decoded)) { + // If decoded text is equal to the head of the original text, we can safely return the tokens + return new EncodingResult(tokens, text.length() > decoded.length()); + } + } + } + + return new EncodingResult(out, false); + } + + /** + * Adds tokens from 'tokensToAdd' to 'out' until either 'maxTokens' is reached or 'tokensToAdd' is exhausted. + * + * @return the number of tokens added to 'out' + */ + private int addTokens(List out, List tokensToAdd, Integer maxTokens) { + if (maxTokens != null) { + List sublist = tokensToAdd.subList(0, Math.min(maxTokens - out.size(), tokensToAdd.size())); + out.addAll(sublist); + return sublist.size(); + } + + out.addAll(tokensToAdd); + return tokensToAdd.size(); + } + + @Override + public int countTokens(String text) { + return encode(text).size(); + } + + @Override + public int countTokensOrdinary(String text) { + return encodeOrdinary(text).size(); + } + + @Override + public String decode(List tokens) { + return new String(decodeBytes(tokens), StandardCharsets.UTF_8); + } + + @Override + public byte[] decodeBytes(List tokens) { + List out = new ArrayList<>(); + for (int token : tokens) { + byte[] decodedToken = decodeToken(token); + for (byte b : decodedToken) { + out.add(b); + } + } + + byte[] outArray = new byte[out.size()]; + for (int i = 0; i < out.size(); i++) { + outArray[i] = out.get(i); + } + return outArray; + } + + @Override + public String getName() { + return name; + } + + /* + * We use a custom implementation of the byte pair encoding algorithm as used by the OpenAI tokenizer. The + * piece is merged according to the merging rules provided by OpenAI. An example of the algorithm: + * + * piece: v e c t o r + * index: 0 1 2 3 4 5 6 + * ranks: 4 3 7 2 13 inf inf + * + * We don't modify piece directly. We instead create a list of tuples (index, rank) where index is the start index + * of a byte pair and rank is it's merge rank. We call this list of tuples parts. The lowest rank is the byte pair + * that will be merged next. In the example above, the lowest rank is 2, so we merge the byte pair at index 3. + * To merge a byte pair at index i, we first update the ranks of the byte pairs that are affected by the merge, in this + * case the byte pair at index 2 and the byte pair at index 3. Then we remove the byte pair at index i + 1 from the list. + * In this case, this is the byte pair at index 4. + * + * piece: v e c to r + * index: 0 1 2 3 5 6 + * ranks: 4 3 5 9 inf inf + * + * We then repeat the process until there are no more byte pairs to merge, either because we have merged all byte pairs + * and parts.size() is 1, or because there are no more merging rules that apply to our tokens. Let's assume there are merging + * rules for "e + c", "to + r" and "v + ec": + * + * piece: v ec to r + * index: 0 1 3 5 6 + * ranks: 4 11 12 inf inf + * ^ + * + * piece: vec to r + * index: 0 3 5 6 + * ranks: inf 12 inf inf + * ^ + * + * piece: vec tor + * index: 0 3 6 + * ranks: inf inf inf + * + * We can extract the final tokens by simply taking piece.get(parts[0].index) until piece.get(parts[1].index - 1) + * and piece.get(parts[1].index) until piece.get(parts[2].index - 1). Analogously for more than two parts. + * Note that we do not actually modify the piece, but only the parts list. The above visualization is just for + * illustration purposes. + */ + private List bytePairMerge(ImmutableByteArray piece) { + /* + * piece: v e c t o r + * index: 0 1 2 3 4 5 6 + * ranks: inf inf inf inf inf inf inf + */ + List parts = new ArrayList<>(); + for (int i = 0; i < piece.length() + 1; i++) { + parts.add(new PieceIndexToRank(i, Integer.MAX_VALUE)); + } + + /* + * piece: v e c t o r + * index: 0 1 2 3 4 5 6 + * ranks: 4 3 7 2 13 inf inf + */ + for (int i = 0; i < parts.size() - 2; i++) { + Optional rank = getRank(piece, parts, i, 0); + if (rank.isPresent()) { + parts.get(i).rank = rank.get(); + } + } + + while (parts.size() > 1) { + /* + * piece: v e c t o r + * index: 0 1 2 3 4 5 6 + * ranks: 4 3 7 2 13 inf inf + * + * minRankIndex = 3 + * minRank = 2 + */ + int minRankIndex = 0; + int minRank = Integer.MAX_VALUE; + for (int i = 0; i < parts.size() - 1; i++) { + int rank = parts.get(i).rank; + if (rank < minRank) { + minRank = rank; + minRankIndex = i; + } + } + + /* + * piece: v e c to r + * index: 0 1 2 3 5 6 + * ranks: 4 3 5 9 inf inf + */ + if (minRank != Integer.MAX_VALUE) { + // Note that we calculate the rank of the byte pairs at minRankIndex and minRankIndex - 1 before removing + // the merged byte pair. We use the skip parameter of the getRank function to calculate the rank of, in our + // example, "t" + "o" + "r" and "c" + "t" + "o". The assumption made in the OpenAI implementation is that + // removing first thrashes the cache, so it's better to calculate the rank of the byte pairs that are + // affected by the merge before removing the merged byte pair. I did not verify, if this is actually the + // case in java. + parts.get(minRankIndex).rank = getRank(piece, parts, minRankIndex, 1).orElse(Integer.MAX_VALUE); + if (minRankIndex > 0) { + parts.get(minRankIndex - 1).rank = getRank(piece, parts, minRankIndex - 1, 1).orElse(Integer.MAX_VALUE); + } + + parts.remove(minRankIndex + 1); + } else { + break; + } + } + + /* + * piece: vec tor + * index: 0 3 6 + * ranks: inf inf inf + */ + List out = new ArrayList<>(); + for (int i = 0; i < parts.size() - 1; i++) { + out.add(encoder.encode(piece.getBytesBetween(parts.get(i).index, parts.get(i + 1).index))); + } + return out; + } + + private boolean maxTokenCountReached(Integer maxTokenCount, int tokenCount) { + return maxTokenCount != null && maxTokenCount.compareTo(tokenCount) <= 0; + } + + private boolean maxTokenCountNotReached(Integer maxTokenCount, int tokenCount) { + return !maxTokenCountReached(maxTokenCount, tokenCount); + } + + private Optional getRank( + ImmutableByteArray piece, + List parts, + int startIndex, + int skip + ) { + if (startIndex + skip + 2 >= parts.size()) { + return Optional.empty(); + } + + int pieceStartIndex = parts.get(startIndex).index; + int pieceEndIndex = parts.get(startIndex + skip + 2).index; + ImmutableByteArray encoderIndex = piece.getBytesBetween(pieceStartIndex, pieceEndIndex); + + return encoder.encodeIfPresent(encoderIndex); + } + + private byte[] decodeToken(int token) { + return requireNonNull(encodedToDecoded.computeIfAbsent(token, specialEncoder::decodeIfPresent), "Unknown token for decoding: " + token); + } + + private static class PieceIndexToRank { + private final int index; + private int rank; + + public PieceIndexToRank(int index, int rank) { + this.index = index; + this.rank = rank; + } + } } diff --git a/lib/src/main/java/com/knuddels/jtokkit/SpecialEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/SpecialEncoder.java new file mode 100644 index 00000000..1dd1fbad --- /dev/null +++ b/lib/src/main/java/com/knuddels/jtokkit/SpecialEncoder.java @@ -0,0 +1,39 @@ +package com.knuddels.jtokkit; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static java.nio.charset.StandardCharsets.UTF_8; + +final class SpecialEncoder { + private static final String SPECIAL_START = "<|"; + private static final String SPECIAL_END = "|>"; + private final Map encodedToDecoded; + + public SpecialEncoder(Map encoder) { + this.encodedToDecoded = new ConcurrentHashMap<>(encoder.size()); + for (Map.Entry entry : encoder.entrySet()) { + String key = entry.getKey(); + Integer value = entry.getValue(); + + assert key.contains(SPECIAL_START) && key.contains(SPECIAL_END) : "Special tokens must contain <| and |> (but was " + key + ")"; + + encodedToDecoded.put(value, key); + } + } + + public byte[] decodeIfPresent(Integer encodedToken) { + String result = encodedToDecoded.get(encodedToken); + return result != null ? result.getBytes(UTF_8) : null; + } + + public void checkForSpecialTokens(String text) { + if (text.contains(SPECIAL_START) && text.contains(SPECIAL_END)) { + for (String specialToken : encodedToDecoded.values()) { + if (text.contains(specialToken)) { + throw new UnsupportedOperationException("Encoding special tokens is not supported yet."); + } + } + } + } +} diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 96d170b2..b5a91f82 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -1,103 +1,80 @@ package com.knuddels.jtokkit; -import java.util.*; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; import java.util.function.Function; /** * A TokenEncoder is used to encode and decode tokens. It is initialized with a map * that contains the decoded tokens as keys and the encoded tokens as values. The * TokenEncoder can then be used to encode and decode tokens. - * - * @param the type of the decoded tokens - * @param the type of the encoded tokens */ -final class TokenEncoder { +class TokenEncoder { - private final Map decodedToEncoded = new HashMap<>(); - private final Map encodedToDecoded = new HashMap<>(); + private final Map decodedToEncoded = new HashMap<>(); - /** - * Creates a new TokenEncoder with the given input map. The keys of the map are - * the decoded tokens and the values are the encoded tokens. - * - * @param input the input map - */ - public TokenEncoder(final Map input) { - this(input, Function.identity()); - } + /** + * Creates a new TokenEncoder with the given input map. The keys of the map are + * the decoded tokens and the values are the encoded tokens. + * + * @param input the input map + */ + public TokenEncoder(Map input) { + this(input, Function.identity()); + } - /** - * Creates a new TokenEncoder with the given input map. The keys of the map are - * the decoded tokens and the values are the encoded tokens. The keyMapper is - * applied to the keys of the input map before they are added to the internal - * maps. - * - * @param input the input map - * @param keyMapper the key mapper - */ - public TokenEncoder(final Map input, final Function keyMapper) { - for (final Map.Entry entry : input.entrySet()) { - final K key = keyMapper.apply(entry.getKey()); - final V value = entry.getValue(); - decodedToEncoded.put(key, value); - encodedToDecoded.put(value, key); - } - } + /** + * Creates a new TokenEncoder with the given input map. The keys of the map are + * the decoded tokens and the values are the encoded tokens. The keyMapper is + * applied to the keys of the input map before they are added to the internal + * maps. + * + * @param input the input map + * @param keyMapper the key mapper + */ + public TokenEncoder(Map input, Function keyMapper) { + for (Map.Entry entry : input.entrySet()) { + ImmutableByteArray key = keyMapper.apply(entry.getKey()); + Integer value = entry.getValue(); + decodedToEncoded.put(key, value); + } + } - /** - * Checks if the given decoded token is contained in this encoder. - * - * @param decodedToken the decoded token - * @return true if the decoded token is contained in this encoder, false otherwise - */ - public boolean containsDecodedToken(final K decodedToken) { - return decodedToEncoded.containsKey(decodedToken); - } + /** + * Checks if the given decoded token is contained in this encoder. + * + * @param decodedToken the decoded token + * @return true if the decoded token is contained in this encoder, false otherwise + */ + public boolean containsDecodedToken(ImmutableByteArray decodedToken) { + return decodedToEncoded.containsKey(decodedToken); + } - /** - * Encodes the given decoded token. - * - * @param decodedToken the decoded token - * @return the encoded token - * @throws IllegalArgumentException if the decoded token is not contained in this encoder - */ - public V encode(final K decodedToken) { - final V encoded = decodedToEncoded.get(decodedToken); - if (encoded == null) { - throw new IllegalArgumentException("Unknown token for encoding: " + decodedToken); - } + /** + * Encodes the given decoded token. + * + * @param decodedToken the decoded token + * @return the encoded token + * @throws IllegalArgumentException if the decoded token is not contained in this encoder + */ + public Integer encode(ImmutableByteArray decodedToken) { + Integer encoded = decodedToEncoded.get(decodedToken); + if (encoded == null) { + throw new IllegalArgumentException("Unknown token for encoding: " + decodedToken); + } - return encoded; - } + return encoded; + } - /** - * Encodes the given decoded token if it is contained in this encoder. Otherwise, - * an empty optional is returned. - * - * @param decodedToken the decoded token - * @return the encoded token or an empty optional - */ - public Optional encodeIfPresent(final K decodedToken) { - return Optional.ofNullable(decodedToEncoded.get(decodedToken)); - } - - /** - * Decodes the given encoded token if it is contained in this encoder. Otherwise, - * an empty optional is returned. - * - * @param encodedToken the encoded token - * @return the decoded token or an empty optional - */ - public Optional decodeIfPresent(final V encodedToken) { - return Optional.ofNullable(encodedToDecoded.get(encodedToken)); - } - - /** - * Returns an unmodifiable set of all decoded tokens contained in this encoder. - * - * @return an unmodifiable set of all decoded tokens - */ - public Set getDecodedTokens() { - return Collections.unmodifiableSet(decodedToEncoded.keySet()); - } + /** + * Encodes the given decoded token if it is contained in this encoder. Otherwise, + * an empty optional is returned. + * + * @param decodedToken the decoded token + * @return the encoded token or an empty optional + */ + public Optional encodeIfPresent(ImmutableByteArray decodedToken) { + return Optional.ofNullable(decodedToEncoded.get(decodedToken)); + } } diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kBaseTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java similarity index 99% rename from lib/src/test/java/com/knuddels/jtokkit/Cl100kBaseTest.java rename to lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index 645c5a26..d64b87fc 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kBaseTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -15,7 +15,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -class Cl100kBaseTest { +class Cl100kTest { private static final String PUNCTUATION = "'\".,?!:()"; private static final String LETTERS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZő你好ſ ½"; private static final String NUMBERS = "0123456789½"; From 9a035e16a2d4e521b46895a684be8ae45adaf6b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Fri, 22 Dec 2023 14:36:51 +0100 Subject: [PATCH 02/14] Optimize TokenEncoder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * When searching for the next minimum value, we've unrolled it to favor SIMD optimizations. * PieceIndexToRank is removed, we're only storing the rank, since we're not deleting next values anymore (which avoids copying every subsequent value) * Since we've replaced minimums with sentinels, previous and next indexes are replaced by iteration * The encoders map is split by input byte array size so that we're only querying small maps * Iteration stops before the last minimum is MAX_RANK by keeping track of merge results - resulting in one less minimum search at the end Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 8.947 ± 0.109 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 9.419 ± 0.082 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 9.365 ± 0.073 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 8.403 ± 0.080 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 7.313 ± 0.031 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 7.242 ± 0.027 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 7.742 ± 0.054 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 7.748 ± 0.121 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 7.017 ± 0.110 s/op --- .../jtokkit/SingleThreadedBenchmark.java | 11 + .../knuddels/jtokkit/GptBytePairEncoding.java | 224 +++------------- .../com/knuddels/jtokkit/TokenEncoder.java | 246 +++++++++++++----- 3 files changed, 223 insertions(+), 258 deletions(-) diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java index ae8f620e..896757af 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java @@ -1,11 +1,22 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; +import org.openjdk.jmh.annotations.Benchmark; import java.util.List; public class SingleThreadedBenchmark extends AbstractBenchmark { + @Benchmark + public int benchmarkCl100kBaseTokenCount(BenchmarkingState state) { + var result = 0; + var encoding = state.cl100kBase; + for (var fileContent : state.fileContents) { + result += encoding.countTokens(fileContent); + } + return result; + } + @Override protected List> encodeAll(final Encoding encoding, final List fileContents) { return fileContents.stream() diff --git a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java index 84e3f72a..5ff296ab 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java @@ -4,11 +4,14 @@ import com.knuddels.jtokkit.api.EncodingResult; import com.knuddels.jtokkit.api.GptBytePairEncodingParams; -import java.nio.charset.StandardCharsets; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; @@ -30,20 +33,20 @@ class GptBytePairEncoding implements Encoding { GptBytePairEncoding(GptBytePairEncodingParams params) { this.name = params.getName(); this.pattern = params.getPattern(); - this.encoder = new TokenEncoder(params.getEncoder(), ImmutableByteArray::from); + this.encoder = new TokenEncoder(params.getEncoder()); this.specialEncoder = new SpecialEncoder(params.getSpecialTokensEncoder()); this.encodedToDecoded = new HashMap<>(params.getEncoder().size()); params.getEncoder().forEach((k, v) -> encodedToDecoded.put(v, k)); } @Override - public List encode(final String text) { + public List encode(String text) { return encodeInternal(text, null).getTokens(); } @Override - public EncodingResult encode(final String text, final int maxTokens) { - return encodeInternal(text, maxTokens); + public EncodingResult encode(String text, int maxTokenCount) { + return encodeInternal(text, maxTokenCount); } private EncodingResult encodeInternal(String text, Integer maxTokenCount) { @@ -62,33 +65,27 @@ public List encodeOrdinary(String text) { } @Override - public EncodingResult encodeOrdinary(String text, int maxTokens) { - return encodeOrdinaryInternal(text, maxTokens); + public EncodingResult encodeOrdinary(String text, int maxTokenCount) { + return encodeOrdinaryInternal(text, maxTokenCount); } - private EncodingResult encodeOrdinaryInternal(String text, Integer maxTokens) { + private EncodingResult encodeOrdinaryInternal(String text, Integer maxTokenCount) { if (text == null) { - return new EncodingResult(Collections.emptyList(), false); + return new EncodingResult(emptyList(), false); } List out = new ArrayList<>(); - Matcher matcher = pattern.matcher(text); - int tokenCount = 0; - while (matcher.find() && maxTokenCountNotReached(maxTokens, tokenCount)) { - ImmutableByteArray match = ImmutableByteArray.from(matcher.group()); - if (encoder.containsDecodedToken(match)) { - out.add(encoder.encode(match)); - tokenCount++; - } else { - List tokensToAdd = bytePairMerge(match); - tokenCount += addTokens(out, tokensToAdd, maxTokens); - } - } + int tokenCount = encodeOrdinaryInternal(text, maxTokenCount, out); + assert maxTokenCount != null || tokenCount == out.size(); - if (maxTokens != null) { + if (maxTokenCount != null) { // Make sure we didn't break the multibyte character for (int tokensToRemove = 0; tokensToRemove <= out.size(); tokensToRemove++) { - List tokens = out.subList(0, out.size() - tokensToRemove); + int size = out.size() - tokensToRemove; + List tokens = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + tokens.add(out.get(i)); + } String decoded = decode(tokens); if (text.startsWith(decoded)) { // If decoded text is equal to the head of the original text, we can safely return the tokens @@ -100,20 +97,14 @@ private EncodingResult encodeOrdinaryInternal(String text, Integer maxTokens) { return new EncodingResult(out, false); } - /** - * Adds tokens from 'tokensToAdd' to 'out' until either 'maxTokens' is reached or 'tokensToAdd' is exhausted. - * - * @return the number of tokens added to 'out' - */ - private int addTokens(List out, List tokensToAdd, Integer maxTokens) { - if (maxTokens != null) { - List sublist = tokensToAdd.subList(0, Math.min(maxTokens - out.size(), tokensToAdd.size())); - out.addAll(sublist); - return sublist.size(); + int encodeOrdinaryInternal(String text, Integer maxTokenCount, List out) { + int tokenCount = 0; + List ranks = new ArrayList<>(); // reused to avoid allocations + for (Matcher matcher = pattern.matcher(text); (maxTokenCount == null || tokenCount < maxTokenCount) && matcher.find(); ) { + byte[] bytes = matcher.group().getBytes(UTF_8); + tokenCount += encoder.addTokensAndGetCount(maxTokenCount, bytes, out, ranks); } - - out.addAll(tokensToAdd); - return tokensToAdd.size(); + return tokenCount; } @Override @@ -128,7 +119,7 @@ public int countTokensOrdinary(String text) { @Override public String decode(List tokens) { - return new String(decodeBytes(tokens), StandardCharsets.UTF_8); + return new String(decodeBytes(tokens), UTF_8); } @Override @@ -153,161 +144,8 @@ public String getName() { return name; } - /* - * We use a custom implementation of the byte pair encoding algorithm as used by the OpenAI tokenizer. The - * piece is merged according to the merging rules provided by OpenAI. An example of the algorithm: - * - * piece: v e c t o r - * index: 0 1 2 3 4 5 6 - * ranks: 4 3 7 2 13 inf inf - * - * We don't modify piece directly. We instead create a list of tuples (index, rank) where index is the start index - * of a byte pair and rank is it's merge rank. We call this list of tuples parts. The lowest rank is the byte pair - * that will be merged next. In the example above, the lowest rank is 2, so we merge the byte pair at index 3. - * To merge a byte pair at index i, we first update the ranks of the byte pairs that are affected by the merge, in this - * case the byte pair at index 2 and the byte pair at index 3. Then we remove the byte pair at index i + 1 from the list. - * In this case, this is the byte pair at index 4. - * - * piece: v e c to r - * index: 0 1 2 3 5 6 - * ranks: 4 3 5 9 inf inf - * - * We then repeat the process until there are no more byte pairs to merge, either because we have merged all byte pairs - * and parts.size() is 1, or because there are no more merging rules that apply to our tokens. Let's assume there are merging - * rules for "e + c", "to + r" and "v + ec": - * - * piece: v ec to r - * index: 0 1 3 5 6 - * ranks: 4 11 12 inf inf - * ^ - * - * piece: vec to r - * index: 0 3 5 6 - * ranks: inf 12 inf inf - * ^ - * - * piece: vec tor - * index: 0 3 6 - * ranks: inf inf inf - * - * We can extract the final tokens by simply taking piece.get(parts[0].index) until piece.get(parts[1].index - 1) - * and piece.get(parts[1].index) until piece.get(parts[2].index - 1). Analogously for more than two parts. - * Note that we do not actually modify the piece, but only the parts list. The above visualization is just for - * illustration purposes. - */ - private List bytePairMerge(ImmutableByteArray piece) { - /* - * piece: v e c t o r - * index: 0 1 2 3 4 5 6 - * ranks: inf inf inf inf inf inf inf - */ - List parts = new ArrayList<>(); - for (int i = 0; i < piece.length() + 1; i++) { - parts.add(new PieceIndexToRank(i, Integer.MAX_VALUE)); - } - - /* - * piece: v e c t o r - * index: 0 1 2 3 4 5 6 - * ranks: 4 3 7 2 13 inf inf - */ - for (int i = 0; i < parts.size() - 2; i++) { - Optional rank = getRank(piece, parts, i, 0); - if (rank.isPresent()) { - parts.get(i).rank = rank.get(); - } - } - - while (parts.size() > 1) { - /* - * piece: v e c t o r - * index: 0 1 2 3 4 5 6 - * ranks: 4 3 7 2 13 inf inf - * - * minRankIndex = 3 - * minRank = 2 - */ - int minRankIndex = 0; - int minRank = Integer.MAX_VALUE; - for (int i = 0; i < parts.size() - 1; i++) { - int rank = parts.get(i).rank; - if (rank < minRank) { - minRank = rank; - minRankIndex = i; - } - } - - /* - * piece: v e c to r - * index: 0 1 2 3 5 6 - * ranks: 4 3 5 9 inf inf - */ - if (minRank != Integer.MAX_VALUE) { - // Note that we calculate the rank of the byte pairs at minRankIndex and minRankIndex - 1 before removing - // the merged byte pair. We use the skip parameter of the getRank function to calculate the rank of, in our - // example, "t" + "o" + "r" and "c" + "t" + "o". The assumption made in the OpenAI implementation is that - // removing first thrashes the cache, so it's better to calculate the rank of the byte pairs that are - // affected by the merge before removing the merged byte pair. I did not verify, if this is actually the - // case in java. - parts.get(minRankIndex).rank = getRank(piece, parts, minRankIndex, 1).orElse(Integer.MAX_VALUE); - if (minRankIndex > 0) { - parts.get(minRankIndex - 1).rank = getRank(piece, parts, minRankIndex - 1, 1).orElse(Integer.MAX_VALUE); - } - - parts.remove(minRankIndex + 1); - } else { - break; - } - } - - /* - * piece: vec tor - * index: 0 3 6 - * ranks: inf inf inf - */ - List out = new ArrayList<>(); - for (int i = 0; i < parts.size() - 1; i++) { - out.add(encoder.encode(piece.getBytesBetween(parts.get(i).index, parts.get(i + 1).index))); - } - return out; - } - - private boolean maxTokenCountReached(Integer maxTokenCount, int tokenCount) { - return maxTokenCount != null && maxTokenCount.compareTo(tokenCount) <= 0; - } - - private boolean maxTokenCountNotReached(Integer maxTokenCount, int tokenCount) { - return !maxTokenCountReached(maxTokenCount, tokenCount); - } - - private Optional getRank( - ImmutableByteArray piece, - List parts, - int startIndex, - int skip - ) { - if (startIndex + skip + 2 >= parts.size()) { - return Optional.empty(); - } - - int pieceStartIndex = parts.get(startIndex).index; - int pieceEndIndex = parts.get(startIndex + skip + 2).index; - ImmutableByteArray encoderIndex = piece.getBytesBetween(pieceStartIndex, pieceEndIndex); - - return encoder.encodeIfPresent(encoderIndex); - } - private byte[] decodeToken(int token) { - return requireNonNull(encodedToDecoded.computeIfAbsent(token, specialEncoder::decodeIfPresent), "Unknown token for decoding: " + token); - } - - private static class PieceIndexToRank { - private final int index; - private int rank; - - public PieceIndexToRank(int index, int rank) { - this.index = index; - this.rank = rank; - } + byte[] decodedToken = encodedToDecoded.computeIfAbsent(token, specialEncoder::decodeIfPresent); + return requireNonNull(decodedToken, "Unknown token for decoding: " + token); } } diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index b5a91f82..03998fde 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -1,80 +1,196 @@ package com.knuddels.jtokkit; -import java.util.HashMap; + +import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.function.Function; - -/** - * A TokenEncoder is used to encode and decode tokens. It is initialized with a map - * that contains the decoded tokens as keys and the encoded tokens as values. The - * TokenEncoder can then be used to encode and decode tokens. - */ -class TokenEncoder { - - private final Map decodedToEncoded = new HashMap<>(); - - /** - * Creates a new TokenEncoder with the given input map. The keys of the map are - * the decoded tokens and the values are the encoded tokens. - * - * @param input the input map - */ - public TokenEncoder(Map input) { - this(input, Function.identity()); +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; + +final class TokenEncoder { + public static final int DUMMY_RANK = Integer.MAX_VALUE; + public static final int MAX_RANK = Integer.MAX_VALUE - 1; + private final Map[] encoders; + private int length = 0; + + public TokenEncoder(Map encoder) { + if (!encoder.isEmpty()) { + TreeMap> tempEncoders = new TreeMap<>(); + encoder.forEach((k, v) -> { + length++; + ImmutableByteArray key = ImmutableByteArray.from(k); + tempEncoders.computeIfAbsent(k.length, integer -> new ConcurrentHashMap<>()).put(key, v); + }); + //noinspection unchecked + encoders = new ConcurrentHashMap[tempEncoders.lastKey() + 1]; + tempEncoders.forEach((k, v) -> encoders[k] = v); + } else { + //noinspection unchecked + encoders = new Map[0]; // for testing + } + } + + public static int getMinRankIndex(List ranks) { + int minRankIndex = -1; + int minRank = MAX_RANK; + + int i = 0; + int length = ranks.size() - 3; + for (; i < length - 2; i += 4) { // Unrolled loop + { + int r = ranks.get(i); + if (r < minRank) { + minRankIndex = i; + minRank = r; + } + } + { + int r = ranks.get(i + 1); + if (r < minRank) { + minRankIndex = i + 1; + minRank = r; + } + } + { + int r = ranks.get(i + 2); + if (r < minRank) { + minRankIndex = i + 2; + minRank = r; + } + } + { + int r = ranks.get(i + 3); + if (r < minRank) { + minRankIndex = i + 3; + minRank = r; + } + } + } + + for (; i <= length; i++) { + int r = ranks.get(i); + if (r < minRank) { + minRankIndex = i; + minRank = r; + } + } + + return minRankIndex; + } + + public static int getNextIndex(List ranks, int nextIndex) { + while (nextIndex < ranks.size() && ranks.get(nextIndex) == DUMMY_RANK) { + nextIndex++; + } + return nextIndex; + } + + public static int getPreviousIndex(List ranks, int previousIndex) { + while (previousIndex >= 0 && ranks.get(previousIndex) == DUMMY_RANK) { + previousIndex--; + } + return previousIndex; } - /** - * Creates a new TokenEncoder with the given input map. The keys of the map are - * the decoded tokens and the values are the encoded tokens. The keyMapper is - * applied to the keys of the input map before they are added to the internal - * maps. - * - * @param input the input map - * @param keyMapper the key mapper - */ - public TokenEncoder(Map input, Function keyMapper) { - for (Map.Entry entry : input.entrySet()) { - ImmutableByteArray key = keyMapper.apply(entry.getKey()); - Integer value = entry.getValue(); - decodedToEncoded.put(key, value); + public int addTokensAndGetCount(Integer maxTokenCount, byte[] utf8Bytes, List out, List ranks) { + ImmutableByteArray match = ImmutableByteArray.from(utf8Bytes); + int encoded = encode(match); + if (encoded != MAX_RANK) { + out.add(encoded); + return 1; + } else { + int length = match.length(); + return addTokensAndGetCount(maxTokenCount, out, ranks, match, length); } } - /** - * Checks if the given decoded token is contained in this encoder. - * - * @param decodedToken the decoded token - * @return true if the decoded token is contained in this encoder, false otherwise - */ - public boolean containsDecodedToken(ImmutableByteArray decodedToken) { - return decodedToEncoded.containsKey(decodedToken); + private int addTokensAndGetCount(Integer maxTokenCount, List out, List ranks, ImmutableByteArray match, int length) { + int validRanks = initRanks(match, length, ranks); + int tokenCount = mergeBytesAndGetTokenCount(match, length, ranks, validRanks); + for (int start = 0, end = 1; end < ranks.size() && (maxTokenCount == null || out.size() < maxTokenCount); end++) { + if (ranks.get(end) != DUMMY_RANK) { + int token = encode(match, start, end); + assert token != MAX_RANK : "Token should not be MAX_RANK"; + out.add(token); + start = end; + } + } + return tokenCount; } - /** - * Encodes the given decoded token. - * - * @param decodedToken the decoded token - * @return the encoded token - * @throws IllegalArgumentException if the decoded token is not contained in this encoder - */ - public Integer encode(ImmutableByteArray decodedToken) { - Integer encoded = decodedToEncoded.get(decodedToken); - if (encoded == null) { - throw new IllegalArgumentException("Unknown token for encoding: " + decodedToken); + int initRanks(ImmutableByteArray piece, int tokenCount, List ranks) { + assert tokenCount > 1 : "Already filtered out"; + ranks.clear(); + int validRanks = 0; + for (int i = 0; i < tokenCount + 1; i++) { + int encoded = encode(piece, i, i + 2); + if (encoded != MAX_RANK) { + validRanks++; + } + ranks.add(encoded); } + return validRanks; + } + + int mergeBytesAndGetTokenCount(ImmutableByteArray piece, int length, List ranks, int validRanks) { + assert true; + while (true) { + if (validRanks == 0) { + assert getMinRankIndex(ranks) < 0; + break; + } + int minRankIndex = getMinRankIndex(ranks); + assert minRankIndex >= 0; + + int previousIndex = getPreviousIndex(ranks, minRankIndex - 1); + int nextIndex = getNextIndex(ranks, minRankIndex + 1); + int nextNextIndex = getNextIndex(ranks, nextIndex + 1); + int nextNextNextIndex = getNextIndex(ranks, nextNextIndex + 1); + + if (previousIndex >= 0) { + assert ranks.get(previousIndex) != DUMMY_RANK; + int newRank = encode(piece, previousIndex, nextNextIndex); + int oldRank = ranks.set(previousIndex, newRank); + if ((newRank == MAX_RANK) != (oldRank == MAX_RANK)) { + validRanks -= (newRank == MAX_RANK) ? 1 : -1; + } + } + assert ranks.get(minRankIndex) != DUMMY_RANK; + int newRank = encode(piece, minRankIndex, nextNextNextIndex); + int oldRank = ranks.set(minRankIndex, newRank); + if ((newRank == MAX_RANK) != (oldRank == MAX_RANK)) { + validRanks--; + } + + int oldDeletedRank = ranks.set(nextIndex, DUMMY_RANK); + if (oldDeletedRank != MAX_RANK) { + validRanks--; + } - return encoded; + length--; + } + return length; + } + + int encode(ImmutableByteArray payload) { + if (payload.length() < encoders.length) { + Map encoder = encoders[payload.length()]; + return encoder == null ? MAX_RANK : encoder.getOrDefault(payload, MAX_RANK); + } else { + return MAX_RANK; + } + } + + int encode(ImmutableByteArray piece, int start, int end) { + if (end > piece.length()) { + return MAX_RANK; + } else if (end - start == piece.length()) { + return encode(piece); + } else { + return encode(piece.getBytesBetween(start, end)); + } } - /** - * Encodes the given decoded token if it is contained in this encoder. Otherwise, - * an empty optional is returned. - * - * @param decodedToken the decoded token - * @return the encoded token or an empty optional - */ - public Optional encodeIfPresent(ImmutableByteArray decodedToken) { - return Optional.ofNullable(decodedToEncoded.get(decodedToken)); + public int length() { + return length; } -} +} \ No newline at end of file From f73b823ce400b94c7f00c9aaef12f7ee7d0570d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Fri, 22 Dec 2023 15:00:40 +0100 Subject: [PATCH 03/14] Optimize countTokens MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 7.242 ± 0.027 s/op Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 6.885 ± 0.049 s/op --- .../knuddels/jtokkit/GptBytePairEncoding.java | 41 +-- .../com/knuddels/jtokkit/TokenEncoder.java | 24 +- .../com/knuddels/jtokkit/api/Encoding.java | 330 ++++++++---------- .../knuddels/jtokkit/api/EncodingResult.java | 72 ++-- .../jtokkit/BaseEncodingRegistryTest.java | 5 - 5 files changed, 229 insertions(+), 243 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java index 5ff296ab..87619a35 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java @@ -19,6 +19,7 @@ * Implementation of the byte pair encoding algorithm as used by the OpenAI tiktoken tokenizer. */ class GptBytePairEncoding implements Encoding { + private final String name; private final Pattern pattern; private final TokenEncoder encoder; @@ -41,44 +42,43 @@ class GptBytePairEncoding implements Encoding { @Override public List encode(String text) { - return encodeInternal(text, null).getTokens(); + return encode(text, Integer.MAX_VALUE).getTokens(); } @Override public EncodingResult encode(String text, int maxTokenCount) { - return encodeInternal(text, maxTokenCount); + return encodeInternal(text, maxTokenCount, true); } - private EncodingResult encodeInternal(String text, Integer maxTokenCount) { + private EncodingResult encodeInternal(String text, int maxTokenCount, boolean keepEncodings) { if (text == null) { - return new EncodingResult(emptyList(), false); + return new EncodingResult(emptyList(), -1, false); } specialEncoder.checkForSpecialTokens(text); - return encodeOrdinaryInternal(text, maxTokenCount); + return encodeOrdinaryInternal(text, maxTokenCount, keepEncodings); } @Override public List encodeOrdinary(String text) { - return encodeOrdinaryInternal(text, null).getTokens(); + return encodeOrdinary(text, Integer.MAX_VALUE).getTokens(); } @Override public EncodingResult encodeOrdinary(String text, int maxTokenCount) { - return encodeOrdinaryInternal(text, maxTokenCount); + return encodeOrdinaryInternal(text, maxTokenCount, true); } - private EncodingResult encodeOrdinaryInternal(String text, Integer maxTokenCount) { + private EncodingResult encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings) { if (text == null) { - return new EncodingResult(emptyList(), false); + return new EncodingResult(emptyList(), -1, false); } List out = new ArrayList<>(); - int tokenCount = encodeOrdinaryInternal(text, maxTokenCount, out); - assert maxTokenCount != null || tokenCount == out.size(); + int tokenCount = encodeOrdinaryInternal(text, maxTokenCount, keepEncodings, out); - if (maxTokenCount != null) { + if (keepEncodings && maxTokenCount != Integer.MAX_VALUE) { // Make sure we didn't break the multibyte character for (int tokensToRemove = 0; tokensToRemove <= out.size(); tokensToRemove++) { int size = out.size() - tokensToRemove; @@ -89,32 +89,27 @@ private EncodingResult encodeOrdinaryInternal(String text, Integer maxTokenCount String decoded = decode(tokens); if (text.startsWith(decoded)) { // If decoded text is equal to the head of the original text, we can safely return the tokens - return new EncodingResult(tokens, text.length() > decoded.length()); + return new EncodingResult(tokens, -1, text.length() > decoded.length()); } } } - return new EncodingResult(out, false); + return new EncodingResult(out, tokenCount, false); } - int encodeOrdinaryInternal(String text, Integer maxTokenCount, List out) { + int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List out) { int tokenCount = 0; List ranks = new ArrayList<>(); // reused to avoid allocations - for (Matcher matcher = pattern.matcher(text); (maxTokenCount == null || tokenCount < maxTokenCount) && matcher.find(); ) { + for (Matcher matcher = pattern.matcher(text); tokenCount < maxTokenCount && matcher.find(); ) { byte[] bytes = matcher.group().getBytes(UTF_8); - tokenCount += encoder.addTokensAndGetCount(maxTokenCount, bytes, out, ranks); + tokenCount += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, bytes, out, ranks); } return tokenCount; } @Override public int countTokens(String text) { - return encode(text).size(); - } - - @Override - public int countTokensOrdinary(String text) { - return encodeOrdinary(text).size(); + return encodeInternal(text, Integer.MAX_VALUE, false).getTokenCount(); } @Override diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 03998fde..8665aa55 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -91,27 +91,31 @@ public static int getPreviousIndex(List ranks, int previousIndex) { return previousIndex; } - public int addTokensAndGetCount(Integer maxTokenCount, byte[] utf8Bytes, List out, List ranks) { + public int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] utf8Bytes, List out, List ranks) { ImmutableByteArray match = ImmutableByteArray.from(utf8Bytes); int encoded = encode(match); if (encoded != MAX_RANK) { - out.add(encoded); + if (keepEncodings) { + out.add(encoded); + } return 1; } else { int length = match.length(); - return addTokensAndGetCount(maxTokenCount, out, ranks, match, length); + return addTokensAndGetCount(maxTokenCount, keepEncodings, out, ranks, match, length); } } - private int addTokensAndGetCount(Integer maxTokenCount, List out, List ranks, ImmutableByteArray match, int length) { + private int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, List out, List ranks, ImmutableByteArray match, int length) { int validRanks = initRanks(match, length, ranks); int tokenCount = mergeBytesAndGetTokenCount(match, length, ranks, validRanks); - for (int start = 0, end = 1; end < ranks.size() && (maxTokenCount == null || out.size() < maxTokenCount); end++) { - if (ranks.get(end) != DUMMY_RANK) { - int token = encode(match, start, end); - assert token != MAX_RANK : "Token should not be MAX_RANK"; - out.add(token); - start = end; + if (keepEncodings) { + for (int start = 0, end = 1; end < ranks.size() && out.size() < maxTokenCount; end++) { + if (ranks.get(end) != DUMMY_RANK) { + int token = encode(match, start, end); + assert token != MAX_RANK : "Token should not be MAX_RANK"; + out.add(token); + start = end; + } } } return tokenCount; diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java b/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java index abd8d076..9fe467cd 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java @@ -4,187 +4,165 @@ public interface Encoding { - /** - * Encodes the given text into a list of token ids. - *

- * Special tokens are artificial tokens used to unlock capabilities from a model, - * such as fill-in-the-middle. There is currently no support for parsing special tokens - * in a text, so if the text contains special tokens, this method will throw an - * {@link UnsupportedOperationException}. - *

- * If you want to encode special tokens as ordinary text, use {@link #encodeOrdinary(String)}. - *

-	 * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
-	 * encoding.encode("hello world");
-	 * // returns [15339, 1917]
-	 *
-	 * encoding.encode("hello <|endoftext|> world");
-	 * // raises an UnsupportedOperationException
-	 * 
- * - * @param text the text to encode - * @return the list of token ids - * @throws UnsupportedOperationException if the text contains special tokens which are not supported for now - */ - List encode(String text); + /** + * Encodes the given text into a list of token ids. + *

+ * Special tokens are artificial tokens used to unlock capabilities from a model, + * such as fill-in-the-middle. There is currently no support for parsing special tokens + * in a text, so if the text contains special tokens, this method will throw an + * {@link UnsupportedOperationException}. + *

+ * If you want to encode special tokens as ordinary text, use {@link #encodeOrdinary(String)}. + *

+     * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
+     * encoding.encode("hello world");
+     * // returns [15339, 1917]
+     *
+     * encoding.encode("hello <|endoftext|> world");
+     * // raises an UnsupportedOperationException
+     * 
+ * + * @param text the text to encode + * @return the list of token ids + * @throws UnsupportedOperationException if the text contains special tokens which are not supported for now + */ + List encode(String text); - /** - * Encodes the given text into a list of token ids. - *

- * Special tokens are artificial tokens used to unlock capabilities from a model, - * such as fill-in-the-middle. There is currently no support for parsing special tokens - * in a text, so if the text contains special tokens, this method will throw an - * {@link UnsupportedOperationException}. - *

- * If you want to encode special tokens as ordinary text, use {@link #encodeOrdinary(String, int)}. - *

- * This method will truncate the list of token ids if the number of tokens exceeds the - * given maxTokens parameter. Note that it will try to keep characters together, that are encoded into - * multiple tokens. For example, if the text contains a character which is encoded into 3 tokens, - * and due to the maxTokens parameter the last token of the character is truncated, the first two - * tokens of the character will also be truncated. Therefore, the actual number of tokens may be - * less than the given maxTokens parameter. - *

-	 * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
-	 * encoding.encode("hello world", 100);
-	 * // returns [15339, 1917]
-	 *
-	 * encoding.encode("hello <|endoftext|> world", 100);
-	 * // raises an UnsupportedOperationException
-	 * 
- * - * @param text the text to encode - * @param maxTokens the maximum number of tokens to encode - * @return the {@link EncodingResult} containing a list of token ids and whether the tokens were truncated due to the maxTokens parameter - * @throws UnsupportedOperationException if the text contains special tokens which are not supported for now - */ - EncodingResult encode(String text, int maxTokens); + /** + * Encodes the given text into a list of token ids. + *

+ * Special tokens are artificial tokens used to unlock capabilities from a model, + * such as fill-in-the-middle. There is currently no support for parsing special tokens + * in a text, so if the text contains special tokens, this method will throw an + * {@link UnsupportedOperationException}. + *

+ * If you want to encode special tokens as ordinary text, use {@link #encodeOrdinary(String, int)}. + *

+ * This method will truncate the list of token ids if the number of tokens exceeds the + * given maxTokens parameter. Note that it will try to keep characters together, that are encoded into + * multiple tokens. For example, if the text contains a character which is encoded into 3 tokens, + * and due to the maxTokens parameter the last token of the character is truncated, the first two + * tokens of the character will also be truncated. Therefore, the actual number of tokens may be + * less than the given maxTokens parameter. + *

+     * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
+     * encoding.encode("hello world", 100);
+     * // returns [15339, 1917]
+     *
+     * encoding.encode("hello <|endoftext|> world", 100);
+     * // raises an UnsupportedOperationException
+     * 
+ * + * @param text the text to encode + * @param maxTokens the maximum number of tokens to encode + * @return the {@link EncodingResult} containing a list of token ids and whether the tokens were truncated due to the maxTokens parameter + * @throws UnsupportedOperationException if the text contains special tokens which are not supported for now + */ + EncodingResult encode(String text, int maxTokens); - /** - * Encodes the given text into a list of token ids, ignoring special tokens. - *

- * This method does not throw an exception if the text contains special tokens, but instead - * encodes them as if they were ordinary text. - *

-	 * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
-	 * encoding.encodeOrdinary("hello world");
-	 * // returns [15339, 1917]
-	 *
-	 * encoding.encodeOrdinary("hello <|endoftext|> world");
-	 * // returns [15339, 83739, 8862, 728, 428, 91, 29, 1917]
-	 * 
- * - * @param text the text to encode - * @return the list of token ids - */ - List encodeOrdinary(String text); + /** + * Encodes the given text into a list of token ids, ignoring special tokens. + *

+ * This method does not throw an exception if the text contains special tokens, but instead + * encodes them as if they were ordinary text. + *

+     * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
+     * encoding.encodeOrdinary("hello world");
+     * // returns [15339, 1917]
+     *
+     * encoding.encodeOrdinary("hello <|endoftext|> world");
+     * // returns [15339, 83739, 8862, 728, 428, 91, 29, 1917]
+     * 
+ * + * @param text the text to encode + * @return the list of token ids + */ + List encodeOrdinary(String text); - /** - * Encodes the given text into a list of token ids, ignoring special tokens. - *

- * This method does not throw an exception if the text contains special tokens, but instead - * encodes them as if they were ordinary text. - *

- * It will truncate the list of token ids if the number of tokens exceeds the - * given maxTokens parameter. Note that it will try to keep characters together, that are encoded into - * multiple tokens. For example, if the text contains a character which is encoded into 3 tokens, - * and due to the maxTokens parameter the last token of the character is truncated, the first two - * tokens of the character will also be truncated. Therefore, the actual number of tokens may be - * less than the given maxTokens parameter. - *

-	 * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
-	 * encoding.encodeOrdinary("hello world", 100);
-	 * // returns [15339, 1917]
-	 *
-	 * encoding.encodeOrdinary("hello <|endoftext|> world", 100);
-	 * // returns [15339, 83739, 8862, 728, 428, 91, 29, 1917]
-	 * 
- * - * @param text the text to encode - * @param maxTokens the maximum number of tokens to encode - * @return the {@link EncodingResult} containing a list of token ids and whether the tokens were truncated due to the maxTokens parameter - */ - EncodingResult encodeOrdinary(String text, int maxTokens); + /** + * Encodes the given text into a list of token ids, ignoring special tokens. + *

+ * This method does not throw an exception if the text contains special tokens, but instead + * encodes them as if they were ordinary text. + *

+ * It will truncate the list of token ids if the number of tokens exceeds the + * given maxTokens parameter. Note that it will try to keep characters together, that are encoded into + * multiple tokens. For example, if the text contains a character which is encoded into 3 tokens, + * and due to the maxTokens parameter the last token of the character is truncated, the first two + * tokens of the character will also be truncated. Therefore, the actual number of tokens may be + * less than the given maxTokens parameter. + *

+     * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
+     * encoding.encodeOrdinary("hello world", 100);
+     * // returns [15339, 1917]
+     *
+     * encoding.encodeOrdinary("hello <|endoftext|> world", 100);
+     * // returns [15339, 83739, 8862, 728, 428, 91, 29, 1917]
+     * 
+ * + * @param text the text to encode + * @param maxTokens the maximum number of tokens to encode + * @return the {@link EncodingResult} containing a list of token ids and whether the tokens were truncated due to the maxTokens parameter + */ + EncodingResult encodeOrdinary(String text, int maxTokens); - /** - * Encodes the given text into a list of token ids and returns the amount of tokens. - * This is a convenience method for {@link #encode(String)}, if all you want is to - * know the amount of tokens. It is not more performant than {@link #encode(String)}, - * so prefer to use {@link #encode(String)} if you actually need the tokens. - *
-	 * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
-	 * encoding.countTokens("hello world");
-	 * // returns 2
-	 *
-	 * encoding.countTokens("hello <|endoftext|> world");
-	 * // raises an UnsupportedOperationException
-	 * 
- * - * @param text the text to count tokens for - * @return the amount of tokens - * @throws UnsupportedOperationException if the text contains special tokens which are not supported for now - */ - int countTokens(String text); + /** + * Encodes the given text into a list of token ids and returns the amount of tokens. + * It is more performant than {@link #encode(String)}. + *
+     * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
+     * encoding.countTokens("hello world");
+     * // returns 2
+     *
+     * encoding.countTokens("hello <|endoftext|> world");
+     * // raises an UnsupportedOperationException
+     * 
+ * + * @param text the text to count tokens for + * @return the amount of tokens + * @throws UnsupportedOperationException if the text contains special tokens which are not supported for now + */ + int countTokens(String text); - /** - * Encodes the given text into a list of token ids and returns the amount of tokens. - * This is a convenience method for {@link #encodeOrdinary(String)}, if all you want is to - * know the amount of tokens. It is not more performant than {@link #encodeOrdinary(String)}, - * so prefer to use {@link #encodeOrdinary(String)} if you actually need the tokens. - *
-	 * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
-	 * encoding.countTokensOrdinary("hello world");
-	 * // returns 2
-	 *
-	 * encoding.countTokensOrdinary("hello <|endoftext|> world");
-	 * // returns 8
-	 * 
- * - * @param text the text to count tokens for - * @return the amount of tokens - * @throws UnsupportedOperationException if the text contains special tokens which are not supported for now - */ - int countTokensOrdinary(String text); + /** + * Decodes the given list of token ids into a text. + *
+     * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
+     * encoding.decode(List.of(15339, 1917));
+     * // returns "hello world"
+     *
+     * encoding.decode(List.of(15339, 1917, Integer.MAX_VALUE));
+     * // raises an IllegalArgumentException
+     * 
+ * + * @param tokens the list of token ids + * @return the decoded text + * @throws IllegalArgumentException if the list contains invalid token ids + */ + String decode(List tokens); - /** - * Decodes the given list of token ids into a text. - *
-	 * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
-	 * encoding.decode(List.of(15339, 1917));
-	 * // returns "hello world"
-	 *
-	 * encoding.decode(List.of(15339, 1917, Integer.MAX_VALUE));
-	 * // raises an IllegalArgumentException
-	 * 
- * - * @param tokens the list of token ids - * @return the decoded text - * @throws IllegalArgumentException if the list contains invalid token ids - */ - String decode(List tokens); + /** + * Decodes the given list of token ids into a byte array. + *
+     * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
+     * encoding.decodeBytes(List.of(15339, 1917));
+     * // returns [104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]
+     *
+     * encoding.decodeBytes(List.of(15339, 1917, Integer.MAX_VALUE));
+     * // raises an IllegalArgumentException
+     * 
+ * + * @param tokens the list of token ids + * @return the decoded byte array + * @throws IllegalArgumentException if the list contains invalid token ids + */ + byte[] decodeBytes(List tokens); - /** - * Decodes the given list of token ids into a byte array. - *
-	 * Encoding encoding = EncodingRegistry.getEncoding(EncodingType.CL100K_BASE);
-	 * encoding.decodeBytes(List.of(15339, 1917));
-	 * // returns [104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]
-	 *
-	 * encoding.decodeBytes(List.of(15339, 1917, Integer.MAX_VALUE));
-	 * // raises an IllegalArgumentException
-	 * 
- * - * @param tokens the list of token ids - * @return the decoded byte array - * @throws IllegalArgumentException if the list contains invalid token ids - */ - byte[] decodeBytes(List tokens); - - /** - * Returns the name of this encoding. This is the name which is used to identify - * the encoding and must be unique for registration in the {@link EncodingRegistry}. - * - * @return the name of this encoding - */ - String getName(); + /** + * Returns the name of this encoding. This is the name which is used to identify + * the encoding and must be unique for registration in the {@link EncodingRegistry}. + * + * @return the name of this encoding + */ + String getName(); } diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java b/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java index 62613e3d..f2713af3 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java +++ b/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java @@ -1,38 +1,52 @@ package com.knuddels.jtokkit.api; + import java.util.List; /** * The result of encoding operation. */ public final class EncodingResult { - private final List tokens; - private final boolean truncated; - - public EncodingResult(final List tokens, final boolean truncated) { - this.tokens = tokens; - this.truncated = truncated; - } - - /** - * @return the list of token ids - */ - public List getTokens() { - return tokens; - } - - /** - * @return true if the token list was truncated because the maximum token length was exceeded - */ - public boolean isTruncated() { - return truncated; - } - - @Override - public String toString() { - return "EncodingResult{" - + "tokens=" + tokens - + ", truncated=" + truncated - + '}'; - } + private final List tokens; + private final boolean truncated; + private int tokenCount; + + public EncodingResult(List tokens, int tokenCount, boolean truncated) { + this.tokens = tokens; + this.tokenCount = tokenCount; + this.truncated = truncated; + } + + /** + * @return the list of token ids + */ + public List getTokens() { + if (tokens.size() != getTokenCount()) { + throw new IllegalStateException("Token count does not match token list size (tokenCount=" + tokenCount + ", tokens size=" + tokens.size() + ")"); + } + return tokens; + } + + public int getTokenCount() { + if (tokenCount < 0) { + tokenCount = tokens.size(); + } + return tokenCount; + } + + /** + * @return true if the token list was truncated because the maximum token length was exceeded + */ + public boolean isTruncated() { + return truncated; + } + + @Override + public String toString() { + return "EncodingResult{" + + "tokens=" + getTokens() + + ", tokenCount=" + getTokenCount() + + ", truncated=" + truncated + + '}'; + } } diff --git a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java index 876b59dc..0c9fca76 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java @@ -167,11 +167,6 @@ public int countTokens(String text) { return 0; } - @Override - public int countTokensOrdinary(String text) { - return 0; - } - @Override public String decode(List tokens) { return null; From 2b8efc94c0aee2f04cfe422cb933bfe5f2b896ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Fri, 22 Dec 2023 16:31:47 +0100 Subject: [PATCH 04/14] Add addTokensAndGetCountLarge with linear byte pair merge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We're storing the ranks in a red-black tree of trees. Getting the minimum rank is basically constant time (grouping by the rank itself since we can have multiple, popping the first (representing the first occurrence)). Here we're removing the node after merge (also basically constant time operation). We're also counting the remaining valid ranks for stopping condition. To know the previous and next values here, we're storing all of it in a RankNode that we're updating after finding the minimum via the tree. Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 7.372 ± 0.063 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 6.885 ± 0.049 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 7.846 ± 0.051 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 7.850 ± 0.051 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 7.006 ± 0.066 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 4.592 ± 0.055 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 4.215 ± 0.036 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.598 ± 0.063 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.569 ± 0.044 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 5.178 ± 0.128 s/op --- .../com/knuddels/jtokkit/TokenEncoder.java | 69 ++++++----- .../knuddels/jtokkit/TokenEncoderLarge.java | 109 ++++++++++++++++++ .../java/com/knuddels/jtokkit/Cl100kTest.java | 26 +++-- 3 files changed, 167 insertions(+), 37 deletions(-) create mode 100644 lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 8665aa55..7941ec4a 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -1,19 +1,22 @@ package com.knuddels.jtokkit; - import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; +import static com.knuddels.jtokkit.TokenEncoderLarge.addTokensAndGetCountLarge; + final class TokenEncoder { public static final int DUMMY_RANK = Integer.MAX_VALUE; public static final int MAX_RANK = Integer.MAX_VALUE - 1; private final Map[] encoders; + private int VERY_LARGE_TOKENIZER_BYTE_THRESHOLD; private int length = 0; public TokenEncoder(Map encoder) { if (!encoder.isEmpty()) { + VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = Integer.parseInt(System.getProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", "500")); TreeMap> tempEncoders = new TreeMap<>(); encoder.forEach((k, v) -> { length++; @@ -101,13 +104,32 @@ public int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] return 1; } else { int length = match.length(); - return addTokensAndGetCount(maxTokenCount, keepEncodings, out, ranks, match, length); + if (length < VERY_LARGE_TOKENIZER_BYTE_THRESHOLD) { + return addTokensAndGetCountSmall(maxTokenCount, keepEncodings, out, ranks, match, length); + } else { + return addTokensAndGetCountLarge(this, maxTokenCount, keepEncodings, out, match, length); + } } } - private int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, List out, List ranks, ImmutableByteArray match, int length) { - int validRanks = initRanks(match, length, ranks); - int tokenCount = mergeBytesAndGetTokenCount(match, length, ranks, validRanks); + private int addTokensAndGetCountSmall(int maxTokenCount, boolean keepEncodings, List out, List ranks, ImmutableByteArray match, int length) { + assert length > 1 : "Already filtered out"; + ranks.clear(); + + int validRanks = 0; + int minRankIndex = -1; + for (int i = 0, minRank = MAX_RANK; i < length + 1; i++) { + int encoded = encode(match, i, i + 2); + if (encoded != MAX_RANK) { + validRanks++; + if (encoded < minRank) { + minRankIndex = i; + minRank = encoded; + } + } + ranks.add(encoded); + } + int tokenCount = mergeBytesAndGetTokenCount(match, length, ranks, validRanks, minRankIndex); if (keepEncodings) { for (int start = 0, end = 1; end < ranks.size() && out.size() < maxTokenCount; end++) { if (ranks.get(end) != DUMMY_RANK) { @@ -121,28 +143,8 @@ private int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, List< return tokenCount; } - int initRanks(ImmutableByteArray piece, int tokenCount, List ranks) { - assert tokenCount > 1 : "Already filtered out"; - ranks.clear(); - int validRanks = 0; - for (int i = 0; i < tokenCount + 1; i++) { - int encoded = encode(piece, i, i + 2); - if (encoded != MAX_RANK) { - validRanks++; - } - ranks.add(encoded); - } - return validRanks; - } - - int mergeBytesAndGetTokenCount(ImmutableByteArray piece, int length, List ranks, int validRanks) { - assert true; - while (true) { - if (validRanks == 0) { - assert getMinRankIndex(ranks) < 0; - break; - } - int minRankIndex = getMinRankIndex(ranks); + int mergeBytesAndGetTokenCount(ImmutableByteArray piece, int length, List ranks, int validRanks, int minRankIndex) { + while (validRanks > 0) { assert minRankIndex >= 0; int previousIndex = getPreviousIndex(ranks, minRankIndex - 1); @@ -171,17 +173,24 @@ int mergeBytesAndGetTokenCount(ImmutableByteArray piece, int length, List encoder = encoders[payload.length()]; - return encoder == null ? MAX_RANK : encoder.getOrDefault(payload, MAX_RANK); - } else { - return MAX_RANK; + if (encoder != null) { + Integer result = encoder.get(payload); + if (result != null) { + return result; + } + } } + return MAX_RANK; } int encode(ImmutableByteArray piece, int start, int end) { diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java new file mode 100644 index 00000000..f1061a6a --- /dev/null +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -0,0 +1,109 @@ +package com.knuddels.jtokkit; + + +import java.util.List; +import java.util.TreeMap; + +import static com.knuddels.jtokkit.TokenEncoder.MAX_RANK; + +final class TokenEncoderLarge { + static int addTokensAndGetCountLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, List out, ImmutableByteArray match, int length) { + assert length > 1 : "Already filtered out"; + + TreeMap> rankMap = new TreeMap<>(); + + RankNode head = null; + RankNode prevNode = null; + int validRanks = 0; + for (int i = 0; i < length + 1; i++) { + int encoded = tokenEncoder.encode(match, i, i + 2); + if (encoded != MAX_RANK) { + validRanks++; + } + RankNode node = new RankNode(encoded, i); + if (head == null) { + head = node; + } else { + prevNode.next = node; + node.prev = prevNode; + } + prevNode = node; + + rankMap.computeIfAbsent(encoded, k -> new TreeMap<>()).put(i, node); + } + + while (validRanks > 0) { + RankNode minNode = rankMap.firstEntry().getValue().firstEntry().getValue(); + assert minNode.rank != MAX_RANK; + + RankNode previousNode = minNode.prev; + RankNode nextNode = minNode.next; + RankNode nextNextNode = nextNode != null ? nextNode.next : null; + RankNode nextNextNextNode = nextNextNode != null ? nextNextNode.next : null; + + if (previousNode != null) { + int newRank = tokenEncoder.encode(match, previousNode.index, nextNextNode != null ? nextNextNode.index : Integer.MAX_VALUE); + if ((newRank == MAX_RANK) != (previousNode.rank == MAX_RANK)) { + validRanks -= (newRank == MAX_RANK) ? 1 : -1; + } + removeNode(rankMap, previousNode); + previousNode.rank = newRank; + rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(previousNode.index, previousNode); + } + + int newRank = tokenEncoder.encode(match, minNode.index, nextNextNextNode != null ? nextNextNextNode.index : Integer.MAX_VALUE); + if ((newRank == MAX_RANK) != (minNode.rank == MAX_RANK)) { + validRanks--; + } + removeNode(rankMap, minNode); + minNode.rank = newRank; + rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(minNode.index, minNode); + + minNode.next = nextNextNode; + if (nextNode != null) { + if (nextNextNode != null) { + nextNextNode.prev = minNode; + } + if (nextNode.rank != MAX_RANK) { + validRanks--; + } + removeNode(rankMap, nextNode); + } + + length--; + } + assert rankMap.firstEntry().getValue().firstEntry().getValue().rank == MAX_RANK; + + if (keepEncodings) { + while (head.next != null && out.size() < maxTokenCount) { + int token = tokenEncoder.encode(match, head.index, head.next.index); + assert token != MAX_RANK : "Token should not be MAX_RANK"; + out.add(token); + head = head.next; + } + } + + return length; + } + + static void removeNode(TreeMap> rankMap, RankNode nextNode) { + TreeMap nodeMap = rankMap.get(nextNode.rank); + if (nodeMap.size() == 1) { + assert nodeMap.containsKey(nextNode.index); + rankMap.remove(nextNode.rank); + } else { + nodeMap.remove(nextNode.index); + } + } + + private static class RankNode { + int rank; + int index; + RankNode prev, next; + + RankNode(int rank, int index) { + this.rank = rank; + this.index = index; + } + } +} \ No newline at end of file diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index d64b87fc..f25048a0 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -5,6 +5,7 @@ import org.junit.jupiter.api.Test; import java.util.List; +import java.util.Map; import java.util.TreeMap; import java.util.concurrent.ThreadLocalRandom; import java.util.stream.IntStream; @@ -45,7 +46,7 @@ void measureEncodingSpeeds() { var measurements = new TreeMap(); var iterations = 20; - for (var i = 1.0; i < 3_000; i = Math.max(i + 1, i * 1.01)) { + for (var i = 1.0; i < 2_000; i = Math.max(i + 1, i * 1.01)) { while (input.length() < i) { input.append("a"); } @@ -144,6 +145,11 @@ void testEdgeCaseRoundTrips() throws Exception { @Test void testRoundTripWithRandomStrings() throws Exception { + System.setProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", String.valueOf(Integer.MAX_VALUE)); + var arrayEncoder = EncodingFactory.cl100kBase(); + + System.setProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", String.valueOf(0)); + var mapEncoder = EncodingFactory.cl100kBase(); var singleTokenStrings = getAllTokens(); IntStream.range(0, 10_000).parallel().forEach(i -> { String testString; @@ -153,13 +159,19 @@ void testRoundTripWithRandomStrings() throws Exception { var maxTokenCount = rand().nextInt(1, 2 * testString.length()); - var actualTokens = ENCODING.encode(testString); - var decodedTokens = ENCODING.decode(actualTokens); - assertEquals(testString, decodedTokens, decodedTokens); + var encoders = Map.of(arrayEncoder, "arrayEncoder", mapEncoder, "mapEncoder"); + for (Encoding encoder : encoders.keySet()) { +// System.out.println("Validating `" + normalizeStringForTesting(testString) + "` with " + encoders.get(encoder) + " and maxTokenCount = " + maxTokenCount); + var actualTokens = encoder.encode(testString); + assertEquals(actualTokens.size(), encoder.countTokens(testString)); - var actualTrimmedTokens = ENCODING.encode(testString, maxTokenCount).getTokens(); - var decodedTrimmedTokens = ENCODING.decode(actualTrimmedTokens); - assertTrue(testString.startsWith(decodedTrimmedTokens)); + var decodedTokens = encoder.decode(actualTokens); + assertEquals(testString, decodedTokens, decodedTokens); + + var actualTrimmedTokens = encoder.encode(testString, maxTokenCount).getTokens(); + var decodedTrimmedTokens = encoder.decode(actualTrimmedTokens); + assertTrue(testString.startsWith(decodedTrimmedTokens)); + } }); } From 5b2872508893d745c4992b47bda69f481ce5335c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Mon, 25 Dec 2023 12:19:58 +0200 Subject: [PATCH 05/14] Optimize decodeBytes --- .../com/knuddels/jtokkit/GptBytePairEncoding.java | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java index 87619a35..1e8c99ee 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java @@ -4,6 +4,7 @@ import com.knuddels.jtokkit.api.EncodingResult; import com.knuddels.jtokkit.api.GptBytePairEncodingParams; +import java.io.ByteArrayOutputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -119,19 +120,14 @@ public String decode(List tokens) { @Override public byte[] decodeBytes(List tokens) { - List out = new ArrayList<>(); + ByteArrayOutputStream out = new ByteArrayOutputStream(10 * tokens.size()); for (int token : tokens) { byte[] decodedToken = decodeToken(token); for (byte b : decodedToken) { - out.add(b); + out.write(b); } } - - byte[] outArray = new byte[out.size()]; - for (int i = 0; i < out.size(); i++) { - outArray[i] = out.get(i); - } - return outArray; + return out.toByteArray(); } @Override From e98e30d49b69d3b949765884a0aee6add115e9a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Mon, 25 Dec 2023 12:43:13 +0200 Subject: [PATCH 06/14] Remove cloning from ImmutableByteArray and rename it to ByteArrayWrapper --- .../knuddels/jtokkit/ByteArrayWrapper.java | 74 ++++++++++++ .../knuddels/jtokkit/ImmutableByteArray.java | 113 ------------------ .../com/knuddels/jtokkit/TokenEncoder.java | 18 +-- .../knuddels/jtokkit/TokenEncoderLarge.java | 2 +- .../jtokkit/ByteArrayWrapperTest.java | 38 ++++++ .../jtokkit/ImmutableByteArrayTest.java | 76 ------------ 6 files changed, 122 insertions(+), 199 deletions(-) create mode 100644 lib/src/main/java/com/knuddels/jtokkit/ByteArrayWrapper.java delete mode 100644 lib/src/main/java/com/knuddels/jtokkit/ImmutableByteArray.java create mode 100644 lib/src/test/java/com/knuddels/jtokkit/ByteArrayWrapperTest.java delete mode 100644 lib/src/test/java/com/knuddels/jtokkit/ImmutableByteArrayTest.java diff --git a/lib/src/main/java/com/knuddels/jtokkit/ByteArrayWrapper.java b/lib/src/main/java/com/knuddels/jtokkit/ByteArrayWrapper.java new file mode 100644 index 00000000..cd4b9aea --- /dev/null +++ b/lib/src/main/java/com/knuddels/jtokkit/ByteArrayWrapper.java @@ -0,0 +1,74 @@ +package com.knuddels.jtokkit; + +import java.util.Arrays; + +class ByteArrayWrapper { + private final byte[] array; + + /* + * Creates a new instance of ByteArrayWrapper from the given array. + * The given array is not copied, so every calling method in this class must make sure + * to never pass an array which reference leaked to the outside. Since some of our + * construction methods already create new arrays, we do not want to copy here in this + * constructor again. + */ + ByteArrayWrapper(byte[] array) { + this.array = array; + } + + /** + * Returns the length of this array. + * + * @return the length of this array. + */ + public int length() { + return array.length; + } + + /** + * Returns the bytes of this array from startIndex (inclusive) to endIndex (exclusive). The returned array is a copy + * of the original array. + * + * @param startIndex the index from which to start copying (inclusive) + * @param endIndex the index at which to stop copying (exclusive) + * @return a new {@link ByteArrayWrapper} containing the bytes from startIndex (inclusive) to endIndex (exclusive) + * @throws IllegalArgumentException if startIndex is out of bounds, endIndex is out of bounds or endIndex is less than + * startIndex + */ + public ByteArrayWrapper getBytesBetween(int startIndex, int endIndex) { + if (startIndex < 0 || startIndex >= array.length) { + throw new IndexOutOfBoundsException("startIndex out of bounds: " + startIndex + " (" + this + ")"); + } else if (endIndex < 0 || endIndex > array.length) { + throw new IndexOutOfBoundsException("endIndex out of bounds: " + endIndex + " (" + this + ")"); + } else if (startIndex >= endIndex) { + throw new IllegalArgumentException("startIndex must be less than endIndex: " + startIndex + " >= " + endIndex); + } + + int length = endIndex - startIndex; + byte[] result = new byte[length]; + System.arraycopy(array, startIndex, result, 0, length); + return new ByteArrayWrapper(result); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other == null || getClass() != other.getClass()) { + return false; + } + + ByteArrayWrapper that = (ByteArrayWrapper) other; + return Arrays.equals(array, that.array); + } + + @Override + public int hashCode() { + return Arrays.hashCode(array); + } + + @Override + public String toString() { + return Arrays.toString(array); + } +} diff --git a/lib/src/main/java/com/knuddels/jtokkit/ImmutableByteArray.java b/lib/src/main/java/com/knuddels/jtokkit/ImmutableByteArray.java deleted file mode 100644 index b0b4fb8c..00000000 --- a/lib/src/main/java/com/knuddels/jtokkit/ImmutableByteArray.java +++ /dev/null @@ -1,113 +0,0 @@ -package com.knuddels.jtokkit; - -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Objects; - -final class ImmutableByteArray { - private final byte[] array; - - /** - * Creates a new instance of {@link ImmutableByteArray} from the given {@code string}. - * - * @param string the string to convert to a byte array - * @return a new {@link ImmutableByteArray} containing the bytes of the given string - */ - public static ImmutableByteArray from(final String string) { - Objects.requireNonNull(string, "String must not be null"); - return new ImmutableByteArray(string.getBytes(StandardCharsets.UTF_8)); - } - - /** - * Creates a new instance of {@link ImmutableByteArray} from the given {@code array}. - * - * @param array the array to copy - * @return a new {@link ImmutableByteArray} containing the bytes of the given array - */ - public static ImmutableByteArray from(final byte[] array) { - Objects.requireNonNull(array, "Byte array must not be null"); - return new ImmutableByteArray(array.clone()); - } - - /* - * Creates a new instance of ImmutableByteArray from the given array. - * The given array is not copied, so every calling method in this class must make sure - * to never pass an array which reference leaked to the outside. Since some of our - * construction methods already create new arrays, we do not want to copy here in this - * constructor again. - */ - private ImmutableByteArray(final byte[] array) { - this.array = array; - } - - /** - * Returns the length of this array. - * - * @return the length of this array. - */ - public int length() { - return array.length; - } - - /** - * Returns the bytes of this array from startIndex (inclusive) to endIndex (exclusive). The returned array is a copy - * of the original array. - * - * @param startIndex the index from which to start copying (inclusive) - * @param endIndex the index at which to stop copying (exclusive) - * @return a new {@link ImmutableByteArray} containing the bytes from startIndex (inclusive) to endIndex (exclusive) - * @throws IllegalArgumentException if startIndex is out of bounds, endIndex is out of bounds or endIndex is less than - * startIndex - */ - public ImmutableByteArray getBytesBetween(final int startIndex, final int endIndex) { - if (startIndex < 0 || startIndex >= array.length) { - throw new IndexOutOfBoundsException("startIndex out of bounds: " + startIndex + " (" + this + ")"); - } - - if (endIndex < 0 || endIndex > array.length) { - throw new IndexOutOfBoundsException("endIndex out of bounds: " + endIndex + " (" + this + ")"); - } - - if (startIndex >= endIndex) { - throw new IllegalArgumentException("startIndex must be less than endIndex: " + startIndex + " >= " + endIndex); - } - - final int length = endIndex - startIndex; - final byte[] result = new byte[length]; - System.arraycopy(array, startIndex, result, 0, length); - return new ImmutableByteArray(result); - } - - /** - * Returns a copy of the raw array backing this {@link ImmutableByteArray}. - * - * @return a copy of the raw array backing this {@link ImmutableByteArray} - */ - public byte[] getRawArray() { - return array.clone(); - } - - @Override - public boolean equals(final Object other) { - if (this == other) { - return true; - } - - if (other == null || getClass() != other.getClass()) { - return false; - } - - final ImmutableByteArray that = (ImmutableByteArray) other; - return Arrays.equals(array, that.array); - } - - @Override - public int hashCode() { - return Arrays.hashCode(array); - } - - @Override - public String toString() { - return Arrays.toString(array); - } -} diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 7941ec4a..16d10aa9 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -10,17 +10,17 @@ final class TokenEncoder { public static final int DUMMY_RANK = Integer.MAX_VALUE; public static final int MAX_RANK = Integer.MAX_VALUE - 1; - private final Map[] encoders; + private final Map[] encoders; private int VERY_LARGE_TOKENIZER_BYTE_THRESHOLD; private int length = 0; public TokenEncoder(Map encoder) { if (!encoder.isEmpty()) { VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = Integer.parseInt(System.getProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", "500")); - TreeMap> tempEncoders = new TreeMap<>(); + TreeMap> tempEncoders = new TreeMap<>(); encoder.forEach((k, v) -> { length++; - ImmutableByteArray key = ImmutableByteArray.from(k); + ByteArrayWrapper key = new ByteArrayWrapper(k); tempEncoders.computeIfAbsent(k.length, integer -> new ConcurrentHashMap<>()).put(key, v); }); //noinspection unchecked @@ -95,7 +95,7 @@ public static int getPreviousIndex(List ranks, int previousIndex) { } public int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] utf8Bytes, List out, List ranks) { - ImmutableByteArray match = ImmutableByteArray.from(utf8Bytes); + ByteArrayWrapper match = new ByteArrayWrapper(utf8Bytes); int encoded = encode(match); if (encoded != MAX_RANK) { if (keepEncodings) { @@ -112,7 +112,7 @@ public int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] } } - private int addTokensAndGetCountSmall(int maxTokenCount, boolean keepEncodings, List out, List ranks, ImmutableByteArray match, int length) { + private int addTokensAndGetCountSmall(int maxTokenCount, boolean keepEncodings, List out, List ranks, ByteArrayWrapper match, int length) { assert length > 1 : "Already filtered out"; ranks.clear(); @@ -143,7 +143,7 @@ private int addTokensAndGetCountSmall(int maxTokenCount, boolean keepEncodings, return tokenCount; } - int mergeBytesAndGetTokenCount(ImmutableByteArray piece, int length, List ranks, int validRanks, int minRankIndex) { + int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, List ranks, int validRanks, int minRankIndex) { while (validRanks > 0) { assert minRankIndex >= 0; @@ -180,9 +180,9 @@ int mergeBytesAndGetTokenCount(ImmutableByteArray piece, int length, List encoder = encoders[payload.length()]; + Map encoder = encoders[payload.length()]; if (encoder != null) { Integer result = encoder.get(payload); if (result != null) { @@ -193,7 +193,7 @@ int encode(ImmutableByteArray payload) { return MAX_RANK; } - int encode(ImmutableByteArray piece, int start, int end) { + int encode(ByteArrayWrapper piece, int start, int end) { if (end > piece.length()) { return MAX_RANK; } else if (end - start == piece.length()) { diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java index f1061a6a..5473b0a7 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -7,7 +7,7 @@ import static com.knuddels.jtokkit.TokenEncoder.MAX_RANK; final class TokenEncoderLarge { - static int addTokensAndGetCountLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, List out, ImmutableByteArray match, int length) { + static int addTokensAndGetCountLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, List out, ByteArrayWrapper match, int length) { assert length > 1 : "Already filtered out"; TreeMap> rankMap = new TreeMap<>(); diff --git a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayWrapperTest.java b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayWrapperTest.java new file mode 100644 index 00000000..053b2f51 --- /dev/null +++ b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayWrapperTest.java @@ -0,0 +1,38 @@ +package com.knuddels.jtokkit; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ByteArrayWrapperTest { + @Test + public void getBytesBetweenReturnsCorrectSliceOfArray() { + final ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); + + assertEquals(new ByteArrayWrapper(new byte[]{4, 5, 6}), byteArray.getBytesBetween(3, 6)); + } + + @Test + public void getBytesBetweenThrowsWhenInclusiveStartIndexOutOfBounds() { + final ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); + + assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(-1, 6)); + assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(9, 10)); + } + + @Test + public void getBytesBetweenThrowsWhenExclusiveEndIndexOutOfBounds() { + final ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); + + assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(0, 7)); + assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(0, -1)); + } + + @Test + public void getBytesBetweenThrowsWhenStartIndexIsGreaterThanEndIndex() { + final ByteArrayWrapper byteArray = new ByteArrayWrapper(new byte[]{1, 2, 3, 4, 5, 6}); + + assertThrows(IllegalArgumentException.class, () -> byteArray.getBytesBetween(3, 2)); + } +} diff --git a/lib/src/test/java/com/knuddels/jtokkit/ImmutableByteArrayTest.java b/lib/src/test/java/com/knuddels/jtokkit/ImmutableByteArrayTest.java deleted file mode 100644 index 4504049f..00000000 --- a/lib/src/test/java/com/knuddels/jtokkit/ImmutableByteArrayTest.java +++ /dev/null @@ -1,76 +0,0 @@ -package com.knuddels.jtokkit; - -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.*; - -class ImmutableByteArrayTest { - - @Test - void canBeUsedAsKeyInMap() { - var key1 = ImmutableByteArray.from("1, 2, 3"); - var key2 = ImmutableByteArray.from("1, 2, 3"); - - assertEquals(key1, key2); - assertEquals(key1.hashCode(), key2.hashCode()); - } - - @Test - void canNotBeMutatedWhenUsingByteArrayConstructor() { - var bytes = new byte[]{1, 2, 3}; - var byteArray = ImmutableByteArray.from(bytes); - - bytes[0] = 4; - - assertNotEquals(byteArray, ImmutableByteArray.from(bytes)); - assertEquals(byteArray, ImmutableByteArray.from(new byte[]{1, 2, 3})); - } - - @Test - void canNotBeMutatedWhenUsingGetRawArray() { - var byteArray = ImmutableByteArray.from("1, 2, 3"); - var bytes = byteArray.getRawArray(); - - bytes[0] = 4; - - assertNotEquals(byteArray, ImmutableByteArray.from(bytes)); - assertEquals(byteArray, ImmutableByteArray.from("1, 2, 3")); - } - - @Test - void getLengthIsCorrect() { - var byteArray = ImmutableByteArray.from("1, 2, 3"); - - assertEquals(7, byteArray.length()); - } - - @Test - void getBytesBetweenReturnsCorrectSliceOfArray() { - var byteArray = ImmutableByteArray.from(new byte[]{1, 2, 3, 4, 5, 6}); - - assertEquals(ImmutableByteArray.from(new byte[]{4, 5, 6}), byteArray.getBytesBetween(3, 6)); - } - - @Test - void getBytesBetweenThrowsWhenInclusiveStartIndexOutOfBounds() { - var byteArray = ImmutableByteArray.from(new byte[]{1, 2, 3, 4, 5, 6}); - - assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(-1, 6)); - assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(9, 10)); - } - - @Test - void getBytesBetweenThrowsWhenExclusiveEndIndexOutOfBounds() { - var byteArray = ImmutableByteArray.from(new byte[]{1, 2, 3, 4, 5, 6}); - - assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(0, 7)); - assertThrows(IndexOutOfBoundsException.class, () -> byteArray.getBytesBetween(0, -1)); - } - - @Test - void getBytesBetweenThrowsWhenStartIndexIsGreaterThanEndIndex() { - var byteArray = ImmutableByteArray.from(new byte[]{1, 2, 3, 4, 5, 6}); - - assertThrows(IllegalArgumentException.class, () -> byteArray.getBytesBetween(3, 2)); - } -} From 91c1430c3dde414cf1e3079eee0932e39340c459 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Thu, 28 Dec 2023 19:34:53 +0200 Subject: [PATCH 07/14] Add ensureCapacity to calculateTokensSmall to minimize copies --- .../java/com/knuddels/jtokkit/GptBytePairEncoding.java | 2 +- .../main/java/com/knuddels/jtokkit/TokenEncoder.java | 10 ++++++---- .../java/com/knuddels/jtokkit/TokenEncoderLarge.java | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java index 1e8c99ee..851a3e62 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java @@ -100,7 +100,7 @@ private EncodingResult encodeOrdinaryInternal(String text, int maxTokenCount, bo int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List out) { int tokenCount = 0; - List ranks = new ArrayList<>(); // reused to avoid allocations + ArrayList ranks = new ArrayList<>(); // reused to avoid allocations for (Matcher matcher = pattern.matcher(text); tokenCount < maxTokenCount && matcher.find(); ) { byte[] bytes = matcher.group().getBytes(UTF_8); tokenCount += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, bytes, out, ranks); diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 16d10aa9..9b21f05f 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -94,7 +94,7 @@ public static int getPreviousIndex(List ranks, int previousIndex) { return previousIndex; } - public int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] utf8Bytes, List out, List ranks) { + int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] utf8Bytes, List out, ArrayList ranks) { ByteArrayWrapper match = new ByteArrayWrapper(utf8Bytes); int encoded = encode(match); if (encoded != MAX_RANK) { @@ -105,16 +105,17 @@ public int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] } else { int length = match.length(); if (length < VERY_LARGE_TOKENIZER_BYTE_THRESHOLD) { - return addTokensAndGetCountSmall(maxTokenCount, keepEncodings, out, ranks, match, length); + return calculateTokensSmall(maxTokenCount, keepEncodings, out, ranks, match, length); } else { - return addTokensAndGetCountLarge(this, maxTokenCount, keepEncodings, out, match, length); + return calculateTokensLarge(this, maxTokenCount, keepEncodings, out, match, length); } } } - private int addTokensAndGetCountSmall(int maxTokenCount, boolean keepEncodings, List out, List ranks, ByteArrayWrapper match, int length) { + private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, List out, ArrayList ranks, ByteArrayWrapper match, int length) { assert length > 1 : "Already filtered out"; ranks.clear(); + ranks.ensureCapacity(length + 1); int validRanks = 0; int minRankIndex = -1; @@ -144,6 +145,7 @@ private int addTokensAndGetCountSmall(int maxTokenCount, boolean keepEncodings, } int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, List ranks, int validRanks, int minRankIndex) { + assert getMinRankIndex(ranks) == minRankIndex; while (validRanks > 0) { assert minRankIndex >= 0; diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java index 5473b0a7..d818e162 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -7,7 +7,7 @@ import static com.knuddels.jtokkit.TokenEncoder.MAX_RANK; final class TokenEncoderLarge { - static int addTokensAndGetCountLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, List out, ByteArrayWrapper match, int length) { + static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, List out, ByteArrayWrapper match, int length) { assert length > 1 : "Already filtered out"; TreeMap> rankMap = new TreeMap<>(); From fe46f897c5b5128dd27346ecdd0dee38820a867a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Thu, 28 Dec 2023 19:35:30 +0200 Subject: [PATCH 08/14] Use simple HashMap instead of ConcurrentHashMap It's simpler and in the current implementation it's basically just as fast. --- .../com/knuddels/jtokkit/SpecialEncoder.java | 4 ++-- .../com/knuddels/jtokkit/TokenEncoder.java | 23 +++++++------------ 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/SpecialEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/SpecialEncoder.java index 1dd1fbad..a0077ff0 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/SpecialEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/SpecialEncoder.java @@ -1,7 +1,7 @@ package com.knuddels.jtokkit; +import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import static java.nio.charset.StandardCharsets.UTF_8; @@ -11,7 +11,7 @@ final class SpecialEncoder { private final Map encodedToDecoded; public SpecialEncoder(Map encoder) { - this.encodedToDecoded = new ConcurrentHashMap<>(encoder.size()); + this.encodedToDecoded = new HashMap<>(encoder.size()); for (Map.Entry entry : encoder.entrySet()) { String key = entry.getKey(); Integer value = entry.getValue(); diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 9b21f05f..4fb6d0b6 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -1,11 +1,8 @@ package com.knuddels.jtokkit; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; -import java.util.concurrent.ConcurrentHashMap; +import java.util.*; -import static com.knuddels.jtokkit.TokenEncoderLarge.addTokensAndGetCountLarge; +import static com.knuddels.jtokkit.TokenEncoderLarge.calculateTokensLarge; final class TokenEncoder { public static final int DUMMY_RANK = Integer.MAX_VALUE; @@ -21,10 +18,10 @@ public TokenEncoder(Map encoder) { encoder.forEach((k, v) -> { length++; ByteArrayWrapper key = new ByteArrayWrapper(k); - tempEncoders.computeIfAbsent(k.length, integer -> new ConcurrentHashMap<>()).put(key, v); + tempEncoders.computeIfAbsent(k.length, integer -> new HashMap<>()).put(key, v); }); //noinspection unchecked - encoders = new ConcurrentHashMap[tempEncoders.lastKey() + 1]; + encoders = new Map[tempEncoders.lastKey() + 1]; tempEncoders.forEach((k, v) -> encoders[k] = v); } else { //noinspection unchecked @@ -32,7 +29,7 @@ public TokenEncoder(Map encoder) { } } - public static int getMinRankIndex(List ranks) { + private static int getMinRankIndex(List ranks) { int minRankIndex = -1; int minRank = MAX_RANK; @@ -80,14 +77,14 @@ public static int getMinRankIndex(List ranks) { return minRankIndex; } - public static int getNextIndex(List ranks, int nextIndex) { + private static int getNextIndex(List ranks, int nextIndex) { while (nextIndex < ranks.size() && ranks.get(nextIndex) == DUMMY_RANK) { nextIndex++; } return nextIndex; } - public static int getPreviousIndex(List ranks, int previousIndex) { + private static int getPreviousIndex(List ranks, int previousIndex) { while (previousIndex >= 0 && ranks.get(previousIndex) == DUMMY_RANK) { previousIndex--; } @@ -182,7 +179,7 @@ int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, List return length; } - int encode(ByteArrayWrapper payload) { + private int encode(ByteArrayWrapper payload) { if (payload.length() < encoders.length) { Map encoder = encoders[payload.length()]; if (encoder != null) { @@ -204,8 +201,4 @@ int encode(ByteArrayWrapper piece, int start, int end) { return encode(piece.getBytesBetween(start, end)); } } - - public int length() { - return length; - } } \ No newline at end of file From 17a83be454a0777ec6095f8398b4df6ab04dc1a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Fri, 29 Dec 2023 20:03:07 +0200 Subject: [PATCH 09/14] Iterate over all occurrences of the found minimum token in the TokenEncoderLarge --- .../com/knuddels/jtokkit/TokenEncoder.java | 2 - .../knuddels/jtokkit/TokenEncoderLarge.java | 82 +++++++++++-------- .../java/com/knuddels/jtokkit/Cl100kTest.java | 3 +- 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 4fb6d0b6..d8ee63f8 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -9,14 +9,12 @@ final class TokenEncoder { public static final int MAX_RANK = Integer.MAX_VALUE - 1; private final Map[] encoders; private int VERY_LARGE_TOKENIZER_BYTE_THRESHOLD; - private int length = 0; public TokenEncoder(Map encoder) { if (!encoder.isEmpty()) { VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = Integer.parseInt(System.getProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", "500")); TreeMap> tempEncoders = new TreeMap<>(); encoder.forEach((k, v) -> { - length++; ByteArrayWrapper key = new ByteArrayWrapper(k); tempEncoders.computeIfAbsent(k.length, integer -> new HashMap<>()).put(key, v); }); diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java index d818e162..419d4d48 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -33,46 +33,49 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo } while (validRanks > 0) { - RankNode minNode = rankMap.firstEntry().getValue().firstEntry().getValue(); - assert minNode.rank != MAX_RANK; - - RankNode previousNode = minNode.prev; - RankNode nextNode = minNode.next; - RankNode nextNextNode = nextNode != null ? nextNode.next : null; - RankNode nextNextNextNode = nextNextNode != null ? nextNextNode.next : null; - - if (previousNode != null) { - int newRank = tokenEncoder.encode(match, previousNode.index, nextNextNode != null ? nextNextNode.index : Integer.MAX_VALUE); - if ((newRank == MAX_RANK) != (previousNode.rank == MAX_RANK)) { - validRanks -= (newRank == MAX_RANK) ? 1 : -1; + TreeMap minNodes = rankMap.firstEntry().getValue(); + for (int i = 0; i < minNodes.size(); i++) { + RankNode minNode = minNodes.firstEntry().getValue(); + assert minNode.rank != MAX_RANK; + + RankNode previousNode = minNode.prev; + RankNode nextNode = minNode.next; + RankNode nextNextNode = nextNode != null ? nextNode.next : null; + RankNode nextNextNextNode = nextNextNode != null ? nextNextNode.next : null; + + if (previousNode != null) { + int newRank = tokenEncoder.encode(match, previousNode.index, nextNextNode != null ? nextNextNode.index : Integer.MAX_VALUE); + if ((newRank == MAX_RANK) != (previousNode.rank == MAX_RANK)) { + validRanks -= (newRank == MAX_RANK) ? 1 : -1; + } + removeNode(rankMap.get(previousNode.rank), rankMap, previousNode); + previousNode.rank = newRank; + rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(previousNode.index, previousNode); } - removeNode(rankMap, previousNode); - previousNode.rank = newRank; - rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(previousNode.index, previousNode); - } - int newRank = tokenEncoder.encode(match, minNode.index, nextNextNextNode != null ? nextNextNextNode.index : Integer.MAX_VALUE); - if ((newRank == MAX_RANK) != (minNode.rank == MAX_RANK)) { - validRanks--; - } - removeNode(rankMap, minNode); - minNode.rank = newRank; - rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(minNode.index, minNode); - - minNode.next = nextNextNode; - if (nextNode != null) { - if (nextNextNode != null) { - nextNextNode.prev = minNode; - } - if (nextNode.rank != MAX_RANK) { + int newRank = tokenEncoder.encode(match, minNode.index, nextNextNextNode != null ? nextNextNextNode.index : Integer.MAX_VALUE); + if ((newRank == MAX_RANK) != (minNode.rank == MAX_RANK)) { validRanks--; } - removeNode(rankMap, nextNode); - } + removeNode(minNodes, rankMap, minNode); + minNode.rank = newRank; + rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(minNode.index, minNode); + + minNode.next = nextNextNode; + if (nextNode != null) { + if (nextNextNode != null) { + nextNextNode.prev = minNode; + } + if (nextNode.rank != MAX_RANK) { + validRanks--; + } + removeNode(rankMap.get(nextNode.rank), rankMap, nextNode); + } - length--; + length--; + } } - assert rankMap.firstEntry().getValue().firstEntry().getValue().rank == MAX_RANK; + assert rankMap.firstEntry().getValue().values().iterator().next().rank == MAX_RANK; if (keepEncodings) { while (head.next != null && out.size() < maxTokenCount) { @@ -86,8 +89,7 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo return length; } - static void removeNode(TreeMap> rankMap, RankNode nextNode) { - TreeMap nodeMap = rankMap.get(nextNode.rank); + static void removeNode(TreeMap nodeMap, TreeMap> rankMap, RankNode nextNode) { if (nodeMap.size() == 1) { assert nodeMap.containsKey(nextNode.index); rankMap.remove(nextNode.rank); @@ -105,5 +107,13 @@ private static class RankNode { this.rank = rank; this.index = index; } + + @Override + public String toString() { + return "RankNode{" + + "rank=" + rank + + ", index=" + index + + '}'; + } } } \ No newline at end of file diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index f25048a0..665d2f08 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -130,7 +130,8 @@ void testEdgeCaseRoundTrips() throws Exception { "෫𞅄", "𬕹\n ", " 😈b\n\uD844\uDDAE'ſ\uD84F\uDDB8\uD84C\uDD2CƘ淚", - "𗭾 󻥹\n\uD875\uDDB0蛇" + "𗭾 󻥹\n\uD875\uDDB0蛇", + "こんにちは世界" ); IntStream.range(0, testStrings.size()).forEachOrdered(i -> { From 78974d4ba4c395a3ef329454791106acd41abdec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Fri, 29 Dec 2023 20:19:59 +0200 Subject: [PATCH 10/14] Move decodeToken to TokenEncoder --- .../com/knuddels/jtokkit/GptBytePairEncoding.java | 7 +------ .../main/java/com/knuddels/jtokkit/TokenEncoder.java | 11 +++++++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java index 851a3e62..2562bb91 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java @@ -6,9 +6,7 @@ import java.io.ByteArrayOutputStream; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -25,7 +23,6 @@ class GptBytePairEncoding implements Encoding { private final Pattern pattern; private final TokenEncoder encoder; private final SpecialEncoder specialEncoder; - private final Map encodedToDecoded; /** * Creates a new instance of {@link GptBytePairEncoding}. @@ -37,8 +34,6 @@ class GptBytePairEncoding implements Encoding { this.pattern = params.getPattern(); this.encoder = new TokenEncoder(params.getEncoder()); this.specialEncoder = new SpecialEncoder(params.getSpecialTokensEncoder()); - this.encodedToDecoded = new HashMap<>(params.getEncoder().size()); - params.getEncoder().forEach((k, v) -> encodedToDecoded.put(v, k)); } @Override @@ -136,7 +131,7 @@ public String getName() { } private byte[] decodeToken(int token) { - byte[] decodedToken = encodedToDecoded.computeIfAbsent(token, specialEncoder::decodeIfPresent); + byte[] decodedToken = encoder.decodeToken(token, specialEncoder); return requireNonNull(decodedToken, "Unknown token for decoding: " + token); } } diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index d8ee63f8..c99f4314 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -3,11 +3,14 @@ import java.util.*; import static com.knuddels.jtokkit.TokenEncoderLarge.calculateTokensLarge; +import static java.util.Collections.emptyMap; final class TokenEncoder { public static final int DUMMY_RANK = Integer.MAX_VALUE; public static final int MAX_RANK = Integer.MAX_VALUE - 1; private final Map[] encoders; + private final Map decoder; + private int VERY_LARGE_TOKENIZER_BYTE_THRESHOLD; public TokenEncoder(Map encoder) { @@ -21,9 +24,13 @@ public TokenEncoder(Map encoder) { //noinspection unchecked encoders = new Map[tempEncoders.lastKey() + 1]; tempEncoders.forEach((k, v) -> encoders[k] = v); + + this.decoder = new HashMap<>(encoder.size()); + encoder.forEach((k, v) -> decoder.put(v, k)); } else { //noinspection unchecked encoders = new Map[0]; // for testing + this.decoder = emptyMap(); } } @@ -199,4 +206,8 @@ int encode(ByteArrayWrapper piece, int start, int end) { return encode(piece.getBytesBetween(start, end)); } } + + public byte[] decodeToken(int token, SpecialEncoder specialEncoder) { + return decoder.computeIfAbsent(token, specialEncoder::decodeIfPresent); + } } \ No newline at end of file From ad7292fdf1921b6bde5a1dc9d0176379c285f446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Fri, 29 Dec 2023 21:21:42 +0200 Subject: [PATCH 11/14] Add Cl100kLargeTokenizerTest to run every Cl100k tokenizer test with the large tokenizer as well --- .../com/knuddels/jtokkit/TokenEncoder.java | 11 ++++--- .../java/com/knuddels/jtokkit/Cl100kTest.java | 5 ++-- .../jtokkit/reference/Cl100kBaseTest.java | 22 ++++++++------ .../reference/Cl100kLargeTokenizerTest.java | 30 +++++++++++++++++++ 4 files changed, 53 insertions(+), 15 deletions(-) create mode 100644 lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index c99f4314..9aefc200 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -3,11 +3,14 @@ import java.util.*; import static com.knuddels.jtokkit.TokenEncoderLarge.calculateTokensLarge; +import static java.lang.Integer.MAX_VALUE; +import static java.lang.Integer.parseInt; import static java.util.Collections.emptyMap; -final class TokenEncoder { - public static final int DUMMY_RANK = Integer.MAX_VALUE; - public static final int MAX_RANK = Integer.MAX_VALUE - 1; +public final class TokenEncoder { + public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD"; + public static final int DUMMY_RANK = MAX_VALUE; + public static final int MAX_RANK = MAX_VALUE - 1; private final Map[] encoders; private final Map decoder; @@ -15,7 +18,7 @@ final class TokenEncoder { public TokenEncoder(Map encoder) { if (!encoder.isEmpty()) { - VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = Integer.parseInt(System.getProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", "500")); + VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = parseInt(System.getProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, "500")); TreeMap> tempEncoders = new TreeMap<>(); encoder.forEach((k, v) -> { ByteArrayWrapper key = new ByteArrayWrapper(k); diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index 665d2f08..474c8bc8 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -10,6 +10,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.stream.IntStream; +import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY; import static java.lang.Character.*; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.stream.Collectors.joining; @@ -146,10 +147,10 @@ void testEdgeCaseRoundTrips() throws Exception { @Test void testRoundTripWithRandomStrings() throws Exception { - System.setProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", String.valueOf(Integer.MAX_VALUE)); + System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(Integer.MAX_VALUE)); var arrayEncoder = EncodingFactory.cl100kBase(); - System.setProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", String.valueOf(0)); + System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(0)); var mapEncoder = EncodingFactory.cl100kBase(); var singleTokenStrings = getAllTokens(); IntStream.range(0, 10_000).parallel().forEach(i -> { diff --git a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kBaseTest.java b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kBaseTest.java index d33c2005..a2108af1 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kBaseTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kBaseTest.java @@ -14,6 +14,10 @@ class Cl100kBaseTest { private static final Encoding ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE); + Encoding getEncoding() { + return ENCODING; + } + @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodesCorrectly( @@ -21,7 +25,7 @@ void cl100kBaseEncodesCorrectly( String output ) { var expected = TestUtils.parseEncodingString(output); - var actual = ENCODING.encode(input); + var actual = getEncoding().encode(input); assertEquals(expected, actual); } @@ -29,7 +33,7 @@ void cl100kBaseEncodesCorrectly( @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodesStable(String input) { - var actual = ENCODING.decode(ENCODING.encode(input)); + var actual = getEncoding().decode(getEncoding().encode(input)); assertEquals(input, actual); } @@ -43,7 +47,7 @@ void cl100kBaseEncodesCorrectlyWithMaxTokensSet( ) { var expected = TestUtils.parseEncodingString(output); var expectedWithMaxTokens = TestUtils.parseEncodingString(outputMaxTokens10); - var encodingResult = ENCODING.encode(input, 10); + var encodingResult = getEncoding().encode(input, 10); assertEquals(expectedWithMaxTokens, encodingResult.getTokens()); assertEquals(expected.size() > expectedWithMaxTokens.size(), encodingResult.isTruncated()); @@ -52,7 +56,7 @@ void cl100kBaseEncodesCorrectlyWithMaxTokensSet( @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodesStableWithMaxTokensSet(String input) { - var actual = ENCODING.decode(ENCODING.encode(input, 10).getTokens()); + var actual = getEncoding().decode(getEncoding().encode(input, 10).getTokens()); assertTrue(input.startsWith(actual)); } @@ -64,7 +68,7 @@ void cl100kBaseEncodeOrdinaryEncodesCorrectly( String output ) { var expected = TestUtils.parseEncodingString(output); - var actual = ENCODING.encodeOrdinary(input); + var actual = getEncoding().encodeOrdinary(input); assertEquals(expected, actual); } @@ -78,7 +82,7 @@ void cl100kBaseEncodeOrdinaryEncodesCorrectly( ) { var expected = TestUtils.parseEncodingString(output); var expectedWithMaxTokens = TestUtils.parseEncodingString(outputMaxTokens10); - var encodingResult = ENCODING.encodeOrdinary(input, 10); + var encodingResult = getEncoding().encodeOrdinary(input, 10); assertEquals(expectedWithMaxTokens, encodingResult.getTokens()); assertEquals(expected.size() > expectedWithMaxTokens.size(), encodingResult.isTruncated()); @@ -87,7 +91,7 @@ void cl100kBaseEncodeOrdinaryEncodesCorrectly( @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodeOrdinaryEncodesStable(String input) { - var actual = ENCODING.decode(ENCODING.encodeOrdinary(input)); + var actual = getEncoding().decode(getEncoding().encodeOrdinary(input)); assertEquals(input, actual); } @@ -95,7 +99,7 @@ void cl100kBaseEncodeOrdinaryEncodesStable(String input) { @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodeOrdinaryEncodesStableWithMaxTokensSet(String input) { - var actual = ENCODING.decode(ENCODING.encodeOrdinary(input, 10).getTokens()); + var actual = getEncoding().decode(getEncoding().encodeOrdinary(input, 10).getTokens()); assertTrue(input.startsWith(actual)); } @@ -103,7 +107,7 @@ void cl100kBaseEncodeOrdinaryEncodesStableWithMaxTokensSet(String input) { @Test void cl100kBaseEncodeOrdinaryEncodesSpecialTokensCorrectly() { var input = "Hello<|endoftext|>, <|fim_prefix|> <|fim_middle|> world <|fim_suffix|> ! <|endofprompt|>"; - var actual = ENCODING.decode(ENCODING.encodeOrdinary(input)); + var actual = getEncoding().decode(getEncoding().encodeOrdinary(input)); assertEquals(input, actual); } diff --git a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java new file mode 100644 index 00000000..c7ea36fa --- /dev/null +++ b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java @@ -0,0 +1,30 @@ +package com.knuddels.jtokkit.reference; + +import com.knuddels.jtokkit.Encodings; +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY; + +class Cl100kLargeTokenizerTest extends Cl100kBaseTest { + + public static Encoding ENCODING; + + @BeforeAll + static void beforeAll() { + System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(0)); + ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE); + } + + @AfterAll + static void afterAll() { + System.clearProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY); + + } + + Encoding getEncoding() { + return ENCODING; + } +} From f95d368f3be2b2fac49c00695c75f48a474dd68f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Sat, 30 Dec 2023 09:27:57 +0200 Subject: [PATCH 12/14] Split Cl100kTest to LargeTokenizer version as well --- .../jtokkit/Cl100kLargeTokenizerTest.java | 29 +++++++++++++ .../java/com/knuddels/jtokkit/Cl100kTest.java | 42 ++++++++----------- 2 files changed, 47 insertions(+), 24 deletions(-) create mode 100644 lib/src/test/java/com/knuddels/jtokkit/Cl100kLargeTokenizerTest.java diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kLargeTokenizerTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kLargeTokenizerTest.java new file mode 100644 index 00000000..debe1c89 --- /dev/null +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kLargeTokenizerTest.java @@ -0,0 +1,29 @@ +package com.knuddels.jtokkit; + +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY; + +class Cl100kLargeTokenizerTest extends Cl100kTest { + + public static Encoding ENCODING; + + @BeforeAll + static void beforeAll() { + System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(0)); + ENCODING = EncodingFactory.cl100kBase(); + } + + @AfterAll + static void afterAll() { + System.clearProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY); + + } + + Encoding getEncoding() { + return ENCODING; + } +} diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index 474c8bc8..19bee341 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -1,16 +1,15 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingType; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.util.List; -import java.util.Map; import java.util.TreeMap; import java.util.concurrent.ThreadLocalRandom; import java.util.stream.IntStream; -import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY; import static java.lang.Character.*; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.stream.Collectors.joining; @@ -40,6 +39,10 @@ private static ThreadLocalRandom rand() { return ThreadLocalRandom.current(); } + Encoding getEncoding() { + return ENCODING; + } + @Disabled @Test void measureEncodingSpeeds() { @@ -54,12 +57,12 @@ void measureEncodingSpeeds() { var inputString = input.toString(); for (var j = 0; j < 10 * iterations; j++) { - var warmup = ENCODING.encode(inputString); + var warmup = getEncoding().encode(inputString); assert !warmup.isEmpty(); } var startTime = System.nanoTime(); for (var j = 0; j < iterations; j++) { - var encodingResult = ENCODING.encode(inputString); + var encodingResult = getEncoding().encode(inputString); assert !encodingResult.isEmpty(); } var endTime = System.nanoTime(); @@ -112,6 +115,7 @@ void testEdgeCaseRoundTrips() throws Exception { "Hello \n\n World !", " It's 2:30pm;\n\n\n\nlet's eat, sleep , and code!", "'Thank God, here it is.' But when we took up the trunk...", + "What in the world are you doing???!!!", "user@example.com", "this is a 'quoted' word", "  a", @@ -139,41 +143,31 @@ void testEdgeCaseRoundTrips() throws Exception { var testString = testStrings.get(i); System.out.println("Validating `" + normalizeStringForTesting(testString) + "`"); - var actualTokens = ENCODING.encode(testString); - var decoded = ENCODING.decode(actualTokens); + var actualTokens = getEncoding().encode(testString); + var decoded = getEncoding().decode(actualTokens); assertEquals(testString, decoded, decoded); }); } @Test void testRoundTripWithRandomStrings() throws Exception { - System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(Integer.MAX_VALUE)); - var arrayEncoder = EncodingFactory.cl100kBase(); - - System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(0)); - var mapEncoder = EncodingFactory.cl100kBase(); var singleTokenStrings = getAllTokens(); - IntStream.range(0, 10_000).parallel().forEach(i -> { + IntStream.range(0, 100_000).parallel().forEach(i -> { String testString; do { testString = generateRandomString(10, singleTokenStrings); } while (!UTF_8.newEncoder().canEncode(testString)); var maxTokenCount = rand().nextInt(1, 2 * testString.length()); + var actualTokens = getEncoding().encode(testString); + assertEquals(actualTokens.size(), getEncoding().countTokens(testString)); - var encoders = Map.of(arrayEncoder, "arrayEncoder", mapEncoder, "mapEncoder"); - for (Encoding encoder : encoders.keySet()) { -// System.out.println("Validating `" + normalizeStringForTesting(testString) + "` with " + encoders.get(encoder) + " and maxTokenCount = " + maxTokenCount); - var actualTokens = encoder.encode(testString); - assertEquals(actualTokens.size(), encoder.countTokens(testString)); + var decodedTokens = getEncoding().decode(actualTokens); + assertEquals(testString, decodedTokens, decodedTokens); - var decodedTokens = encoder.decode(actualTokens); - assertEquals(testString, decodedTokens, decodedTokens); - - var actualTrimmedTokens = encoder.encode(testString, maxTokenCount).getTokens(); - var decodedTrimmedTokens = encoder.decode(actualTrimmedTokens); - assertTrue(testString.startsWith(decodedTrimmedTokens)); - } + var actualTrimmedTokens = getEncoding().encode(testString, maxTokenCount).getTokens(); + var decodedTrimmedTokens = getEncoding().decode(actualTrimmedTokens); + assertTrue(testString.startsWith(decodedTrimmedTokens)); }); } From 5ee81a144672b8e858c1bf64f3d8d4db8419c97a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Sat, 30 Dec 2023 09:52:51 +0200 Subject: [PATCH 13/14] Fix Hungarian test prompt --- lib/src/test/resources/base_prompts.csv | 2 +- lib/src/test/resources/cl100k_base_encodings.csv | 2 +- lib/src/test/resources/p50k_base_encodings.csv | 2 +- lib/src/test/resources/p50k_edit_encodings.csv | 2 +- lib/src/test/resources/r50k_base_encodings.csv | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/src/test/resources/base_prompts.csv b/lib/src/test/resources/base_prompts.csv index f7008d6a..0489369a 100644 --- a/lib/src/test/resources/base_prompts.csv +++ b/lib/src/test/resources/base_prompts.csv @@ -104,7 +104,7 @@ Kateri je tvoj najljubši okus sladoleda?,"[42, 977, 72, 4864, 259, 3415, 73, 30 Quel est ton livre préféré?,"[2232, 301, 1826, 8941, 56984, 27389, 69, 68862, 30]" Qual é a tua cor favorita?,"[32129, 4046, 264, 64984, 1867, 4799, 6388, 30]" Koja ti je omiljena boja?,"[42, 78, 5697, 9165, 4864, 8019, 321, 73, 7304, 712, 5697, 30]" -Melyik a kedvenc étel?,"[44, 989, 1609, 264, 80142, 85, 967, 14240, 301, 30]" +Melyik a kedvenc ételed?,"[44, 989, 1609, 264, 80142, 85, 967, 4046, 668, 839, 30]" Koji je tvoj omiljeni grad?,"[42, 28000, 4864, 259, 3415, 73, 8019, 321, 24041, 72, 6117, 30]" Quale è il tuo piatto preferito?,"[2232, 1604, 11676, 3900, 63258, 9115, 17173, 10932, 6491, 30]" Kurš ir tavs mīļākais TV šovs?,"[42, 324, 11906, 6348, 259, 39851, 296, 61711, 128, 120, 31757, 74, 2852, 6007, 37524, 869, 82, 30]" diff --git a/lib/src/test/resources/cl100k_base_encodings.csv b/lib/src/test/resources/cl100k_base_encodings.csv index fa656186..6ca65e20 100644 --- a/lib/src/test/resources/cl100k_base_encodings.csv +++ b/lib/src/test/resources/cl100k_base_encodings.csv @@ -104,7 +104,7 @@ Kateri je tvoj najljubši okus sladoleda?,"[42, 977, 72, 4864, 259, 3415, 73, 30 Quel est ton livre préféré?,"[2232, 301, 1826, 8941, 56984, 27389, 69, 68862, 30]","[2232, 301, 1826, 8941, 56984, 27389, 69, 68862, 30]" Qual é a tua cor favorita?,"[32129, 4046, 264, 64984, 1867, 4799, 6388, 30]","[32129, 4046, 264, 64984, 1867, 4799, 6388, 30]" Koja ti je omiljena boja?,"[42, 78, 5697, 9165, 4864, 8019, 321, 73, 7304, 712, 5697, 30]","[42, 78, 5697, 9165, 4864, 8019, 321, 73, 7304, 712]" -Melyik a kedvenc étel?,"[44, 989, 1609, 264, 80142, 85, 967, 14240, 301, 30]","[44, 989, 1609, 264, 80142, 85, 967, 14240, 301, 30]" +Melyik a kedvenc ételed?,"[44, 989, 1609, 264, 80142, 85, 967, 4046, 668, 839, 30]","[44, 989, 1609, 264, 80142, 85, 967, 4046, 668, 839]" Koji je tvoj omiljeni grad?,"[42, 28000, 4864, 259, 3415, 73, 8019, 321, 24041, 72, 6117, 30]","[42, 28000, 4864, 259, 3415, 73, 8019, 321, 24041, 72]" Quale è il tuo piatto preferito?,"[2232, 1604, 11676, 3900, 63258, 9115, 17173, 10932, 6491, 30]","[2232, 1604, 11676, 3900, 63258, 9115, 17173, 10932, 6491, 30]" Kurš ir tavs mīļākais TV šovs?,"[42, 324, 11906, 6348, 259, 39851, 296, 61711, 128, 120, 31757, 74, 2852, 6007, 37524, 869, 82, 30]","[42, 324, 11906, 6348, 259, 39851, 296, 61711, 128, 120]" diff --git a/lib/src/test/resources/p50k_base_encodings.csv b/lib/src/test/resources/p50k_base_encodings.csv index 1529211a..ca1dc49a 100644 --- a/lib/src/test/resources/p50k_base_encodings.csv +++ b/lib/src/test/resources/p50k_base_encodings.csv @@ -104,7 +104,7 @@ Kateri je tvoj najljubši okus sladoleda?,"[42, 729, 72, 11223, 256, 13038, 73, Quel est ton livre préféré?,"[48, 2731, 1556, 5680, 17717, 260, 778, 2634, 69, 2634, 29350, 30]","[48, 2731, 1556, 5680, 17717, 260, 778, 2634, 69, 2634]" Qual é a tua cor favorita?,"[46181, 38251, 257, 256, 6413, 1162, 2661, 5350, 30]","[46181, 38251, 257, 256, 6413, 1162, 2661, 5350, 30]" Koja ti je omiljena boja?,"[48735, 6592, 46668, 11223, 267, 25433, 73, 8107, 1489, 6592, 30]","[48735, 6592, 46668, 11223, 267, 25433, 73, 8107, 1489, 6592]" -Melyik a kedvenc étel?,"[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125, 417, 30]","[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125]" +Melyik a kedvenc ételed?,"[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125, 18449, 30]","[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125]" Koji je tvoj omiljeni grad?,"[42, 31370, 11223, 256, 13038, 73, 267, 25433, 73, 43850, 3915, 30]","[42, 31370, 11223, 256, 13038, 73, 267, 25433, 73, 43850]" Quale è il tuo piatto preferito?,"[46181, 68, 6184, 101, 4229, 12777, 78, 31028, 45807, 4702, 10094, 30]","[46181, 68, 6184, 101, 4229, 12777, 78, 31028, 45807, 4702]" Kurš ir tavs mīļākais TV šovs?,"[42, 333, 32790, 4173, 256, 615, 82, 285, 18962, 128, 120, 10235, 4914, 271, 3195, 25370, 94, 709, 82, 30]","[42, 333, 32790, 4173, 256, 615, 82, 285, 18962]" diff --git a/lib/src/test/resources/p50k_edit_encodings.csv b/lib/src/test/resources/p50k_edit_encodings.csv index 3bb7431a..2179bc8e 100644 --- a/lib/src/test/resources/p50k_edit_encodings.csv +++ b/lib/src/test/resources/p50k_edit_encodings.csv @@ -104,7 +104,7 @@ Kateri je tvoj najljubši okus sladoleda?,"[42, 729, 72, 11223, 256, 13038, 73, Quel est ton livre préféré?,"[48, 2731, 1556, 5680, 17717, 260, 778, 2634, 69, 2634, 29350, 30]","[48, 2731, 1556, 5680, 17717, 260, 778, 2634, 69, 2634]" Qual é a tua cor favorita?,"[46181, 38251, 257, 256, 6413, 1162, 2661, 5350, 30]","[46181, 38251, 257, 256, 6413, 1162, 2661, 5350, 30]" Koja ti je omiljena boja?,"[48735, 6592, 46668, 11223, 267, 25433, 73, 8107, 1489, 6592, 30]","[48735, 6592, 46668, 11223, 267, 25433, 73, 8107, 1489, 6592]" -Melyik a kedvenc étel?,"[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125, 417, 30]","[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125]" +Melyik a kedvenc ételed?,"[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125, 18449, 30]","[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125]" Koji je tvoj omiljeni grad?,"[42, 31370, 11223, 256, 13038, 73, 267, 25433, 73, 43850, 3915, 30]","[42, 31370, 11223, 256, 13038, 73, 267, 25433, 73, 43850]" Quale è il tuo piatto preferito?,"[46181, 68, 6184, 101, 4229, 12777, 78, 31028, 45807, 4702, 10094, 30]","[46181, 68, 6184, 101, 4229, 12777, 78, 31028, 45807, 4702]" Kurš ir tavs mīļākais TV šovs?,"[42, 333, 32790, 4173, 256, 615, 82, 285, 18962, 128, 120, 10235, 4914, 271, 3195, 25370, 94, 709, 82, 30]","[42, 333, 32790, 4173, 256, 615, 82, 285, 18962]" diff --git a/lib/src/test/resources/r50k_base_encodings.csv b/lib/src/test/resources/r50k_base_encodings.csv index e79cae81..9e858f13 100644 --- a/lib/src/test/resources/r50k_base_encodings.csv +++ b/lib/src/test/resources/r50k_base_encodings.csv @@ -104,7 +104,7 @@ Kateri je tvoj najljubši okus sladoleda?,"[42, 729, 72, 11223, 256, 13038, 73, Quel est ton livre préféré?,"[48, 2731, 1556, 5680, 17717, 260, 778, 2634, 69, 2634, 29350, 30]","[48, 2731, 1556, 5680, 17717, 260, 778, 2634, 69, 2634]" Qual é a tua cor favorita?,"[46181, 38251, 257, 256, 6413, 1162, 2661, 5350, 30]","[46181, 38251, 257, 256, 6413, 1162, 2661, 5350, 30]" Koja ti je omiljena boja?,"[48735, 6592, 46668, 11223, 267, 25433, 73, 8107, 1489, 6592, 30]","[48735, 6592, 46668, 11223, 267, 25433, 73, 8107, 1489, 6592]" -Melyik a kedvenc étel?,"[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125, 417, 30]","[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125]" +Melyik a kedvenc ételed?,"[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125, 18449, 30]","[5308, 306, 1134, 257, 479, 276, 574, 66, 220, 25125]" Koji je tvoj omiljeni grad?,"[42, 31370, 11223, 256, 13038, 73, 267, 25433, 73, 43850, 3915, 30]","[42, 31370, 11223, 256, 13038, 73, 267, 25433, 73, 43850]" Quale è il tuo piatto preferito?,"[46181, 68, 6184, 101, 4229, 12777, 78, 31028, 45807, 4702, 10094, 30]","[46181, 68, 6184, 101, 4229, 12777, 78, 31028, 45807, 4702]" Kurš ir tavs mīļākais TV šovs?,"[42, 333, 32790, 4173, 256, 615, 82, 285, 18962, 128, 120, 10235, 4914, 271, 3195, 25370, 94, 709, 82, 30]","[42, 333, 32790, 4173, 256, 615, 82, 285, 18962]" From 03fcd3bfc09ac0821d9976cc88b8dfece932e083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Sat, 30 Dec 2023 16:38:13 +0200 Subject: [PATCH 14/14] Don't remove same tokens during TokenEncoderLarge's identical tokens iteration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 4.547 ± 0.056 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 3.944 ± 0.031 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.427 ± 0.065 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.375 ± 0.062 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 5.073 ± 0.063 s/op --- .../com/knuddels/jtokkit/DataDownloader.java | 4 +-- .../knuddels/jtokkit/TokenEncoderLarge.java | 33 ++++++++++++------- .../java/com/knuddels/jtokkit/Cl100kTest.java | 2 +- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/DataDownloader.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/DataDownloader.java index ca7f5a6e..64a435d8 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/DataDownloader.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/DataDownloader.java @@ -150,7 +150,7 @@ public static void main(String[] args) throws Exception { "'s", "'t", "'re", "'ve", "'m", "'ll", "'d", "'x", "x", "123", - "ő", + "a", ",.!?:; \n", "\n \n", " ", @@ -177,7 +177,7 @@ public static void main(String[] args) throws Exception { } var totalSize = calculateTotalFileSize(rootFolder); - if (totalSize != 99_945_290) { + if (totalSize != 99_925_295) { throw new AssertionError("Total size did not match expected value, actual: " + totalSize); } } diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java index 419d4d48..ada9b364 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -2,6 +2,7 @@ import java.util.List; +import java.util.Map.Entry; import java.util.TreeMap; import static com.knuddels.jtokkit.TokenEncoder.MAX_RANK; @@ -33,10 +34,12 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo } while (validRanks > 0) { - TreeMap minNodes = rankMap.firstEntry().getValue(); - for (int i = 0; i < minNodes.size(); i++) { - RankNode minNode = minNodes.firstEntry().getValue(); - assert minNode.rank != MAX_RANK; + TreeMap minNodes = rankMap.pollFirstEntry().getValue(); + int firstIndex; + for (Entry entry = minNodes.firstEntry(); entry != null; entry = minNodes.ceilingEntry(firstIndex)) { + RankNode minNode = entry.getValue(); + int minRank = minNode.rank; + assert minRank != MAX_RANK; RankNode previousNode = minNode.prev; RankNode nextNode = minNode.next; @@ -45,19 +48,22 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo if (previousNode != null) { int newRank = tokenEncoder.encode(match, previousNode.index, nextNextNode != null ? nextNextNode.index : Integer.MAX_VALUE); - if ((newRank == MAX_RANK) != (previousNode.rank == MAX_RANK)) { - validRanks -= (newRank == MAX_RANK) ? 1 : -1; + if (previousNode.rank != newRank) { + if ((newRank == MAX_RANK) != (previousNode.rank == MAX_RANK)) { + validRanks -= (newRank == MAX_RANK) ? 1 : -1; + } + assert previousNode.rank != minRank; + removeNode(rankMap.get(previousNode.rank), rankMap, previousNode); + previousNode.rank = newRank; + rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(previousNode.index, previousNode); } - removeNode(rankMap.get(previousNode.rank), rankMap, previousNode); - previousNode.rank = newRank; - rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(previousNode.index, previousNode); } int newRank = tokenEncoder.encode(match, minNode.index, nextNextNextNode != null ? nextNextNextNode.index : Integer.MAX_VALUE); - if ((newRank == MAX_RANK) != (minNode.rank == MAX_RANK)) { + if (newRank == MAX_RANK) { validRanks--; } - removeNode(minNodes, rankMap, minNode); + firstIndex = minNode.index + 1; minNode.rank = newRank; rankMap.computeIfAbsent(newRank, k -> new TreeMap<>()).put(minNode.index, minNode); @@ -68,8 +74,11 @@ static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, bo } if (nextNode.rank != MAX_RANK) { validRanks--; + if (nextNode.rank != minRank) { + removeNode(rankMap.get(nextNode.rank), rankMap, nextNode); + } } - removeNode(rankMap.get(nextNode.rank), rankMap, nextNode); + firstIndex = nextNode.index + 1; } length--; diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index 19bee341..e96d3b1b 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -1,7 +1,6 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; -import com.knuddels.jtokkit.api.EncodingType; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -107,6 +106,7 @@ void testEdgeCaseRoundTrips() throws Exception { "Many words map to one token, but some don't: indivisible.\n\nUnicode characters like emojis may be split into many tokens containing the underlying bytes: \uD83E\uDD1A\uD83C\uDFFE\n\nSequences of characters commonly found next to each other may be grouped together: 1234567890", "I paid $123,456 to 9876543210 people!", "Mixed script: 你好 world! 🌍", + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "Unicode snowman: ☃️", "I'm: 0\n", "We'll meet at 3 o'clock.",