Skip to content

Commit

Permalink
Add Cl100kLargeTokenizerTest to run every Cl100k tokenizer test with …
Browse files Browse the repository at this point in the history
…the large tokenizer as well
  • Loading branch information
Lőrinc committed Dec 29, 2023
1 parent d51cd37 commit 059f95f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 15 deletions.
11 changes: 7 additions & 4 deletions lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
import java.util.*;

import static com.knuddels.jtokkit.TokenEncoderLarge.calculateTokensLarge;
import static java.lang.Integer.MAX_VALUE;
import static java.lang.Integer.parseInt;
import static java.util.Collections.emptyMap;

final class TokenEncoder {
public static final int DUMMY_RANK = Integer.MAX_VALUE;
public static final int MAX_RANK = Integer.MAX_VALUE - 1;
public final class TokenEncoder {
public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD";
public static final int DUMMY_RANK = MAX_VALUE;
public static final int MAX_RANK = MAX_VALUE - 1;
private final Map<ByteArrayWrapper, Integer>[] encoders;
private final Map<Integer, byte[]> decoder;

private int VERY_LARGE_TOKENIZER_BYTE_THRESHOLD;

public TokenEncoder(Map<byte[], Integer> encoder) {
if (!encoder.isEmpty()) {
VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = Integer.parseInt(System.getProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", "500"));
VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = parseInt(System.getProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, "500"));
TreeMap<Integer, Map<ByteArrayWrapper, Integer>> tempEncoders = new TreeMap<>();
encoder.forEach((k, v) -> {
ByteArrayWrapper key = new ByteArrayWrapper(k);
Expand Down
5 changes: 3 additions & 2 deletions lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.IntStream;

import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY;
import static java.lang.Character.*;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.stream.Collectors.joining;
Expand Down Expand Up @@ -146,10 +147,10 @@ void testEdgeCaseRoundTrips() throws Exception {

@Test
void testRoundTripWithRandomStrings() throws Exception {
System.setProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", String.valueOf(Integer.MAX_VALUE));
System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(Integer.MAX_VALUE));
var arrayEncoder = EncodingFactory.cl100kBase();

System.setProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", String.valueOf(0));
System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(0));
var mapEncoder = EncodingFactory.cl100kBase();
var singleTokenStrings = getAllTokens();
IntStream.range(0, 10_000).parallel().forEach(i -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,26 @@ class Cl100kBaseTest {

private static final Encoding ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);

Encoding getEncoding() {
return ENCODING;
}

@ParameterizedTest
@CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
void cl100kBaseEncodesCorrectly(
String input,
String output
) {
var expected = TestUtils.parseEncodingString(output);
var actual = ENCODING.encode(input);
var actual = getEncoding().encode(input);

assertEquals(expected, actual);
}

@ParameterizedTest
@CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
void cl100kBaseEncodesStable(String input) {
var actual = ENCODING.decode(ENCODING.encode(input));
var actual = getEncoding().decode(getEncoding().encode(input));

assertEquals(input, actual);
}
Expand All @@ -43,7 +47,7 @@ void cl100kBaseEncodesCorrectlyWithMaxTokensSet(
) {
var expected = TestUtils.parseEncodingString(output);
var expectedWithMaxTokens = TestUtils.parseEncodingString(outputMaxTokens10);
var encodingResult = ENCODING.encode(input, 10);
var encodingResult = getEncoding().encode(input, 10);

assertEquals(expectedWithMaxTokens, encodingResult.getTokens());
assertEquals(expected.size() > expectedWithMaxTokens.size(), encodingResult.isTruncated());
Expand All @@ -52,7 +56,7 @@ void cl100kBaseEncodesCorrectlyWithMaxTokensSet(
@ParameterizedTest
@CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
void cl100kBaseEncodesStableWithMaxTokensSet(String input) {
var actual = ENCODING.decode(ENCODING.encode(input, 10).getTokens());
var actual = getEncoding().decode(getEncoding().encode(input, 10).getTokens());

assertTrue(input.startsWith(actual));
}
Expand All @@ -64,7 +68,7 @@ void cl100kBaseEncodeOrdinaryEncodesCorrectly(
String output
) {
var expected = TestUtils.parseEncodingString(output);
var actual = ENCODING.encodeOrdinary(input);
var actual = getEncoding().encodeOrdinary(input);

assertEquals(expected, actual);
}
Expand All @@ -78,7 +82,7 @@ void cl100kBaseEncodeOrdinaryEncodesCorrectly(
) {
var expected = TestUtils.parseEncodingString(output);
var expectedWithMaxTokens = TestUtils.parseEncodingString(outputMaxTokens10);
var encodingResult = ENCODING.encodeOrdinary(input, 10);
var encodingResult = getEncoding().encodeOrdinary(input, 10);

assertEquals(expectedWithMaxTokens, encodingResult.getTokens());
assertEquals(expected.size() > expectedWithMaxTokens.size(), encodingResult.isTruncated());
Expand All @@ -87,23 +91,23 @@ void cl100kBaseEncodeOrdinaryEncodesCorrectly(
@ParameterizedTest
@CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
void cl100kBaseEncodeOrdinaryEncodesStable(String input) {
var actual = ENCODING.decode(ENCODING.encodeOrdinary(input));
var actual = getEncoding().decode(getEncoding().encodeOrdinary(input));

assertEquals(input, actual);
}

@ParameterizedTest
@CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
void cl100kBaseEncodeOrdinaryEncodesStableWithMaxTokensSet(String input) {
var actual = ENCODING.decode(ENCODING.encodeOrdinary(input, 10).getTokens());
var actual = getEncoding().decode(getEncoding().encodeOrdinary(input, 10).getTokens());

assertTrue(input.startsWith(actual));
}

@Test
void cl100kBaseEncodeOrdinaryEncodesSpecialTokensCorrectly() {
var input = "Hello<|endoftext|>, <|fim_prefix|> <|fim_middle|> world <|fim_suffix|> ! <|endofprompt|>";
var actual = ENCODING.decode(ENCODING.encodeOrdinary(input));
var actual = getEncoding().decode(getEncoding().encodeOrdinary(input));

assertEquals(input, actual);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.knuddels.jtokkit.reference;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingType;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;

import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY;

class Cl100kLargeTokenizerTest extends Cl100kBaseTest {

public static Encoding ENCODING;

@BeforeAll
static void beforeAll() {
System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(0));
ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);
}

@AfterAll
static void afterAll() {
System.clearProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY);

}

Encoding getEncoding() {
return ENCODING;
}
}

0 comments on commit 059f95f

Please sign in to comment.