Skip to content

Commit

Permalink
Fix issue with setPadding and setTruncation overriding configurations…
Browse files Browse the repository at this point in the history
… set in tokenizer.json (#2741)
  • Loading branch information
siddvenk authored Aug 10, 2023
1 parent ff66278 commit 17bfda1
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 9 deletions.
40 changes: 31 additions & 9 deletions extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,11 +661,24 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
_ => Err("strategy must be one of [longest, max_length]"),
};

let mut params = PaddingParams::default();
params.strategy = res_strategy.unwrap();
params.pad_to_multiple_of = Some(pad_to_multiple_of as usize);
let res_pad_to_multiple_of = match pad_to_multiple_of as usize {
0 => None,
val => Some(val)
};

let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_padding(Some(params));

if let Some(padding_params) = tokenizer.get_padding_mut() {
padding_params.strategy = res_strategy.unwrap();
padding_params.pad_to_multiple_of = res_pad_to_multiple_of;
} else {
let padding_params = PaddingParams {
strategy: res_strategy.unwrap(),
pad_to_multiple_of: res_pad_to_multiple_of,
..Default::default()
};
tokenizer.with_padding(Some(padding_params));
}
}

#[no_mangle]
Expand Down Expand Up @@ -697,13 +710,22 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
"ONLY_SECOND" => Ok(TruncationStrategy::OnlySecond),
_ => Err("strategy must be one of [longest_first, only_first, only_second]"),
};
let mut params = TruncationParams::default();
params.max_length = truncation_max_length as usize;
params.strategy = res_strategy.unwrap();
params.stride = truncation_stride as usize;

let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_truncation(Some(params));

if let Some(truncation_params) = tokenizer.get_truncation_mut() {
truncation_params.strategy = res_strategy.unwrap();
truncation_params.stride = truncation_stride as usize;
truncation_params.max_length = truncation_max_length as usize;
} else {
let truncation_params = TruncationParams {
strategy: res_strategy.unwrap(),
stride: truncation_stride as usize,
max_length: truncation_max_length as usize,
..Default::default()
};
tokenizer.with_truncation(Some(truncation_params));
}
}

#[no_mangle]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,4 +464,22 @@ public void testBatchProcessing() throws IOException {
Assert.assertEquals(outputs, outputsWithoutSpecialTokens);
}
}

@Test
public void testTokenizerWithPresetPaddingConfiguration() throws IOException {
try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder()
.optTokenizerPath(
Paths.get("src/test/resources/fake-tokenizer-with-padding/"))
.optMaxLength(8)
.optPadToMaxLength()
.build()) {
Encoding encoding = tokenizer.encode("test sentence");
String[] tokens = encoding.getTokens();
String[] expected = {
"<s>", "▁", "test", "▁sentence", "</s>", "<pad>", "<pad>", "<pad>"
};
Assert.assertEquals(tokens, expected);
}
}
}

Large diffs are not rendered by default.

0 comments on commit 17bfda1

Please sign in to comment.