Skip to content

Commit

Permalink
[tokenizers] Fixes memory leak when there is overflowing tokens (#3317)
Browse files Browse the repository at this point in the history
* If you call TokenizersLibrary.LIB.getOverflowing you must also clean up all overflow encodings.
If withOverflowingTokens was false no Encodings where generated leaving jni Encoding handles that would not be properly deleted.

This introduces a new native method where you can inquire about number of overflow tokens without using any jni resources.
And you will now only call TokenizersLibrary.LIB.getOverflowing(encoding) if withOverflowingTokens is true.
  • Loading branch information
baldersheim authored Jul 10, 2024
1 parent f50900a commit f5c9a82
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
11 changes: 11 additions & 0 deletions extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,17 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
array
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getOverflowCount(
_: JNIEnv,
_: JObject,
handle: jlong,
) -> jint {
let encoding = cast_handle::<Encoding>(handle);
let count = encoding.get_overflowing().len();
count as jint
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getOverflowing<
'local,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,10 +608,11 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);

long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);
boolean exceedMaxLength = overflowingHandles.length > 0;
int overFlowCount = TokenizersLibrary.LIB.getOverflowCount(encoding);
boolean exceedMaxLength = overFlowCount > 0;
Encoding[] overflowing;
if (withOverflowingTokens) {
long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);
overflowing = new Encoding[overflowingHandles.length];
for (int i = 0; i < overflowingHandles.length; ++i) {
overflowing[i] = toEncoding(overflowingHandles[i], true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public native long[] batchEncodePair(

public native long[] getOverflowing(long encoding);

public native int getOverflowCount(long encoding);

public native String decode(long tokenizer, long[] ids, boolean addSpecialTokens);

public native String getTruncationStrategy(long tokenizer);
Expand Down

0 comments on commit f5c9a82

Please sign in to comment.