From 059f95f61f18ce2f3021b118288c1e3d549b9bca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Fri, 29 Dec 2023 21:21:42 +0200 Subject: [PATCH] Add Cl100kLargeTokenizerTest to run every Cl100k tokenizer test with the large tokenizer as well --- .../com/knuddels/jtokkit/TokenEncoder.java | 11 ++++--- .../java/com/knuddels/jtokkit/Cl100kTest.java | 5 ++-- .../jtokkit/reference/Cl100kBaseTest.java | 22 ++++++++------ .../reference/Cl100kLargeTokenizerTest.java | 30 +++++++++++++++++++ 4 files changed, 53 insertions(+), 15 deletions(-) create mode 100644 lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java diff --git a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java index c99f431..9aefc20 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java +++ b/lib/src/main/java/com/knuddels/jtokkit/TokenEncoder.java @@ -3,11 +3,14 @@ import java.util.*; import static com.knuddels.jtokkit.TokenEncoderLarge.calculateTokensLarge; +import static java.lang.Integer.MAX_VALUE; +import static java.lang.Integer.parseInt; import static java.util.Collections.emptyMap; -final class TokenEncoder { - public static final int DUMMY_RANK = Integer.MAX_VALUE; - public static final int MAX_RANK = Integer.MAX_VALUE - 1; +public final class TokenEncoder { + public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD"; + public static final int DUMMY_RANK = MAX_VALUE; + public static final int MAX_RANK = MAX_VALUE - 1; private final Map[] encoders; private final Map decoder; @@ -15,7 +18,7 @@ final class TokenEncoder { public TokenEncoder(Map encoder) { if (!encoder.isEmpty()) { - VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = Integer.parseInt(System.getProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", "500")); + VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = parseInt(System.getProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, "500")); TreeMap> tempEncoders = new TreeMap<>(); encoder.forEach((k, v) -> { ByteArrayWrapper key = new ByteArrayWrapper(k); diff --git a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java index 665d2f0..474c8bc 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/Cl100kTest.java @@ -10,6 +10,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.stream.IntStream; +import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY; import static java.lang.Character.*; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.stream.Collectors.joining; @@ -146,10 +147,10 @@ void testEdgeCaseRoundTrips() throws Exception { @Test void testRoundTripWithRandomStrings() throws Exception { - System.setProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", String.valueOf(Integer.MAX_VALUE)); + System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(Integer.MAX_VALUE)); var arrayEncoder = EncodingFactory.cl100kBase(); - System.setProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", String.valueOf(0)); + System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(0)); var mapEncoder = EncodingFactory.cl100kBase(); var singleTokenStrings = getAllTokens(); IntStream.range(0, 10_000).parallel().forEach(i -> { diff --git a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kBaseTest.java b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kBaseTest.java index d33c200..a2108af 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kBaseTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kBaseTest.java @@ -14,6 +14,10 @@ class Cl100kBaseTest { private static final Encoding ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE); + Encoding getEncoding() { + return ENCODING; + } + @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodesCorrectly( @@ -21,7 +25,7 @@ void cl100kBaseEncodesCorrectly( String output ) { var expected = TestUtils.parseEncodingString(output); - var actual = ENCODING.encode(input); + var actual = getEncoding().encode(input); assertEquals(expected, actual); } @@ -29,7 +33,7 @@ void cl100kBaseEncodesCorrectly( @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodesStable(String input) { - var actual = ENCODING.decode(ENCODING.encode(input)); + var actual = getEncoding().decode(getEncoding().encode(input)); assertEquals(input, actual); } @@ -43,7 +47,7 @@ void cl100kBaseEncodesCorrectlyWithMaxTokensSet( ) { var expected = TestUtils.parseEncodingString(output); var expectedWithMaxTokens = TestUtils.parseEncodingString(outputMaxTokens10); - var encodingResult = ENCODING.encode(input, 10); + var encodingResult = getEncoding().encode(input, 10); assertEquals(expectedWithMaxTokens, encodingResult.getTokens()); assertEquals(expected.size() > expectedWithMaxTokens.size(), encodingResult.isTruncated()); @@ -52,7 +56,7 @@ void cl100kBaseEncodesCorrectlyWithMaxTokensSet( @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodesStableWithMaxTokensSet(String input) { - var actual = ENCODING.decode(ENCODING.encode(input, 10).getTokens()); + var actual = getEncoding().decode(getEncoding().encode(input, 10).getTokens()); assertTrue(input.startsWith(actual)); } @@ -64,7 +68,7 @@ void cl100kBaseEncodeOrdinaryEncodesCorrectly( String output ) { var expected = TestUtils.parseEncodingString(output); - var actual = ENCODING.encodeOrdinary(input); + var actual = getEncoding().encodeOrdinary(input); assertEquals(expected, actual); } @@ -78,7 +82,7 @@ void cl100kBaseEncodeOrdinaryEncodesCorrectly( ) { var expected = TestUtils.parseEncodingString(output); var expectedWithMaxTokens = TestUtils.parseEncodingString(outputMaxTokens10); - var encodingResult = ENCODING.encodeOrdinary(input, 10); + var encodingResult = getEncoding().encodeOrdinary(input, 10); assertEquals(expectedWithMaxTokens, encodingResult.getTokens()); assertEquals(expected.size() > expectedWithMaxTokens.size(), encodingResult.isTruncated()); @@ -87,7 +91,7 @@ void cl100kBaseEncodeOrdinaryEncodesCorrectly( @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodeOrdinaryEncodesStable(String input) { - var actual = ENCODING.decode(ENCODING.encodeOrdinary(input)); + var actual = getEncoding().decode(getEncoding().encodeOrdinary(input)); assertEquals(input, actual); } @@ -95,7 +99,7 @@ void cl100kBaseEncodeOrdinaryEncodesStable(String input) { @ParameterizedTest @CsvFileSource(resources = "/cl100k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000) void cl100kBaseEncodeOrdinaryEncodesStableWithMaxTokensSet(String input) { - var actual = ENCODING.decode(ENCODING.encodeOrdinary(input, 10).getTokens()); + var actual = getEncoding().decode(getEncoding().encodeOrdinary(input, 10).getTokens()); assertTrue(input.startsWith(actual)); } @@ -103,7 +107,7 @@ void cl100kBaseEncodeOrdinaryEncodesStableWithMaxTokensSet(String input) { @Test void cl100kBaseEncodeOrdinaryEncodesSpecialTokensCorrectly() { var input = "Hello<|endoftext|>, <|fim_prefix|> <|fim_middle|> world <|fim_suffix|> ! <|endofprompt|>"; - var actual = ENCODING.decode(ENCODING.encodeOrdinary(input)); + var actual = getEncoding().decode(getEncoding().encodeOrdinary(input)); assertEquals(input, actual); } diff --git a/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java new file mode 100644 index 0000000..c7ea36f --- /dev/null +++ b/lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kLargeTokenizerTest.java @@ -0,0 +1,30 @@ +package com.knuddels.jtokkit.reference; + +import com.knuddels.jtokkit.Encodings; +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY; + +class Cl100kLargeTokenizerTest extends Cl100kBaseTest { + + public static Encoding ENCODING; + + @BeforeAll + static void beforeAll() { + System.setProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, String.valueOf(0)); + ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE); + } + + @AfterAll + static void afterAll() { + System.clearProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY); + + } + + Encoding getEncoding() { + return ENCODING; + } +}