From 305ff598e3c746c31361c1519042bbe7ae8ea39d Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Tue, 25 Jul 2023 18:38:41 -0700 Subject: [PATCH 1/3] Restore unit test on contrastive and beam --- .../modality/nlp/generate/TextGenerator.java | 13 +- .../djl/onnxruntime/engine/OrtNDManager.java | 11 +- .../djl/onnxruntime/zoo/nlp/package-info.java | 18 ++ .../nlp/textgeneration/OrtGptTranslator.java | 91 ++++++++++ .../OrtGptTranslatorFactory.java | 57 ++++++ .../zoo/nlp/textgeneration/package-info.java | 15 ++ .../onnxruntime/engine/GptTranslatorTest.java | 91 ++++++++++ .../nlp/textgeneration/PtGptTranslator.java | 29 ++-- .../nlp/textgeneration/GptTranslatorTest.java | 4 +- ...orch_jni_PyTorchLibrary_torch_pointwise.cc | 22 +-- .../examples/inference/nlp/RollingBatch.java | 122 +++++++++++++ .../inference/nlp/TextGeneration.java | 162 +++++++++++++++++- .../inference/nlp/TextGenerationTest.java | 83 ++++++++- 13 files changed, 680 insertions(+), 38 deletions(-) create mode 100644 engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/package-info.java create mode 100644 engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/OrtGptTranslator.java create mode 100644 engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/OrtGptTranslatorFactory.java create mode 100644 engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/package-info.java create mode 100644 engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/GptTranslatorTest.java create mode 100644 examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java index 3d79168ae45..ed1667c1e28 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java @@ -535,13 +535,13 @@ private NDList prepareInput( } /** - * Forward function call to generate text. + * Generate function call to generate text. * * @param inputIds the input token ids * @return generated token ids * @throws TranslateException if prediction fails */ - public NDArray forward(NDArray inputIds) throws TranslateException { + public NDArray generate(NDArray inputIds) throws TranslateException { switch (searchName) { case "greedy": return greedySearch(inputIds); @@ -555,4 +555,13 @@ public NDArray forward(NDArray inputIds) throws TranslateException { + " contrastive}"); } } + + /** + * Gets the value of the positionOffset. + * + * @return the value of positionOffset + */ + public NDArray getPositionOffset() { + return positionOffset; + } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index 77abc2146d3..aefd0e2158f 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -107,7 +107,16 @@ public OrtNDArray create(boolean[] data, Shape shape) { + " dimensions are not supported."); } - Object tensorIn = OrtUtil.reshape(data, sh); + Object tensorIn; + if (sh.length != 1) { + tensorIn = OrtUtil.reshape(data, sh); + } else { + // Work around the bug in OrtUtil.reshape() when sh.length == 1. + long[] shExpanded = {1, sh[0]}; + boolean[][] arrayIn = (boolean[][]) OrtUtil.reshape(data, shExpanded); + tensorIn = arrayIn[0]; + } + try { return new OrtNDArray(this, alternativeManager, OrtUtils.toTensor(env, tensorIn)); } catch (OrtException e) { diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/package-info.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/package-info.java new file mode 100644 index 00000000000..537a113b11c --- /dev/null +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** + * Contains supplemental classes for the {@link ai.djl.Application.NLP} models in the {@link + * ai.djl.onnxruntime.zoo}. + */ +package ai.djl.onnxruntime.zoo.nlp; diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/OrtGptTranslator.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/OrtGptTranslator.java new file mode 100644 index 00000000000..6a0292034cf --- /dev/null +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/OrtGptTranslator.java @@ -0,0 +1,91 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.onnxruntime.zoo.nlp.textgeneration; + +import ai.djl.modality.nlp.generate.CausalLMOutput; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; + +/** The {@link ai.djl.translate.Translator} for PyTorch GPT2 model. */ +public class OrtGptTranslator implements NoBatchifyTranslator { + + private long kvDim; + private int numAttentionHeads; + private int numLayers; + + /** + * Constructs a new instance of {@code PtGptTranslator}. + * + * @param kvDim the kv dimension + * @param numAttentionHeads the number of attention heads + * @param numLayers the number of layers + */ + public OrtGptTranslator(long kvDim, int numAttentionHeads, int numLayers) { + this.kvDim = kvDim; + this.numAttentionHeads = numAttentionHeads; + this.numLayers = numLayers; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, NDList input) throws Exception { + // input = [inputIds, posIds, attnMask] + NDManager manager = ctx.getNDManager(); + NDArray inputIds = input.get(0); + inputIds.setName("input_ids"); + + NDArray attentionMask = input.get(2); + attentionMask.setName("attention_mask"); + + NDList inputNew; + if (input.size() == 3) { + // pastKeyValue == null + NDArray useCacheBranch = manager.create(new boolean[] {false}, new Shape(1)); + useCacheBranch.setName("use_cache_branch"); + inputNew = new NDList(inputIds, attentionMask, useCacheBranch); + initialDummyPastKeyValues(inputIds, manager, inputNew); + } else { + NDArray useCacheBranch = manager.create(new boolean[] {true}, new Shape(1)); + useCacheBranch.setName("use_cache_branch"); + inputNew = new NDList(inputIds, attentionMask, useCacheBranch); + inputNew.addAll(input.subNDList(3)); + } + + int offset = 3; + for (int i = offset; i < numLayers * 2 + offset; i += 2) { + int order = (i - offset) / 2; + inputNew.get(i).setName(String.format("past_key_values.%s.key", order)); + inputNew.get(i + 1).setName(String.format("past_key_values.%s.value", order)); + } + + return inputNew; + } + + /** {@inheritDoc} */ + @Override + public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws Exception { + return new CausalLMOutput(output.get(0), output.subNDList(1)); + } + + private void initialDummyPastKeyValues(NDArray inputIds, NDManager manager, NDList list) { + long numBatch = inputIds.getShape().get(0); + for (int i = 0; i < numLayers * 2; ++i) { + NDArray array = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim)); + list.add(array); + } + } +} diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/OrtGptTranslatorFactory.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/OrtGptTranslatorFactory.java new file mode 100644 index 00000000000..341d75c8e66 --- /dev/null +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/OrtGptTranslatorFactory.java @@ -0,0 +1,57 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.onnxruntime.zoo.nlp.textgeneration; + +import ai.djl.Model; +import ai.djl.modality.nlp.generate.CausalLMOutput; +import ai.djl.ndarray.NDList; +import ai.djl.translate.ArgumentsUtil; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; + +import java.lang.reflect.Type; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** An {@link TranslatorFactory} that creates a {@link OrtGptTranslator} instance. */ +public class OrtGptTranslatorFactory implements TranslatorFactory { + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(NDList.class, CausalLMOutput.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + if (!isSupported(input, output)) { + throw new IllegalArgumentException("Unsupported input/output types."); + } + long kvDim = ArgumentsUtil.longValue(arguments, "kvDim", 64); + int numAttentionHeads = ArgumentsUtil.intValue(arguments, "numAttentionHeads", 12); + int numLayers = ArgumentsUtil.intValue(arguments, "numLayers", 12); + + return (Translator) (new OrtGptTranslator(kvDim, numAttentionHeads, numLayers)); + } +} diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/package-info.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/package-info.java new file mode 100644 index 00000000000..1f4e1061001 --- /dev/null +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/textgeneration/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes for the {@link ai.djl.Application.NLP#TEXT_GENERATION} models. */ +package ai.djl.onnxruntime.zoo.nlp.textgeneration; diff --git a/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/GptTranslatorTest.java b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/GptTranslatorTest.java new file mode 100644 index 00000000000..8e9b736feb3 --- /dev/null +++ b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/GptTranslatorTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.onnxruntime.engine; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.nlp.generate.CausalLMOutput; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Block; +import ai.djl.nn.LambdaBlock; +import ai.djl.onnxruntime.zoo.nlp.textgeneration.OrtGptTranslatorFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +public class GptTranslatorTest { + + @Test + public void testGpt2() throws TranslateException, ModelException, IOException { + // This is a fake model that simulates language models like GPT2: NDList(inputIds, posIds, + // attnMask) -> NDList(logits(1), pastKv(12*2)[, hiddenStates(13)]) + Block block = + new LambdaBlock( + a -> { + NDList list = new NDList(25); + NDManager manager = a.getManager(); + long[][] logits = new long[4][50257]; + logits[3][257] = 1; + NDArray arr = manager.create(logits).expandDims(0); + list.add(arr); + + for (int i = 0; i < 12 * 2; ++i) { + NDArray array = manager.zeros(new Shape(1, 12, 1, 64)); + list.add(array); + } + return list; + }, + "model"); + + Path modelDir = Paths.get("build/text_generation"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, CausalLMOutput.class) + .optModelPath(modelDir) + .optBlock(block) + .optOption("hasParameter", "false") + .optTranslatorFactory(new OrtGptTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + NDManager manager = NDManager.newBaseManager()) { + long[][] inputIds = {{29744, 28478, 5834, 318}}; + int len = inputIds[0].length; + NDArray input = manager.create(inputIds); + NDArray positionIds = manager.arange(0, len, 1, DataType.INT64).expandDims(0); + NDArray attentionMask = manager.ones(new Shape(1, len), DataType.INT64); + CausalLMOutput res = predictor.predict(new NDList(input, positionIds, attentionMask)); + NDArray logits = res.getLogits(); + long nextTokenId = logits.get(":, -1, :").argMax().getLong(); + Assert.assertEquals(nextTokenId, 257); + NDList list = res.getPastKeyValuesList(); + Assert.assertEquals(list.size(), 24); + Assert.assertEquals(res.getHiddenState().getShape().get(0), 1); + } + } +} diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java index 544f6503f50..841460331cb 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java @@ -22,7 +22,6 @@ import ai.djl.translate.NoBatchifyTranslator; import ai.djl.translate.TranslatorContext; -import java.util.Collections; import java.util.stream.Collectors; /** The {@link ai.djl.translate.Translator} for PyTorch GPT2 model. */ @@ -55,26 +54,22 @@ public NDList processInput(TranslatorContext ctx, NDList input) throws Exception if (input.size() == 3) { // In this case, input has null pastKeyValues. We prefix-append a dummy pastKeyValues, // which is treated as prefix padding, and set the corresponding attnMask to be zero. No - // need to shift the position ids. + // need to shift the position ids, since the starting position id, which is 0, won't + // change after appending the dummy kv_cache. ctx.setAttachment("useDummyPastKeyValues", Boolean.TRUE); // Pad the null pastKeyValues with dummy values - NDList pastKeyValues = initialDummyPastKeyValues(input.get(0), manager); - for (NDArray pkv : pastKeyValues) { - pkv.setName(tupleName); - input.add(pkv); - } + initialDummyPastKeyValues(input.get(0), manager, input); // Append zero to the attentionMask from left, corresponding to the padding long batchSize = input.get(0).getShape().get(0); NDArray attentionMask = manager.zeros(new Shape(batchSize, 1), DataType.INT64).concat(input.get(2), -1); input.set(2, attentionMask); - } else { - for (int i = 3; i < numLayers * 2 + 3; ++i) { - NDArray pkv = input.get(i); - pkv.setName(tupleName); - } + } + + for (int i = 3; i < numLayers * 2 + 3; ++i) { + input.get(i).setName(tupleName); } return input; @@ -111,11 +106,11 @@ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws return new CausalLMOutput(logitsOutput, hiddenStatesOutput, pastKeyValuesOutput); } - private NDList initialDummyPastKeyValues(NDArray inputIds, NDManager manager) { + private void initialDummyPastKeyValues(NDArray inputIds, NDManager manager, NDList list) { long numBatch = inputIds.getShape().get(0); - NDArray dummyKV = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim)); - NDList pastKeyValues = new NDList(); - pastKeyValues.addAll(Collections.nCopies(2 * numLayers, dummyKV)); - return pastKeyValues; + for (int i = 0; i < numLayers * 2; ++i) { + NDArray array = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim)); + list.add(array); + } } } diff --git a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java index 68523be7802..95d2614f339 100644 --- a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java +++ b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java @@ -76,9 +76,9 @@ public void testGpt2() throws TranslateException, ModelException, IOException { long[][] inputIds = {{29744, 28478, 5834, 318}}; int len = inputIds[0].length; NDArray input = manager.create(inputIds); - NDArray attentionMask = manager.ones(new Shape(1, len), DataType.INT64); NDArray positionIds = manager.arange(0, len, 1, DataType.INT64).expandDims(0); - CausalLMOutput res = predictor.predict(new NDList(input, attentionMask, positionIds)); + NDArray attentionMask = manager.ones(new Shape(1, len), DataType.INT64); + CausalLMOutput res = predictor.predict(new NDList(input, positionIds, attentionMask)); NDArray logits = res.getLogits(); long nextTokenId = logits.get(":, -1, :").argMax().getLong(); Assert.assertEquals(nextTokenId, 257); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc index cfc4e97681a..28e40e916be 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc @@ -214,17 +214,17 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMinimum( } JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMedian( - JNIEnv* env, jobject jthis, jlong jself, jlong jdim, jboolean keep_dim) { - API_BEGIN() - const auto* self_ptr = reinterpret_cast(jself); - const auto result = self_ptr->median(jdim, keep_dim); - const auto* value_ptr = new torch::Tensor(std::get<0>(result)); - const auto* indices_ptr = new torch::Tensor(std::get<1>(result)); - std::vector vect; - vect.push_back(reinterpret_cast(value_ptr)); - vect.push_back(reinterpret_cast(indices_ptr)); - return djl::utils::jni::GetLongArrayFromVec(env, vect); - API_END_RETURN() + JNIEnv* env, jobject jthis, jlong jself, jlong jdim, jboolean keep_dim) { + API_BEGIN() + const auto* self_ptr = reinterpret_cast(jself); + const auto result = self_ptr->median(jdim, keep_dim); + const auto* value_ptr = new torch::Tensor(std::get<0>(result)); + const auto* indices_ptr = new torch::Tensor(std::get<1>(result)); + std::vector vect; + vect.push_back(reinterpret_cast(value_ptr)); + vect.push_back(reinterpret_cast(indices_ptr)); + return djl::utils::jni::GetLongArrayFromVec(env, vect); + API_END_RETURN() } JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAbs(JNIEnv* env, jobject jthis, jlong jhandle) { diff --git a/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java b/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java new file mode 100644 index 00000000000..6aac42d6f2a --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference.nlp; + +import ai.djl.MalformedModelException; +import ai.djl.examples.inference.ImageClassification; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.inference.Predictor; +import ai.djl.modality.nlp.generate.CausalLMOutput; +import ai.djl.modality.nlp.generate.ContrastiveSeqBatchScheduler; +import ai.djl.modality.nlp.generate.SearchConfig; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.DeferredTranslatorFactory; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Map; + +public final class RollingBatch { + + private static final Logger logger = LoggerFactory.getLogger(ImageClassification.class); + + private RollingBatch() {} + + public static void main(String[] args) + throws ModelNotFoundException, MalformedModelException, IOException, + TranslateException { + String[] ret = seqBatchSchedulerWithPyTorchContrastive(); + logger.info("{}", ret[0]); + } + + public static String[] seqBatchSchedulerWithPyTorchContrastive() + throws ModelNotFoundException, MalformedModelException, IOException, + TranslateException { + String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip"; + + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, CausalLMOutput.class) + .optModelUrls(url) + .optEngine("PyTorch") + .optTranslatorFactory(new DeferredTranslatorFactory()) + .build(); + + String[] testResult = new String[5]; + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + NDManager manager = model.getNDManager().newSubManager(); + HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("gpt2")) { + + SearchConfig config = new SearchConfig(); + config.setMaxSeqLength(30); + config.setAlpha(0.6f); + config.setK(5); + long padTokenId = 220; + config.setPadTokenId(padTokenId); + + ContrastiveSeqBatchScheduler scheduler = + new ContrastiveSeqBatchScheduler(predictor, config); + + // Initial input + String[] inputs1 = {"DeepMind Company is", "Memories follow me left and right. I can"}; + NDArray inputIds1 = + TextGeneration.encodeWithPadding(manager, tokenizer, inputs1, padTokenId); + NDArray batchUids1 = manager.create(new long[] {0, 1}); + + // Contains both initForward and seqBatcher merge + scheduler.addRequest(inputIds1, batchUids1); + + // Increment forward + scheduler.incrementForward(2); + + // Add more batch (longer) + String[] inputs2 = { + "When your legs don't work like they used to before And I can't sweep you" + " off", + "There's a time that I remember, when I did not know" + }; + NDArray inputIds2 = + TextGeneration.encodeWithPadding(manager, tokenizer, inputs2, padTokenId); + NDArray batchUids2 = manager.create(new long[] {2, 3}); + scheduler.addRequest(inputIds2, batchUids2); + scheduler.incrementForward(2); + + // Add more batch (shorter) + String[] inputs3 = {"A person gets sent back"}; + NDArray inputIds3 = + TextGeneration.encodeWithPadding(manager, tokenizer, inputs3, padTokenId); + NDArray batchUids3 = manager.create(new long[] {4}); + + scheduler.addRequest(inputIds3, batchUids3); + scheduler.incrementForward(config.getMaxSeqLength()); + + // Collect result + Map output = scheduler.collectResults(); + testResult[0] = tokenizer.decode(output.get(0L).toLongArray()); + testResult[1] = tokenizer.decode(output.get(1L).toLongArray()); + testResult[2] = tokenizer.decode(output.get(2L).toLongArray()); + testResult[3] = tokenizer.decode(output.get(3L).toLongArray()); + testResult[4] = tokenizer.decode(output.get(4L).toLongArray()); + } + return testResult; + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java index 4bab3a9cfb6..acbaa152f8c 100644 --- a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java +++ b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java @@ -23,6 +23,8 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; @@ -43,11 +45,15 @@ private TextGeneration() {} public static void main(String[] args) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException { - String ret = generateTextWithPyTorch(); - logger.info("{}", ret); + String ret1 = generateTextWithPyTorchGreedy(); + logger.info("{}", ret1); + String[] ret2 = generateTextWithPyTorchContrastive(); + logger.info("{}", ret2[0]); + String[] ret3 = generateTextWithPyTorchBeam(); + logger.info("{}", ret3[0]); } - public static String generateTextWithPyTorch() + public static String generateTextWithPyTorchGreedy() throws ModelNotFoundException, MalformedModelException, IOException, TranslateException { SearchConfig config = new SearchConfig(); @@ -75,9 +81,157 @@ public static String generateTextWithPyTorch() long[] inputIds = encoding.getIds(); NDArray inputIdArray = manager.create(inputIds).expandDims(0); - NDArray output = generator.greedySearch(inputIdArray); + NDArray output = generator.generate(inputIdArray); long[] outputIds = output.toLongArray(); return tokenizer.decode(outputIds); } } + + public static String[] generateTextWithPyTorchContrastive() + throws ModelNotFoundException, MalformedModelException, IOException, + TranslateException { + SearchConfig config = new SearchConfig(); + config.setMaxSeqLength(60); + long padTokenId = 220; + config.setPadTokenId(padTokenId); + + String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip"; + + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, CausalLMOutput.class) + .optModelUrls(url) + .optEngine("PyTorch") + .optTranslatorFactory(new DeferredTranslatorFactory()) + .build(); + String[] inputs = {"DeepMind Company is", "Memories follow me left and right. I can"}; + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + NDManager manager = model.getNDManager().newSubManager(); + HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("gpt2")) { + + TextGenerator generator = new TextGenerator(predictor, "contrastive", config); + NDArray inputIdArray = encodeWithPadding(manager, tokenizer, inputs, padTokenId); + + NDArray outputs = generator.generate(inputIdArray); + return decodeWithOffset(tokenizer, outputs, generator.getPositionOffset()); + } + } + + public static String[] generateTextWithPyTorchBeam() + throws ModelNotFoundException, MalformedModelException, IOException, + TranslateException { + SearchConfig config = new SearchConfig(); + config.setMaxSeqLength(60); + long padTokenId = 220; + config.setPadTokenId(padTokenId); + + String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip"; + + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, CausalLMOutput.class) + .optModelUrls(url) + .optEngine("PyTorch") + .optTranslatorFactory(new DeferredTranslatorFactory()) + .build(); + String[] inputs = {"DeepMind Company is", "Memories follow me left and right. I can"}; + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + NDManager manager = model.getNDManager().newSubManager(); + HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("gpt2")) { + + TextGenerator generator = new TextGenerator(predictor, "beam", config); + NDArray inputIdArray = encodeWithPadding(manager, tokenizer, inputs, padTokenId); + + NDArray outputs = generator.generate(inputIdArray); + return decodeWithOffset( + tokenizer, outputs, generator.getPositionOffset().repeat(0, config.getBeam())); + } + } + + public static String[] generateTextWithOnnxRuntimeBeam() + throws ModelNotFoundException, MalformedModelException, IOException, + TranslateException { + SearchConfig config = new SearchConfig(); + config.setMaxSeqLength(60); + long padTokenId = 220; + config.setPadTokenId(padTokenId); + + String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_onnx.zip"; + + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, CausalLMOutput.class) + .optModelUrls(url) + .optEngine("OnnxRuntime") + .optTranslatorFactory(new DeferredTranslatorFactory()) + .build(); + String[] inputs = {"DeepMind Company is"}; + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + NDManager manager = model.getNDManager().newSubManager(); + HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("gpt2")) { + + TextGenerator generator = new TextGenerator(predictor, "beam", config); + NDArray inputIdArray = encodeWithPadding(manager, tokenizer, inputs, padTokenId); + + NDArray outputs = generator.generate(inputIdArray); + return decodeWithOffset( + tokenizer, outputs, generator.getPositionOffset().repeat(0, config.getBeam())); + } + } + + public static NDArray encodeWithPadding( + NDManager manager, HuggingFaceTokenizer tokenizer, String[] inputs, long padTokenId) { + NDArray inputIdArray = null; + for (String input : inputs) { + long[] inputIds = tokenizer.encode(input).getIds(); + NDArray deltaInputIdArray = manager.create(inputIds).expandDims(0); + if (inputIdArray == null) { + inputIdArray = deltaInputIdArray; + } else { + if (inputIdArray.getShape().get(1) > deltaInputIdArray.getShape().get(1)) { + // pad deltaInputIdArray + long batchSize = deltaInputIdArray.getShape().get(0); + long deltaSeqLength = + inputIdArray.getShape().get(1) - deltaInputIdArray.getShape().get(1); + deltaInputIdArray = + manager.full( + new Shape(batchSize, deltaSeqLength), + padTokenId, + DataType.INT64) + .concat(deltaInputIdArray, 1); + } else if (inputIdArray.getShape().get(1) < deltaInputIdArray.getShape().get(1)) { + // pad inputIdArray + long batchSize = inputIdArray.getShape().get(0); + long deltaSeqLength = + deltaInputIdArray.getShape().get(1) - inputIdArray.getShape().get(1); + inputIdArray = + manager.full( + new Shape(batchSize, deltaSeqLength), + padTokenId, + DataType.INT64) + .concat(inputIdArray, 1); + } + inputIdArray = inputIdArray.concat(deltaInputIdArray, 0); + } + } + return inputIdArray; + } + + public static String[] decodeWithOffset( + HuggingFaceTokenizer tokenizer, NDArray outputIds, NDArray offset) { + long batchSize = outputIds.getShape().get(0); + String[] outputs = new String[(int) batchSize]; + for (int i = 0; i < batchSize; i++) { + long startIndex = offset.getLong(i); + long[] outputId = outputIds.get("{},{}:", i, startIndex).toLongArray(); + outputs[i] = tokenizer.decode(outputId); + } + return outputs; + } } diff --git a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java index a79573fcdab..d8393d6c07d 100644 --- a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java @@ -28,6 +28,20 @@ public void testTextGeneration() throws TranslateException, ModelException, IOEx TestRequirements.weekly(); TestRequirements.engine("PyTorch"); + // Beam with Ort + String[] output0 = TextGeneration.generateTextWithOnnxRuntimeBeam(); + Assert.assertEquals( + output0[0], + "DeepMind Company is a global leader in the field of artificial intelligence and" + + " artificial intelligence research and development.\n" + + "\n" + + "Our mission is to provide the world with the best and brightest minds in the" + + " field of artificial intelligence and artificial intelligence research and" + + " development.\n" + + "\n" + + "Our mission is to provide the world with the best"); + + // Greedy String expected = "DeepMind Company is a global leader in the field of artificial" + " intelligence and artificial intelligence. We are a leading provider" @@ -35,7 +49,74 @@ public void testTextGeneration() throws TranslateException, ModelException, IOEx + " latest in advanced AI solutions for the automotive industry. We are" + " also a leading provider of advanced AI solutions for the automotive" + " industry, including the"; + Assert.assertEquals(TextGeneration.generateTextWithPyTorchGreedy(), expected); + + // Contrastive + String[] output1 = TextGeneration.generateTextWithPyTorchContrastive(); + Assert.assertEquals( + output1[0], + "DeepMind Company is a leading provider of advanced AI solutions for businesses," + + " government agencies and individuals. We offer a wide range of services" + + " including research, development, training, consulting, and" + + " support.<|endoftext|>This article is about the character. You may be" + + " looking for the original version"); + Assert.assertEquals( + output1[1], + "Memories follow me left and right. I can't remember the last time I saw her.\n" + + "\n" + + "\"What do you mean?\" asked my mother.\n" + + "\n" + + "\"I'm sorry, but I don't know what happened to her.\"\n" + + "\n" + + "\"Well, you're right. She was very"); + + // Beam + String[] output2 = TextGeneration.generateTextWithPyTorchBeam(); + Assert.assertEquals( + output2[0], + "DeepMind Company is a global leader in the field of artificial intelligence and" + + " artificial intelligence research and development.\n" + + "\n" + + "Our mission is to provide the world with the best and brightest minds in the" + + " field of artificial intelligence and artificial intelligence research and" + + " development.\n" + + "\n" + + "Our mission is to"); + Assert.assertEquals( + output2[3], + "Memories follow me left and right. I can't tell you how many times I've been told" + + " that I'm not a good person. I'm not a good person. I'm not a good person." + + " I'm not a good person. I'm not a good person. I'm not a"); + } - Assert.assertEquals(TextGeneration.generateTextWithPyTorch(), expected); + @Test + public void testSeqBatchScheduler() throws TranslateException, ModelException, IOException { + TestRequirements.weekly(); + TestRequirements.engine("PyTorch"); + String[] output = RollingBatch.seqBatchSchedulerWithPyTorchContrastive(); + Assert.assertEquals( + output[0], + "DeepMind Company is a leading provider of advanced AI solutions for businesses," + + " government agencies and individuals. We offer a wide range of services" + + " including research, development"); + Assert.assertEquals( + output[1], + "Memories follow me left and right. I can't wait to see what happens next.\n" + + "\n" + + "Advertisements<|endoftext|>"); + Assert.assertEquals( + output[2], + "When your legs don't work like they used to before And I can't sweep you off my" + + " feet, but I can help you out with your hair"); + Assert.assertEquals( + output[3], + "There's a time that I remember, when I did not know what to do with myself. I felt" + + " like I was going to die. I thought"); + Assert.assertEquals( + output[4], + "A person gets sent back to prison for life.\n" + + "\n" + + "But if you're lucky, you can escape from prison and live happily ever" + + " after.\n"); } } From 76f21b201ee3cf13a9e44dfe99ba4d52c4405f13 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Thu, 27 Jul 2023 15:29:03 -0700 Subject: [PATCH 2/3] fix --- .../test/java/ai/djl/onnxruntime/engine/GptTranslatorTest.java | 1 - 1 file changed, 1 deletion(-) diff --git a/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/GptTranslatorTest.java b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/GptTranslatorTest.java index 8e9b736feb3..ef385de92fc 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/GptTranslatorTest.java +++ b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/GptTranslatorTest.java @@ -85,7 +85,6 @@ public void testGpt2() throws TranslateException, ModelException, IOException { Assert.assertEquals(nextTokenId, 257); NDList list = res.getPastKeyValuesList(); Assert.assertEquals(list.size(), 24); - Assert.assertEquals(res.getHiddenState().getShape().get(0), 1); } } } From dcf26b0b3625fb07843e0f5fff664cb8b491a1c8 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 12 Aug 2023 12:02:04 -0700 Subject: [PATCH 3/3] Remove unnecessary package-info file --- .../djl/onnxruntime/zoo/nlp/package-info.java | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/package-info.java diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/package-info.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/package-info.java deleted file mode 100644 index 537a113b11c..00000000000 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/nlp/package-info.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ - -/** - * Contains supplemental classes for the {@link ai.djl.Application.NLP} models in the {@link - * ai.djl.onnxruntime.zoo}. - */ -package ai.djl.onnxruntime.zoo.nlp;