diff --git a/README.md b/README.md index 7fe6e4e3..d2929400 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ [![javadoc](https://javadoc.io/badge2/com.knuddels/jtokkit/javadoc.svg)](https://javadoc.io/doc/com.knuddels/jtokkit) Welcome to JTokkit, a Java tokenizer library designed for use with OpenAI models. + ```java EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE); @@ -20,6 +21,7 @@ enc = registry.getEncodingForModel(ModelType.TEXT_EMBEDDING_ADA_002); For a quick getting started, see our [documentation](https://jtokkit.knuddels.de/). ## 📖 Introduction + JTokkit aims to be a fast and efficient tokenizer designed for use in natural language processing tasks using the OpenAI models. It provides an easy-to-use interface for tokenizing input text, for example for counting required tokens @@ -42,7 +44,6 @@ and `cl100k_base` ✅ Fast and efficient performance - 🔨 Handling of special tokens during encoding (not started) ## 📊 Performance @@ -54,6 +55,7 @@ JTokkit is between 2-3 times faster than a comparable tokenizer. For details on the benchmark, see the [benchmark](benchmark) directory. ## 🛠️ Installation + You can install JTokkit by adding the following dependency to your Maven project: ```xml @@ -73,6 +75,7 @@ dependencies { ``` ## 🔰 Getting Started + To use JTokkit, simply create a new `EncodingRegistry` and use `getEncoding` to retrieve the encoding you want to use. You can then use the `encode` and `decode` methods to encode and decode text. @@ -80,9 +83,9 @@ retrieve the encoding you want to use. You can then use the `encode` and ```java EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE); -List encoded = enc.encode("This is a sample sentence."); +IntArrayList encoded = enc.encode("This is a sample sentence."); // encoded = [2028, 374, 264, 6205, 11914, 13] - + String decoded = enc.decode(encoded); // decoded = "This is a sample sentence." @@ -100,12 +103,15 @@ You may want to extend JTokkit to support custom encodings. To do so, you have t options: 1. Implement the `Encoding` interface and register it with the `EncodingRegistry` + ```java EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); Encoding customEncoding = new CustomEncoding(); registry.registerEncoding(customEncoding); ``` + 2. Add new parameters for use with the existing BPE algorithm + ```java EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); GptBytePairEncodingParams params = new GptBytePairEncodingParams( @@ -122,6 +128,7 @@ them by using `registry.getEncoding("custom-name")`. See the JavaDoc for more details. ## 📄 License + JTokkit is licensed under the MIT License. See the [LICENSE](https://github.com/knuddelsgmbh/jtokkit/blob/main/LICENSE) file for more information. diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractBenchmark.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractBenchmark.java index 7e2c338c..014b3789 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractBenchmark.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractBenchmark.java @@ -1,6 +1,7 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.IntArrayList; import org.openjdk.jmh.annotations.Benchmark; import java.util.List; @@ -34,5 +35,5 @@ public Object benchmarkCl100kBase(BenchmarkingState state) { * @param fileContents the file contents to encode * @return a list of encoded token lists */ - protected abstract List> encodeAll(Encoding encoding, List fileContents); + protected abstract List encodeAll(Encoding encoding, List fileContents); } diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractMultiThreadedBenchmark.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractMultiThreadedBenchmark.java index a0cc3cfa..80353a54 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractMultiThreadedBenchmark.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractMultiThreadedBenchmark.java @@ -1,46 +1,48 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.IntArrayList; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; + import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.stream.Collectors; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.TearDown; @State(Scope.Thread) public abstract class AbstractMultiThreadedBenchmark extends AbstractBenchmark { - private final int threads; - private ExecutorService executor; + private final int threads; + private ExecutorService executor; - public AbstractMultiThreadedBenchmark(final int threads) { - this.threads = threads; - } + public AbstractMultiThreadedBenchmark(int threads) { + this.threads = threads; + } - @Setup - public void setup() { - executor = Executors.newFixedThreadPool(threads); - } + @Setup + public void setup() { + executor = Executors.newFixedThreadPool(threads); + } - @TearDown - public void tearDown() { - executor.shutdown(); - } + @TearDown + public void tearDown() { + executor.shutdown(); + } - @Override - protected List> encodeAll(final Encoding encoding, final List fileContents) { - final var futures = fileContents.stream() - .map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor)) - .collect(Collectors.toList()); + @Override + protected List encodeAll(Encoding encoding, List fileContents) { + var futures = fileContents.stream() + .map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor)) + .toList(); - CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join(); + CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join(); - return futures.stream() - .map(CompletableFuture::join) - .collect(Collectors.toList()); - } + return futures.stream() + .map(CompletableFuture::join) + .collect(Collectors.toList()); + } } diff --git a/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java b/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java index 896757af..c7ce4014 100644 --- a/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java +++ b/benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java @@ -1,6 +1,7 @@ package com.knuddels.jtokkit; import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.IntArrayList; import org.openjdk.jmh.annotations.Benchmark; import java.util.List; @@ -18,7 +19,7 @@ public int benchmarkCl100kBaseTokenCount(BenchmarkingState state) { } @Override - protected List> encodeAll(final Encoding encoding, final List fileContents) { + protected List encodeAll(Encoding encoding, List fileContents) { return fileContents.stream() .map(encoding::encode) .toList(); diff --git a/docs/docs/getting-started/usage.md b/docs/docs/getting-started/usage.md index 152119b2..7eecd98a 100644 --- a/docs/docs/getting-started/usage.md +++ b/docs/docs/getting-started/usage.md @@ -9,9 +9,12 @@ To use JTokkit, first create a new `EncodingRegistry`: EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); ``` -Make sure to keep a reference to the registry, as the creation of the registry is expensive. Creating the registry loads the vocabularies from the classpath. The registry itself handles caching of the loaded encodings. It is thread-safe and can safely be used concurrently by multiple components. +Make sure to keep a reference to the registry, as the creation of the registry is expensive. Creating the registry loads +the vocabularies from the classpath. The registry itself handles caching of the loaded encodings. It is thread-safe and +can safely be used concurrently by multiple components. -If you do not want to automatically load all vocabularies of all encodings on registry creation, you can use the following lazy loading registry. +If you do not want to automatically load all vocabularies of all encodings on registry creation, you can use the +following lazy loading registry. ```java EncodingRegistry registry = Encodings.newLazyEncodingRegistry(); @@ -45,7 +48,7 @@ Optional encoding = registry.getEncodingForModel("gpt_4"); You can use an `Encoding` to encode and decode text: ```java -List encoded = encoding.encode("This is a sample sentence."); +IntArrayList encoded = encoding.encode("This is a sample sentence."); // encoded = [2028, 374, 264, 6205, 11914, 13] String decoded = encoding.decode(encoded); @@ -56,7 +59,9 @@ The encoding is also fully thread-safe and can be used concurrently by multiple :::info -Note that the library currently does not support encoding of special tokens. Special tokens are artificial tokens used to unlock capabilities from a model, such as fill-in-the-middle. If the `Encoding#encode` method encounters a special token in the input text, it will throw an `UnsupportedOperationException`. +Note that the library currently does not support encoding of special tokens. Special tokens are artificial tokens used +to unlock capabilities from a model, such as fill-in-the-middle. If the `Encoding#encode` method encounters a special +token in the input text, it will throw an `UnsupportedOperationException`. If you want to encode special tokens as if they were normal text, you can use `Encoding#encodeOrdinary` instead: @@ -72,7 +77,8 @@ encoding.encodeOrdinary("hello <|endoftext|> world"); ## Counting tokens -If all you want is the amount of tokens the text encodes to, you can use the shorthand method `Encoding#countTokens` or `Encoding#countTokensOrdinary`: +If all you want is the amount of tokens the text encodes to, you can use the shorthand method `Encoding#countTokens` +or `Encoding#countTokensOrdinary`: ```java int tokenCount = encoding.countTokens("This is a sample sentence."); @@ -84,16 +90,19 @@ int tokenCount = encoding.countTokensOrdinary("hello <|endoftext|> world"); ## Encoding text with truncation -If you want to only encode up until a specified amount of `maxTokens` and truncate after that amount, you can use `Encoding#encode(String, int)` or `Encoding#encodeOrdinary(String, int)`. These methods will truncate the encoded tokens to the specified length. They will automatically handle unicode characters that were split in half by the truncation by removing those tokens from the end of the list. +If you want to only encode up until a specified amount of `maxTokens` and truncate after that amount, you can +use `Encoding#encode(String, int)` or `Encoding#encodeOrdinary(String, int)`. These methods will truncate the encoded +tokens to the specified length. They will automatically handle unicode characters that were split in half by the +truncation by removing those tokens from the end of the list. ```java -List encoded = encoding.encode("This is a sample sentence.", 3); +IntArrayList encoded = encoding.encode("This is a sample sentence.", 3); // encoded = [2028, 374, 264] String decoded = encoding.decode(encoded); // decoded = "This is a" -List encoded = encoding.encode("I love 🍕", 4); +IntArrayList encoded = encoding.encode("I love 🍕", 4); // encoded = [40, 3021] String decoded = encoding.decode(encoded); diff --git a/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java b/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java index a4341bba..3c8b944f 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java +++ b/lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java @@ -4,11 +4,14 @@ public class ByteArrayList { private byte[] array; - private int size; + private int size = 0; public ByteArrayList() { - array = new byte[10]; - size = 0; + this(10); + } + + public ByteArrayList(int size) { + array = new byte[size]; } public void clear() { @@ -22,17 +25,65 @@ public void add(byte element) { array[size++] = element; } + public byte get(int index) { + return array[index]; + } + + public int set(int index, byte element) { + int old = array[index]; + array[index] = element; + return old; + } + private void resize() { - byte[] newArray = new byte[array.length * 2]; - System.arraycopy(array, 0, newArray, 0, size); + ensureCapacity(Math.max(1, array.length) * 2); + } + + public void ensureCapacity(int targetSize) { + if (targetSize <= size) { + return; + } + byte[] newArray = new byte[targetSize]; + if (size > 0) { + System.arraycopy(array, 0, newArray, 0, size); + } array = newArray; } - int length() { + public int size() { return size; } - public byte[] toByteArray() { + public boolean isEmpty() { + return size == 0; + } + + public byte[] toArray() { return Arrays.copyOf(array, size); } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o == null || getClass() != o.getClass()) { + return false; + } + ByteArrayList that = (ByteArrayList) o; + for (int i = 0; i < size; i++) { + if (array[i] != that.array[i]) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + int result = 1; + for (int i = 0; i < size; i++) { + result = 31 * result + array[i]; + } + return result; + } } diff --git a/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java b/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java index bdddfce3..cad8e718 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java +++ b/lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java @@ -2,6 +2,7 @@ import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.GptBytePairEncodingParams; +import com.knuddels.jtokkit.api.IntArrayList; import java.io.BufferedReader; import java.io.IOException; @@ -180,11 +181,11 @@ public Cl100kGptBytePairEncoding(GptBytePairEncodingParams params) { } @Override - int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List out) { + int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, IntArrayList out) { int[] tokenCount = {0}; - ArrayList ranks = new ArrayList<>(); + IntArrayList ranks = new IntArrayList(); Cl100kParser.split(text, utf8BytesList -> { - tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8BytesList.toByteArray(), out, ranks); + tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8BytesList.toArray(), out, ranks); return tokenCount[0] >= maxTokenCount; }); return tokenCount[0]; diff --git a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java index 34da2537..5a1b4aa5 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java @@ -3,15 +3,12 @@ import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.EncodingResult; import com.knuddels.jtokkit.api.GptBytePairEncodingParams; +import com.knuddels.jtokkit.api.IntArrayList; -import java.io.ByteArrayOutputStream; -import java.util.ArrayList; -import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -37,7 +34,7 @@ class GptBytePairEncoding implements Encoding { } @Override - public List encode(String text) { + public IntArrayList encode(String text) { return encode(text, Integer.MAX_VALUE).getTokens(); } @@ -48,7 +45,7 @@ public EncodingResult encode(String text, int maxTokenCount) { private EncodingResult encodeInternal(String text, int maxTokenCount, boolean keepEncodings) { if (text == null) { - return new EncodingResult(emptyList(), -1, false); + return new EncodingResult(new IntArrayList(0), -1, false); } specialEncoder.checkForSpecialTokens(text); @@ -57,7 +54,7 @@ private EncodingResult encodeInternal(String text, int maxTokenCount, boolean ke } @Override - public List encodeOrdinary(String text) { + public IntArrayList encodeOrdinary(String text) { return encodeOrdinary(text, Integer.MAX_VALUE).getTokens(); } @@ -68,17 +65,17 @@ public EncodingResult encodeOrdinary(String text, int maxTokenCount) { private EncodingResult encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings) { if (text == null) { - return new EncodingResult(emptyList(), -1, false); + return new EncodingResult(new IntArrayList(0), -1, false); } - List out = new ArrayList<>(); + IntArrayList out = new IntArrayList(); int tokenCount = encodeOrdinaryInternal(text, maxTokenCount, keepEncodings, out); if (keepEncodings && maxTokenCount != Integer.MAX_VALUE) { // Make sure we didn't break the multibyte character for (int tokensToRemove = 0; tokensToRemove <= out.size(); tokensToRemove++) { int size = out.size() - tokensToRemove; - ArrayList tokens = new ArrayList<>(size); + IntArrayList tokens = new IntArrayList(size); for (int i = 0; i < size; i++) { tokens.add(out.get(i)); } @@ -93,9 +90,9 @@ private EncodingResult encodeOrdinaryInternal(String text, int maxTokenCount, bo return new EncodingResult(out, tokenCount, false); } - int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List out) { + int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, IntArrayList out) { int tokenCount = 0; - ArrayList ranks = new ArrayList<>(); // reused to avoid allocations + IntArrayList ranks = new IntArrayList(); // reused to avoid allocations for (Matcher matcher = pattern.matcher(text); tokenCount < maxTokenCount && matcher.find(); ) { byte[] bytes = matcher.group().getBytes(UTF_8); tokenCount += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, bytes, out, ranks); @@ -109,20 +106,20 @@ public int countTokens(String text) { } @Override - public String decode(List tokens) { + public String decode(IntArrayList tokens) { return new String(decodeBytes(tokens), UTF_8); } @Override - public byte[] decodeBytes(List tokens) { - ByteArrayOutputStream out = new ByteArrayOutputStream(10 * tokens.size()); - for (int token : tokens) { - byte[] decodedToken = decodeToken(token); + public byte[] decodeBytes(IntArrayList tokens) { + ByteArrayList out = new ByteArrayList(10 * tokens.size()); + for (int i = 0; i < tokens.size(); i++) { + byte[] decodedToken = decodeToken(tokens.get(i)); for (byte b : decodedToken) { - out.write(b); + out.add(b); } } - return out.toByteArray(); + return out.toArray(); } @Override diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index 60d3d9f0..081d9567 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -1,6 +1,10 @@ package com.knuddels.jtokkit; -import java.util.*; +import com.knuddels.jtokkit.api.IntArrayList; + +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; import static com.knuddels.jtokkit.TokenEncoderLarge.calculateTokensLarge; import static java.lang.Integer.MAX_VALUE; @@ -9,7 +13,7 @@ public final class TokenEncoder { public static final int MAX_RANK = MAX_VALUE - 1; - static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD"; + public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD"; static final int DUMMY_RANK = MAX_VALUE; private final Map[] encoders; private final Map decoder; @@ -37,7 +41,7 @@ public final class TokenEncoder { } } - private static int getMinRankIndex(List ranks) { + private static int getMinRankIndex(IntArrayList ranks) { int minRankIndex = -1; int minRank = MAX_RANK; @@ -85,21 +89,21 @@ private static int getMinRankIndex(List ranks) { return minRankIndex; } - private static int getNextIndex(List ranks, int nextIndex) { + private static int getNextIndex(IntArrayList ranks, int nextIndex) { while (nextIndex < ranks.size() && ranks.get(nextIndex) == DUMMY_RANK) { nextIndex++; } return nextIndex; } - private static int getPreviousIndex(List ranks, int previousIndex) { + private static int getPreviousIndex(IntArrayList ranks, int previousIndex) { while (previousIndex >= 0 && ranks.get(previousIndex) == DUMMY_RANK) { previousIndex--; } return previousIndex; } - int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] byteArray, List out, ArrayList ranks) { + int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] byteArray, IntArrayList out, IntArrayList ranks) { ByteArrayWrapper match = new ByteArrayWrapper(byteArray); int encoded = encode(match); if (encoded != MAX_RANK) { @@ -117,7 +121,7 @@ int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] byteAr } } - private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, List out, ArrayList ranks, ByteArrayWrapper match, int length) { + private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, IntArrayList out, IntArrayList ranks, ByteArrayWrapper match, int length) { assert length > 1 : "Already filtered out"; ranks.clear(); ranks.ensureCapacity(length + 1); @@ -149,7 +153,7 @@ private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, List< return tokenCount; } - int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, List ranks, int validRanks, int minRankIndex) { + int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, IntArrayList ranks, int validRanks, int minRankIndex) { assert getMinRankIndex(ranks) == minRankIndex; while (validRanks > 0) { assert minRankIndex >= 0; diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java index ada9b364..4df67ff5 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoderLarge.java @@ -1,14 +1,15 @@ package com.knuddels.jtokkit; -import java.util.List; +import com.knuddels.jtokkit.api.IntArrayList; + import java.util.Map.Entry; import java.util.TreeMap; import static com.knuddels.jtokkit.TokenEncoder.MAX_RANK; final class TokenEncoderLarge { - static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, List out, ByteArrayWrapper match, int length) { + static int calculateTokensLarge(TokenEncoder tokenEncoder, int maxTokenCount, boolean keepEncodings, IntArrayList out, ByteArrayWrapper match, int length) { assert length > 1 : "Already filtered out"; TreeMap> rankMap = new TreeMap<>(); diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java b/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java index 9fe467cd..728faa5f 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java +++ b/lib/src/main/java/com/knuddels/jtokkit/api/Encoding.java @@ -1,7 +1,5 @@ package com.knuddels.jtokkit.api; -import java.util.List; - public interface Encoding { /** @@ -26,7 +24,7 @@ public interface Encoding { * @return the list of token ids * @throws UnsupportedOperationException if the text contains special tokens which are not supported for now */ - List encode(String text); + IntArrayList encode(String text); /** * Encodes the given text into a list of token ids. @@ -77,7 +75,7 @@ public interface Encoding { * @param text the text to encode * @return the list of token ids */ - List encodeOrdinary(String text); + IntArrayList encodeOrdinary(String text); /** * Encodes the given text into a list of token ids, ignoring special tokens. @@ -139,7 +137,7 @@ public interface Encoding { * @return the decoded text * @throws IllegalArgumentException if the list contains invalid token ids */ - String decode(List tokens); + String decode(IntArrayList tokens); /** * Decodes the given list of token ids into a byte array. @@ -156,7 +154,7 @@ public interface Encoding { * @return the decoded byte array * @throws IllegalArgumentException if the list contains invalid token ids */ - byte[] decodeBytes(List tokens); + byte[] decodeBytes(IntArrayList tokens); /** * Returns the name of this encoding. This is the name which is used to identify diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java b/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java index 7fb4dc1a..3870ae37 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java +++ b/lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java @@ -1,17 +1,15 @@ package com.knuddels.jtokkit.api; -import java.util.List; - /** * The result of encoding operation. */ public final class EncodingResult { - private final List tokens; + private final IntArrayList tokens; private final boolean truncated; private int tokenCount; - public EncodingResult(List tokens, int tokenCount, boolean truncated) { + public EncodingResult(IntArrayList tokens, int tokenCount, boolean truncated) { this.tokens = tokens; this.tokenCount = tokenCount; this.truncated = truncated; @@ -20,7 +18,7 @@ public EncodingResult(List tokens, int tokenCount, boolean truncated) { /** * @return the list of token ids */ - public List getTokens() { + public IntArrayList getTokens() { if (tokens.size() != getTokenCount()) { throw new IllegalStateException("Token count does not match token list size (tokenCount=" + tokenCount + ", tokens size=" + tokens.size() + ")"); } diff --git a/lib/src/main/java/com/knuddels/jtokkit/api/IntArrayList.java b/lib/src/main/java/com/knuddels/jtokkit/api/IntArrayList.java new file mode 100644 index 00000000..df3cbd27 --- /dev/null +++ b/lib/src/main/java/com/knuddels/jtokkit/api/IntArrayList.java @@ -0,0 +1,89 @@ +package com.knuddels.jtokkit.api; + +import java.util.Arrays; + +public class IntArrayList { + private int[] array; + private int size = 0; + + public IntArrayList() { + this(10); + } + + public IntArrayList(int size) { + array = new int[size]; + } + + public void clear() { + size = 0; + } + + public void add(int element) { + if (size >= array.length) { + resize(); + } + array[size++] = element; + } + + public int get(int index) { + return array[index]; + } + + public int set(int index, int element) { + int old = array[index]; + array[index] = element; + return old; + } + + private void resize() { + ensureCapacity(Math.max(1, array.length) * 2); + } + + public void ensureCapacity(int targetSize) { + if (targetSize <= size) { + return; + } + int[] newArray = new int[targetSize]; + if (size > 0) { + System.arraycopy(array, 0, newArray, 0, size); + } + array = newArray; + } + + public int size() { + return size; + } + + public boolean isEmpty() { + return size == 0; + } + + public int[] toArray() { + return Arrays.copyOf(array, size); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o == null || getClass() != o.getClass()) { + return false; + } + IntArrayList that = (IntArrayList) o; + for (int i = 0; i < size; i++) { + if (array[i] != that.array[i]) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + int result = 1; + for (int i = 0; i < size; i++) { + result = 31 * result + array[i]; + } + return result; + } +} diff --git a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java index 0c9fca76..0d7e6ce2 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java @@ -5,7 +5,6 @@ import org.junit.jupiter.api.Test; import java.util.Collections; -import java.util.List; import java.util.function.Consumer; import java.util.regex.Pattern; @@ -13,8 +12,8 @@ abstract class BaseEncodingRegistryTest { - protected final T registry; - protected final Consumer initializer; + protected T registry; + protected Consumer initializer; BaseEncodingRegistryTest(T registry) { this(registry, __ -> { @@ -143,7 +142,7 @@ void getEncodingReturnsEmptyOptionalForNonExistingEncodingName() { private static class DummyEncoding implements Encoding { @Override - public List encode(String text) { + public IntArrayList encode(String text) { return null; } @@ -153,7 +152,7 @@ public EncodingResult encode(String text, int maxTokens) { } @Override - public List encodeOrdinary(String text) { + public IntArrayList encodeOrdinary(String text) { return null; } @@ -168,12 +167,12 @@ public int countTokens(String text) { } @Override - public String decode(List tokens) { + public String decode(IntArrayList tokens) { return null; } @Override - public byte[] decodeBytes(List tokens) { + public byte[] decodeBytes(IntArrayList tokens) { return new byte[0]; } diff --git a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java index 012aa424..c788dcba 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/ByteArrayListTest.java @@ -5,7 +5,7 @@ import java.util.ArrayList; import java.util.Random; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.*; class ByteArrayListTest { @@ -14,37 +14,56 @@ private static byte randomByte(Random random) { } @Test - public void testArrayListOperations() { + void testArrayListOperations() { var byteArrayList = new ByteArrayList(); var standardList = new ArrayList(); var random = new Random(); + assertTrue(byteArrayList.isEmpty()); + for (var i = 0; i < 1_000; i++) { // Add - if (randomByte(random) % 2 == 0) { - var element = randomByte(random); - var lastIndex = standardList.size(); - byteArrayList.add(element); - standardList.add(element); - assertEquals(standardList.get(lastIndex), byteArrayList.toByteArray()[lastIndex]); + var element = randomByte(random); + byteArrayList.add(element); + standardList.add(element); + assertEquals(standardList.get(standardList.size() - 1), byteArrayList.get(byteArrayList.size() - 1)); + + // Set + if (!byteArrayList.isEmpty() && random.nextBoolean()) { + var randomIndex = random.nextInt(byteArrayList.size()); + var newElement = randomByte(random); + byteArrayList.set(randomIndex, newElement); + standardList.set(randomIndex, newElement); + assertEquals(standardList.get(randomIndex), byteArrayList.get(randomIndex)); } - // Size - assertEquals(standardList.size(), byteArrayList.length()); + // Size and IsEmpty + assertEquals(standardList.size(), byteArrayList.size()); + assertEquals(standardList.isEmpty(), byteArrayList.isEmpty()); // Clear if (randomByte(random) % 10 == 0) { byteArrayList.clear(); standardList.clear(); - assertEquals(standardList.size(), byteArrayList.length()); + assertEquals(standardList.size(), byteArrayList.size()); } } - assertEquals(standardList.size(), byteArrayList.length()); - var byteArray = byteArrayList.toByteArray(); + // Test toArray + var byteArray = byteArrayList.toArray(); assertEquals(standardList.size(), byteArray.length); - for (var i = 0; i < byteArrayList.length(); i++) { + for (var i = 0; i < byteArrayList.size(); i++) { assertEquals(standardList.get(i), byteArray[i]); } + + // Test Equals and HashCode + var anotherByteArrayList = new ByteArrayList(); + standardList.forEach(anotherByteArrayList::add); + + assertEquals(byteArrayList, anotherByteArrayList); + if (!byteArrayList.isEmpty()) { + assertNotEquals(byteArrayList, new ByteArrayList()); + } + assertEquals(byteArrayList.hashCode(), anotherByteArrayList.hashCode()); } } diff --git a/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java b/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java new file mode 100644 index 00000000..909f6a7f --- /dev/null +++ b/lib/src/test/java/com/knuddels/jtokkit/IntArrayListTest.java @@ -0,0 +1,66 @@ +package com.knuddels.jtokkit; + +import com.knuddels.jtokkit.api.IntArrayList; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.*; + +class IntArrayListTest { + + @Test + void testArrayListOperations() { + var byteArrayList = new IntArrayList(); + var standardList = new ArrayList(); + var random = new Random(); + + assertTrue(byteArrayList.isEmpty()); + + for (var i = 0; i < 1_000; i++) { + // Add + var element = random.nextInt(); + byteArrayList.add(element); + standardList.add(element); + assertEquals(standardList.get(standardList.size() - 1), byteArrayList.get(byteArrayList.size() - 1)); + + // Set + if (!byteArrayList.isEmpty() && random.nextBoolean()) { + var randomIndex = random.nextInt(byteArrayList.size()); + var newElement = random.nextInt(); + byteArrayList.set(randomIndex, newElement); + standardList.set(randomIndex, newElement); + assertEquals(standardList.get(randomIndex), byteArrayList.get(randomIndex)); + } + + // Size and IsEmpty + assertEquals(standardList.size(), byteArrayList.size()); + assertEquals(standardList.isEmpty(), byteArrayList.isEmpty()); + + // Clear + if (random.nextInt() % 10 == 0) { + byteArrayList.clear(); + standardList.clear(); + assertEquals(standardList.size(), byteArrayList.size()); + } + } + + // Test toArray + var byteArray = byteArrayList.toArray(); + assertEquals(standardList.size(), byteArray.length); + for (var i = 0; i < byteArrayList.size(); i++) { + assertEquals(standardList.get(i), byteArray[i]); + } + + // Test Equals and HashCode + var anotherIntArrayList = new IntArrayList(); + standardList.forEach(anotherIntArrayList::add); + + assertEquals(byteArrayList, anotherIntArrayList); + if (!byteArrayList.isEmpty()) { + assertNotEquals(byteArrayList, new IntArrayList()); + } + assertEquals(byteArrayList.hashCode(), anotherIntArrayList.hashCode()); + } +} diff --git a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java index c7ea36fa..9c3f2bfc 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java @@ -8,6 +8,7 @@ import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY; + class Cl100kLargeTokenizerTest extends Cl100kBaseTest { public static Encoding ENCODING; diff --git a/lib/src/test/java/com/knuddels/jtokkit/reference/TestUtils.java b/lib/src/test/java/com/knuddels/jtokkit/reference/TestUtils.java index 3b8961f1..ae485d78 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/reference/TestUtils.java +++ b/lib/src/test/java/com/knuddels/jtokkit/reference/TestUtils.java @@ -1,17 +1,21 @@ package com.knuddels.jtokkit.reference; +import com.knuddels.jtokkit.api.IntArrayList; + import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; -class TestUtils { +public class TestUtils { - static List parseEncodingString(String encodingString) { - return Arrays.stream( + public static IntArrayList parseEncodingString(final String encodingString) { + List list = Arrays.stream( encodingString.substring(1, encodingString.length() - 1) .replaceAll(" ", "") .split(",") ).map(Integer::parseInt) - .collect(Collectors.toList()); + .toList(); + var result = new IntArrayList(list.size()); + list.forEach(result::add); + return result; } }