Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2) Optimize byte pair merge for small and big character sequences - 8.2s to 3.9s #76

Merged
merged 14 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public static void main(String[] args) throws Exception {
"'s", "'t", "'re", "'ve", "'m", "'ll", "'d", "'x",
"x",
"123",
"ő",
"a",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testing a slightly different setup here since aaaaaa is still a single token, so this makes it a more difficult case

",.!?:; \n",
"\n \n",
" ",
Expand All @@ -177,7 +177,7 @@ public static void main(String[] args) throws Exception {
}

var totalSize = calculateTotalFileSize(rootFolder);
if (totalSize != 99_945_290) {
if (totalSize != 99_925_295) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reran the before/after with these data

throw new AssertionError("Total size did not match expected value, actual: " + totalSize);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import org.openjdk.jmh.annotations.Benchmark;

import java.util.List;

public class SingleThreadedBenchmark extends AbstractBenchmark {

@Benchmark
public int benchmarkCl100kBaseTokenCount(BenchmarkingState state) {
tox-p marked this conversation as resolved.
Show resolved Hide resolved
var result = 0;
var encoding = state.cl100kBase;
for (var fileContent : state.fileContents) {
result += encoding.countTokens(fileContent);
}
return result;
}

@Override
protected List<List<Integer>> encodeAll(final Encoding encoding, final List<String> fileContents) {
return fileContents.stream()
Expand Down
74 changes: 74 additions & 0 deletions lib/src/main/java/com/knuddels/jtokkit/ByteArrayWrapper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.knuddels.jtokkit;

import java.util.Arrays;

class ByteArrayWrapper {
private final byte[] array;

/*
* Creates a new instance of ByteArrayWrapper from the given array.
* The given array is not copied, so every calling method in this class must make sure
* to never pass an array which reference leaked to the outside. Since some of our
* construction methods already create new arrays, we do not want to copy here in this
* constructor again.
*/
ByteArrayWrapper(byte[] array) {
tox-p marked this conversation as resolved.
Show resolved Hide resolved
this.array = array;
}

/**
* Returns the length of this array.
*
* @return the length of this array.
*/
public int length() {
return array.length;
}

/**
* Returns the bytes of this array from startIndex (inclusive) to endIndex (exclusive). The returned array is a copy
* of the original array.
*
* @param startIndex the index from which to start copying (inclusive)
* @param endIndex the index at which to stop copying (exclusive)
* @return a new {@link ByteArrayWrapper} containing the bytes from startIndex (inclusive) to endIndex (exclusive)
* @throws IllegalArgumentException if startIndex is out of bounds, endIndex is out of bounds or endIndex is less than
* startIndex
*/
public ByteArrayWrapper getBytesBetween(int startIndex, int endIndex) {
if (startIndex < 0 || startIndex >= array.length) {
throw new IndexOutOfBoundsException("startIndex out of bounds: " + startIndex + " (" + this + ")");
} else if (endIndex < 0 || endIndex > array.length) {
throw new IndexOutOfBoundsException("endIndex out of bounds: " + endIndex + " (" + this + ")");
} else if (startIndex >= endIndex) {
throw new IllegalArgumentException("startIndex must be less than endIndex: " + startIndex + " >= " + endIndex);
}

int length = endIndex - startIndex;
byte[] result = new byte[length];
System.arraycopy(array, startIndex, result, 0, length);
return new ByteArrayWrapper(result);
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other == null || getClass() != other.getClass()) {
return false;
}

ByteArrayWrapper that = (ByteArrayWrapper) other;
return Arrays.equals(array, that.array);
}

@Override
public int hashCode() {
return Arrays.hashCode(array);
}

@Override
public String toString() {
return Arrays.toString(array);
}
}
Loading