Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow pytorch stream model loading #729

Merged
merged 2 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
with:
python-version: '3.x'
- name: Install CN fonts
run: apt-get update && apt-get install fonts-arphic-uming
run: sudo apt-get update && sudo apt-get install fonts-arphic-uming
- name: install Python Dependencies
run: pip3 install nbconvert==5.6.1 mkdocs mkdocs-exclude mknotebooks==0.4.1 mkdocs-material jupyter Pygments Markdown==3.2.2
- name: Install IJava kernel
Expand Down
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,15 @@ default boolean[] toBooleanArray() {
return ret;
}

/**
* Converts this {@code NDArray} to a String array.
*
* <p>This method is only applicable to the String typed NDArray and not for printing purpose
*
* @return Array of Strings
*/
String[] toStringArray();

/**
* Converts this {@code NDArray} to a Number array based on its {@link DataType}.
*
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ default NDArray stopGradient() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default String[] toStringArray() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default ByteBuffer toByteBuffer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ static Engine newInstance() {
}

private Engine getAlternativeEngine() {
if (Boolean.getBoolean("ai.djl.dlr.disable_alternative")) {
return null;
}
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,18 @@ public boolean hasGradient() {
return hasGradient;
}

/** {@inheritDoc} */
@Override
public NDArray stopGradient() {
return manager.invoke("stop_gradient", this, null);
}

/** {@inheritDoc} */
@Override
public String[] toStringArray() {
throw new UnsupportedOperationException("String NDArray is not supported!");
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ public int getRank() {
}

private Engine getAlternativeEngine() {
if (Boolean.getBoolean("ai.djl.onnx.disable_alternative")) {
return null;
}
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
package ai.djl.onnxruntime.engine;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrayAdapter;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.UUID;

Expand Down Expand Up @@ -117,20 +121,35 @@ public void detach() {
manager = OrtNDManager.getSystemManager();
}

/** {@inheritDoc} */
@Override
public String[] toStringArray() {
try {
return (String[]) tensor.getValue();
} catch (OrtException e) {
throw new EngineException(e);
}
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return tensor.getByteBuffer().order(ByteOrder.nativeOrder());
}

/** {@inheritDoc} */
@Override
public String toString() {
if (isClosed) {
return "This array is already closed";
}
return "ND: "
+ getShape()
+ ' '
+ getDevice()
+ ' '
+ getDataType()
+ '\n'
+ Arrays.toString(toArray());
String arrStr;
if (getDataType() == DataType.STRING) {
arrStr = Arrays.toString(toStringArray());
} else {
arrStr = Arrays.toString(toArray());
}
return "ND: " + getShape() + ' ' + getDevice() + ' ' + getDataType() + '\n' + arrStr;
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ public static DataType toDataType(OnnxJavaType javaType) {
return DataType.BOOLEAN;
case UNKNOWN:
return DataType.UNKNOWN;
case STRING:
return DataType.STRING;
default:
throw new UnsupportedOperationException("type is not supported: " + javaType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public void testOrt() throws TranslateException, ModelException, IOException {
public void testStringTensor()
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
System.setProperty("ai.djl.onnx.disable_alternative", "true");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about putting this functionality into Criteria?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot simply added to criteria, it is usually scoped into a series of operations

Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
Expand All @@ -82,12 +83,15 @@ public void testStringTensor()
.build();
try (ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
Predictor<NDList, NDList> predictor = model.newPredictor()) {
OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager();
OrtNDManager manager = (OrtNDManager) model.getNDManager();
lanking520 marked this conversation as resolved.
Show resolved Hide resolved
lanking520 marked this conversation as resolved.
Show resolved Hide resolved
NDArray stringNd =
manager.create(
new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"},
new Shape(1, 2));
predictor.predict(new NDList(stringNd));
NDList result = predictor.predict(new NDList(stringNd));
Assert.assertEquals(result.size(), 2);
Assert.assertEquals(result.get(0).toLongArray(), new long[] {1});
}
System.clearProperty("ai.djl.onnx.disable_alternative");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ public int getRank() {
}

Engine getAlternativeEngine() {
if (Boolean.getBoolean("ai.djl.paddlepaddle.disable_alternative")) {
return null;
}
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ai.djl.util.PairList;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
Expand Down Expand Up @@ -101,6 +102,18 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
}
}

/**
* Load PyTorch model from {@link InputStream}.
*
* <p>Currently, only TorchScript file are supported
*
* @param modelStream the stream of the model file
* @throws IOException model loading error
*/
public void load(InputStream modelStream) throws IOException {
block = JniUtils.loadModule((PtNDManager) manager, modelStream, manager.getDevice(), false);
}

private Path findModelFile(String prefix) {
if (Files.isRegularFile(modelDir)) {
Path file = modelDir;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ public ByteBuffer toByteBuffer() {
return JniUtils.getByteBuffer(this);
}

/** {@inheritDoc} */
@Override
public String[] toStringArray() {
throw new UnsupportedOperationException("String NDArray is not supported!");
}

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.pytorch.integration;

import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.engine.PtModel;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.net.URL;
import org.testng.Assert;
import org.testng.annotations.Test;

public class PtModelTest {

@Test
public void testLoadFromStream() throws IOException, TranslateException {
URL url =
new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt");
try (PtModel model = (PtModel) Model.newInstance("test model")) {
model.load(url.openStream());
try (Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator())) {
NDArray array = model.getNDManager().ones(new Shape(1, 3, 224, 224));
NDArray result = predictor.predict(new NDList(array)).singletonOrThrow();
Assert.assertEquals(result.getShape(), new Shape(1, 1000));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ public boolean[] toBooleanArray() {
return result;
}

@Override
public String[] toStringArray() {
// TODO: Parse String Array from bytes[]
throw new UnsupportedOperationException(
"TensorFlow does not supporting printing String NDArray");
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ public int getRank() {
}

private Engine getAlternativeEngine() {
if (Boolean.getBoolean("ai.djl.tflite.disable_alternative")) {
return null;
}
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand All @@ -67,7 +70,7 @@ private Engine getAlternativeEngine() {
/** {@inheritDoc} */
@Override
public String getVersion() {
return "1.4.0";
return "2.4.1";
}

/** {@inheritDoc} */
Expand Down