Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tokenizer] Adds includeTokenTypes for all translators #3035

Merged
merged 1 commit into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@ public class FillMaskBatchTranslator implements NoBatchifyTranslator<String[], C
private String maskToken;
private long maskTokenId;
private int topK;
private boolean includeTokenTypes;
private Batchifier batchifier;

FillMaskBatchTranslator(
HuggingFaceTokenizer tokenizer, String maskToken, int topK, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
String maskToken,
int topK,
boolean includeTokenTypes,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.maskToken = maskToken;
this.topK = topK;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
Expand All @@ -52,7 +58,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) throws Transl
for (int i = 0; i < encodings.length; ++i) {
long[] indices = encodings[i].getIds();
maskIndices[i] = FillMaskTranslator.getMaskIndex(indices, maskToken, maskTokenId);
batch[i] = encodings[i].toNDList(manager, false);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
}
return batchifier.batchify(batch);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ public class FillMaskTranslator implements Translator<String, Classifications> {
private String maskToken;
private long maskTokenId;
private int topK;
private boolean includeTokenTypes;
private Batchifier batchifier;

FillMaskTranslator(
HuggingFaceTokenizer tokenizer, String maskToken, int topK, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
String maskToken,
int topK,
boolean includeTokenTypes,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.maskToken = maskToken;
this.topK = topK;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
Expand All @@ -61,7 +67,7 @@ public NDList processInput(TranslatorContext ctx, String input) throws Translate
long[] indices = encoding.getIds();
int maskIndex = getMaskIndex(indices, maskToken, maskTokenId);
ctx.setAttachment("maskIndex", maskIndex);
return encoding.toNDList(ctx.getNDManager(), false);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
}

/** {@inheritDoc} */
Expand All @@ -75,7 +81,8 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) {
@Override
public FillMaskBatchTranslator toBatchTranslator(Batchifier batchifier) {
tokenizer.enableBatch();
return new FillMaskBatchTranslator(tokenizer, maskToken, topK, batchifier);
return new FillMaskBatchTranslator(
tokenizer, maskToken, topK, includeTokenTypes, batchifier);
}

static int getMaskIndex(long[] indices, String maskToken, long maskTokenId)
Expand Down Expand Up @@ -139,6 +146,7 @@ public static final class Builder {
private HuggingFaceTokenizer tokenizer;
private String maskedToken = "[MASK]";
private int topK = 5;
private boolean includeTokenTypes;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
Expand Down Expand Up @@ -167,6 +175,17 @@ public Builder optTopK(int topK) {
return this;
}

/**
* Sets if include token types for the {@link Translator}.
*
* @param includeTokenTypes true to include token types
* @return this builder
*/
public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
this.includeTokenTypes = includeTokenTypes;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -186,6 +205,7 @@ public Builder optBatchifier(Batchifier batchifier) {
public void configure(Map<String, ?> arguments) {
optMaskToken(ArgumentsUtil.stringValue(arguments, "maskToken", "[MASK]"));
optTopK(ArgumentsUtil.intValue(arguments, "topK", 5));
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}
Expand All @@ -197,7 +217,8 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public FillMaskTranslator build() throws IOException {
return new FillMaskTranslator(tokenizer, maskedToken, topK, batchifier);
return new FillMaskTranslator(
tokenizer, maskedToken, topK, includeTokenTypes, batchifier);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ public class TextClassificationBatchTranslator
implements NoBatchifyTranslator<String[], Classifications[]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private PretrainedConfig config;

TextClassificationBatchTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
TextClassificationBatchTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

Expand All @@ -56,7 +59,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) {
Encoding[] encodings = tokenizer.batchEncode(inputs);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, false);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
}
return batchifier.batchify(batch);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@
public class TextClassificationTranslator implements Translator<String, Classifications> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private PretrainedConfig config;

TextClassificationTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
TextClassificationTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

Expand All @@ -63,7 +66,7 @@ public void prepare(TranslatorContext ctx) throws IOException {
@Override
public NDList processInput(TranslatorContext ctx, String input) {
Encoding encoding = tokenizer.encode(input);
return encoding.toNDList(ctx.getNDManager(), false);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
}

/** {@inheritDoc} */
Expand All @@ -76,7 +79,7 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) {
@Override
public TextClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) {
tokenizer.enableBatch();
return new TextClassificationBatchTranslator(tokenizer, batchifier);
return new TextClassificationBatchTranslator(tokenizer, includeTokenTypes, batchifier);
}

static Classifications toClassifications(PretrainedConfig config, NDList list) {
Expand Down Expand Up @@ -127,12 +130,24 @@ public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arg
public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
}

/**
* Sets if include token types for the {@link Translator}.
*
* @param includeTokenTypes true to include token types
* @return this builder
*/
public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
this.includeTokenTypes = includeTokenTypes;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -150,6 +165,7 @@ public Builder optBatchifier(Batchifier batchifier) {
* @param arguments the model arguments
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}
Expand All @@ -161,7 +177,7 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public TextClassificationTranslator build() throws IOException {
return new TextClassificationTranslator(tokenizer, batchifier);
return new TextClassificationTranslator(tokenizer, includeTokenTypes, batchifier);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ public class TokenClassificationBatchTranslator
implements NoBatchifyTranslator<String[], NamedEntity[][]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private PretrainedConfig config;

TokenClassificationBatchTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
TokenClassificationBatchTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

Expand All @@ -58,7 +61,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) {
ctx.setAttachment("encodings", encodings);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, false);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
}
return batchifier.batchify(batch);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@
public class TokenClassificationTranslator implements Translator<String, NamedEntity[]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private PretrainedConfig config;

TokenClassificationTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
TokenClassificationTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

Expand All @@ -65,7 +68,7 @@ public void prepare(TranslatorContext ctx) throws IOException {
public NDList processInput(TranslatorContext ctx, String input) {
Encoding encoding = tokenizer.encode(input);
ctx.setAttachment("encoding", encoding);
return encoding.toNDList(ctx.getNDManager(), false);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
}

/** {@inheritDoc} */
Expand All @@ -79,7 +82,7 @@ public NamedEntity[] processOutput(TranslatorContext ctx, NDList list) {
@Override
public TokenClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) {
tokenizer.enableBatch();
return new TokenClassificationBatchTranslator(tokenizer, batchifier);
return new TokenClassificationBatchTranslator(tokenizer, includeTokenTypes, batchifier);
}

/**
Expand Down Expand Up @@ -139,12 +142,24 @@ static NamedEntity[] toNamedEntities(Encoding encoding, NDList list, PretrainedC
public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
}

/**
* Sets if include token types for the {@link Translator}.
*
* @param includeTokenTypes true to include token types
* @return this builder
*/
public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
this.includeTokenTypes = includeTokenTypes;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -162,6 +177,7 @@ public Builder optBatchifier(Batchifier batchifier) {
* @param arguments the model arguments
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}
Expand All @@ -173,7 +189,7 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public TokenClassificationTranslator build() throws IOException {
return new TokenClassificationTranslator(tokenizer, batchifier);
return new TokenClassificationTranslator(tokenizer, includeTokenTypes, batchifier);
}
}
}
Loading