Skip to content

Commit

Permalink
feat: Implement o200k_base encoding and support gpt-4o
Browse files Browse the repository at this point in the history
  • Loading branch information
chatanywhere committed May 27, 2024
1 parent 680ca02 commit 3099c9e
Show file tree
Hide file tree
Showing 4 changed files with 200,029 additions and 0 deletions.
29 changes: 29 additions & 0 deletions lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.regex.Pattern;

class EncodingFactory {
private static final Map<String, Integer> SPECIAL_TOKENS_O200K_BASE;
private static final Map<String, Integer> SPECIAL_TOKENS_CL100K_BASE;
private static final Map<String, Integer> SPECIAL_TOKENS_X50K_BASE;
private static final Map<String, Integer> SPECIAL_TOKENS_P50K_EDIT;
Expand Down Expand Up @@ -53,6 +54,13 @@ class EncodingFactory {
SPECIAL_TOKENS_CL100K_BASE = Collections.unmodifiableMap(map);
}

static {
Map<String, Integer> map = new HashMap<>();
map.put(ENDOFTEXT, 199999);
map.put(ENDOFPROMPT, 200018);
SPECIAL_TOKENS_O200K_BASE = Collections.unmodifiableMap(map);
}

private EncodingFactory() {
}

Expand Down Expand Up @@ -107,6 +115,27 @@ static Encoding cl100kBase() {
return new Cl100kGptBytePairEncoding(params);
}

/**
* Returns an {@link Encoding} instance for the o200k_base encoding.
*
* @return an {@link Encoding} instance for the o200k_base encoding
*/

static Encoding o200kBase() {
Map<byte[], Integer> mergeableRanks = loadMergeableRanks("/com/knuddels/jtokkit/o200k_base.tiktoken");
List<String> patStrList = new ArrayList<>();
patStrList.add("[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?");
patStrList.add("[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?");
patStrList.add("\\p{N}{1,3}");
patStrList.add(" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*");
patStrList.add("\\s*[\\r\\n]+");
patStrList.add("\\s+(?!\\S)");
patStrList.add("\\s+");
Pattern regex = compileRegex(patStrList.stream().map(String::valueOf).collect(Collectors.joining("|")), false);
GptBytePairEncodingParams params = new GptBytePairEncodingParams("o200k_base", regex, mergeableRanks, SPECIAL_TOKENS_O200K_BASE);
return fromParameters(params);
}

/**
* Returns an {@link Encoding} instance for the given GPT BytePairEncoding parameters.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public enum EncodingType {
P50K_BASE("p50k_base"),
P50K_EDIT("p50k_edit"),
CL100K_BASE("cl100k_base");
O200K_BASE("o200k_base");

private static final Map<String, EncodingType> nameToEncodingType = Arrays.stream(values())
.collect(Collectors.toMap(EncodingType::getName, Function.identity()));
Expand Down
1 change: 1 addition & 0 deletions lib/src/main/java/com/knuddels/jtokkit/api/ModelType.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
public enum ModelType {
// chat
GPT_4("gpt-4", EncodingType.CL100K_BASE, 8192),
GPT_4O("gpt-4o", EncodingType.O200K_BASE, 131072),
GPT_4_32K("gpt-4-32k", EncodingType.CL100K_BASE, 32768),
GPT_3_5_TURBO("gpt-3.5-turbo", EncodingType.CL100K_BASE, 4096),
GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k", EncodingType.CL100K_BASE, 16385),
Expand Down
Loading

0 comments on commit 3099c9e

Please sign in to comment.