Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Restore Lm search unittest to recover coverage rate #2723

Merged
merged 3 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,13 @@ private NDList prepareInput(
}

/**
* Forward function call to generate text.
* Generate function call to generate text.
*
* @param inputIds the input token ids
* @return generated token ids
* @throws TranslateException if prediction fails
*/
public NDArray forward(NDArray inputIds) throws TranslateException {
public NDArray generate(NDArray inputIds) throws TranslateException {
switch (searchName) {
case "greedy":
return greedySearch(inputIds);
Expand All @@ -555,4 +555,13 @@ public NDArray forward(NDArray inputIds) throws TranslateException {
+ " contrastive}");
}
}

/**
* Gets the value of the positionOffset.
*
* @return the value of positionOffset
*/
public NDArray getPositionOffset() {
return positionOffset;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,16 @@ public OrtNDArray create(boolean[] data, Shape shape) {
+ " dimensions are not supported.");
}

Object tensorIn = OrtUtil.reshape(data, sh);
Object tensorIn;
if (sh.length != 1) {
tensorIn = OrtUtil.reshape(data, sh);
} else {
// Work around the bug in OrtUtil.reshape() when sh.length == 1.
long[] shExpanded = {1, sh[0]};
boolean[][] arrayIn = (boolean[][]) OrtUtil.reshape(data, shExpanded);
tensorIn = arrayIn[0];
}

try {
return new OrtNDArray(this, alternativeManager, OrtUtils.toTensor(env, tensorIn));
} catch (OrtException e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.onnxruntime.zoo.nlp.textgeneration;

import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;

/** The {@link ai.djl.translate.Translator} for PyTorch GPT2 model. */
public class OrtGptTranslator implements NoBatchifyTranslator<NDList, CausalLMOutput> {

private long kvDim;
private int numAttentionHeads;
private int numLayers;

/**
* Constructs a new instance of {@code PtGptTranslator}.
*
* @param kvDim the kv dimension
* @param numAttentionHeads the number of attention heads
* @param numLayers the number of layers
*/
public OrtGptTranslator(long kvDim, int numAttentionHeads, int numLayers) {
this.kvDim = kvDim;
this.numAttentionHeads = numAttentionHeads;
this.numLayers = numLayers;
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, NDList input) throws Exception {
// input = [inputIds, posIds, attnMask]
NDManager manager = ctx.getNDManager();
NDArray inputIds = input.get(0);
inputIds.setName("input_ids");

NDArray attentionMask = input.get(2);
attentionMask.setName("attention_mask");

NDList inputNew;
if (input.size() == 3) {
// pastKeyValue == null
NDArray useCacheBranch = manager.create(new boolean[] {false}, new Shape(1));
useCacheBranch.setName("use_cache_branch");
inputNew = new NDList(inputIds, attentionMask, useCacheBranch);
initialDummyPastKeyValues(inputIds, manager, inputNew);
} else {
NDArray useCacheBranch = manager.create(new boolean[] {true}, new Shape(1));
useCacheBranch.setName("use_cache_branch");
inputNew = new NDList(inputIds, attentionMask, useCacheBranch);
inputNew.addAll(input.subNDList(3));
}

int offset = 3;
for (int i = offset; i < numLayers * 2 + offset; i += 2) {
int order = (i - offset) / 2;
inputNew.get(i).setName(String.format("past_key_values.%s.key", order));
inputNew.get(i + 1).setName(String.format("past_key_values.%s.value", order));
}

return inputNew;
}

/** {@inheritDoc} */
@Override
public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws Exception {
return new CausalLMOutput(output.get(0), output.subNDList(1));
}

private void initialDummyPastKeyValues(NDArray inputIds, NDManager manager, NDList list) {
long numBatch = inputIds.getShape().get(0);
for (int i = 0; i < numLayers * 2; ++i) {
NDArray array = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim));
list.add(array);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.onnxruntime.zoo.nlp.textgeneration;

import ai.djl.Model;
import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDList;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;

import java.lang.reflect.Type;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/** An {@link TranslatorFactory} that creates a {@link OrtGptTranslator} instance. */
public class OrtGptTranslatorFactory implements TranslatorFactory {

private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet<>();

static {
SUPPORTED_TYPES.add(new Pair<>(NDList.class, CausalLMOutput.class));
}

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return SUPPORTED_TYPES;
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
if (!isSupported(input, output)) {
throw new IllegalArgumentException("Unsupported input/output types.");
}
long kvDim = ArgumentsUtil.longValue(arguments, "kvDim", 64);
int numAttentionHeads = ArgumentsUtil.intValue(arguments, "numAttentionHeads", 12);
int numLayers = ArgumentsUtil.intValue(arguments, "numLayers", 12);

return (Translator<I, O>) (new OrtGptTranslator(kvDim, numAttentionHeads, numLayers));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

/** Contains classes for the {@link ai.djl.Application.NLP#TEXT_GENERATION} models. */
package ai.djl.onnxruntime.zoo.nlp.textgeneration;
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.onnxruntime.engine;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.onnxruntime.zoo.nlp.textgeneration.OrtGptTranslatorFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

public class GptTranslatorTest {

@Test
public void testGpt2() throws TranslateException, ModelException, IOException {
// This is a fake model that simulates language models like GPT2: NDList(inputIds, posIds,
// attnMask) -> NDList(logits(1), pastKv(12*2)[, hiddenStates(13)])
Block block =
new LambdaBlock(
a -> {
NDList list = new NDList(25);
NDManager manager = a.getManager();
long[][] logits = new long[4][50257];
logits[3][257] = 1;
NDArray arr = manager.create(logits).expandDims(0);
list.add(arr);

for (int i = 0; i < 12 * 2; ++i) {
NDArray array = manager.zeros(new Shape(1, 12, 1, 64));
list.add(array);
}
return list;
},
"model");

Path modelDir = Paths.get("build/text_generation");
Files.createDirectories(modelDir);

Criteria<NDList, CausalLMOutput> criteria =
Criteria.builder()
.setTypes(NDList.class, CausalLMOutput.class)
.optModelPath(modelDir)
.optBlock(block)
.optOption("hasParameter", "false")
.optTranslatorFactory(new OrtGptTranslatorFactory())
.build();

try (ZooModel<NDList, CausalLMOutput> model = criteria.loadModel();
Predictor<NDList, CausalLMOutput> predictor = model.newPredictor();
NDManager manager = NDManager.newBaseManager()) {
long[][] inputIds = {{29744, 28478, 5834, 318}};
int len = inputIds[0].length;
NDArray input = manager.create(inputIds);
NDArray positionIds = manager.arange(0, len, 1, DataType.INT64).expandDims(0);
NDArray attentionMask = manager.ones(new Shape(1, len), DataType.INT64);
CausalLMOutput res = predictor.predict(new NDList(input, positionIds, attentionMask));
NDArray logits = res.getLogits();
long nextTokenId = logits.get(":, -1, :").argMax().getLong();
Assert.assertEquals(nextTokenId, 257);
NDList list = res.getPastKeyValuesList();
Assert.assertEquals(list.size(), 24);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;

import java.util.Collections;
import java.util.stream.Collectors;

/** The {@link ai.djl.translate.Translator} for PyTorch GPT2 model. */
Expand Down Expand Up @@ -55,26 +54,22 @@ public NDList processInput(TranslatorContext ctx, NDList input) throws Exception
if (input.size() == 3) {
// In this case, input has null pastKeyValues. We prefix-append a dummy pastKeyValues,
// which is treated as prefix padding, and set the corresponding attnMask to be zero. No
// need to shift the position ids.
// need to shift the position ids, since the starting position id, which is 0, won't
// change after appending the dummy kv_cache.
ctx.setAttachment("useDummyPastKeyValues", Boolean.TRUE);

// Pad the null pastKeyValues with dummy values
NDList pastKeyValues = initialDummyPastKeyValues(input.get(0), manager);
for (NDArray pkv : pastKeyValues) {
pkv.setName(tupleName);
input.add(pkv);
}
initialDummyPastKeyValues(input.get(0), manager, input);

// Append zero to the attentionMask from left, corresponding to the padding
long batchSize = input.get(0).getShape().get(0);
NDArray attentionMask =
manager.zeros(new Shape(batchSize, 1), DataType.INT64).concat(input.get(2), -1);
input.set(2, attentionMask);
} else {
for (int i = 3; i < numLayers * 2 + 3; ++i) {
NDArray pkv = input.get(i);
pkv.setName(tupleName);
}
}

for (int i = 3; i < numLayers * 2 + 3; ++i) {
input.get(i).setName(tupleName);
}

return input;
Expand Down Expand Up @@ -111,11 +106,11 @@ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws
return new CausalLMOutput(logitsOutput, hiddenStatesOutput, pastKeyValuesOutput);
}

private NDList initialDummyPastKeyValues(NDArray inputIds, NDManager manager) {
private void initialDummyPastKeyValues(NDArray inputIds, NDManager manager, NDList list) {
long numBatch = inputIds.getShape().get(0);
NDArray dummyKV = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim));
NDList pastKeyValues = new NDList();
pastKeyValues.addAll(Collections.nCopies(2 * numLayers, dummyKV));
return pastKeyValues;
for (int i = 0; i < numLayers * 2; ++i) {
NDArray array = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim));
list.add(array);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ public void testGpt2() throws TranslateException, ModelException, IOException {
long[][] inputIds = {{29744, 28478, 5834, 318}};
int len = inputIds[0].length;
NDArray input = manager.create(inputIds);
NDArray attentionMask = manager.ones(new Shape(1, len), DataType.INT64);
NDArray positionIds = manager.arange(0, len, 1, DataType.INT64).expandDims(0);
CausalLMOutput res = predictor.predict(new NDList(input, attentionMask, positionIds));
NDArray attentionMask = manager.ones(new Shape(1, len), DataType.INT64);
CausalLMOutput res = predictor.predict(new NDList(input, positionIds, attentionMask));
NDArray logits = res.getLogits();
long nextTokenId = logits.get(":, -1, :").argMax().getLong();
Assert.assertEquals(nextTokenId, 257);
Expand Down
Loading
Loading