From 8fc0019a9f2fbfcce34f53b161c7e7513a6d0d16 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 11 Apr 2024 21:28:47 -0700 Subject: [PATCH] [api] Support "inputs" key in json request for text embedding translator --- .../java/ai/djl/modality/nlp/TextPrompt.java | 94 +++++++++++++++++++ .../TextClassificationServingTranslator.java | 13 ++- .../TextEmbeddingServingTranslator.java | 13 ++- .../TokenClassificationServingTranslator.java | 13 ++- .../TextEmbeddingTranslatorTest.java | 36 +++++++ 5 files changed, 148 insertions(+), 21 deletions(-) create mode 100644 api/src/main/java/ai/djl/modality/nlp/TextPrompt.java diff --git a/api/src/main/java/ai/djl/modality/nlp/TextPrompt.java b/api/src/main/java/ai/djl/modality/nlp/TextPrompt.java new file mode 100644 index 00000000000..dd1cef113bd --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/TextPrompt.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp; + +import ai.djl.modality.Input; +import ai.djl.translate.TranslateException; +import ai.djl.util.JsonUtils; + +import com.google.gson.JsonElement; +import com.google.gson.JsonParseException; + +/** The input container for NLP text prompt. */ +public final class TextPrompt { + + private String text; + private String[] batch; + + private TextPrompt(String text) { + this.text = text; + } + + private TextPrompt(String[] batch) { + this.batch = batch; + } + + /** + * Returns if the prompt is a batch. + * + * @return {@code true} if the prompt is a batch + */ + public boolean isBatch() { + return batch != null; + } + + /** + * Returns the single prompt. + * + * @return the single prompt + */ + public String getText() { + return text; + } + + /** + * Returns the batch prompt. + * + * @return the batch prompt + */ + public String[] getBatch() { + return batch; + } + + /** + * Returns the {@code TextPrompt} from the {@link Input}. + * + * @param input the input object + * @return the {@code TextPrompt} from the {@link Input} + * @throws TranslateException if the input is invalid + */ + public static TextPrompt parseInput(Input input) throws TranslateException { + String contentType = input.getProperty("Content-Type", null); + String text = input.getData().getAsString(); + if (!"application/json".equals(contentType)) { + return new TextPrompt(text); + } + + try { + JsonElement element = JsonUtils.GSON.fromJson(text, JsonElement.class); + if (element != null && element.isJsonObject()) { + element = element.getAsJsonObject().get("inputs"); + } + if (element == null) { + throw new TranslateException("Missing \"inputs\" in json."); + } else if (element.isJsonArray()) { + String[] batch = JsonUtils.GSON.fromJson(element, String[].class); + return new TextPrompt(batch); + } else { + return new TextPrompt(element.getAsString()); + } + } catch (JsonParseException e) { + throw new TranslateException("Input is not a valid json.", e); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java index 27e343120c4..cb265087557 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java @@ -15,6 +15,7 @@ import ai.djl.modality.Classifications; import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.modality.nlp.TextPrompt; import ai.djl.ndarray.BytesSupplier; import ai.djl.ndarray.NDList; import ai.djl.translate.Batchifier; @@ -22,7 +23,6 @@ import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; -import ai.djl.util.JsonUtils; /** * A {@link Translator} that can handle generic text classification {@link Input} and {@link @@ -57,14 +57,13 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception throw new TranslateException("Input data is empty."); } - String contentType = input.getProperty("Content-Type", null); - String text = input.getData().getAsString(); - if ("application/json".equals(contentType)) { + TextPrompt prompt = TextPrompt.parseInput(input); + if (prompt.isBatch()) { ctx.setAttachment("batch", Boolean.TRUE); - String[] inputs = JsonUtils.GSON.fromJson(text, String[].class); - return batchTranslator.processInput(ctx, inputs); + return batchTranslator.processInput(ctx, prompt.getBatch()); } - NDList ret = translator.processInput(ctx, text); + + NDList ret = translator.processInput(ctx, prompt.getText()); Batchifier batchifier = translator.getBatchifier(); if (batchifier != null) { NDList[] batch = {ret}; diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java index 110f9e09fe5..6222dfc6f6e 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java @@ -14,6 +14,7 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.modality.nlp.TextPrompt; import ai.djl.ndarray.BytesSupplier; import ai.djl.ndarray.NDList; import ai.djl.translate.Batchifier; @@ -21,7 +22,6 @@ import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; -import ai.djl.util.JsonUtils; /** A {@link Translator} that can handle generic text embedding {@link Input} and {@link Output}. */ public class TextEmbeddingServingTranslator implements NoBatchifyTranslator { @@ -53,14 +53,13 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception throw new TranslateException("Input data is empty."); } - String contentType = input.getProperty("Content-Type", null); - String text = input.getData().getAsString(); - if ("application/json".equals(contentType)) { + TextPrompt prompt = TextPrompt.parseInput(input); + if (prompt.isBatch()) { ctx.setAttachment("batch", Boolean.TRUE); - String[] inputs = JsonUtils.GSON.fromJson(text, String[].class); - return batchTranslator.processInput(ctx, inputs); + return batchTranslator.processInput(ctx, prompt.getBatch()); } - NDList ret = translator.processInput(ctx, text); + + NDList ret = translator.processInput(ctx, prompt.getText()); Batchifier batchifier = translator.getBatchifier(); if (batchifier != null) { NDList[] batch = {ret}; diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java index 6f97964351f..e9c5751a324 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java @@ -14,6 +14,7 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.modality.nlp.TextPrompt; import ai.djl.ndarray.BytesSupplier; import ai.djl.ndarray.NDList; import ai.djl.translate.Batchifier; @@ -21,7 +22,6 @@ import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; -import ai.djl.util.JsonUtils; /** * A {@link Translator} that can handle generic token classification {@link Input} and {@link @@ -56,14 +56,13 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception throw new TranslateException("Input data is empty."); } - String contentType = input.getProperty("Content-Type", null); - String text = input.getData().getAsString(); - if ("application/json".equals(contentType)) { + TextPrompt prompt = TextPrompt.parseInput(input); + if (prompt.isBatch()) { ctx.setAttachment("batch", Boolean.TRUE); - String[] inputs = JsonUtils.GSON.fromJson(text, String[].class); - return batchTranslator.processInput(ctx, inputs); + return batchTranslator.processInput(ctx, prompt.getBatch()); } - NDList ret = translator.processInput(ctx, text); + + NDList ret = translator.processInput(ctx, prompt.getText()); Batchifier batchifier = translator.getBatchifier(); if (batchifier != null) { NDList[] batch = {ret}; diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java index 91a96fd3ec8..9affa1341f2 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java @@ -157,6 +157,16 @@ public void testTextEmbeddingTranslator() float[] res = JsonUtils.GSON.fromJson(out.getAsString(0), float[].class); Assert.assertEquals(res.length, 384); Assertions.assertAlmostEquals(res[0], 0.05103); + + input = new Input(); + Map map = new HashMap<>(); + map.put("inputs", text); + input.add(JsonUtils.GSON.toJson(map)); + input.addProperty("Content-Type", "application/json"); + out = predictor.predict(input); + res = (float[]) out.getData().getAsObject(); + Assert.assertEquals(res.length, 384); + Assertions.assertAlmostEquals(res[0], 0.05103); } try (Model model = Model.newInstance("test")) { @@ -237,6 +247,32 @@ public void testTextEmbeddingBatchTranslator() float[][] res = (float[][]) out.getData().getAsObject(); Assert.assertEquals(res[0].length, 384); Assertions.assertAlmostEquals(res[0][0], 0.05103); + + input = new Input(); + Map map = new HashMap<>(); + map.put("inputs", text); + input.add(JsonUtils.GSON.toJson(map)); + input.addProperty("Content-Type", "application/json"); + out = predictor.predict(input); + res = (float[][]) out.getData().getAsObject(); + Assert.assertEquals(res[0].length, 384); + Assertions.assertAlmostEquals(res[0][0], 0.05103); + + Assert.assertThrows( + () -> { + Input empty = new Input(); + empty.add(JsonUtils.GSON.toJson(new HashMap<>())); + empty.addProperty("Content-Type", "application/json"); + predictor.predict(empty); + }); + + Assert.assertThrows( + () -> { + Input empty = new Input(); + empty.add("{ \"invalid json\""); + empty.addProperty("Content-Type", "application/json"); + predictor.predict(empty); + }); } } }