Skip to content

Commit

Permalink
Add IntArrayList to store tokens without boxing
Browse files Browse the repository at this point in the history
Before:
Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBase                        data    ss   10  3.263 ± 0.286   s/op
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount              data    ss   10  2.688 ± 0.054   s/op
SingleThreadedBenchmark.benchmarkP50kBase                          data    ss   10  5.335 ± 0.106   s/op
SingleThreadedBenchmark.benchmarkP50kEdit                          data    ss   10  5.277 ± 0.067   s/op
SingleThreadedBenchmark.benchmarkR50kBase                          data    ss   10  5.002 ± 0.091   s/op

After:
Benchmark                                              (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBase                        data    ss   10  2.498 ± 0.019   s/op
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount              data    ss   10  2.223 ± 0.014   s/op
SingleThreadedBenchmark.benchmarkP50kBase                          data    ss   10  4.354 ± 0.122   s/op
SingleThreadedBenchmark.benchmarkP50kEdit                          data    ss   10  4.341 ± 0.076   s/op
SingleThreadedBenchmark.benchmarkR50kBase                          data    ss   10  4.068 ± 0.020   s/op
  • Loading branch information
Lőrinc committed Dec 30, 2023
1 parent fec3eeb commit aee505c
Show file tree
Hide file tree
Showing 18 changed files with 364 additions and 116 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
[![javadoc](https://javadoc.io/badge2/com.knuddels/jtokkit/javadoc.svg)](https://javadoc.io/doc/com.knuddels/jtokkit)

Welcome to JTokkit, a Java tokenizer library designed for use with OpenAI models.

```java
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE);
Expand All @@ -20,6 +21,7 @@ enc = registry.getEncodingForModel(ModelType.TEXT_EMBEDDING_ADA_002);
For a quick getting started, see our [documentation](https://jtokkit.knuddels.de/).

## 📖 Introduction

JTokkit aims to be a fast and efficient tokenizer designed for use in natural
language processing tasks using the OpenAI models. It provides an easy-to-use
interface for tokenizing input text, for example for counting required tokens
Expand All @@ -42,7 +44,6 @@ and `cl100k_base`

✅ Fast and efficient performance


🔨 Handling of special tokens during encoding (not started)

## 📊 Performance
Expand All @@ -54,6 +55,7 @@ JTokkit is between 2-3 times faster than a comparable tokenizer.
For details on the benchmark, see the [benchmark](benchmark) directory.

## 🛠️ Installation

You can install JTokkit by adding the following dependency to your Maven project:

```xml
Expand All @@ -73,16 +75,17 @@ dependencies {
```

## 🔰 Getting Started

To use JTokkit, simply create a new `EncodingRegistry` and use `getEncoding` to
retrieve the encoding you want to use. You can then use the `encode` and
`decode` methods to encode and decode text.

```java
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE);
List<Integer> encoded = enc.encode("This is a sample sentence.");
IntArrayList encoded = enc.encode("This is a sample sentence.");
// encoded = [2028, 374, 264, 6205, 11914, 13]

String decoded = enc.decode(encoded);
// decoded = "This is a sample sentence."

Expand All @@ -100,12 +103,15 @@ You may want to extend JTokkit to support custom encodings. To do so, you have t
options:

1. Implement the `Encoding` interface and register it with the `EncodingRegistry`

```java
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
Encoding customEncoding = new CustomEncoding();
registry.registerEncoding(customEncoding);
```

2. Add new parameters for use with the existing BPE algorithm

```java
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
GptBytePairEncodingParams params = new GptBytePairEncodingParams(
Expand All @@ -122,6 +128,7 @@ them by using `registry.getEncoding("custom-name")`. See the JavaDoc for more
details.

## 📄 License

JTokkit is licensed under the MIT License. See the
[LICENSE](https://github.com/knuddelsgmbh/jtokkit/blob/main/LICENSE) file
for more information.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.IntArrayList;
import org.openjdk.jmh.annotations.Benchmark;

import java.util.List;
Expand Down Expand Up @@ -34,5 +35,5 @@ public Object benchmarkCl100kBase(BenchmarkingState state) {
* @param fileContents the file contents to encode
* @return a list of encoded token lists
*/
protected abstract List<List<Integer>> encodeAll(Encoding encoding, List<String> fileContents);
protected abstract List<IntArrayList> encodeAll(Encoding encoding, List<String> fileContents);
}
Original file line number Diff line number Diff line change
@@ -1,46 +1,48 @@
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.IntArrayList;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;

@State(Scope.Thread)
public abstract class AbstractMultiThreadedBenchmark extends AbstractBenchmark {

private final int threads;
private ExecutorService executor;
private final int threads;
private ExecutorService executor;

public AbstractMultiThreadedBenchmark(final int threads) {
this.threads = threads;
}
public AbstractMultiThreadedBenchmark(int threads) {
this.threads = threads;
}

@Setup
public void setup() {
executor = Executors.newFixedThreadPool(threads);
}
@Setup
public void setup() {
executor = Executors.newFixedThreadPool(threads);
}

@TearDown
public void tearDown() {
executor.shutdown();
}
@TearDown
public void tearDown() {
executor.shutdown();
}

@Override
protected List<List<Integer>> encodeAll(final Encoding encoding, final List<String> fileContents) {
final var futures = fileContents.stream()
.map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor))
.collect(Collectors.toList());
@Override
protected List<IntArrayList> encodeAll(Encoding encoding, List<String> fileContents) {
var futures = fileContents.stream()
.map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor))
.toList();

CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();

return futures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList());
}
return futures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.IntArrayList;
import org.openjdk.jmh.annotations.Benchmark;

import java.util.List;
Expand All @@ -18,7 +19,7 @@ public int benchmarkCl100kBaseTokenCount(BenchmarkingState state) {
}

@Override
protected List<List<Integer>> encodeAll(final Encoding encoding, final List<String> fileContents) {
protected List<IntArrayList> encodeAll(Encoding encoding, List<String> fileContents) {
return fileContents.stream()
.map(encoding::encode)
.toList();
Expand Down
25 changes: 17 additions & 8 deletions docs/docs/getting-started/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ To use JTokkit, first create a new `EncodingRegistry`:
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
```

Make sure to keep a reference to the registry, as the creation of the registry is expensive. Creating the registry loads the vocabularies from the classpath. The registry itself handles caching of the loaded encodings. It is thread-safe and can safely be used concurrently by multiple components.
Make sure to keep a reference to the registry, as the creation of the registry is expensive. Creating the registry loads
the vocabularies from the classpath. The registry itself handles caching of the loaded encodings. It is thread-safe and
can safely be used concurrently by multiple components.

If you do not want to automatically load all vocabularies of all encodings on registry creation, you can use the following lazy loading registry.
If you do not want to automatically load all vocabularies of all encodings on registry creation, you can use the
following lazy loading registry.

```java
EncodingRegistry registry = Encodings.newLazyEncodingRegistry();
Expand Down Expand Up @@ -45,7 +48,7 @@ Optional<Encoding> encoding = registry.getEncodingForModel("gpt_4");
You can use an `Encoding` to encode and decode text:

```java
List<Integer> encoded = encoding.encode("This is a sample sentence.");
IntArrayList encoded = encoding.encode("This is a sample sentence.");
// encoded = [2028, 374, 264, 6205, 11914, 13]

String decoded = encoding.decode(encoded);
Expand All @@ -56,7 +59,9 @@ The encoding is also fully thread-safe and can be used concurrently by multiple

:::info

Note that the library currently does not support encoding of special tokens. Special tokens are artificial tokens used to unlock capabilities from a model, such as fill-in-the-middle. If the `Encoding#encode` method encounters a special token in the input text, it will throw an `UnsupportedOperationException`.
Note that the library currently does not support encoding of special tokens. Special tokens are artificial tokens used
to unlock capabilities from a model, such as fill-in-the-middle. If the `Encoding#encode` method encounters a special
token in the input text, it will throw an `UnsupportedOperationException`.

If you want to encode special tokens as if they were normal text, you can use `Encoding#encodeOrdinary` instead:

Expand All @@ -72,7 +77,8 @@ encoding.encodeOrdinary("hello <|endoftext|> world");

## Counting tokens

If all you want is the amount of tokens the text encodes to, you can use the shorthand method `Encoding#countTokens` or `Encoding#countTokensOrdinary`:
If all you want is the amount of tokens the text encodes to, you can use the shorthand method `Encoding#countTokens`
or `Encoding#countTokensOrdinary`:

```java
int tokenCount = encoding.countTokens("This is a sample sentence.");
Expand All @@ -84,16 +90,19 @@ int tokenCount = encoding.countTokensOrdinary("hello <|endoftext|> world");

## Encoding text with truncation

If you want to only encode up until a specified amount of `maxTokens` and truncate after that amount, you can use `Encoding#encode(String, int)` or `Encoding#encodeOrdinary(String, int)`. These methods will truncate the encoded tokens to the specified length. They will automatically handle unicode characters that were split in half by the truncation by removing those tokens from the end of the list.
If you want to only encode up until a specified amount of `maxTokens` and truncate after that amount, you can
use `Encoding#encode(String, int)` or `Encoding#encodeOrdinary(String, int)`. These methods will truncate the encoded
tokens to the specified length. They will automatically handle unicode characters that were split in half by the
truncation by removing those tokens from the end of the list.

```java
List<Integer> encoded = encoding.encode("This is a sample sentence.", 3);
IntArrayList encoded = encoding.encode("This is a sample sentence.", 3);
// encoded = [2028, 374, 264]

String decoded = encoding.decode(encoded);
// decoded = "This is a"

List<Integer> encoded = encoding.encode("I love 🍕", 4);
IntArrayList encoded = encoding.encode("I love 🍕", 4);
// encoded = [40, 3021]

String decoded = encoding.decode(encoded);
Expand Down
65 changes: 58 additions & 7 deletions lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

public class ByteArrayList {
private byte[] array;
private int size;
private int size = 0;

public ByteArrayList() {
array = new byte[10];
size = 0;
this(10);
}

public ByteArrayList(int size) {
array = new byte[size];
}

public void clear() {
Expand All @@ -22,17 +25,65 @@ public void add(byte element) {
array[size++] = element;
}

public byte get(int index) {
return array[index];
}

public int set(int index, byte element) {
int old = array[index];
array[index] = element;
return old;
}

private void resize() {
byte[] newArray = new byte[array.length * 2];
System.arraycopy(array, 0, newArray, 0, size);
ensureCapacity(Math.max(1, array.length) * 2);
}

public void ensureCapacity(int targetSize) {
if (targetSize <= size) {
return;
}
byte[] newArray = new byte[targetSize];
if (size > 0) {
System.arraycopy(array, 0, newArray, 0, size);
}
array = newArray;
}

int length() {
public int size() {
return size;
}

public byte[] toByteArray() {
public boolean isEmpty() {
return size == 0;
}

public byte[] toArray() {
return Arrays.copyOf(array, size);
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
} else if (o == null || getClass() != o.getClass()) {
return false;
}
ByteArrayList that = (ByteArrayList) o;
for (int i = 0; i < size; i++) {
if (array[i] != that.array[i]) {
return false;
}
}
return true;
}

@Override
public int hashCode() {
int result = 1;
for (int i = 0; i < size; i++) {
result = 31 * result + array[i];
}
return result;
}
}
7 changes: 4 additions & 3 deletions lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.GptBytePairEncodingParams;
import com.knuddels.jtokkit.api.IntArrayList;

import java.io.BufferedReader;
import java.io.IOException;
Expand Down Expand Up @@ -180,11 +181,11 @@ public Cl100kGptBytePairEncoding(GptBytePairEncodingParams params) {
}

@Override
int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List<Integer> out) {
int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, IntArrayList out) {
int[] tokenCount = {0};
ArrayList<Integer> ranks = new ArrayList<>();
IntArrayList ranks = new IntArrayList();
Cl100kParser.split(text, utf8BytesList -> {
tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8BytesList.toByteArray(), out, ranks);
tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8BytesList.toArray(), out, ranks);
return tokenCount[0] >= maxTokenCount;
});
return tokenCount[0];
Expand Down
Loading

0 comments on commit aee505c

Please sign in to comment.