Skip to content

Commit

Permalink
[tokenizer] add optional tokenizerPath Prior to modelPath (#3120)
Browse files Browse the repository at this point in the history
* [tokenizer] add optional tokenizerPath Prior to modelPath

---------

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
ewan0x79 and frankfliu authored Apr 25, 2024
1 parent c03669f commit 6efe660
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -686,7 +687,6 @@ static PaddingStrategy fromValue(String value) {
/** The builder for creating huggingface tokenizer. */
public static final class Builder {

private Path tokenizerPath;
private NDManager manager;
private Map<String, String> options;

Expand Down Expand Up @@ -724,7 +724,7 @@ public Builder optTokenizerName(String tokenizerName) {
* @return this builder
*/
public Builder optTokenizerPath(Path tokenizerPath) {
this.tokenizerPath = tokenizerPath;
options.putIfAbsent("tokenizerPath", tokenizerPath.toString());
return this;
}

Expand Down Expand Up @@ -894,9 +894,11 @@ public HuggingFaceTokenizer build() throws IOException {
if (tokenizerName != null) {
return managed(HuggingFaceTokenizer.newInstance(tokenizerName, options));
}
if (tokenizerPath == null) {
String path = options.get("tokenizerPath");
if (path == null) {
throw new IllegalArgumentException("Missing tokenizer path.");
}
Path tokenizerPath = Paths.get(path);
if (Files.isDirectory(tokenizerPath)) {
Path tokenizerFile = tokenizerPath.resolve("tokenizer.json");
if (Files.exists(tokenizerFile)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public void testCrossEncoderTranslator()
.optBlock(block)
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-cased")
.optArgument("tokenizerPath", modelDir)
.optOption("hasParameter", "false")
.optTranslatorFactory(new CrossEncoderTranslatorFactory())
.build();
Expand Down

0 comments on commit 6efe660

Please sign in to comment.